@ -50,21 +50,23 @@ func protocolVersionPrologue(version uint16) []byte {
return strconv . AppendUint ( ret , uint64 ( version ) , 10 )
}
// Client initiates a control client handshake, returning the resulting
// control connection.
// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
// is assumed to have already sent the client>server handshake
// initiation message.
type HandshakeContinuation func ( context . Context , net . Conn ) ( * Conn , error )
// ClientDeferred initiates a control client handshake, returning the
// initial message to send to the server and a continuation to
// finalize the handshake.
//
// The context deadline, if any, covers the entire handshaking
// process. Any preexisting Conn deadline is removed.
func Client ( ctx context . Context , conn net . Conn , machineKey key . MachinePrivate , controlKey key . MachinePublic ) ( * Conn , error ) {
if deadline , ok := ctx . Deadline ( ) ; ok {
if err := conn . SetDeadline ( deadline ) ; err != nil {
return nil , fmt . Errorf ( "setting conn deadline: %w" , err )
}
defer func ( ) {
conn . SetDeadline ( time . Time { } )
} ( )
}
// ClientDeferred is split in this way for RTT reduction: we run this
// protocol after negotiating a protocol switch from HTTP/HTTPS. If we
// completely serialized the negotiation followed by the handshake,
// we'd pay an extra RTT to transmit the handshake initiation after
// protocol switching. By splitting the handshake into an initial
// message and a continuation, we can embed the handshake initiation
// into the HTTP protocol switching request and avoid a bit of delay.
func ClientDeferred ( machineKey key . MachinePrivate , controlKey key . MachinePublic ) ( initialHandshake [ ] byte , continueHandshake HandshakeContinuation , err error ) {
var s symmetricState
s . Initialize ( )
@ -83,18 +85,53 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
s . MixHash ( machineEphemeralPub . UntypedBytes ( ) )
cipher , err := s . MixDH ( machineEphemeral , controlKey )
if err != nil {
return nil , fmt . Errorf ( "computing es: %w" , err )
return nil , nil , fmt . Errorf ( "computing es: %w" , err )
}
machineKeyPub := machineKey . Public ( )
s . EncryptAndHash ( cipher , init . MachinePub ( ) , machineKeyPub . UntypedBytes ( ) )
cipher , err = s . MixDH ( machineKey , controlKey )
if err != nil {
return nil , fmt . Errorf ( "computing ss: %w" , err )
return nil , nil , fmt . Errorf ( "computing ss: %w" , err )
}
s . EncryptAndHash ( cipher , init . Tag ( ) , nil ) // empty message payload
if _ , err := conn . Write ( init [ : ] ) ; err != nil {
return nil , fmt . Errorf ( "writing initiation: %w" , err )
cont := func ( ctx context . Context , conn net . Conn ) ( * Conn , error ) {
return continueClientHandshake ( ctx , conn , & s , machineKey , machineEphemeral , controlKey )
}
return init [ : ] , cont , nil
}
// Client wraps ClientDeferred and immediately invokes the returned
// continuation with conn.
//
// This is a helper for when you don't need the fancy
// continuation-style handshake, and just want to synchronously
// upgrade a net.Conn to a secure transport.
func Client ( ctx context . Context , conn net . Conn , machineKey key . MachinePrivate , controlKey key . MachinePublic ) ( * Conn , error ) {
init , cont , err := ClientDeferred ( machineKey , controlKey )
if err != nil {
return nil , err
}
if _ , err := conn . Write ( init ) ; err != nil {
return nil , err
}
return cont ( ctx , conn )
}
func continueClientHandshake ( ctx context . Context , conn net . Conn , s * symmetricState , machineKey , machineEphemeral key . MachinePrivate , controlKey key . MachinePublic ) ( * Conn , error ) {
// No matter what, this function can only run once per s. Ensure
// attempted reuse causes a panic.
defer func ( ) {
s . finished = true
} ( )
if deadline , ok := ctx . Deadline ( ) ; ok {
if err := conn . SetDeadline ( deadline ) ; err != nil {
return nil , fmt . Errorf ( "setting conn deadline: %w" , err )
}
defer func ( ) {
conn . SetDeadline ( time . Time { } )
} ( )
}
// Read in the payload and look for errors/protocol violations from the server.
@ -122,10 +159,10 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
// <- e, ee, se
controlEphemeralPub := key . MachinePublicFromRaw32 ( mem . B ( resp . EphemeralPub ( ) ) )
s . MixHash ( controlEphemeralPub . UntypedBytes ( ) )
if _ , err = s . MixDH ( machineEphemeral , controlEphemeralPub ) ; err != nil {
if _ , err : = s . MixDH ( machineEphemeral , controlEphemeralPub ) ; err != nil {
return nil , fmt . Errorf ( "computing ee: %w" , err )
}
cipher , err = s . MixDH ( machineKey , controlEphemeralPub )
cipher , err : = s . MixDH ( machineKey , controlEphemeralPub )
if err != nil {
return nil , fmt . Errorf ( "computing se: %w" , err )
}
@ -156,9 +193,13 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
// Server initiates a control server handshake, returning the resulting
// control connection.
//
// optionalInit can be the client's initial handshake message as
// returned by ClientDeferred, or nil in which case the initial
// message is read from conn.
//
// The context deadline, if any, covers the entire handshaking
// process.
func Server ( ctx context . Context , conn net . Conn , controlKey key . MachinePrivate ) ( * Conn , error ) {
func Server ( ctx context . Context , conn net . Conn , controlKey key . MachinePrivate , optionalInit [ ] byte ) ( * Conn , error ) {
if deadline , ok := ctx . Deadline ( ) ; ok {
if err := conn . SetDeadline ( deadline ) ; err != nil {
return nil , fmt . Errorf ( "setting conn deadline: %w" , err )
@ -190,7 +231,12 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (
s . Initialize ( )
var init initiationMessage
if _ , err := io . ReadFull ( conn , init . Header ( ) ) ; err != nil {
if optionalInit != nil {
if len ( optionalInit ) != len ( init ) {
return nil , sendErr ( "wrong handshake initiation size" )
}
copy ( init [ : ] , optionalInit )
} else if _ , err := io . ReadFull ( conn , init . Header ( ) ) ; err != nil {
return nil , err
}
if init . Version ( ) != protocolVersion {
@ -202,8 +248,11 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (
if init . Length ( ) != len ( init . Payload ( ) ) {
return nil , sendErr ( "wrong handshake initiation length" )
}
if _ , err := io . ReadFull ( conn , init . Payload ( ) ) ; err != nil {
return nil , err
// if optionalInit was provided, we have the payload already.
if optionalInit == nil {
if _ , err := io . ReadFull ( conn , init . Payload ( ) ) ; err != nil {
return nil , err
}
}
// prologue. Can only do this once we at least think the client is