net/wsconn: accept a remote addr string and plumb it through

This makes wsconn.Conns somewhat present reasonably when they are
the client of an http.Request, rather than just put a placeholder
in that field.

Updates tailscale/corp#13777

Signed-off-by: David Anderson <danderson@tailscale.com>
pull/9151/head
David Anderson 1 year ago committed by Dave Anderson
parent e952564b59
commit 8b492b4121

@ -50,7 +50,7 @@ func addWebSocketSupport(s *derp.Server, base http.Handler) http.Handler {
return return
} }
counterWebSocketAccepts.Add(1) counterWebSocketAccepts.Add(1)
wc := wsconn.NetConn(r.Context(), c, websocket.MessageBinary) wc := wsconn.NetConn(r.Context(), c, websocket.MessageBinary, r.RemoteAddr)
brw := bufio.NewReadWriter(bufio.NewReader(wc), bufio.NewWriter(wc)) brw := bufio.NewReadWriter(bufio.NewReader(wc), bufio.NewWriter(wc))
s.Accept(r.Context(), wc, brw, r.RemoteAddr) s.Accept(r.Context(), wc, brw, r.RemoteAddr)
}) })

@ -51,7 +51,7 @@ func (d *Dialer) Dial(ctx context.Context) (*ClientConn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
netConn := wsconn.NetConn(context.Background(), wsConn, websocket.MessageBinary) netConn := wsconn.NetConn(context.Background(), wsConn, websocket.MessageBinary, wsURL.String())
cbConn, err := cont(ctx, netConn) cbConn, err := cont(ctx, netConn)
if err != nil { if err != nil {
netConn.Close() netConn.Close()

@ -146,7 +146,7 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request
return nil, fmt.Errorf("decoding base64 handshake parameter: %v", err) return nil, fmt.Errorf("decoding base64 handshake parameter: %v", err)
} }
conn := wsconn.NetConn(ctx, c, websocket.MessageBinary) conn := wsconn.NetConn(ctx, c, websocket.MessageBinary, r.RemoteAddr)
nc, err := controlbase.Server(ctx, conn, private, init) nc, err := controlbase.Server(ctx, conn, private, init)
if err != nil { if err != nil {
conn.Close() conn.Close()

@ -27,6 +27,6 @@ func dialWebsocket(ctx context.Context, urlStr string) (net.Conn, error) {
return nil, err return nil, err
} }
log.Printf("websocket: connected to %v", urlStr) log.Printf("websocket: connected to %v", urlStr)
netConn := wsconn.NetConn(context.Background(), c, websocket.MessageBinary) netConn := wsconn.NetConn(context.Background(), c, websocket.MessageBinary, urlStr)
return netConn, nil return netConn, nil
} }

@ -48,10 +48,18 @@ import (
// //
// A received StatusNormalClosure or StatusGoingAway close frame will be translated to // A received StatusNormalClosure or StatusGoingAway close frame will be translated to
// io.EOF when reading. // io.EOF when reading.
func NetConn(ctx context.Context, c *websocket.Conn, msgType websocket.MessageType) net.Conn { //
// The given remoteAddr will be the value of the returned conn's
// RemoteAddr().String(). For best compatibility with consumers of
// conns, the string should be an ip:port if available, but in the
// absence of that it can be any string that describes the remote
// endpoint, or the empty string to makes RemoteAddr() return a place
// holder value.
func NetConn(ctx context.Context, c *websocket.Conn, msgType websocket.MessageType, remoteAddr string) net.Conn {
nc := &netConn{ nc := &netConn{
c: c, c: c,
msgType: msgType, msgType: msgType,
remoteAddr: remoteAddr,
} }
var writeCancel context.CancelFunc var writeCancel context.CancelFunc
@ -84,6 +92,7 @@ func NetConn(ctx context.Context, c *websocket.Conn, msgType websocket.MessageTy
type netConn struct { type netConn struct {
c *websocket.Conn c *websocket.Conn
msgType websocket.MessageType msgType websocket.MessageType
remoteAddr string
writeTimer *time.Timer writeTimer *time.Timer
writeContext context.Context writeContext context.Context
@ -167,6 +176,7 @@ func (c *netConn) Read(p []byte) (int, error) {
} }
type websocketAddr struct { type websocketAddr struct {
addr string
} }
func (a websocketAddr) Network() string { func (a websocketAddr) Network() string {
@ -174,15 +184,18 @@ func (a websocketAddr) Network() string {
} }
func (a websocketAddr) String() string { func (a websocketAddr) String() string {
if a.addr != "" {
return a.addr
}
return "websocket/unknown-addr" return "websocket/unknown-addr"
} }
func (c *netConn) RemoteAddr() net.Addr { func (c *netConn) RemoteAddr() net.Addr {
return websocketAddr{} return websocketAddr{c.remoteAddr}
} }
func (c *netConn) LocalAddr() net.Addr { func (c *netConn) LocalAddr() net.Addr {
return websocketAddr{} return websocketAddr{""}
} }
func (c *netConn) SetDeadline(t time.Time) error { func (c *netConn) SetDeadline(t time.Time) error {

Loading…
Cancel
Save