diff --git a/control/controlbase/conn_test.go b/control/controlbase/conn_test.go index f25ae2850..827d5b9a1 100644 --- a/control/controlbase/conn_test.go +++ b/control/controlbase/conn_test.go @@ -26,6 +26,8 @@ import ( "tailscale.com/types/key" ) +const testProtocolVersion = 1 + func TestMessageSize(t *testing.T) { // This test is a regression guard against someone looking at // maxCiphertextSize, going "huh, we could be more efficient if it @@ -204,10 +206,10 @@ func TestConnStd(t *testing.T) { serverErr := make(chan error, 1) go func() { var err error - c2, err = Server(context.Background(), s2, controlKey, nil) + c2, err = Server(context.Background(), s2, controlKey, testProtocolVersion, nil) serverErr <- err }() - c1, err = Client(context.Background(), s1, machineKey, controlKey.Public()) + c1, err = Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion) if err != nil { s1.Close() s2.Close() @@ -396,11 +398,11 @@ func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) ) go func() { var err error - server, err = Server(context.Background(), serverConn, controlKey, nil) + server, err = Server(context.Background(), serverConn, controlKey, testProtocolVersion, nil) serverErr <- err }() - client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public()) + client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public(), testProtocolVersion) if err != nil { t.Fatalf("client connection failed: %v", err) } diff --git a/control/controlbase/handshake.go b/control/controlbase/handshake.go index 393576ee8..0fb2859b6 100644 --- a/control/controlbase/handshake.go +++ b/control/controlbase/handshake.go @@ -32,7 +32,7 @@ const ( protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" // protocolVersion is the version of the control protocol that // Client will use when initiating a handshake. - protocolVersion uint16 = 1 + //protocolVersion uint16 = 1 // protocolVersionPrefix is the name portion of the protocol // name+version string that gets mixed into the handshake as a // prologue. @@ -66,7 +66,7 @@ type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error) // protocol switching. By splitting the handshake into an initial // message and a continuation, we can embed the handshake initiation // into the HTTP protocol switching request and avoid a bit of delay. -func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) { +func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) { var s symmetricState s.Initialize() @@ -78,7 +78,7 @@ func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic) s.MixHash(controlKey.UntypedBytes()) // -> e, es, s, ss - init := mkInitiationMessage() + init := mkInitiationMessage(protocolVersion) machineEphemeral := key.NewMachine() machineEphemeralPub := machineEphemeral.Public() copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes()) @@ -96,7 +96,7 @@ func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic) s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload cont := func(ctx context.Context, conn net.Conn) (*Conn, error) { - return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey) + return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion) } return init[:], cont, nil } @@ -107,8 +107,8 @@ func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic) // This is a helper for when you don't need the fancy // continuation-style handshake, and just want to synchronously // upgrade a net.Conn to a secure transport. -func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) { - init, cont, err := ClientDeferred(machineKey, controlKey) +func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { + init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion) if err != nil { return nil, err } @@ -118,7 +118,7 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c return cont(ctx, conn) } -func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) { +func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { // No matter what, this function can only run once per s. Ensure // attempted reuse causes a panic. defer func() { @@ -193,13 +193,19 @@ func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricSta // Server initiates a control server handshake, returning the resulting // control connection. // +// maxSupportedVersion is the highest handshake version the server is +// willing to handshake with. The server will handshake with any +// version from 0 to maxSupportedVersion inclusive, the caller should +// inspect conn.Version() to determine what version of the handshake +// was executed. +// // optionalInit can be the client's initial handshake message as // returned by ClientDeferred, or nil in which case the initial // message is read from conn. // // The context deadline, if any, covers the entire handshaking // process. -func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { +func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, maxSupportedVersion uint16, optionalInit []byte) (*Conn, error) { if deadline, ok := ctx.Deadline(); ok { if err := conn.SetDeadline(deadline); err != nil { return nil, fmt.Errorf("setting conn deadline: %w", err) @@ -239,9 +245,16 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, o } else if _, err := io.ReadFull(conn, init.Header()); err != nil { return nil, err } - if init.Version() != protocolVersion { - return nil, sendErr("unsupported protocol version") + // Currently, these versions exclusively indicate what the upper + // RPC protocol understands, the Noise handshake is exactly the + // same in all versions. If that ever changes, this check will + // need to become more complex to handle different kinds of + // handshake. + if init.Version() > maxSupportedVersion { + return nil, sendErr("unsupported handshake version") } + // Just a rename to make it more obvious what the value is + clientVersion := init.Version() if init.Type() != msgTypeInitiation { return nil, sendErr("unexpected handshake message type") } @@ -257,7 +270,7 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, o // prologue. Can only do this once we at least think the client is // handshaking using a supported version. - s.MixHash(protocolVersionPrologue(protocolVersion)) + s.MixHash(protocolVersionPrologue(clientVersion)) // <- s // ... @@ -310,7 +323,7 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, o c := &Conn{ conn: conn, - version: protocolVersion, + version: clientVersion, peer: machineKey, handshakeHash: s.h, tx: txState{ diff --git a/control/controlbase/handshake_test.go b/control/controlbase/handshake_test.go index 9cdc6f5f2..ce28f12e8 100644 --- a/control/controlbase/handshake_test.go +++ b/control/controlbase/handshake_test.go @@ -26,11 +26,11 @@ func TestHandshake(t *testing.T) { ) go func() { var err error - server, err = Server(context.Background(), serverConn, serverKey, nil) + server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) serverErr <- err }() - client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) + client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion) if err != nil { t.Fatalf("client connection failed: %v", err) } @@ -42,8 +42,8 @@ func TestHandshake(t *testing.T) { t.Fatal("client and server disagree on handshake hash") } - if client.ProtocolVersion() != int(protocolVersion) { - t.Fatalf("client reporting wrong protocol version %d, want %d", client.ProtocolVersion(), protocolVersion) + if client.ProtocolVersion() != int(testProtocolVersion) { + t.Fatalf("client reporting wrong protocol version %d, want %d", client.ProtocolVersion(), testProtocolVersion) } if client.ProtocolVersion() != server.ProtocolVersion() { t.Fatalf("peers disagree on protocol version, client=%d server=%d", client.ProtocolVersion(), server.ProtocolVersion()) @@ -78,11 +78,11 @@ func TestNoReuse(t *testing.T) { ) go func() { var err error - server, err = Server(context.Background(), serverConn, serverKey, nil) + server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) serverErr <- err }() - client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) + client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion) if err != nil { t.Fatalf("client connection failed: %v", err) } @@ -172,7 +172,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - _, err := Server(context.Background(), serverConn, serverKey, nil) + _, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) // If the server failed, we have to close the Conn to // unblock the client. if err != nil { @@ -181,7 +181,7 @@ func TestTampering(t *testing.T) { serverErr <- err }() - _, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) + _, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion) if err == nil { t.Fatal("client connection succeeded despite tampering") } @@ -200,11 +200,11 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - _, err := Server(context.Background(), serverConn, serverKey, nil) + _, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) serverErr <- err }() - _, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) + _, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion) if err == nil { t.Fatal("client connection succeeded despite tampering") } @@ -225,13 +225,13 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - server, err := Server(context.Background(), serverConn, serverKey, nil) + server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) serverErr <- err _, err = io.WriteString(server, strings.Repeat("a", 14)) serverErr <- err }() - client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) + client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion) if err != nil { t.Fatalf("client handshake failed: %v", err) } @@ -266,7 +266,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - server, err := Server(context.Background(), serverConn, serverKey, nil) + server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) serverErr <- err var bs [100]byte // The server needs a timeout if the tampering is hitting the length header. @@ -281,7 +281,7 @@ func TestTampering(t *testing.T) { } }() - client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) + client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion) if err != nil { t.Fatalf("client handshake failed: %v", err) } diff --git a/control/controlbase/interop_test.go b/control/controlbase/interop_test.go index 3417639fe..b7e7d15e8 100644 --- a/control/controlbase/interop_test.go +++ b/control/controlbase/interop_test.go @@ -29,7 +29,7 @@ func TestInteropClient(t *testing.T) { ) go func() { - server, err := Server(context.Background(), s2, controlKey, nil) + server, err := Server(context.Background(), s2, controlKey, testProtocolVersion, nil) serverErr <- err if err != nil { return @@ -77,7 +77,7 @@ func TestInteropServer(t *testing.T) { ) go func() { - client, err := Client(context.Background(), s1, machineKey, controlKey.Public()) + client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion) clientErr <- err if err != nil { return @@ -121,11 +121,11 @@ func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey copy(mk.public_key[:], machineKey.Public().UntypedBytes()) var peerKey [32]byte copy(peerKey[:], controlKey.UntypedBytes()) - session := InitSession(true, protocolVersionPrologue(protocolVersion), mk, peerKey) + session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey) _, msg1 := SendMessage(&session, nil) var hdr [initiationHeaderLen]byte - binary.BigEndian.PutUint16(hdr[:2], protocolVersion) + binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion) hdr[2] = msgTypeInitiation binary.BigEndian.PutUint16(hdr[3:5], 96) if _, err := conn.Write(hdr[:]); err != nil { @@ -193,7 +193,7 @@ func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachi var mk keypair copy(mk.private_key[:], controlKey.UntypedBytes()) copy(mk.public_key[:], controlKey.Public().UntypedBytes()) - session := InitSession(false, protocolVersionPrologue(protocolVersion), mk, [32]byte{}) + session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{}) var buf [1024]byte if _, err := io.ReadFull(conn, buf[:101]); err != nil { diff --git a/control/controlbase/messages.go b/control/controlbase/messages.go index 2a64d4585..62d42c2f2 100644 --- a/control/controlbase/messages.go +++ b/control/controlbase/messages.go @@ -39,9 +39,9 @@ const ( // 16b: message tag (authenticates the whole message) type initiationMessage [101]byte -func mkInitiationMessage() initiationMessage { +func mkInitiationMessage(protocolVersion uint16) initiationMessage { var ret initiationMessage - binary.BigEndian.PutUint16(ret[:2], uint16(protocolVersion)) + binary.BigEndian.PutUint16(ret[:2], protocolVersion) ret[2] = msgTypeInitiation binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload()))) return ret diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index 8fb828f15..1a5f4a73f 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -8,6 +8,7 @@ import ( "context" "crypto/tls" "fmt" + "math" "net" "net/http" "net/url" @@ -17,6 +18,7 @@ import ( "golang.org/x/net/http2" "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp" + "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/util/multierr" ) @@ -146,7 +148,12 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - conn, err := controlhttp.Dial(ctx, nc.serverHost, nc.privKey, nc.serverPubKey) + if tailcfg.CurrentCapabilityVersion > math.MaxUint16 { + // Panic, because a test should have started failing several + // thousand version numbers before getting to this point. + panic("capability version is too high to fit in the wire protocol") + } + conn, err := controlhttp.Dial(ctx, nc.serverHost, nc.privKey, nc.serverPubKey, uint16(tailcfg.CurrentCapabilityVersion)) if err != nil { return nil, err } diff --git a/control/controlclient/noise_test.go b/control/controlclient/noise_test.go new file mode 100644 index 000000000..a97af4045 --- /dev/null +++ b/control/controlclient/noise_test.go @@ -0,0 +1,28 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package controlclient + +import ( + "math" + "testing" + + "tailscale.com/tailcfg" +) + +// maxAllowedNoiseVersion is the highest we expect the Tailscale +// capability version to ever get. It's a value close to 2^16, but +// with enough leeway that we get a very early warning that it's time +// to rework the wire protocol to allow larger versions, while still +// giving us headroom to bump this test and fix the build. +// +// Code elsewhere in the client will panic() if the tailcfg capability +// version exceeds 16 bits, so take a failure of this test seriously. +const maxAllowedNoiseVersion = math.MaxUint16 - 5000 + +func TestNoiseVersion(t *testing.T) { + if tailcfg.CurrentCapabilityVersion > maxAllowedNoiseVersion { + t.Fatalf("tailcfg.CurrentCapabilityVersion is %d, want <=%d", tailcfg.CurrentCapabilityVersion, maxAllowedNoiseVersion) + } +} diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index cf97acb23..5a92d50f8 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -65,7 +65,7 @@ const ( // // The provided ctx is only used for the initial connection, until // Dial returns. It does not affect the connection once established. -func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*controlbase.Conn, error) { +func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*controlbase.Conn, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -77,6 +77,7 @@ func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, contr httpsPort: "443", machineKey: machineKey, controlKey: controlKey, + version: protocolVersion, proxyFunc: tshttpproxy.ProxyFromEnvironment, } return a.dial() @@ -89,6 +90,7 @@ type dialParams struct { httpsPort string machineKey key.MachinePrivate controlKey key.MachinePublic + version uint16 proxyFunc func(*http.Request) (*url.URL, error) // or nil // For tests only @@ -96,7 +98,7 @@ type dialParams struct { } func (a *dialParams) dial() (*controlbase.Conn, error) { - init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey) + init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version) if err != nil { return nil, err } @@ -120,7 +122,7 @@ func (a *dialParams) dial() (*controlbase.Conn, error) { // being difficult and see if we can get through over HTTPS. u.Scheme = "https" u.Host = net.JoinHostPort(a.host, a.httpsPort) - init, cont, err = controlbase.ClientDeferred(a.machineKey, a.controlKey) + init, cont, err = controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version) if err != nil { return nil, err } diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index 799eb1b19..c4b8ddc36 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -104,9 +104,10 @@ func TestControlHTTP(t *testing.T) { func testControlHTTP(t *testing.T, proxy proxy) { client, server := key.NewMachine(), key.NewMachine() + const testProtocolVersion = 1 sch := make(chan serverResult, 1) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := AcceptHTTP(context.Background(), w, r, server) + conn, err := AcceptHTTP(context.Background(), w, r, server, testProtocolVersion) if err != nil { log.Print(err) } @@ -152,6 +153,7 @@ func testControlHTTP(t *testing.T, proxy proxy) { httpsPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port), machineKey: client, controlKey: server.Public(), + version: testProtocolVersion, insecureTLS: true, } diff --git a/control/controlhttp/server.go b/control/controlhttp/server.go index 0e38da860..8d7073ffe 100644 --- a/control/controlhttp/server.go +++ b/control/controlhttp/server.go @@ -21,7 +21,7 @@ import ( // // AcceptHTTP always writes an HTTP response to w. The caller must not // attempt their own response after calling AcceptHTTP. -func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) { +func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, maxSupportedVersion uint16) (*controlbase.Conn, error) { next := r.Header.Get("Upgrade") if next == "" { http.Error(w, "missing next protocol", http.StatusBadRequest) @@ -63,7 +63,7 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri } conn = netutil.NewDrainBufConn(conn, brw.Reader) - nc, err := controlbase.Server(ctx, conn, private, init) + nc, err := controlbase.Server(ctx, conn, private, maxSupportedVersion, init) if err != nil { conn.Close() return nil, fmt.Errorf("noise handshake failed: %w", err)