From cbc89830c4414883b21658329299e713e0d959aa Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 11 Nov 2022 17:55:14 -0800 Subject: [PATCH] tsnet: be stricter about arguments to Server.Listen Fixes #6201 Change-Id: I14b2b8ce9bee838344a3fad4f305c78ab775f72e Signed-off-by: Brad Fitzpatrick --- tsnet/tsnet.go | 25 ++++++++++++++++--------- tsnet/tsnet_test.go | 31 ++++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 68e67a76f..6ad856eb4 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "log" + "math" "net" "net/http" "net/netip" @@ -38,6 +39,7 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/smallzstd" "tailscale.com/types/logger" + "tailscale.com/util/mak" "tailscale.com/wgengine" "tailscale.com/wgengine/monitor" "tailscale.com/wgengine/netstack" @@ -423,7 +425,7 @@ func (s *Server) printAuthURLLoop() { func (s *Server) forwardTCP(c net.Conn, port uint16) { s.mu.Lock() - ln, ok := s.listeners[listenKey{"tcp", "", fmt.Sprint(port)}] + ln, ok := s.listeners[listenKey{"tcp", "", port}] s.mu.Unlock() if !ok { c.Close() @@ -500,16 +502,24 @@ func (s *Server) APIClient() (*tailscale.Client, error) { // Listen announces only on the Tailscale network. // It will start the server if it has not been started yet. func (s *Server) Listen(network, addr string) (net.Listener, error) { - host, port, err := net.SplitHostPort(addr) + switch network { + case "", "tcp", "tcp4", "tcp6": + default: + return nil, errors.New("unsupported network type") + } + host, portStr, err := net.SplitHostPort(addr) if err != nil { return nil, fmt.Errorf("tsnet: %w", err) } - + port, err := net.LookupPort(network, portStr) + if err != nil || port < 0 || port > math.MaxUint16 { + return nil, fmt.Errorf("invalid port: %w", err) + } if err := s.Start(); err != nil { return nil, err } - key := listenKey{network, host, port} + key := listenKey{network, host, uint16(port)} ln := &listener{ s: s, key: key, @@ -518,14 +528,11 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) { conn: make(chan net.Conn), } s.mu.Lock() - if s.listeners == nil { - s.listeners = map[listenKey]*listener{} - } if _, ok := s.listeners[key]; ok { s.mu.Unlock() return nil, fmt.Errorf("tsnet: listener already open for %s, %s", network, addr) } - s.listeners[key] = ln + mak.Set(&s.listeners, key, ln) s.mu.Unlock() return ln, nil } @@ -533,7 +540,7 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) { type listenKey struct { network string host string - port string + port uint16 } type listener struct { diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 59f7ce206..35110d9ab 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -4,7 +4,10 @@ package tsnet -import "testing" +import ( + "errors" + "testing" +) // TestListener_Server ensures that the listener type always keeps the Server // method, which is used by some external applications to identify a tsnet.Listener @@ -16,3 +19,29 @@ func TestListener_Server(t *testing.T) { t.Errorf("listener.Server() returned %v, want %v", ln.Server(), s) } } + +func TestListenerPort(t *testing.T) { + errNone := errors.New("sentinel start error") + + tests := []struct { + network string + addr string + wantErr bool + }{ + {"tcp", ":80", false}, + {"foo", ":80", true}, + {"tcp", ":http", false}, // built-in name to Go; doesn't require cgo, /etc/services + {"tcp", ":https", false}, // built-in name to Go; doesn't require cgo, /etc/services + {"tcp", ":gibberishsdlkfj", true}, + {"tcp", ":%!d(string=80)", true}, // issue 6201 + } + for _, tt := range tests { + s := &Server{} + s.initOnce.Do(func() { s.initErr = errNone }) + _, err := s.Listen(tt.network, tt.addr) + gotErr := err != nil && err != errNone + if gotErr != tt.wantErr { + t.Errorf("Listen(%q, %q) error = %v, want %v", tt.network, tt.addr, gotErr, tt.wantErr) + } + } +}