|
|
|
|
@ -106,21 +106,19 @@ class SshAgentFailure(RuntimeError):
|
|
|
|
|
# NOTE: Classes below somewhat represent "Data Type Representations Used in the SSH Protocols"
|
|
|
|
|
# as specified by RFC4251
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@t.runtime_checkable
|
|
|
|
|
class SupportsToBlob(t.Protocol):
|
|
|
|
|
def to_blob(self) -> bytes:
|
|
|
|
|
...
|
|
|
|
|
def to_blob(self) -> bytes: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@t.runtime_checkable
|
|
|
|
|
class SupportsFromBlob(t.Protocol):
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
|
|
|
|
|
...
|
|
|
|
|
def from_blob(cls, blob: memoryview | bytes) -> t.Self: ...
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]:
|
|
|
|
|
...
|
|
|
|
|
def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _split_blob(blob: memoryview | bytes, length: int) -> tuple[memoryview | bytes, memoryview | bytes]:
|
|
|
|
|
@ -304,10 +302,12 @@ class PrivateKeyMsg(Msg):
|
|
|
|
|
return EcdsaPrivateKeyMsg(
|
|
|
|
|
getattr(KeyAlgo, f'ECDSA{key_size}'),
|
|
|
|
|
unicode_string(f'nistp{key_size}'),
|
|
|
|
|
binary_string(private_key.public_key().public_bytes(
|
|
|
|
|
binary_string(
|
|
|
|
|
private_key.public_key().public_bytes(
|
|
|
|
|
encoding=serialization.Encoding.X962,
|
|
|
|
|
format=serialization.PublicFormat.UncompressedPoint
|
|
|
|
|
)),
|
|
|
|
|
format=serialization.PublicFormat.UncompressedPoint,
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
mpint(ecdsa_pn.private_value),
|
|
|
|
|
)
|
|
|
|
|
case Ed25519PrivateKey():
|
|
|
|
|
@ -318,7 +318,7 @@ class PrivateKeyMsg(Msg):
|
|
|
|
|
private_bytes = private_key.private_bytes(
|
|
|
|
|
encoding=serialization.Encoding.Raw,
|
|
|
|
|
format=serialization.PrivateFormat.Raw,
|
|
|
|
|
encryption_algorithm=serialization.NoEncryption()
|
|
|
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
|
|
|
)
|
|
|
|
|
return Ed25519PrivateKeyMsg(
|
|
|
|
|
KeyAlgo.ED25519,
|
|
|
|
|
@ -376,14 +376,14 @@ class Ed25519PrivateKeyMsg(PrivateKeyMsg):
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class PublicKeyMsg(Msg):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_dataclass(
|
|
|
|
|
type: KeyAlgo
|
|
|
|
|
) -> type[t.Union[
|
|
|
|
|
def get_dataclass(type: KeyAlgo) -> type[
|
|
|
|
|
t.Union[
|
|
|
|
|
RSAPublicKeyMsg,
|
|
|
|
|
EcdsaPublicKeyMsg,
|
|
|
|
|
Ed25519PublicKeyMsg,
|
|
|
|
|
DSAPublicKeyMsg
|
|
|
|
|
]]:
|
|
|
|
|
DSAPublicKeyMsg,
|
|
|
|
|
]
|
|
|
|
|
]:
|
|
|
|
|
match type:
|
|
|
|
|
case KeyAlgo.RSA:
|
|
|
|
|
return RSAPublicKeyMsg
|
|
|
|
|
@ -401,29 +401,14 @@ class PublicKeyMsg(Msg):
|
|
|
|
|
type: KeyAlgo = self.type
|
|
|
|
|
match type:
|
|
|
|
|
case KeyAlgo.RSA:
|
|
|
|
|
return RSAPublicNumbers(
|
|
|
|
|
self.e,
|
|
|
|
|
self.n
|
|
|
|
|
).public_key()
|
|
|
|
|
return RSAPublicNumbers(self.e, self.n).public_key()
|
|
|
|
|
case KeyAlgo.ECDSA256 | KeyAlgo.ECDSA384 | KeyAlgo.ECDSA521:
|
|
|
|
|
curve = _ECDSA_KEY_TYPE[KeyAlgo(type)]
|
|
|
|
|
return EllipticCurvePublicKey.from_encoded_point(
|
|
|
|
|
curve(),
|
|
|
|
|
self.Q
|
|
|
|
|
)
|
|
|
|
|
return EllipticCurvePublicKey.from_encoded_point(curve(), self.Q)
|
|
|
|
|
case KeyAlgo.ED25519:
|
|
|
|
|
return Ed25519PublicKey.from_public_bytes(
|
|
|
|
|
self.enc_a
|
|
|
|
|
)
|
|
|
|
|
return Ed25519PublicKey.from_public_bytes(self.enc_a)
|
|
|
|
|
case KeyAlgo.DSA:
|
|
|
|
|
return DSAPublicNumbers(
|
|
|
|
|
self.y,
|
|
|
|
|
DSAParameterNumbers(
|
|
|
|
|
self.p,
|
|
|
|
|
self.q,
|
|
|
|
|
self.g
|
|
|
|
|
)
|
|
|
|
|
).public_key()
|
|
|
|
|
return DSAPublicNumbers(self.y, DSAParameterNumbers(self.p, self.q, self.g)).public_key()
|
|
|
|
|
case _:
|
|
|
|
|
raise NotImplementedError(type)
|
|
|
|
|
|
|
|
|
|
@ -437,32 +422,32 @@ class PublicKeyMsg(Msg):
|
|
|
|
|
mpint(dsa_pn.parameter_numbers.p),
|
|
|
|
|
mpint(dsa_pn.parameter_numbers.q),
|
|
|
|
|
mpint(dsa_pn.parameter_numbers.g),
|
|
|
|
|
mpint(dsa_pn.y)
|
|
|
|
|
mpint(dsa_pn.y),
|
|
|
|
|
)
|
|
|
|
|
case EllipticCurvePublicKey():
|
|
|
|
|
return EcdsaPublicKeyMsg(
|
|
|
|
|
getattr(KeyAlgo, f'ECDSA{public_key.curve.key_size}'),
|
|
|
|
|
unicode_string(f'nistp{public_key.curve.key_size}'),
|
|
|
|
|
binary_string(public_key.public_bytes(
|
|
|
|
|
binary_string(
|
|
|
|
|
public_key.public_bytes(
|
|
|
|
|
encoding=serialization.Encoding.X962,
|
|
|
|
|
format=serialization.PublicFormat.UncompressedPoint
|
|
|
|
|
))
|
|
|
|
|
format=serialization.PublicFormat.UncompressedPoint,
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
case Ed25519PublicKey():
|
|
|
|
|
return Ed25519PublicKeyMsg(
|
|
|
|
|
KeyAlgo.ED25519,
|
|
|
|
|
binary_string(public_key.public_bytes(
|
|
|
|
|
binary_string(
|
|
|
|
|
public_key.public_bytes(
|
|
|
|
|
encoding=serialization.Encoding.Raw,
|
|
|
|
|
format=serialization.PublicFormat.Raw,
|
|
|
|
|
))
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
case RSAPublicKey():
|
|
|
|
|
rsa_pn: RSAPublicNumbers = public_key.public_numbers()
|
|
|
|
|
return RSAPublicKeyMsg(
|
|
|
|
|
KeyAlgo.RSA,
|
|
|
|
|
mpint(rsa_pn.e),
|
|
|
|
|
mpint(rsa_pn.n)
|
|
|
|
|
)
|
|
|
|
|
return RSAPublicKeyMsg(KeyAlgo.RSA, mpint(rsa_pn.e), mpint(rsa_pn.n))
|
|
|
|
|
case _:
|
|
|
|
|
raise NotImplementedError(public_key)
|
|
|
|
|
|
|
|
|
|
@ -473,10 +458,7 @@ class PublicKeyMsg(Msg):
|
|
|
|
|
msg.comments = unicode_string('')
|
|
|
|
|
k = msg.to_blob()
|
|
|
|
|
digest.update(k)
|
|
|
|
|
return binascii.b2a_base64(
|
|
|
|
|
digest.digest(),
|
|
|
|
|
newline=False
|
|
|
|
|
).rstrip(b'=').decode('utf-8')
|
|
|
|
|
return binascii.b2a_base64(digest.digest(), newline=False).rstrip(b'=').decode('utf-8')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(order=True, slots=True)
|
|
|
|
|
@ -519,9 +501,7 @@ class KeyList(Msg):
|
|
|
|
|
|
|
|
|
|
def __post_init__(self) -> None:
|
|
|
|
|
if self.nkeys != len(self.keys):
|
|
|
|
|
raise SshAgentFailure(
|
|
|
|
|
"agent: invalid number of keys received for identities list"
|
|
|
|
|
)
|
|
|
|
|
raise SshAgentFailure("agent: invalid number of keys received for identities list")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(order=True, slots=True)
|
|
|
|
|
@ -535,8 +515,7 @@ class PublicKeyMsgList(Msg):
|
|
|
|
|
return len(self.keys)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
|
|
|
|
|
...
|
|
|
|
|
def from_blob(cls, blob: memoryview | bytes) -> t.Self: ...
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]:
|
|
|
|
|
@ -546,22 +525,16 @@ class PublicKeyMsgList(Msg):
|
|
|
|
|
key_blob, key_blob_length, comment_blob = cls._consume_field(blob)
|
|
|
|
|
|
|
|
|
|
peek_key_algo, _length, _blob = cls._consume_field(key_blob)
|
|
|
|
|
pub_key_msg_cls = PublicKeyMsg.get_dataclass(
|
|
|
|
|
KeyAlgo(bytes(peek_key_algo).decode('utf-8'))
|
|
|
|
|
)
|
|
|
|
|
pub_key_msg_cls = PublicKeyMsg.get_dataclass(KeyAlgo(bytes(peek_key_algo).decode('utf-8')))
|
|
|
|
|
|
|
|
|
|
_fv, comment_blob_length, blob = cls._consume_field(comment_blob)
|
|
|
|
|
key_plus_comment = (
|
|
|
|
|
prev_blob[4: (4 + key_blob_length) + (4 + comment_blob_length)]
|
|
|
|
|
)
|
|
|
|
|
key_plus_comment = prev_blob[4 : (4 + key_blob_length) + (4 + comment_blob_length)]
|
|
|
|
|
|
|
|
|
|
args.append(pub_key_msg_cls.from_blob(key_plus_comment))
|
|
|
|
|
return cls(args), b""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _consume_field(
|
|
|
|
|
blob: memoryview | bytes
|
|
|
|
|
) -> tuple[memoryview | bytes, uint32, memoryview | bytes]:
|
|
|
|
|
def _consume_field(blob: memoryview | bytes) -> tuple[memoryview | bytes, uint32, memoryview | bytes]:
|
|
|
|
|
length = uint32.from_blob(blob[:4])
|
|
|
|
|
blob = blob[4:]
|
|
|
|
|
data, rest = _split_blob(blob, length)
|
|
|
|
|
@ -584,7 +557,7 @@ class SshAgentClient:
|
|
|
|
|
self,
|
|
|
|
|
exc_type: type[BaseException] | None,
|
|
|
|
|
exc_value: BaseException | None,
|
|
|
|
|
traceback: types.TracebackType | None
|
|
|
|
|
traceback: types.TracebackType | None,
|
|
|
|
|
) -> None:
|
|
|
|
|
self.close()
|
|
|
|
|
|
|
|
|
|
@ -598,16 +571,11 @@ class SshAgentClient:
|
|
|
|
|
return resp
|
|
|
|
|
|
|
|
|
|
def remove_all(self) -> None:
|
|
|
|
|
self.send(
|
|
|
|
|
ProtocolMsgNumbers.SSH_AGENTC_REMOVE_ALL_IDENTITIES.to_blob()
|
|
|
|
|
)
|
|
|
|
|
self.send(ProtocolMsgNumbers.SSH_AGENTC_REMOVE_ALL_IDENTITIES.to_blob())
|
|
|
|
|
|
|
|
|
|
def remove(self, public_key: CryptoPublicKey) -> None:
|
|
|
|
|
key_blob = PublicKeyMsg.from_public_key(public_key).to_blob()
|
|
|
|
|
self.send(
|
|
|
|
|
ProtocolMsgNumbers.SSH_AGENTC_REMOVE_IDENTITY.to_blob() +
|
|
|
|
|
uint32(len(key_blob)).to_blob() + key_blob
|
|
|
|
|
)
|
|
|
|
|
self.send(ProtocolMsgNumbers.SSH_AGENTC_REMOVE_IDENTITY.to_blob() + uint32(len(key_blob)).to_blob() + key_blob)
|
|
|
|
|
|
|
|
|
|
def add(
|
|
|
|
|
self,
|
|
|
|
|
@ -619,13 +587,9 @@ class SshAgentClient:
|
|
|
|
|
key_msg = PrivateKeyMsg.from_private_key(private_key)
|
|
|
|
|
key_msg.comments = unicode_string(comments or '')
|
|
|
|
|
if lifetime:
|
|
|
|
|
key_msg.constraints += constraints(
|
|
|
|
|
[ProtocolMsgNumbers.SSH_AGENT_CONSTRAIN_LIFETIME]
|
|
|
|
|
).to_blob() + uint32(lifetime).to_blob()
|
|
|
|
|
key_msg.constraints += constraints([ProtocolMsgNumbers.SSH_AGENT_CONSTRAIN_LIFETIME]).to_blob() + uint32(lifetime).to_blob()
|
|
|
|
|
if confirm:
|
|
|
|
|
key_msg.constraints += constraints(
|
|
|
|
|
[ProtocolMsgNumbers.SSH_AGENT_CONSTRAIN_CONFIRM]
|
|
|
|
|
).to_blob()
|
|
|
|
|
key_msg.constraints += constraints([ProtocolMsgNumbers.SSH_AGENT_CONSTRAIN_CONFIRM]).to_blob()
|
|
|
|
|
|
|
|
|
|
if key_msg.constraints:
|
|
|
|
|
msg = ProtocolMsgNumbers.SSH_AGENTC_ADD_ID_CONSTRAINED.to_blob()
|
|
|
|
|
@ -638,9 +602,7 @@ class SshAgentClient:
|
|
|
|
|
req = ProtocolMsgNumbers.SSH_AGENTC_REQUEST_IDENTITIES.to_blob()
|
|
|
|
|
r = memoryview(bytearray(self.send(req)))
|
|
|
|
|
if r[0] != ProtocolMsgNumbers.SSH_AGENT_IDENTITIES_ANSWER:
|
|
|
|
|
raise SshAgentFailure(
|
|
|
|
|
'agent: non-identities answer received for identities list'
|
|
|
|
|
)
|
|
|
|
|
raise SshAgentFailure('agent: non-identities answer received for identities list')
|
|
|
|
|
return KeyList.from_blob(r[1:])
|
|
|
|
|
|
|
|
|
|
def __contains__(self, public_key: CryptoPublicKey) -> bool:
|
|
|
|
|
@ -649,7 +611,7 @@ class SshAgentClient:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@functools.cache
|
|
|
|
|
def _key_data_into_crypto_objects(key_data: bytes, passphrase: bytes | None) -> tuple[CryptoPrivateKey, CryptoPublicKey, str]:
|
|
|
|
|
def key_data_into_crypto_objects(key_data: bytes, passphrase: bytes | None) -> tuple[CryptoPrivateKey, CryptoPublicKey, str]:
|
|
|
|
|
private_key = serialization.ssh.load_ssh_private_key(key_data, passphrase)
|
|
|
|
|
public_key = private_key.public_key()
|
|
|
|
|
fingerprint = PublicKeyMsg.from_public_key(public_key).fingerprint
|