diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 20cc285c9..4de2f682d 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -47,6 +47,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/net/tlsdial from tailscale.com/derp/derphttp tailscale.com/net/tsaddr from tailscale.com/ipn+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/derp/derphttp+ + tailscale.com/net/wsconn from tailscale.com/cmd/derper+ tailscale.com/paths from tailscale.com/client/tailscale tailscale.com/safesocket from tailscale.com/client/tailscale tailscale.com/syncs from tailscale.com/cmd/derper+ diff --git a/cmd/derper/websocket.go b/cmd/derper/websocket.go index 32258ae87..68b0ce940 100644 --- a/cmd/derper/websocket.go +++ b/cmd/derper/websocket.go @@ -13,6 +13,7 @@ import ( "nhooyr.io/websocket" "tailscale.com/derp" + "tailscale.com/net/wsconn" ) var counterWebSocketAccepts = expvar.NewInt("derp_websocket_accepts") @@ -50,7 +51,7 @@ func addWebSocketSupport(s *derp.Server, base http.Handler) http.Handler { return } counterWebSocketAccepts.Add(1) - wc := websocket.NetConn(r.Context(), c, websocket.MessageBinary) + wc := wsconn.NetConn(r.Context(), c, websocket.MessageBinary) brw := bufio.NewReadWriter(bufio.NewReader(wc), bufio.NewWriter(wc)) s.Accept(r.Context(), wc, brw, r.RemoteAddr) }) diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 4dfd77fac..63012be74 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -70,6 +70,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/net/tlsdial from tailscale.com/derp/derphttp+ tailscale.com/net/tsaddr from tailscale.com/net/interfaces+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/derp/derphttp+ + tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ tailscale.com/paths from tailscale.com/cmd/tailscale/cli+ tailscale.com/safesocket from tailscale.com/cmd/tailscale/cli+ tailscale.com/syncs from tailscale.com/net/netcheck+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index ac80fc49b..d9f7bd5e6 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -241,6 +241,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de 💣 tailscale.com/net/tshttpproxy from tailscale.com/control/controlclient+ tailscale.com/net/tstun from tailscale.com/net/dns+ tailscale.com/net/tunstats from tailscale.com/net/tstun+ + tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ tailscale.com/paths from tailscale.com/ipn/ipnlocal+ tailscale.com/portlist from tailscale.com/ipn/ipnlocal tailscale.com/safesocket from tailscale.com/client/tailscale+ diff --git a/control/controlhttp/client_js.go b/control/controlhttp/client_js.go index 5bc2cda76..feb660cfa 100644 --- a/control/controlhttp/client_js.go +++ b/control/controlhttp/client_js.go @@ -13,6 +13,7 @@ import ( "nhooyr.io/websocket" "tailscale.com/control/controlbase" + "tailscale.com/net/wsconn" ) // Variant of Dial that tunnels the request over WebSockets, since we cannot do @@ -51,7 +52,7 @@ func (d *Dialer) Dial(ctx context.Context) (*ClientConn, error) { if err != nil { return nil, err } - netConn := websocket.NetConn(context.Background(), wsConn, websocket.MessageBinary) + netConn := wsconn.NetConn(context.Background(), wsConn, websocket.MessageBinary) cbConn, err := cont(ctx, netConn) if err != nil { netConn.Close() diff --git a/control/controlhttp/server.go b/control/controlhttp/server.go index 23a8cf8ff..748da2527 100644 --- a/control/controlhttp/server.go +++ b/control/controlhttp/server.go @@ -14,6 +14,7 @@ import ( "nhooyr.io/websocket" "tailscale.com/control/controlbase" "tailscale.com/net/netutil" + "tailscale.com/net/wsconn" "tailscale.com/types/key" ) @@ -118,7 +119,7 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request return nil, fmt.Errorf("decoding base64 handshake parameter: %v", err) } - conn := websocket.NetConn(ctx, c, websocket.MessageBinary) + conn := wsconn.NetConn(ctx, c, websocket.MessageBinary) nc, err := controlbase.Server(ctx, conn, private, init) if err != nil { conn.Close() diff --git a/derp/derphttp/websocket.go b/derp/derphttp/websocket.go index 110c5fec5..e2c343dfe 100644 --- a/derp/derphttp/websocket.go +++ b/derp/derphttp/websocket.go @@ -13,6 +13,7 @@ import ( "net" "nhooyr.io/websocket" + "tailscale.com/net/wsconn" ) func init() { @@ -28,6 +29,6 @@ func dialWebsocket(ctx context.Context, urlStr string) (net.Conn, error) { return nil, err } log.Printf("websocket: connected to %v", urlStr) - netConn := websocket.NetConn(context.Background(), c, websocket.MessageBinary) + netConn := wsconn.NetConn(context.Background(), c, websocket.MessageBinary) return netConn, nil } diff --git a/net/wsconn/wsconn.go b/net/wsconn/wsconn.go new file mode 100644 index 000000000..e846aad51 --- /dev/null +++ b/net/wsconn/wsconn.go @@ -0,0 +1,213 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package wsconn contains an adapter type that turns +// a websocket connection into a net.Conn. It a temporary fork of the +// netconn.go file from the nhooyr.io/websocket package while we wait for +// https://github.com/nhooyr/websocket/pull/350 to be merged. +package wsconn + +import ( + "context" + "fmt" + "io" + "math" + "net" + "os" + "sync" + "sync/atomic" + "time" + + "nhooyr.io/websocket" +) + +// NetConn converts a *websocket.Conn into a net.Conn. +// +// It's for tunneling arbitrary protocols over WebSockets. +// Few users of the library will need this but it's tricky to implement +// correctly and so provided in the library. +// See https://github.com/nhooyr/websocket/issues/100. +// +// Every Write to the net.Conn will correspond to a message write of +// the given type on *websocket.Conn. +// +// The passed ctx bounds the lifetime of the net.Conn. If cancelled, +// all reads and writes on the net.Conn will be cancelled. +// +// If a message is read that is not of the correct type, the connection +// will be closed with StatusUnsupportedData and an error will be returned. +// +// Close will close the *websocket.Conn with StatusNormalClosure. +// +// When a deadline is hit, the connection will be closed. This is +// different from most net.Conn implementations where only the +// reading/writing goroutines are interrupted but the connection is kept alive. +// +// The Addr methods will return a mock net.Addr that returns "websocket" for Network +// and "websocket/unknown-addr" for String. +// +// A received StatusNormalClosure or StatusGoingAway close frame will be translated to +// io.EOF when reading. +func NetConn(ctx context.Context, c *websocket.Conn, msgType websocket.MessageType) net.Conn { + nc := &netConn{ + c: c, + msgType: msgType, + } + + var writeCancel context.CancelFunc + nc.writeContext, writeCancel = context.WithCancel(ctx) + nc.writeTimer = time.AfterFunc(math.MaxInt64, func() { + nc.afterWriteDeadline.Store(true) + if nc.writing.Load() { + writeCancel() + } + }) + if !nc.writeTimer.Stop() { + <-nc.writeTimer.C + } + + var readCancel context.CancelFunc + nc.readContext, readCancel = context.WithCancel(ctx) + nc.readTimer = time.AfterFunc(math.MaxInt64, func() { + nc.afterReadDeadline.Store(true) + if nc.reading.Load() { + readCancel() + } + }) + if !nc.readTimer.Stop() { + <-nc.readTimer.C + } + + return nc +} + +type netConn struct { + c *websocket.Conn + msgType websocket.MessageType + + writeTimer *time.Timer + writeContext context.Context + writing atomic.Bool + afterWriteDeadline atomic.Bool + + readTimer *time.Timer + readContext context.Context + reading atomic.Bool + afterReadDeadline atomic.Bool + + readMu sync.Mutex + eofed bool + reader io.Reader +} + +var _ net.Conn = &netConn{} + +func (c *netConn) Close() error { + return c.c.Close(websocket.StatusNormalClosure, "") +} + +func (c *netConn) Write(p []byte) (int, error) { + if c.afterWriteDeadline.Load() { + return 0, os.ErrDeadlineExceeded + } + + if swapped := c.writing.CompareAndSwap(false, true); !swapped { + panic("Concurrent writes not allowed") + } + defer c.writing.Store(false) + + err := c.c.Write(c.writeContext, c.msgType, p) + if err != nil { + return 0, err + } + + return len(p), nil +} + +func (c *netConn) Read(p []byte) (int, error) { + if c.afterReadDeadline.Load() { + return 0, os.ErrDeadlineExceeded + } + + c.readMu.Lock() + defer c.readMu.Unlock() + if swapped := c.reading.CompareAndSwap(false, true); !swapped { + panic("Concurrent reads not allowed") + } + defer c.reading.Store(false) + + if c.eofed { + return 0, io.EOF + } + + if c.reader == nil { + typ, r, err := c.c.Reader(c.readContext) + if err != nil { + switch websocket.CloseStatus(err) { + case websocket.StatusNormalClosure, websocket.StatusGoingAway: + c.eofed = true + return 0, io.EOF + } + return 0, err + } + if typ != c.msgType { + err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) + c.c.Close(websocket.StatusUnsupportedData, err.Error()) + return 0, err + } + c.reader = r + } + + n, err := c.reader.Read(p) + if err == io.EOF { + c.reader = nil + err = nil + } + return n, err +} + +type websocketAddr struct { +} + +func (a websocketAddr) Network() string { + return "websocket" +} + +func (a websocketAddr) String() string { + return "websocket/unknown-addr" +} + +func (c *netConn) RemoteAddr() net.Addr { + return websocketAddr{} +} + +func (c *netConn) LocalAddr() net.Addr { + return websocketAddr{} +} + +func (c *netConn) SetDeadline(t time.Time) error { + c.SetWriteDeadline(t) + c.SetReadDeadline(t) + return nil +} + +func (c *netConn) SetWriteDeadline(t time.Time) error { + if t.IsZero() { + c.writeTimer.Stop() + } else { + c.writeTimer.Reset(time.Until(t)) + } + c.afterWriteDeadline.Store(false) + return nil +} + +func (c *netConn) SetReadDeadline(t time.Time) error { + if t.IsZero() { + c.readTimer.Stop() + } else { + c.readTimer.Reset(time.Until(t)) + } + c.afterReadDeadline.Store(false) + return nil +}