diff --git a/control/controlclient/sign_supported.go b/control/controlclient/sign_supported.go index 63f362a57..3c51126a2 100644 --- a/control/controlclient/sign_supported.go +++ b/control/controlclient/sign_supported.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "sync" + "time" "github.com/tailscale/certstore" "tailscale.com/tailcfg" @@ -73,23 +74,46 @@ func isSubjectInChain(subject string, chain []*x509.Certificate) bool { return false } -func selectIdentityFromSlice(subject string, ids []certstore.Identity) (certstore.Identity, []*x509.Certificate) { +func selectIdentityFromSlice(subject string, ids []certstore.Identity, now time.Time) (certstore.Identity, []*x509.Certificate) { + var bestCandidate struct { + id certstore.Identity + chain []*x509.Certificate + } + for _, id := range ids { chain, err := id.CertificateChain() if err != nil { continue } + if len(chain) < 1 { + continue + } + if !isSupportedCertificate(chain[0]) { continue } - if isSubjectInChain(subject, chain) { - return id, chain + if now.Before(chain[0].NotBefore) || now.After(chain[0].NotAfter) { + // Certificate is not valid at this time + continue } + + if !isSubjectInChain(subject, chain) { + continue + } + + // Select the most recently issued certificate. If there is a tie, pick + // one arbitrarily. + if len(bestCandidate.chain) > 0 && bestCandidate.chain[0].NotBefore.After(chain[0].NotBefore) { + continue + } + + bestCandidate.id = id + bestCandidate.chain = chain } - return nil, nil + return bestCandidate.id, bestCandidate.chain } // findIdentity locates an identity from the Windows or Darwin certificate @@ -105,7 +129,7 @@ func findIdentity(subject string, st certstore.Store) (certstore.Identity, []*x5 return nil, nil, err } - selected, chain := selectIdentityFromSlice(subject, ids) + selected, chain := selectIdentityFromSlice(subject, ids, time.Now()) for _, id := range ids { if id != selected { diff --git a/control/controlclient/sign_supported_test.go b/control/controlclient/sign_supported_test.go new file mode 100644 index 000000000..c196cbb44 --- /dev/null +++ b/control/controlclient/sign_supported_test.go @@ -0,0 +1,238 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows && cgo +// +build windows,cgo + +package controlclient + +import ( + "crypto" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "reflect" + "testing" + "time" + + "github.com/tailscale/certstore" +) + +const ( + testRootCommonName = "testroot" + testRootSubject = "CN=testroot" +) + +type testIdentity struct { + chain []*x509.Certificate +} + +func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate { + return []*x509.Certificate{ + { + NotBefore: notBefore, + NotAfter: notAfter, + PublicKeyAlgorithm: x509.RSA, + }, + { + Subject: pkix.Name{ + CommonName: rootCommonName, + }, + PublicKeyAlgorithm: x509.RSA, + }, + } +} + +func (t *testIdentity) Certificate() (*x509.Certificate, error) { + return t.chain[0], nil +} + +func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) { + return t.chain, nil +} + +func (t *testIdentity) Signer() (crypto.Signer, error) { + return nil, errors.New("not implemented") +} + +func (t *testIdentity) Delete() error { + return errors.New("not implemented") +} + +func (t *testIdentity) Close() {} + +func TestSelectIdentityFromSlice(t *testing.T) { + var times []time.Time + for _, ts := range []string{ + "2000-01-01T00:00:00Z", + "2001-01-01T00:00:00Z", + "2002-01-01T00:00:00Z", + "2003-01-01T00:00:00Z", + } { + tm, err := time.Parse(time.RFC3339, ts) + if err != nil { + t.Fatal(err) + } + times = append(times, tm) + } + + tests := []struct { + name string + subject string + ids []certstore.Identity + now time.Time + // wantIndex is an index into ids, or -1 for nil. + wantIndex int + }{ + { + name: "single unexpired identity", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[2]), + }, + }, + now: times[1], + wantIndex: 0, + }, + { + name: "single expired identity", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + }, + now: times[2], + wantIndex: -1, + }, + { + name: "unrelated ids", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain("something", times[0], times[2]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[2]), + }, + &testIdentity{ + chain: makeChain("else", times[0], times[2]), + }, + }, + now: times[1], + wantIndex: 1, + }, + { + name: "expired with unrelated ids", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain("something", times[0], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + &testIdentity{ + chain: makeChain("else", times[0], times[3]), + }, + }, + now: times[2], + wantIndex: -1, + }, + { + name: "one expired", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[1], times[3]), + }, + }, + now: times[2], + wantIndex: 1, + }, + { + name: "two certs both unexpired", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[1], times[3]), + }, + }, + now: times[2], + wantIndex: 1, + }, + { + name: "two unexpired one expired", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[1], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + }, + now: times[2], + wantIndex: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now) + + if gotId == nil && gotChain != nil { + t.Error("id is nil: got non-nil chain, want nil chain") + return + } + if gotId != nil && gotChain == nil { + t.Error("id is not nil: got nil chain, want non-nil chain") + return + } + if tt.wantIndex == -1 { + if gotId != nil { + t.Error("got non-nil id, want nil id") + } + return + } + if gotId == nil { + t.Error("got nil id, want non-nil id") + return + } + if gotId != tt.ids[tt.wantIndex] { + found := -1 + for i := range tt.ids { + if tt.ids[i] == gotId { + found = i + break + } + } + if found == -1 { + t.Errorf("got unknown id, want id at index %v", tt.wantIndex) + } else { + t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex) + } + } + + tid, ok := tt.ids[tt.wantIndex].(*testIdentity) + if !ok { + t.Error("got non-testIdentity, want testIdentity") + return + } + + if !reflect.DeepEqual(tid.chain, gotChain) { + t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex) + } + }) + } +}