@ -12,9 +12,15 @@ import (
// CertPool is a set of certificates.
// CertPool is a set of certificates.
type CertPool struct {
type CertPool struct {
bySubjectKeyId map [ string ] [ ] int
bySubjectKeyId map [ string ] [ ] int // cert.SubjectKeyId => getCert index
byName map [ string ] [ ] int
byName map [ string ] [ ] int // cert.RawSubject => getCert index
certs [ ] * Certificate
// getCert contains funcs that return the certificates.
getCert [ ] func ( ) ( * Certificate , error )
// rawSubjects is each cert's RawSubject field.
// Its indexes correspond to the getCert indexes.
rawSubjects [ ] [ ] byte
}
}
// NewCertPool returns a new, empty CertPool.
// NewCertPool returns a new, empty CertPool.
@ -25,11 +31,26 @@ func NewCertPool() *CertPool {
}
}
}
}
// len returns the number of certs in the set.
// A nil set is a valid empty set.
func ( s * CertPool ) len ( ) int {
if s == nil {
return 0
}
return len ( s . getCert )
}
// cert returns cert index n in s.
func ( s * CertPool ) cert ( n int ) ( * Certificate , error ) {
return s . getCert [ n ] ( )
}
func ( s * CertPool ) copy ( ) * CertPool {
func ( s * CertPool ) copy ( ) * CertPool {
p := & CertPool {
p := & CertPool {
bySubjectKeyId : make ( map [ string ] [ ] int , len ( s . bySubjectKeyId ) ) ,
bySubjectKeyId : make ( map [ string ] [ ] int , len ( s . bySubjectKeyId ) ) ,
byName : make ( map [ string ] [ ] int , len ( s . byName ) ) ,
byName : make ( map [ string ] [ ] int , len ( s . byName ) ) ,
certs : make ( [ ] * Certificate , len ( s . certs ) ) ,
getCert : make ( [ ] func ( ) ( * Certificate , error ) , len ( s . getCert ) ) ,
rawSubjects : make ( [ ] [ ] byte , len ( s . rawSubjects ) ) ,
}
}
for k , v := range s . bySubjectKeyId {
for k , v := range s . bySubjectKeyId {
indexes := make ( [ ] int , len ( v ) )
indexes := make ( [ ] int , len ( v ) )
@ -41,7 +62,8 @@ func (s *CertPool) copy() *CertPool {
copy ( indexes , v )
copy ( indexes , v )
p . byName [ k ] = indexes
p . byName [ k ] = indexes
}
}
copy ( p . certs , s . certs )
copy ( p . getCert , s . getCert )
copy ( p . rawSubjects , s . rawSubjects )
return p
return p
}
}
@ -82,19 +104,22 @@ func (s *CertPool) findPotentialParents(cert *Certificate) []int {
return candidates
return candidates
}
}
func ( s * CertPool ) contains ( cert * Certificate ) bool {
func ( s * CertPool ) contains ( cert * Certificate ) ( bool , error ) {
if s == nil {
if s == nil {
return false
return false , nil
}
}
candidates := s . byName [ string ( cert . RawSubject ) ]
candidates := s . byName [ string ( cert . RawSubject ) ]
for _ , c := range candidates {
for _ , i := range candidates {
if s . certs [ c ] . Equal ( cert ) {
c , err := s . cert ( i )
return true
if err != nil {
return false , err
}
if c . Equal ( cert ) {
return true , nil
}
}
}
}
return false
return false , nil
}
}
// AddCert adds a certificate to a pool.
// AddCert adds a certificate to a pool.
@ -102,21 +127,47 @@ func (s *CertPool) AddCert(cert *Certificate) {
if cert == nil {
if cert == nil {
panic ( "adding nil Certificate to CertPool" )
panic ( "adding nil Certificate to CertPool" )
}
}
err := s . AddCertFunc ( string ( cert . RawSubject ) , string ( cert . SubjectKeyId ) , func ( ) ( * Certificate , error ) {
return cert , nil
} )
if err != nil {
panic ( err . Error ( ) )
}
}
// AddCertFunc adds metadata about a certificate to a pool, along with
// a func to fetch that certificate later when needed.
//
// The rawSubject is Certificate.RawSubject and must be non-empty.
// The subjectKeyID is Certificate.SubjectKeyId and may be empty.
// The getCert func may be called 0 or more times.
func ( s * CertPool ) AddCertFunc ( rawSubject , subjectKeyID string , getCert func ( ) ( * Certificate , error ) ) error {
if getCert == nil {
panic ( "getCert can't be nil" )
}
// Check that the certificate isn't being added twice.
// Check that the certificate isn't being added twice.
if s . contains ( cert ) {
if len ( s . byName [ rawSubject ] ) > 0 {
return
c , err := getCert ( )
if err != nil {
return err
}
if dup , err := s . contains ( c ) ; dup {
return nil
} else if err != nil {
return err
}
}
}
n := len ( s . certs )
n := len ( s . getCert )
s . certs = append ( s . certs , cert )
s . getCert = append ( s . getCert , getC ert)
if len ( cert . SubjectKeyId ) > 0 {
if subjectKeyID != "" {
keyId := string ( cert . SubjectKeyId )
s . bySubjectKeyId [ subjectKeyID ] = append ( s . bySubjectKeyId [ subjectKeyID ] , n )
s . bySubjectKeyId [ keyId ] = append ( s . bySubjectKeyId [ keyId ] , n )
}
}
name := string ( cert . RawSubject )
s . byName [ rawSubject ] = append ( s . byName [ rawSubject ] , n )
s . byName [ name ] = append ( s . byName [ name ] , n )
s . rawSubjects = append ( s . rawSubjects , [ ] byte ( rawSubject ) )
return nil
}
}
// AppendCertsFromPEM attempts to parse a series of PEM encoded certificates.
// AppendCertsFromPEM attempts to parse a series of PEM encoded certificates.
@ -151,9 +202,9 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
// Subjects returns a list of the DER-encoded subjects of
// Subjects returns a list of the DER-encoded subjects of
// all of the certificates in the pool.
// all of the certificates in the pool.
func ( s * CertPool ) Subjects ( ) [ ] [ ] byte {
func ( s * CertPool ) Subjects ( ) [ ] [ ] byte {
res := make ( [ ] [ ] byte , len ( s . certs ) )
res := make ( [ ] [ ] byte , s . len ( ) )
for i , c := range s . ce rts {
for i , s := range s . rawSubjec ts {
res [ i ] = c. RawSubject
res [ i ] = s
}
}
return res
return res
}
}