diff --git a/tempfork/x509/cert_pool.go b/tempfork/x509/cert_pool.go index 3e1e5fb8c..6c9203457 100644 --- a/tempfork/x509/cert_pool.go +++ b/tempfork/x509/cert_pool.go @@ -12,9 +12,15 @@ import ( // CertPool is a set of certificates. type CertPool struct { - bySubjectKeyId map[string][]int - byName map[string][]int - certs []*Certificate + bySubjectKeyId map[string][]int // cert.SubjectKeyId => getCert index + byName map[string][]int // cert.RawSubject => getCert index + + // getCert contains funcs that return the certificates. + getCert []func() (*Certificate, error) + + // rawSubjects is each cert's RawSubject field. + // Its indexes correspond to the getCert indexes. + rawSubjects [][]byte } // NewCertPool returns a new, empty CertPool. @@ -25,11 +31,26 @@ func NewCertPool() *CertPool { } } +// len returns the number of certs in the set. +// A nil set is a valid empty set. +func (s *CertPool) len() int { + if s == nil { + return 0 + } + return len(s.getCert) +} + +// cert returns cert index n in s. +func (s *CertPool) cert(n int) (*Certificate, error) { + return s.getCert[n]() +} + func (s *CertPool) copy() *CertPool { p := &CertPool{ bySubjectKeyId: make(map[string][]int, len(s.bySubjectKeyId)), byName: make(map[string][]int, len(s.byName)), - certs: make([]*Certificate, len(s.certs)), + getCert: make([]func() (*Certificate, error), len(s.getCert)), + rawSubjects: make([][]byte, len(s.rawSubjects)), } for k, v := range s.bySubjectKeyId { indexes := make([]int, len(v)) @@ -41,7 +62,8 @@ func (s *CertPool) copy() *CertPool { copy(indexes, v) p.byName[k] = indexes } - copy(p.certs, s.certs) + copy(p.getCert, s.getCert) + copy(p.rawSubjects, s.rawSubjects) return p } @@ -82,19 +104,22 @@ func (s *CertPool) findPotentialParents(cert *Certificate) []int { return candidates } -func (s *CertPool) contains(cert *Certificate) bool { +func (s *CertPool) contains(cert *Certificate) (bool, error) { if s == nil { - return false + return false, nil } - candidates := s.byName[string(cert.RawSubject)] - for _, c := range candidates { - if s.certs[c].Equal(cert) { - return true + for _, i := range candidates { + c, err := s.cert(i) + if err != nil { + return false, err + } + if c.Equal(cert) { + return true, nil } } - return false + return false, nil } // AddCert adds a certificate to a pool. @@ -102,21 +127,47 @@ func (s *CertPool) AddCert(cert *Certificate) { if cert == nil { panic("adding nil Certificate to CertPool") } + err := s.AddCertFunc(string(cert.RawSubject), string(cert.SubjectKeyId), func() (*Certificate, error) { + return cert, nil + }) + if err != nil { + panic(err.Error()) + } +} + +// AddCertFunc adds metadata about a certificate to a pool, along with +// a func to fetch that certificate later when needed. +// +// The rawSubject is Certificate.RawSubject and must be non-empty. +// The subjectKeyID is Certificate.SubjectKeyId and may be empty. +// The getCert func may be called 0 or more times. +func (s *CertPool) AddCertFunc(rawSubject, subjectKeyID string, getCert func() (*Certificate, error)) error { + if getCert == nil { + panic("getCert can't be nil") + } // Check that the certificate isn't being added twice. - if s.contains(cert) { - return + if len(s.byName[rawSubject]) > 0 { + c, err := getCert() + if err != nil { + return err + } + if dup, err := s.contains(c); dup { + return nil + } else if err != nil { + return err + } } - n := len(s.certs) - s.certs = append(s.certs, cert) + n := len(s.getCert) + s.getCert = append(s.getCert, getCert) - if len(cert.SubjectKeyId) > 0 { - keyId := string(cert.SubjectKeyId) - s.bySubjectKeyId[keyId] = append(s.bySubjectKeyId[keyId], n) + if subjectKeyID != "" { + s.bySubjectKeyId[subjectKeyID] = append(s.bySubjectKeyId[subjectKeyID], n) } - name := string(cert.RawSubject) - s.byName[name] = append(s.byName[name], n) + s.byName[rawSubject] = append(s.byName[rawSubject], n) + s.rawSubjects = append(s.rawSubjects, []byte(rawSubject)) + return nil } // AppendCertsFromPEM attempts to parse a series of PEM encoded certificates. @@ -151,9 +202,9 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) { // Subjects returns a list of the DER-encoded subjects of // all of the certificates in the pool. func (s *CertPool) Subjects() [][]byte { - res := make([][]byte, len(s.certs)) - for i, c := range s.certs { - res[i] = c.RawSubject + res := make([][]byte, s.len()) + for i, s := range s.rawSubjects { + res[i] = s } return res } diff --git a/tempfork/x509/name_constraints_test.go b/tempfork/x509/name_constraints_test.go index 5469e28de..de92552cc 100644 --- a/tempfork/x509/name_constraints_test.go +++ b/tempfork/x509/name_constraints_test.go @@ -1993,7 +1993,7 @@ func TestConstraintCases(t *testing.T) { pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) return buf.String() } - t.Errorf("#%d: root:\n%s", i, certAsPEM(rootPool.certs[0])) + t.Errorf("#%d: root:\n%s", i, certAsPEM(rootPool.mustCert(0))) t.Errorf("#%d: leaf:\n%s", i, certAsPEM(leafCert)) } @@ -2019,10 +2019,18 @@ func writePEMsToTempFile(certs []*Certificate) *os.File { return file } +func allCerts(p *CertPool) []*Certificate { + all := make([]*Certificate, p.len()) + for i := range all { + all[i] = p.mustCert(i) + } + return all +} + func testChainAgainstOpenSSL(leaf *Certificate, intermediates, roots *CertPool) (string, error) { args := []string{"verify", "-no_check_time"} - rootsFile := writePEMsToTempFile(roots.certs) + rootsFile := writePEMsToTempFile(allCerts(roots)) if debugOpenSSLFailure { println("roots file:", rootsFile.Name()) } else { @@ -2030,8 +2038,8 @@ func testChainAgainstOpenSSL(leaf *Certificate, intermediates, roots *CertPool) } args = append(args, "-CAfile", rootsFile.Name()) - if len(intermediates.certs) > 0 { - intermediatesFile := writePEMsToTempFile(intermediates.certs) + if intermediates.len() > 0 { + intermediatesFile := writePEMsToTempFile(allCerts(intermediates)) if debugOpenSSLFailure { println("intermediates file:", intermediatesFile.Name()) } else { diff --git a/tempfork/x509/root_unix.go b/tempfork/x509/root_unix.go index 1be4058ba..0fce0a1d7 100644 --- a/tempfork/x509/root_unix.go +++ b/tempfork/x509/root_unix.go @@ -84,7 +84,7 @@ func loadSystemRoots() (*CertPool, error) { } } - if len(roots.certs) > 0 || firstErr == nil { + if roots.len() > 0 || firstErr == nil { return roots, nil } diff --git a/tempfork/x509/root_unix_test.go b/tempfork/x509/root_unix_test.go index 5a27d639b..cbb48eddb 100644 --- a/tempfork/x509/root_unix_test.go +++ b/tempfork/x509/root_unix_test.go @@ -113,15 +113,15 @@ func TestEnvVars(t *testing.T) { // Verify that the returned certs match, otherwise report where the mismatch is. for i, cn := range tc.cns { - if i >= len(r.certs) { + if i >= r.len() { t.Errorf("missing cert %v @ %v", cn, i) - } else if r.certs[i].Subject.CommonName != cn { - fmt.Printf("%#v\n", r.certs[0].Subject) - t.Errorf("unexpected cert common name %q, want %q", r.certs[i].Subject.CommonName, cn) + } else if r.mustCert(i).Subject.CommonName != cn { + fmt.Printf("%#v\n", r.mustCert(0).Subject) + t.Errorf("unexpected cert common name %q, want %q", r.mustCert(i).Subject.CommonName, cn) } } - if len(r.certs) > len(tc.cns) { - t.Errorf("got %v certs, which is more than %v wanted", len(r.certs), len(tc.cns)) + if r.len() > len(tc.cns) { + t.Errorf("got %v certs, which is more than %v wanted", r.len(), len(tc.cns)) } }) } @@ -197,6 +197,10 @@ func TestLoadSystemCertsLoadColonSeparatedDirs(t *testing.T) { strCertPool := func(p *CertPool) string { return string(bytes.Join(p.Subjects(), []byte("\n"))) } + + zeroPoolFuncs(gotPool) + zeroPoolFuncs(wantPool) + if !reflect.DeepEqual(gotPool, wantPool) { g, w := strCertPool(gotPool), strCertPool(wantPool) t.Fatalf("Mismatched certPools\nGot:\n%s\n\nWant:\n%s", g, w) diff --git a/tempfork/x509/verify.go b/tempfork/x509/verify.go index 358fca470..23bce95b7 100644 --- a/tempfork/x509/verify.go +++ b/tempfork/x509/verify.go @@ -737,11 +737,13 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e if len(c.Raw) == 0 { return nil, errNotParsed } - if opts.Intermediates != nil { - for _, intermediate := range opts.Intermediates.certs { - if len(intermediate.Raw) == 0 { - return nil, errNotParsed - } + for i := 0; i < opts.Intermediates.len(); i++ { + c, err := opts.Intermediates.cert(i) + if err != nil { + return nil, fmt.Errorf("crypto/x509: error fetching cert: %w", err) + } + if len(c.Raw) == 0 { + return nil, errNotParsed } } @@ -770,8 +772,10 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e } var candidateChains [][]*Certificate - if opts.Roots.contains(c) { + if inRoots, err := opts.Roots.contains(c); inRoots { candidateChains = append(candidateChains, []*Certificate{c}) + } else if err != nil { + return nil, err } else { if candidateChains, err = c.buildChains(nil, []*Certificate{c}, nil, &opts); err != nil { return nil, err @@ -868,10 +872,18 @@ func (c *Certificate) buildChains(cache map[*Certificate][][]*Certificate, curre } for _, rootNum := range opts.Roots.findPotentialParents(c) { - considerCandidate(rootCertificate, opts.Roots.certs[rootNum]) + c, err := opts.Roots.cert(rootNum) + if err != nil { + return nil, fmt.Errorf("crypto/x509: error fetching cert: %w", err) + } + considerCandidate(rootCertificate, c) } for _, intermediateNum := range opts.Intermediates.findPotentialParents(c) { - considerCandidate(intermediateCertificate, opts.Intermediates.certs[intermediateNum]) + c, err := opts.Intermediates.cert(intermediateNum) + if err != nil { + return nil, fmt.Errorf("crypto/x509: error fetching cert: %w", err) + } + considerCandidate(intermediateCertificate, c) } if len(chains) > 0 { diff --git a/tempfork/x509/x509_test.go b/tempfork/x509/x509_test.go index c980f4073..bd297858e 100644 --- a/tempfork/x509/x509_test.go +++ b/tempfork/x509/x509_test.go @@ -1983,6 +1983,8 @@ func TestSystemCertPool(t *testing.T) { if err != nil { t.Fatal(err) } + zeroPoolFuncs(a) + zeroPoolFuncs(b) if !reflect.DeepEqual(a, b) { t.Fatal("two calls to SystemCertPool had different results") } @@ -2644,3 +2646,19 @@ func TestCreateRevocationList(t *testing.T) { }) } } + +func (s *CertPool) mustCert(n int) *Certificate { + c, err := s.getCert[n]() + if err != nil { + panic(err.Error()) + } + return c +} + +// zeroPoolFuncs zeros out funcs in p so two pools can be compared +// with reflect.DeepEqual. +func zeroPoolFuncs(p *CertPool) { + for i := range p.getCert { + p.getCert[i] = nil + } +}