// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build windows && cgo package controlclient import ( "crypto" "crypto/x509" "crypto/x509/pkix" "errors" "reflect" "testing" "time" "github.com/tailscale/certstore" ) const ( testRootCommonName = "testroot" testRootSubject = "CN=testroot" ) type testIdentity struct { chain []*x509.Certificate } func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate { return []*x509.Certificate{ { NotBefore: notBefore, NotAfter: notAfter, PublicKeyAlgorithm: x509.RSA, }, { Subject: pkix.Name{ CommonName: rootCommonName, }, PublicKeyAlgorithm: x509.RSA, }, } } func (t *testIdentity) Certificate() (*x509.Certificate, error) { return t.chain[0], nil } func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) { return t.chain, nil } func (t *testIdentity) Signer() (crypto.Signer, error) { return nil, errors.New("not implemented") } func (t *testIdentity) Delete() error { return errors.New("not implemented") } func (t *testIdentity) Close() {} func TestSelectIdentityFromSlice(t *testing.T) { var times []time.Time for _, ts := range []string{ "2000-01-01T00:00:00Z", "2001-01-01T00:00:00Z", "2002-01-01T00:00:00Z", "2003-01-01T00:00:00Z", } { tm, err := time.Parse(time.RFC3339, ts) if err != nil { t.Fatal(err) } times = append(times, tm) } tests := []struct { name string subject string ids []certstore.Identity now time.Time // wantIndex is an index into ids, or -1 for nil. wantIndex int }{ { name: "single unexpired identity", subject: testRootSubject, ids: []certstore.Identity{ &testIdentity{ chain: makeChain(testRootCommonName, times[0], times[2]), }, }, now: times[1], wantIndex: 0, }, { name: "single expired identity", subject: testRootSubject, ids: []certstore.Identity{ &testIdentity{ chain: makeChain(testRootCommonName, times[0], times[1]), }, }, now: times[2], wantIndex: -1, }, { name: "unrelated ids", subject: testRootSubject, ids: []certstore.Identity{ &testIdentity{ chain: makeChain("something", times[0], times[2]), }, &testIdentity{ chain: makeChain(testRootCommonName, times[0], times[2]), }, &testIdentity{ chain: makeChain("else", times[0], times[2]), }, }, now: times[1], wantIndex: 1, }, { name: "expired with unrelated ids", subject: testRootSubject, ids: []certstore.Identity{ &testIdentity{ chain: makeChain("something", times[0], times[3]), }, &testIdentity{ chain: makeChain(testRootCommonName, times[0], times[1]), }, &testIdentity{ chain: makeChain("else", times[0], times[3]), }, }, now: times[2], wantIndex: -1, }, { name: "one expired", subject: testRootSubject, ids: []certstore.Identity{ &testIdentity{ chain: makeChain(testRootCommonName, times[0], times[1]), }, &testIdentity{ chain: makeChain(testRootCommonName, times[1], times[3]), }, }, now: times[2], wantIndex: 1, }, { name: "two certs both unexpired", subject: testRootSubject, ids: []certstore.Identity{ &testIdentity{ chain: makeChain(testRootCommonName, times[0], times[3]), }, &testIdentity{ chain: makeChain(testRootCommonName, times[1], times[3]), }, }, now: times[2], wantIndex: 1, }, { name: "two unexpired one expired", subject: testRootSubject, ids: []certstore.Identity{ &testIdentity{ chain: makeChain(testRootCommonName, times[0], times[3]), }, &testIdentity{ chain: makeChain(testRootCommonName, times[1], times[3]), }, &testIdentity{ chain: makeChain(testRootCommonName, times[0], times[1]), }, }, now: times[2], wantIndex: 1, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now) if gotId == nil && gotChain != nil { t.Error("id is nil: got non-nil chain, want nil chain") return } if gotId != nil && gotChain == nil { t.Error("id is not nil: got nil chain, want non-nil chain") return } if tt.wantIndex == -1 { if gotId != nil { t.Error("got non-nil id, want nil id") } return } if gotId == nil { t.Error("got nil id, want non-nil id") return } if gotId != tt.ids[tt.wantIndex] { found := -1 for i := range tt.ids { if tt.ids[i] == gotId { found = i break } } if found == -1 { t.Errorf("got unknown id, want id at index %v", tt.wantIndex) } else { t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex) } } tid, ok := tt.ids[tt.wantIndex].(*testIdentity) if !ok { t.Error("got non-testIdentity, want testIdentity") return } if !reflect.DeepEqual(tid.chain, gotChain) { t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex) } }) } }