control/controlbase: don't enforce a max protocol version at handshake time.

Doing so makes development unpleasant, because we have to first break the
client by bumping to a version the control server rejects, then upgrade
the control server to make it accept the new version.

This strict rejection at handshake time is only necessary if we want to
blocklist some vulnerable protocol versions in the future. So, switch
to a default-permissive stance: until we have such a version that we
have to eagerly block early, we'll accept whatever version the client
presents, and leave it to the user of controlbase.Conn to make decisions
based on that version.

Noise still enforces that the client and server *agree* on what protocol
version is being used, and the control server still has the option to
finish the handshake and then hang up with an in-noise error, rather
than abort at the handshake level.

Updates #3488

Signed-off-by: David Anderson <danderson@tailscale.com>
pull/4386/head
David Anderson 2 years ago committed by Dave Anderson
parent c6ac29bcc4
commit f570372b4d

@ -206,7 +206,7 @@ func TestConnStd(t *testing.T) {
serverErr := make(chan error, 1)
go func() {
var err error
c2, err = Server(context.Background(), s2, controlKey, testProtocolVersion, nil)
c2, err = Server(context.Background(), s2, controlKey, nil)
serverErr <- err
}()
c1, err = Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
@ -398,7 +398,7 @@ func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn)
)
go func() {
var err error
server, err = Server(context.Background(), serverConn, controlKey, testProtocolVersion, nil)
server, err = Server(context.Background(), serverConn, controlKey, nil)
serverErr <- err
}()

@ -193,19 +193,13 @@ func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricSta
// Server initiates a control server handshake, returning the resulting
// control connection.
//
// maxSupportedVersion is the highest handshake version the server is
// willing to handshake with. The server will handshake with any
// version from 0 to maxSupportedVersion inclusive, the caller should
// inspect conn.Version() to determine what version of the handshake
// was executed.
//
// optionalInit can be the client's initial handshake message as
// returned by ClientDeferred, or nil in which case the initial
// message is read from conn.
//
// The context deadline, if any, covers the entire handshaking
// process.
func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, maxSupportedVersion uint16, optionalInit []byte) (*Conn, error) {
func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("setting conn deadline: %w", err)
@ -245,15 +239,11 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, m
} else if _, err := io.ReadFull(conn, init.Header()); err != nil {
return nil, err
}
// Currently, these versions exclusively indicate what the upper
// RPC protocol understands, the Noise handshake is exactly the
// same in all versions. If that ever changes, this check will
// need to become more complex to handle different kinds of
// handshake.
if init.Version() > maxSupportedVersion {
return nil, sendErr("unsupported handshake version")
}
// Just a rename to make it more obvious what the value is
// Just a rename to make it more obvious what the value is. In the
// current implementation we don't need to block any protocol
// versions at this layer, it's safe to let the handshake proceed
// and then let the caller make decisions based on the agreed-upon
// protocol version.
clientVersion := init.Version()
if init.Type() != msgTypeInitiation {
return nil, sendErr("unexpected handshake message type")

@ -26,7 +26,7 @@ func TestHandshake(t *testing.T) {
)
go func() {
var err error
server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
server, err = Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err
}()
@ -78,7 +78,7 @@ func TestNoReuse(t *testing.T) {
)
go func() {
var err error
server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
server, err = Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err
}()
@ -172,7 +172,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1)
)
go func() {
_, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
_, err := Server(context.Background(), serverConn, serverKey, nil)
// If the server failed, we have to close the Conn to
// unblock the client.
if err != nil {
@ -200,7 +200,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1)
)
go func() {
_, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
_, err := Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err
}()
@ -225,7 +225,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1)
)
go func() {
server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
server, err := Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err
_, err = io.WriteString(server, strings.Repeat("a", 14))
serverErr <- err
@ -266,7 +266,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1)
)
go func() {
server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
server, err := Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err
var bs [100]byte
// The server needs a timeout if the tampering is hitting the length header.

@ -29,7 +29,7 @@ func TestInteropClient(t *testing.T) {
)
go func() {
server, err := Server(context.Background(), s2, controlKey, testProtocolVersion, nil)
server, err := Server(context.Background(), s2, controlKey, nil)
serverErr <- err
if err != nil {
return

@ -107,7 +107,7 @@ func testControlHTTP(t *testing.T, proxy proxy) {
const testProtocolVersion = 1
sch := make(chan serverResult, 1)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := AcceptHTTP(context.Background(), w, r, server, testProtocolVersion)
conn, err := AcceptHTTP(context.Background(), w, r, server)
if err != nil {
log.Print(err)
}

@ -21,7 +21,7 @@ import (
//
// AcceptHTTP always writes an HTTP response to w. The caller must not
// attempt their own response after calling AcceptHTTP.
func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, maxSupportedVersion uint16) (*controlbase.Conn, error) {
func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) {
next := r.Header.Get("Upgrade")
if next == "" {
http.Error(w, "missing next protocol", http.StatusBadRequest)
@ -63,7 +63,7 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri
}
conn = netutil.NewDrainBufConn(conn, brw.Reader)
nc, err := controlbase.Server(ctx, conn, private, maxSupportedVersion, init)
nc, err := controlbase.Server(ctx, conn, private, init)
if err != nil {
conn.Close()
return nil, fmt.Errorf("noise handshake failed: %w", err)

Loading…
Cancel
Save