Move refactoring steps from #54635 to own PR. (#54690)

pull/54990/head
Felix Fontein 6 years ago committed by John R Barker
parent 97ac03f613
commit e079758b31

@ -0,0 +1,2 @@
minor_changes:
- "openssl_certificate - the messages of the ``assertonly`` provider with respect to private key and CSR checking are now more precise."

@ -533,6 +533,7 @@ backup_file:
from random import randint from random import randint
import abc
import datetime import datetime
import os import os
import traceback import traceback
@ -1103,10 +1104,28 @@ class OwnCACertificate(Certificate):
return result return result
class AssertOnlyCertificateCryptography(Certificate): def compare_sets(subset, superset, equality=False):
"""Validate the supplied cert, using the cryptography backend""" if equality:
def __init__(self, module): return set(subset) == set(superset)
super(AssertOnlyCertificateCryptography, self).__init__(module, 'cryptography') else:
return all(x in superset for x in subset)
def compare_dicts(subset, superset, equality=False):
if equality:
return subset == superset
else:
return all(superset.get(x) == v for x, v in subset.items())
NO_EXTENSION = 'no extension'
class AssertOnlyCertificateBase(Certificate):
def __init__(self, module, backend):
super(AssertOnlyCertificateBase, self).__init__(module, backend)
self.signature_algorithms = module.params['signature_algorithms'] self.signature_algorithms = module.params['signature_algorithms']
if module.params['subject']: if module.params['subject']:
self.subject = crypto_utils.parse_name_field(module.params['subject']) self.subject = crypto_utils.parse_name_field(module.params['subject'])
@ -1120,226 +1139,256 @@ class AssertOnlyCertificateCryptography(Certificate):
self.issuer_strict = module.params['issuer_strict'] self.issuer_strict = module.params['issuer_strict']
self.has_expired = module.params['has_expired'] self.has_expired = module.params['has_expired']
self.version = module.params['version'] self.version = module.params['version']
self.keyUsage = module.params['key_usage'] self.key_usage = module.params['key_usage']
self.keyUsage_strict = module.params['key_usage_strict'] self.key_usage_strict = module.params['key_usage_strict']
self.extendedKeyUsage = module.params['extended_key_usage'] self.extended_key_usage = module.params['extended_key_usage']
self.extendedKeyUsage_strict = module.params['extended_key_usage_strict'] self.extended_key_usage_strict = module.params['extended_key_usage_strict']
self.subjectAltName = module.params['subject_alt_name'] self.subject_alt_name = module.params['subject_alt_name']
self.subjectAltName_strict = module.params['subject_alt_name_strict'] self.subject_alt_name_strict = module.params['subject_alt_name_strict']
self.notBefore = module.params['not_before'], self.not_before = module.params['not_before']
self.notAfter = module.params['not_after'], self.not_after = module.params['not_after']
self.valid_at = module.params['valid_at'], self.valid_at = module.params['valid_at']
self.invalid_at = module.params['invalid_at'], self.invalid_at = module.params['invalid_at']
self.valid_in = module.params['valid_in'], self.valid_in = module.params['valid_in']
self.message = [] if self.valid_in and not self.valid_in.startswith("+") and not self.valid_in.startswith("-"):
try:
def assertonly(self): int(self.valid_in)
except ValueError:
module.fail_json(msg='The supplied value for "valid_in" (%s) is not an integer or a valid timespec' % self.valid_in)
self.valid_in = "+" + self.valid_in + "s"
# Load objects
self.cert = crypto_utils.load_certificate(self.path, backend=self.backend) self.cert = crypto_utils.load_certificate(self.path, backend=self.backend)
if self.privatekey_path is not None:
try:
self.privatekey = crypto_utils.load_privatekey(
self.privatekey_path,
self.privatekey_passphrase,
backend=self.backend
)
except crypto_utils.OpenSSLBadPassphraseError as exc:
raise CertificateError(exc)
if self.csr_path is not None:
self.csr = crypto_utils.load_certificate_request(self.csr_path, backend=self.backend)
def _validate_signature_algorithms(): @abc.abstractmethod
if self.signature_algorithms: def _validate_privatekey(self):
if self.cert.signature_algorithm_oid._name not in self.signature_algorithms: pass
self.message.append(
'Invalid signature algorithm (got %s, expected one of %s)' % @abc.abstractmethod
(self.cert.signature_algorithm_oid._name, self.signature_algorithms) def _validate_csr_signature(self):
pass
@abc.abstractmethod
def _validate_csr_subject(self):
pass
@abc.abstractmethod
def _validate_csr_extensions(self):
pass
@abc.abstractmethod
def _validate_signature_algorithms(self):
pass
@abc.abstractmethod
def _validate_subject(self):
pass
@abc.abstractmethod
def _validate_issuer(self):
pass
@abc.abstractmethod
def _validate_has_expired(self):
pass
@abc.abstractmethod
def _validate_version(self):
pass
@abc.abstractmethod
def _validate_key_usage(self):
pass
@abc.abstractmethod
def _validate_extended_key_usage(self):
pass
@abc.abstractmethod
def _validate_subject_alt_name(self):
pass
@abc.abstractmethod
def _validate_not_before(self):
pass
@abc.abstractmethod
def _validate_not_after(self):
pass
@abc.abstractmethod
def _validate_valid_at(self):
pass
@abc.abstractmethod
def _validate_invalid_at(self):
pass
@abc.abstractmethod
def _validate_valid_in(self):
pass
def assertonly(self, module):
messages = []
if self.privatekey_path is not None:
if not self._validate_privatekey():
messages.append(
'Certificate %s and private key %s do not match' %
(self.path, self.privatekey_path)
) )
def _validate_subject(): if self.csr_path is not None:
if self.subject: if not self._validate_csr_signature():
expected_subject = Name([NameAttribute(oid=crypto_utils.cryptography_get_name_oid(sub[0]), value=to_text(sub[1])) messages.append(
for sub in self.subject]) 'Certificate %s and CSR %s do not match: private key mismatch' %
cert_subject = self.cert.subject (self.path, self.csr_path)
if (not self.subject_strict and not all(x in cert_subject for x in expected_subject)) or \ )
(self.subject_strict and not set(expected_subject) == set(cert_subject)): if not self._validate_csr_subject():
self.message.append( messages.append(
'Invalid subject component (got %s, expected all of %s to be present)' % 'Certificate %s and CSR %s do not match: subject mismatch' %
(cert_subject, expected_subject) (self.path, self.csr_path)
)
if not self._validate_csr_extensions():
messages.append(
'Certificate %s and CSR %s do not match: extensions mismatch' %
(self.path, self.csr_path)
) )
def _validate_issuer(): if self.signature_algorithms is not None:
if self.issuer: wrong_alg = self._validate_signature_algorithms()
expected_issuer = Name([NameAttribute(oid=crypto_utils.cryptography_get_name_oid(iss[0]), value=to_text(iss[1])) if wrong_alg:
for iss in self.issuer]) messages.append(
cert_issuer = self.cert.issuer 'Invalid signature algorithm (got %s, expected one of %s)' %
if (not self.issuer_strict and not all(x in cert_issuer for x in expected_issuer)) or \ (wrong_alg, self.signature_algorithms)
(self.issuer_strict and not set(expected_issuer) == set(cert_issuer)):
self.message.append(
'Invalid issuer component (got %s, expected all of %s to be present)' % (cert_issuer, self.issuer)
) )
def _validate_has_expired(): if self.subject is not None:
cert_not_after = self.cert.not_valid_after failure = self._validate_subject()
cert_expired = cert_not_after < datetime.datetime.utcnow() if failure:
dummy, cert_subject = failure
messages.append(
'Invalid subject component (got %s, expected all of %s to be present)' %
(cert_subject, self.subject)
)
if self.has_expired != cert_expired: if self.issuer is not None:
self.message.append( failure = self._validate_issuer()
'Certificate expiration check failed (certificate expiration is %s, expected %s)' % (cert_expired, self.has_expired) if failure:
dummy, cert_issuer = failure
messages.append(
'Invalid issuer component (got %s, expected all of %s to be present)' % (cert_issuer, self.issuer)
) )
def _validate_version(): if self.has_expired is not None:
# FIXME cert_expired = self._validate_has_expired()
if self.version: if cert_expired != self.has_expired:
expected_version = x509.Version(int(self.version) - 1) messages.append(
if expected_version != self.cert.version: 'Certificate expiration check failed (certificate expiration is %s, expected %s)' %
self.message.append( (cert_expired, self.has_expired)
'Invalid certificate version number (got %s, expected %s)' % (self.cert.version, self.version)
) )
def _validate_keyUsage(): if self.version is not None:
if self.keyUsage: cert_version = self._validate_version()
try: if cert_version != self.version:
current_keyusage = self.cert.extensions.get_extension_for_class(x509.KeyUsage).value messages.append(
expected_keyusage = x509.KeyUsage(**crypto_utils.cryptography_parse_key_usage_params(self.keyUsage)) 'Invalid certificate version number (got %s, expected %s)' %
test_keyusage = dict( (cert_version, self.version)
digital_signature=current_keyusage.digital_signature,
content_commitment=current_keyusage.content_commitment,
key_encipherment=current_keyusage.key_encipherment,
data_encipherment=current_keyusage.data_encipherment,
key_agreement=current_keyusage.key_agreement,
key_cert_sign=current_keyusage.key_cert_sign,
crl_sign=current_keyusage.crl_sign,
) )
if test_keyusage['key_agreement']:
test_keyusage.update(dict(
encipher_only=current_keyusage.encipher_only,
decipher_only=current_keyusage.decipher_only
))
else:
test_keyusage.update(dict(
encipher_only=False,
decipher_only=False
))
key_usages = crypto_utils.cryptography_parse_key_usage_params(self.keyUsage) if self.key_usage is not None:
if (not self.keyUsage_strict and not all(key_usages[x] == test_keyusage[x] for x in key_usages)) or \ failure = self._validate_key_usage()
(self.keyUsage_strict and current_keyusage != expected_keyusage): if failure == NO_EXTENSION:
self.message.append( messages.append('Found no keyUsage extension')
elif failure:
dummy, cert_key_usage = failure
messages.append(
'Invalid keyUsage components (got %s, expected all of %s to be present)' % 'Invalid keyUsage components (got %s, expected all of %s to be present)' %
([x for x in test_keyusage if x is True], [x for x in self.keyUsage if x is True]) (cert_key_usage, self.key_usage)
) )
except cryptography.x509.ExtensionNotFound: if self.extended_key_usage is not None:
self.message.append('Found no keyUsage extension') failure = self._validate_extended_key_usage()
if failure == NO_EXTENSION:
def _validate_extendedKeyUsage(): messages.append('Found no extendedKeyUsage extension')
if self.extendedKeyUsage: elif failure:
try: dummy, ext_cert_key_usage = failure
current_ext_keyusage = self.cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value messages.append(
usages = [crypto_utils.cryptography_get_ext_keyusage(usage) for usage in self.extendedKeyUsage] 'Invalid extendedKeyUsage component (got %s, expected all of %s to be present)' % (ext_cert_key_usage, self.extended_key_usage)
expected_ext_keyusage = x509.ExtendedKeyUsage(usages)
if (not self.extendedKeyUsage_strict and not all(x in current_ext_keyusage for x in expected_ext_keyusage)) or \
(self.extendedKeyUsage_strict and not current_ext_keyusage == expected_ext_keyusage):
self.message.append(
'Invalid extendedKeyUsage component (got %s, expected all of %s to be present)' % ([xku.value for xku in current_ext_keyusage],
[exku.value for exku in expected_ext_keyusage])
) )
except cryptography.x509.ExtensionNotFound: if self.subject_alt_name is not None:
self.message.append('Found no extendedKeyUsage extension') failure = self._validate_subject_alt_name()
if failure == NO_EXTENSION:
def _validate_subjectAltName(): messages.append('Found no subjectAltName extension')
if self.subjectAltName: elif failure:
try: dummy, cert_san = failure
current_san = self.cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value messages.append(
expected_san = [crypto_utils.cryptography_get_name(san) for san in self.subjectAltName]
if (not self.subjectAltName_strict and not all(x in current_san for x in expected_san)) or \
(self.subjectAltName_strict and not set(current_san) == set(expected_san)):
self.message.append(
'Invalid subjectAltName component (got %s, expected all of %s to be present)' % 'Invalid subjectAltName component (got %s, expected all of %s to be present)' %
(current_san, self.subjectAltName) (cert_san, self.subject_alt_name)
) )
except cryptography.x509.ExtensionNotFound:
self.message.append('Found no subjectAltName extension') if self.not_before is not None:
cert_not_valid_before = self._validate_not_before()
def _validate_notBefore(): if cert_not_valid_before != self.get_relative_time_option(self.not_before, 'not_before'):
if self.notBefore[0]: messages.append(
# try: 'Invalid not_before component (got %s, expected %s to be present)' %
if self.cert.not_valid_before != self.get_relative_time_option(self.notBefore[0], 'not_before'): (cert_not_valid_before, self.not_before)
self.message.append(
'Invalid notBefore component (got %s, expected %s to be present)' % (self.cert.not_valid_before, self.notBefore)
) )
# except AttributeError:
# self.message.append(str(self.notBefore)) if self.not_after is not None:
cert_not_valid_after = self._validate_not_after()
def _validate_notAfter(): if cert_not_valid_after != self.get_relative_time_option(self.not_after, 'not_after'):
if self.notAfter[0]: messages.append(
if self.cert.not_valid_after != self.get_relative_time_option(self.notAfter[0], 'not_after'): 'Invalid not_after component (got %s, expected %s to be present)' %
self.message.append( (cert_not_valid_after, self.not_after)
'Invalid notAfter component (got %s, expected %s to be present)' % (self.cert.not_valid_after, self.notAfter)
) )
def _validate_valid_at(): if self.valid_at is not None:
if self.valid_at[0]: not_before, valid_at, not_after = self._validate_valid_at()
rt = self.get_relative_time_option(self.valid_at[0], 'valid_at') if not (not_before <= valid_at <= not_after):
if not (self.cert.not_valid_before <= rt <= self.cert.not_valid_after): messages.append(
self.message.append( 'Certificate is not valid for the specified date (%s) - not_before: %s - not_after: %s' %
'Certificate is not valid for the specified date (%s) - notBefore: %s - notAfter: %s' % (self.valid_at, (self.valid_at, not_before, not_after)
self.cert.not_valid_before,
self.cert.not_valid_after)
) )
def _validate_invalid_at(): if self.invalid_at is not None:
if self.invalid_at[0]: not_before, invalid_at, not_after = self._validate_invalid_at()
if (self.get_relative_time_option(self.invalid_at[0], 'invalid_at') <= self.cert.not_valid_before) \ if (invalid_at <= not_before) or (invalid_at >= not_after):
or (self.get_relative_time_option(self.invalid_at, 'invalid_at') >= self.cert.not_valid_after): messages.append(
self.message.append( 'Certificate is not invalid for the specified date (%s) - not_before: %s - not_after: %s' %
'Certificate is not invalid for the specified date (%s) - notBefore: %s - notAfter: %s' % (self.invalid_at, (self.invalid_at, not_before, not_after)
self.cert.not_valid_before,
self.cert.not_valid_after)
) )
def _validate_valid_in(): if self.valid_in is not None:
if self.valid_in[0]: not_before, valid_in, not_after = self._validate_valid_in()
if not self.valid_in[0].startswith("+") and not self.valid_in[0].startswith("-"): if not not_before <= valid_in <= not_after:
try: messages.append(
int(self.valid_in[0]) 'Certificate is not valid in %s from now (that would be %s) - not_before: %s - not_after: %s' %
except ValueError: (self.valid_in, valid_in, not_before, not_after)
raise CertificateError( )
'The supplied value for "valid_in" (%s) is not an integer or a valid timespec' % self.valid_in) return messages
self.valid_in = "+" + self.valid_in + "s"
valid_in_date = self.get_relative_time_option(self.valid_in[0], "valid_in")
if not self.cert.not_valid_before <= valid_in_date <= self.cert.not_valid_after:
self.message.append(
'Certificate is not valid in %s from now (that would be %s) - notBefore: %s - notAfter: %s'
% (self.valid_in, valid_in_date,
self.cert.not_valid_before,
self.cert.not_valid_after))
for validation in ['signature_algorithms', 'subject', 'issuer',
'has_expired', 'version', 'keyUsage',
'extendedKeyUsage', 'subjectAltName',
'notBefore', 'notAfter', 'valid_at', 'valid_in', 'invalid_at']:
f_name = locals()['_validate_%s' % validation]
f_name()
def generate(self, module): def generate(self, module):
"""Don't generate anything - only assert""" """Don't generate anything - only assert"""
messages = self.assertonly(module)
self.assertonly() if messages:
module.fail_json(msg=' | '.join(messages))
try:
if self.privatekey_path and \
not super(AssertOnlyCertificateCryptography, self).check(module, perms_required=False):
self.message.append(
'Certificate %s and private key %s do not match' % (self.path, self.privatekey_path)
)
except CertificateError as e:
self.message.append(
'Error while reading private key %s: %s' % (self.privatekey_path, str(e))
)
if len(self.message):
module.fail_json(msg=' | '.join(self.message))
def check(self, module, perms_required=False): def check(self, module, perms_required=False):
"""Ensure the resource is in its desired state.""" """Ensure the resource is in its desired state."""
messages = self.assertonly(module)
parent_check = super(AssertOnlyCertificateCryptography, self).check(module, perms_required) return len(messages) == 0
self.assertonly()
assertonly_check = not len(self.message)
self.message = []
return parent_check and assertonly_check
def dump(self, check_mode=False): def dump(self, check_mode=False):
result = { result = {
@ -1351,45 +1400,150 @@ class AssertOnlyCertificateCryptography(Certificate):
return result return result
class AssertOnlyCertificate(Certificate): class AssertOnlyCertificateCryptography(AssertOnlyCertificateBase):
"""Validate the supplied cert, using the cryptography backend"""
def __init__(self, module):
super(AssertOnlyCertificateCryptography, self).__init__(module, 'cryptography')
def _validate_privatekey(self):
return self.cert.public_key().public_numbers() == self.privatekey.public_key().public_numbers()
def _validate_csr_signature(self):
if not self.csr.is_signature_valid:
return False
if self.csr.public_key().public_numbers() != self.cert.public_key().public_numbers():
return False
def _validate_csr_subject(self):
if self.csr.subject != self.cert.subject:
return False
def _validate_csr_extensions(self):
cert_exts = self.cert.extensions
csr_exts = self.csr.extensions
if len(cert_exts) != len(csr_exts):
return False
for cert_ext in cert_exts:
try:
csr_ext = csr_exts.get_extension_for_oid(cert_ext.oid)
if cert_ext != csr_ext:
return False
except cryptography.x509.ExtensionNotFound as dummy:
return False
return True
def _validate_signature_algorithms(self):
if self.cert.signature_algorithm_oid._name not in self.signature_algorithms:
return self.cert.signature_algorithm_oid._name
def _validate_subject(self):
expected_subject = Name([NameAttribute(oid=crypto_utils.cryptography_get_name_oid(sub[0]), value=to_text(sub[1]))
for sub in self.subject])
cert_subject = self.cert.subject
if not compare_sets(expected_subject, cert_subject, self.subject_strict):
return expected_subject, cert_subject
def _validate_issuer(self):
expected_issuer = Name([NameAttribute(oid=crypto_utils.cryptography_get_name_oid(iss[0]), value=to_text(iss[1]))
for iss in self.issuer])
cert_issuer = self.cert.issuer
if not compare_sets(expected_issuer, cert_issuer, self.issuer_strict):
return self.issuer, cert_issuer
def _validate_has_expired(self):
cert_not_after = self.cert.not_valid_after
cert_expired = cert_not_after < datetime.datetime.utcnow()
return cert_expired
def _validate_version(self):
if self.cert.version == x509.Version.v1:
return 1
if self.cert.version == x509.Version.v3:
return 3
return "unknown"
def _validate_key_usage(self):
try:
current_key_usage = self.cert.extensions.get_extension_for_class(x509.KeyUsage).value
test_key_usage = dict(
digital_signature=current_key_usage.digital_signature,
content_commitment=current_key_usage.content_commitment,
key_encipherment=current_key_usage.key_encipherment,
data_encipherment=current_key_usage.data_encipherment,
key_agreement=current_key_usage.key_agreement,
key_cert_sign=current_key_usage.key_cert_sign,
crl_sign=current_key_usage.crl_sign,
encipher_only=False,
decipher_only=False
)
if test_key_usage['key_agreement']:
test_key_usage.update(dict(
encipher_only=current_key_usage.encipher_only,
decipher_only=current_key_usage.decipher_only
))
key_usages = crypto_utils.cryptography_parse_key_usage_params(self.key_usage)
if not compare_dicts(key_usages, test_key_usage, self.key_usage_strict):
return self.key_usage, [x for x in test_key_usage if x is True]
except cryptography.x509.ExtensionNotFound:
# This is only bad if the user specified a non-empty list
if self.key_usage:
return NO_EXTENSION
def _validate_extended_key_usage(self):
try:
current_ext_keyusage = self.cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value
usages = [crypto_utils.cryptography_get_ext_keyusage(usage) for usage in self.extended_key_usage]
expected_ext_keyusage = x509.ExtendedKeyUsage(usages)
if not compare_sets(expected_ext_keyusage, current_ext_keyusage, self.extended_key_usage_strict):
return [eku.value for eku in expected_ext_keyusage], [eku.value for eku in current_ext_keyusage]
except cryptography.x509.ExtensionNotFound:
# This is only bad if the user specified a non-empty list
if self.extended_key_usage:
return NO_EXTENSION
def _validate_subject_alt_name(self):
try:
current_san = self.cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
expected_san = [crypto_utils.cryptography_get_name(san) for san in self.subject_alt_name]
if not compare_sets(expected_san, current_san, self.subject_alt_name_strict):
return self.subject_alt_name, current_san
except cryptography.x509.ExtensionNotFound:
# This is only bad if the user specified a non-empty list
if self.subject_alt_name:
return NO_EXTENSION
def _validate_not_before(self):
return self.cert.not_valid_before
def _validate_not_after(self):
return self.cert.not_valid_after
def _validate_valid_at(self):
rt = self.get_relative_time_option(self.valid_at, 'valid_at')
return self.cert.not_valid_before, rt, self.cert.not_valid_after
def _validate_invalid_at(self):
rt = self.get_relative_time_option(self.valid_at, 'valid_at')
return self.cert.not_valid_before, rt, self.cert.not_valid_after
def _validate_valid_in(self):
valid_in_date = self.get_relative_time_option(self.valid_in, "valid_in")
return self.cert.not_valid_before, valid_in_date, self.cert.not_valid_after
class AssertOnlyCertificate(AssertOnlyCertificateBase):
"""validate the supplied certificate.""" """validate the supplied certificate."""
def __init__(self, module): def __init__(self, module):
super(AssertOnlyCertificate, self).__init__(module, 'pyopenssl') super(AssertOnlyCertificate, self).__init__(module, 'pyopenssl')
self.signature_algorithms = module.params['signature_algorithms']
if module.params['subject']:
self.subject = crypto_utils.parse_name_field(module.params['subject'])
else:
self.subject = []
self.subject_strict = module.params['subject_strict']
if module.params['issuer']:
self.issuer = crypto_utils.parse_name_field(module.params['issuer'])
else:
self.issuer = []
self.issuer_strict = module.params['issuer_strict']
self.has_expired = module.params['has_expired']
self.version = module.params['version']
self.keyUsage = module.params['key_usage']
self.keyUsage_strict = module.params['key_usage_strict']
self.extendedKeyUsage = module.params['extended_key_usage']
self.extendedKeyUsage_strict = module.params['extended_key_usage_strict']
self.subjectAltName = module.params['subject_alt_name']
self.subjectAltName_strict = module.params['subject_alt_name_strict']
self.notBefore = module.params['not_before']
self.notAfter = module.params['not_after']
self.valid_at = module.params['valid_at']
self.invalid_at = module.params['invalid_at']
self.valid_in = module.params['valid_in']
self.message = []
self._sanitize_inputs()
def _sanitize_inputs(self):
"""Ensure inputs are properly sanitized before comparison."""
for param in ['signature_algorithms', 'keyUsage', 'extendedKeyUsage',
'subjectAltName', 'subject', 'issuer', 'notBefore',
'notAfter', 'valid_at', 'invalid_at']:
# Ensure inputs are properly sanitized before comparison.
for param in ['signature_algorithms', 'key_usage', 'extended_key_usage',
'subject_alt_name', 'subject', 'issuer', 'not_before',
'not_after', 'valid_at', 'invalid_at']:
attr = getattr(self, param) attr = getattr(self, param)
if isinstance(attr, list) and attr: if isinstance(attr, list) and attr:
if isinstance(attr[0], str): if isinstance(attr[0], str):
@ -1403,40 +1557,57 @@ class AssertOnlyCertificate(Certificate):
elif isinstance(attr, str): elif isinstance(attr, str):
setattr(self, param, to_bytes(attr)) setattr(self, param, to_bytes(attr))
def assertonly(self): def _validate_privatekey(self):
ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD)
ctx.use_privatekey(self.privatekey)
ctx.use_certificate(self.cert)
try:
ctx.check_privatekey()
return True
except OpenSSL.SSL.Error:
return False
self.cert = crypto_utils.load_certificate(self.path) def _validate_csr_signature(self):
try:
self.csr.verify(self.cert.get_pubkey())
except OpenSSL.crypto.Error:
return False
def _validate_csr_subject(self):
if self.csr.get_subject() != self.cert.get_subject():
return False
def _validate_csr_extensions(self):
csr_extensions = self.csr.get_extensions()
cert_extension_count = self.cert.get_extension_count()
if len(csr_extensions) != cert_extension_count:
return False
for extension_number in range(0, cert_extension_count):
cert_extension = self.cert.get_extension(extension_number)
csr_extension = filter(lambda extension: extension.get_short_name() == cert_extension.get_short_name(), csr_extensions)
if cert_extension.get_data() != list(csr_extension)[0].get_data():
return False
return True
def _validate_signature_algorithms(): def _validate_signature_algorithms(self):
if self.signature_algorithms:
if self.cert.get_signature_algorithm() not in self.signature_algorithms: if self.cert.get_signature_algorithm() not in self.signature_algorithms:
self.message.append( return self.cert.get_signature_algorithm()
'Invalid signature algorithm (got %s, expected one of %s)' % (self.cert.get_signature_algorithm(), self.signature_algorithms)
)
def _validate_subject(): def _validate_subject(self):
if self.subject:
expected_subject = [(OpenSSL._util.lib.OBJ_txt2nid(sub[0]), sub[1]) for sub in self.subject] expected_subject = [(OpenSSL._util.lib.OBJ_txt2nid(sub[0]), sub[1]) for sub in self.subject]
cert_subject = self.cert.get_subject().get_components() cert_subject = self.cert.get_subject().get_components()
current_subject = [(OpenSSL._util.lib.OBJ_txt2nid(sub[0]), sub[1]) for sub in cert_subject] current_subject = [(OpenSSL._util.lib.OBJ_txt2nid(sub[0]), sub[1]) for sub in cert_subject]
if (not self.subject_strict and not all(x in current_subject for x in expected_subject)) or \ if not compare_sets(expected_subject, current_subject, self.subject_strict):
(self.subject_strict and not set(expected_subject) == set(current_subject)): return expected_subject, current_subject
self.message.append(
'Invalid subject component (got %s, expected all of %s to be present)' % (cert_subject, self.subject)
)
def _validate_issuer(): def _validate_issuer(self):
if self.issuer:
expected_issuer = [(OpenSSL._util.lib.OBJ_txt2nid(iss[0]), iss[1]) for iss in self.issuer] expected_issuer = [(OpenSSL._util.lib.OBJ_txt2nid(iss[0]), iss[1]) for iss in self.issuer]
cert_issuer = self.cert.get_issuer().get_components() cert_issuer = self.cert.get_issuer().get_components()
current_issuer = [(OpenSSL._util.lib.OBJ_txt2nid(iss[0]), iss[1]) for iss in cert_issuer] current_issuer = [(OpenSSL._util.lib.OBJ_txt2nid(iss[0]), iss[1]) for iss in cert_issuer]
if (not self.issuer_strict and not all(x in current_issuer for x in expected_issuer)) or \ if not compare_sets(expected_issuer, current_issuer, self.issuer_strict):
(self.issuer_strict and not set(expected_issuer) == set(current_issuer)): return self.issuer, cert_issuer
self.message.append(
'Invalid issuer component (got %s, expected all of %s to be present)' % (cert_issuer, self.issuer)
)
def _validate_has_expired(): def _validate_has_expired(self):
# The following 3 lines are the same as the current PyOpenSSL code for cert.has_expired(). # The following 3 lines are the same as the current PyOpenSSL code for cert.has_expired().
# Older version of PyOpenSSL have a buggy implementation, # Older version of PyOpenSSL have a buggy implementation,
# to avoid issues with those we added the code from a more recent release here. # to avoid issues with those we added the code from a more recent release here.
@ -1444,60 +1615,46 @@ class AssertOnlyCertificate(Certificate):
time_string = to_native(self.cert.get_notAfter()) time_string = to_native(self.cert.get_notAfter())
not_after = datetime.datetime.strptime(time_string, "%Y%m%d%H%M%SZ") not_after = datetime.datetime.strptime(time_string, "%Y%m%d%H%M%SZ")
cert_expired = not_after < datetime.datetime.utcnow() cert_expired = not_after < datetime.datetime.utcnow()
return cert_expired
if self.has_expired != cert_expired: def _validate_version(self):
self.message.append(
'Certificate expiration check failed (certificate expiration is %s, expected %s)' % (cert_expired, self.has_expired)
)
def _validate_version():
if self.version:
# Version numbers in certs are off by one: # Version numbers in certs are off by one:
# v1: 0, v2: 1, v3: 2 ... # v1: 0, v2: 1, v3: 2 ...
if self.version != self.cert.get_version() + 1: return self.cert.get_version() + 1
self.message.append(
'Invalid certificate version number (got %s, expected %s)' % (self.cert.get_version() + 1, self.version)
)
def _validate_keyUsage(): def _validate_key_usage(self):
if self.keyUsage:
found = False found = False
for extension_idx in range(0, self.cert.get_extension_count()): for extension_idx in range(0, self.cert.get_extension_count()):
extension = self.cert.get_extension(extension_idx) extension = self.cert.get_extension(extension_idx)
if extension.get_short_name() == b'keyUsage': if extension.get_short_name() == b'keyUsage':
found = True found = True
keyUsage = [OpenSSL._util.lib.OBJ_txt2nid(keyUsage) for keyUsage in self.keyUsage] key_usage = [OpenSSL._util.lib.OBJ_txt2nid(key_usage) for key_usage in self.key_usage]
current_ku = [OpenSSL._util.lib.OBJ_txt2nid(usage.strip()) for usage in current_ku = [OpenSSL._util.lib.OBJ_txt2nid(usage.strip()) for usage in
to_bytes(extension, errors='surrogate_or_strict').split(b',')] to_bytes(extension, errors='surrogate_or_strict').split(b',')]
if (not self.keyUsage_strict and not all(x in current_ku for x in keyUsage)) or \ if not compare_sets(key_usage, current_ku, self.key_usage_strict):
(self.keyUsage_strict and not set(keyUsage) == set(current_ku)): return self.key_usage, str(extension).split(', ')
self.message.append(
'Invalid keyUsage component (got %s, expected all of %s to be present)' % (str(extension).split(', '), self.keyUsage)
)
if not found: if not found:
self.message.append('Found no keyUsage extension') # This is only bad if the user specified a non-empty list
if self.key_usage:
return NO_EXTENSION
def _validate_extendedKeyUsage(): def _validate_extended_key_usage(self):
if self.extendedKeyUsage:
found = False found = False
for extension_idx in range(0, self.cert.get_extension_count()): for extension_idx in range(0, self.cert.get_extension_count()):
extension = self.cert.get_extension(extension_idx) extension = self.cert.get_extension(extension_idx)
if extension.get_short_name() == b'extendedKeyUsage': if extension.get_short_name() == b'extendedKeyUsage':
found = True found = True
extKeyUsage = [OpenSSL._util.lib.OBJ_txt2nid(keyUsage) for keyUsage in self.extendedKeyUsage] extKeyUsage = [OpenSSL._util.lib.OBJ_txt2nid(keyUsage) for keyUsage in self.extended_key_usage]
current_xku = [OpenSSL._util.lib.OBJ_txt2nid(usage.strip()) for usage in current_xku = [OpenSSL._util.lib.OBJ_txt2nid(usage.strip()) for usage in
to_bytes(extension, errors='surrogate_or_strict').split(b',')] to_bytes(extension, errors='surrogate_or_strict').split(b',')]
if (not self.extendedKeyUsage_strict and not all(x in current_xku for x in extKeyUsage)) or \ if not compare_sets(extKeyUsage, current_xku, self.extended_key_usage_strict):
(self.extendedKeyUsage_strict and not set(extKeyUsage) == set(current_xku)): return self.extended_key_usage, str(extension).split(', ')
self.message.append(
'Invalid extendedKeyUsage component (got %s, expected all of %s to be present)' % (str(extension).split(', '),
self.extendedKeyUsage)
)
if not found: if not found:
self.message.append('Found no extendedKeyUsage extension') # This is only bad if the user specified a non-empty list
if self.extended_key_usage:
return NO_EXTENSION
def _validate_subjectAltName(): def _validate_subject_alt_name(self):
if self.subjectAltName:
found = False found = False
for extension_idx in range(0, self.cert.get_extension_count()): for extension_idx in range(0, self.cert.get_extension_count()):
extension = self.cert.get_extension(extension_idx) extension = self.cert.get_extension(extension_idx)
@ -1505,111 +1662,29 @@ class AssertOnlyCertificate(Certificate):
found = True found = True
l_altnames = [altname.replace(b'IP Address', b'IP') for altname in l_altnames = [altname.replace(b'IP Address', b'IP') for altname in
to_bytes(extension, errors='surrogate_or_strict').split(b', ')] to_bytes(extension, errors='surrogate_or_strict').split(b', ')]
if (not self.subjectAltName_strict and not all(x in l_altnames for x in self.subjectAltName)) or \ if not compare_sets(self.subject_alt_name, l_altnames, self.subject_alt_name_strict):
(self.subjectAltName_strict and not set(self.subjectAltName) == set(l_altnames)): return self.subject_alt_name, l_altnames
self.message.append(
'Invalid subjectAltName component (got %s, expected all of %s to be present)' % (l_altnames, self.subjectAltName)
)
if not found: if not found:
self.message.append('Found no subjectAltName extension') # This is only bad if the user specified a non-empty list
if self.subject_alt_name:
return NO_EXTENSION
def _validate_notBefore(): def _validate_not_before(self):
if self.notBefore: return self.cert.get_notBefore()
if self.cert.get_notBefore() != self.notBefore:
self.message.append(
'Invalid notBefore component (got %s, expected %s to be present)' % (self.cert.get_notBefore(), self.notBefore)
)
def _validate_notAfter(): def _validate_not_after(self):
if self.notAfter: return self.cert.get_notAfter()
if self.cert.get_notAfter() != self.notAfter:
self.message.append(
'Invalid notAfter component (got %s, expected %s to be present)' % (self.cert.get_notAfter(), self.notAfter)
)
def _validate_valid_at(): def _validate_valid_at(self):
if self.valid_at: return self.cert.get_notBefore(), self.valid_at, self.cert.get_notAfter()
if not (self.cert.get_notBefore() <= self.valid_at <= self.cert.get_notAfter()):
self.message.append(
'Certificate is not valid for the specified date (%s) - notBefore: %s - notAfter: %s' % (self.valid_at,
self.cert.get_notBefore(),
self.cert.get_notAfter())
)
def _validate_invalid_at(): def _validate_invalid_at(self):
if self.invalid_at: return self.cert.get_notBefore(), self.valid_at, self.cert.get_notAfter()
if not (self.invalid_at <= self.cert.get_notBefore() or self.invalid_at >= self.cert.get_notAfter()):
self.message.append(
'Certificate is not invalid for the specified date (%s) - notBefore: %s - notAfter: %s' % (self.invalid_at,
self.cert.get_notBefore(),
self.cert.get_notAfter())
)
def _validate_valid_in(): def _validate_valid_in(self):
if self.valid_in:
if not self.valid_in.startswith("+") and not self.valid_in.startswith("-"):
try:
int(self.valid_in)
except ValueError:
raise CertificateError(
'The supplied value for "valid_in" (%s) is not an integer or a valid timespec' % self.valid_in)
self.valid_in = "+" + self.valid_in + "s"
valid_in_asn1 = self.get_relative_time_option(self.valid_in, "valid_in") valid_in_asn1 = self.get_relative_time_option(self.valid_in, "valid_in")
valid_in_date = to_bytes(valid_in_asn1, errors='surrogate_or_strict') valid_in_date = to_bytes(valid_in_asn1, errors='surrogate_or_strict')
if not (self.cert.get_notBefore() <= valid_in_date <= self.cert.get_notAfter()): return self.cert.get_notBefore(), valid_in_date, self.cert.get_notAfter()
self.message.append(
'Certificate is not valid in %s from now (that would be %s) - notBefore: %s - notAfter: %s'
% (self.valid_in, valid_in_date,
self.cert.get_notBefore(),
self.cert.get_notAfter()))
for validation in ['signature_algorithms', 'subject', 'issuer',
'has_expired', 'version', 'keyUsage',
'extendedKeyUsage', 'subjectAltName',
'notBefore', 'notAfter', 'valid_at',
'invalid_at', 'valid_in']:
f_name = locals()['_validate_%s' % validation]
f_name()
def generate(self, module):
"""Don't generate anything - assertonly"""
self.assertonly()
try:
if self.privatekey_path and \
not super(AssertOnlyCertificate, self).check(module, perms_required=False):
self.message.append(
'Certificate %s and private key %s do not match' % (self.path, self.privatekey_path)
)
except CertificateError as e:
self.message.append(
'Error while reading private key %s: %s' % (self.privatekey_path, str(e))
)
if len(self.message):
module.fail_json(msg=' | '.join(self.message))
def check(self, module, perms_required=True):
"""Ensure the resource is in its desired state."""
parent_check = super(AssertOnlyCertificate, self).check(module, perms_required)
self.assertonly()
assertonly_check = not len(self.message)
self.message = []
return parent_check and assertonly_check
def dump(self, check_mode=False):
result = {
'changed': self.changed,
'filename': self.path,
'privatekey': self.privatekey_path,
'csr': self.csr_path,
}
return result
class AcmeCertificate(Certificate): class AcmeCertificate(Certificate):

Loading…
Cancel
Save