diff --git a/ipn/ipnlocal/peerapi.go b/ipn/ipnlocal/peerapi.go index b720ce08b..36770d7ab 100644 --- a/ipn/ipnlocal/peerapi.go +++ b/ipn/ipnlocal/peerapi.go @@ -411,7 +411,7 @@ func (s *peerAPIServer) listen(ip netaddr.IP, ifState *interfaces.State) (ln net } // Make a best effort to pick a deterministic port number for - // the ip The lower three bytes are the same for IPv4 and IPv6 + // the ip. The lower three bytes are the same for IPv4 and IPv6 // Tailscale addresses (at least currently), so we'll usually // get the same port number on both address families for // dev/debugging purposes, which is nice. But it's not so @@ -507,7 +507,7 @@ func (pln *peerAPIListener) ServeConn(src netaddr.IPPort, c net.Conn) { if addH2C != nil { addH2C(httpServer) } - go httpServer.Serve(netutil.NewOneConnListenerFrom(c, pln.ln)) + go httpServer.Serve(netutil.NewOneConnListener(c, pln.ln.Addr())) } // peerAPIHandler serves the Peer API for a source specific client. diff --git a/ipn/ipnserver/server.go b/ipn/ipnserver/server.go index 522fead66..805c00c3a 100644 --- a/ipn/ipnserver/server.go +++ b/ipn/ipnserver/server.go @@ -312,7 +312,7 @@ func (s *Server) serveConn(ctx context.Context, c net.Conn, logf logger.Logf) { ErrorLog: logger.StdLogger(logf), Handler: s.localhostHandler(ci), } - httpServer.Serve(netutil.NewOneConnListener(&protoSwitchConn{s: s, br: br, Conn: c})) + httpServer.Serve(netutil.NewOneConnListener(&protoSwitchConn{s: s, br: br, Conn: c}, nil)) return } diff --git a/net/netutil/netutil.go b/net/netutil/netutil.go index 49a2aec64..12d61122e 100644 --- a/net/netutil/netutil.go +++ b/net/netutil/netutil.go @@ -8,36 +8,51 @@ package netutil import ( "io" "net" + "sync" ) -// NewOneConnListener returns a net.Listener that returns c on its first -// Accept and EOF thereafter. The Listener's Addr is a dummy address. -func NewOneConnListener(c net.Conn) net.Listener { - return NewOneConnListenerFrom(c, dummyListener{}) -} - -// NewOneConnListenerFrom returns a net.Listener wrapping ln where -// its Accept returns c on the first call and io.EOF thereafter. -func NewOneConnListenerFrom(c net.Conn, ln net.Listener) net.Listener { - return &oneConnListener{c, ln} +// NewOneConnListener returns a net.Listener that returns c on its +// first Accept and EOF thereafter. +// +// The returned Listener's Addr method returns addr if non-nil. If nil, +// Addr returns a non-nil dummy address instead. +func NewOneConnListener(c net.Conn, addr net.Addr) net.Listener { + if addr == nil { + addr = dummyAddr("one-conn-listener") + } + return &oneConnListener{ + addr: addr, + conn: c, + } } type oneConnListener struct { + addr net.Addr + + mu sync.Mutex conn net.Conn - net.Listener } -func (l *oneConnListener) Accept() (c net.Conn, err error) { - c = l.conn +func (ln *oneConnListener) Accept() (c net.Conn, err error) { + ln.mu.Lock() + defer ln.mu.Unlock() + c = ln.conn if c == nil { err = io.EOF return } err = nil - l.conn = nil + ln.conn = nil return } +func (ln *oneConnListener) Addr() net.Addr { return ln.addr } + +func (ln *oneConnListener) Close() error { + ln.Accept() // guarantee future call returns io.EOF + return nil +} + type dummyListener struct{} func (dummyListener) Close() error { return nil } diff --git a/net/netutil/netutil_test.go b/net/netutil/netutil_test.go new file mode 100644 index 000000000..5deb0d94d --- /dev/null +++ b/net/netutil/netutil_test.go @@ -0,0 +1,54 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netutil + +import ( + "io" + "net" + "testing" +) + +type conn struct { + net.Conn +} + +func TestOneConnListener(t *testing.T) { + c1 := new(conn) + a1 := dummyAddr("a1") + + // Two Accepts + ln := NewOneConnListener(c1, a1) + if got := ln.Addr(); got != a1 { + t.Errorf("Addr = %#v; want %#v", got, a1) + } + c, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + if c != c1 { + t.Fatalf("didn't get c1; got %p", c) + } + c, err = ln.Accept() + if err != io.EOF { + t.Errorf("got %v; want EOF", err) + } + if c != nil { + t.Errorf("unexpected non-nil Conn") + } + + // Close before Accept + ln = NewOneConnListener(c1, a1) + ln.Close() + _, err = ln.Accept() + if err != io.EOF { + t.Fatalf("got %v; want EOF", err) + } + + // Implicit addr + ln = NewOneConnListener(c1, nil) + if ln.Addr() == nil { + t.Errorf("nil Addr") + } +}