diff --git a/tempfork/x509/cert_pool.go b/tempfork/x509/cert_pool.go index 6c9203457..42b886f35 100644 --- a/tempfork/x509/cert_pool.go +++ b/tempfork/x509/cert_pool.go @@ -5,15 +5,26 @@ package x509 import ( + "crypto/sha256" "encoding/pem" "errors" "runtime" + "sync" ) +type sum224 [sha256.Size224]byte + // CertPool is a set of certificates. type CertPool struct { - bySubjectKeyId map[string][]int // cert.SubjectKeyId => getCert index - byName map[string][]int // cert.RawSubject => getCert index + bySubjectKeyId map[string][]int // cert.SubjectKeyId => getCert index(es) + byName map[string][]int // cert.RawSubject => getCert index(es) + + // haveSum maps from sum224(cert.Raw) to true. It's used only + // for AddCert duplicate detection, to avoid CertPool.contains + // calls in the AddCert path (because the contains method can + // call getCert and otherwise negate savings from lazy getCert + // funcs). + haveSum map[sum224]bool // getCert contains funcs that return the certificates. getCert []func() (*Certificate, error) @@ -28,6 +39,7 @@ func NewCertPool() *CertPool { return &CertPool{ bySubjectKeyId: make(map[string][]int), byName: make(map[string][]int), + haveSum: make(map[sum224]bool), } } @@ -49,6 +61,7 @@ func (s *CertPool) copy() *CertPool { p := &CertPool{ bySubjectKeyId: make(map[string][]int, len(s.bySubjectKeyId)), byName: make(map[string][]int, len(s.byName)), + haveSum: make(map[sum224]bool, len(s.haveSum)), getCert: make([]func() (*Certificate, error), len(s.getCert)), rawSubjects: make([][]byte, len(s.rawSubjects)), } @@ -62,6 +75,9 @@ func (s *CertPool) copy() *CertPool { copy(indexes, v) p.byName[k] = indexes } + for k := range s.haveSum { + p.haveSum[k] = true + } copy(p.getCert, s.getCert) copy(p.rawSubjects, s.rawSubjects) return p @@ -127,7 +143,7 @@ func (s *CertPool) AddCert(cert *Certificate) { if cert == nil { panic("adding nil Certificate to CertPool") } - err := s.AddCertFunc(string(cert.RawSubject), string(cert.SubjectKeyId), func() (*Certificate, error) { + err := s.AddCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), string(cert.SubjectKeyId), func() (*Certificate, error) { return cert, nil }) if err != nil { @@ -141,23 +157,16 @@ func (s *CertPool) AddCert(cert *Certificate) { // The rawSubject is Certificate.RawSubject and must be non-empty. // The subjectKeyID is Certificate.SubjectKeyId and may be empty. // The getCert func may be called 0 or more times. -func (s *CertPool) AddCertFunc(rawSubject, subjectKeyID string, getCert func() (*Certificate, error)) error { +func (s *CertPool) AddCertFunc(rawSum224 sum224, rawSubject, subjectKeyID string, getCert func() (*Certificate, error)) error { if getCert == nil { panic("getCert can't be nil") } // Check that the certificate isn't being added twice. - if len(s.byName[rawSubject]) > 0 { - c, err := getCert() - if err != nil { - return err - } - if dup, err := s.contains(c); dup { - return nil - } else if err != nil { - return err - } + if s.haveSum[rawSum224] { + return nil } + s.haveSum[rawSum224] = true n := len(s.getCert) s.getCert = append(s.getCert, getCert) @@ -187,16 +196,26 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) { continue } - cert, err := ParseCertificate(block.Bytes) + certBytes := block.Bytes + cert, err := ParseCertificate(certBytes) if err != nil { continue } - - s.AddCert(cert) + var lazyCert struct { + sync.Once + v *Certificate + } + s.AddCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), string(cert.SubjectKeyId), func() (*Certificate, error) { + lazyCert.Do(func() { + // This can't fail, as the same bytes already parsed above. + lazyCert.v, _ = ParseCertificate(certBytes) + certBytes = nil + }) + return lazyCert.v, nil + }) ok = true } - - return + return ok } // Subjects returns a list of the DER-encoded subjects of