From bb94561c96a4a8f4089acadbdda58aa8f0e83067 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 22 Feb 2022 13:29:17 -0800 Subject: [PATCH] net/netutil: fix regression where peerapi would get closed after 1st req I introduced a bug in 8fe503057da26 when unifying oneConnListener implementations. The NewOneConnListenerFrom API was easy to misuse (its Close method closes the underlying Listener), and we did (via http.Serve, which closes the listener after use, which meant we were close the peerapi's listener, even though we only wanted its Addr) Instead, combine those two constructors into one and pass in the Addr explicitly, without delegating through to any Listener. Change-Id: I061d7e5f842e0cada416e7b2dd62100d4f987125 Signed-off-by: Brad Fitzpatrick --- ipn/ipnlocal/peerapi.go | 4 +-- ipn/ipnserver/server.go | 2 +- net/netutil/netutil.go | 43 +++++++++++++++++++---------- net/netutil/netutil_test.go | 54 +++++++++++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 17 deletions(-) create mode 100644 net/netutil/netutil_test.go diff --git a/ipn/ipnlocal/peerapi.go b/ipn/ipnlocal/peerapi.go index b720ce08b..36770d7ab 100644 --- a/ipn/ipnlocal/peerapi.go +++ b/ipn/ipnlocal/peerapi.go @@ -411,7 +411,7 @@ func (s *peerAPIServer) listen(ip netaddr.IP, ifState *interfaces.State) (ln net } // Make a best effort to pick a deterministic port number for - // the ip The lower three bytes are the same for IPv4 and IPv6 + // the ip. The lower three bytes are the same for IPv4 and IPv6 // Tailscale addresses (at least currently), so we'll usually // get the same port number on both address families for // dev/debugging purposes, which is nice. But it's not so @@ -507,7 +507,7 @@ func (pln *peerAPIListener) ServeConn(src netaddr.IPPort, c net.Conn) { if addH2C != nil { addH2C(httpServer) } - go httpServer.Serve(netutil.NewOneConnListenerFrom(c, pln.ln)) + go httpServer.Serve(netutil.NewOneConnListener(c, pln.ln.Addr())) } // peerAPIHandler serves the Peer API for a source specific client. diff --git a/ipn/ipnserver/server.go b/ipn/ipnserver/server.go index 522fead66..805c00c3a 100644 --- a/ipn/ipnserver/server.go +++ b/ipn/ipnserver/server.go @@ -312,7 +312,7 @@ func (s *Server) serveConn(ctx context.Context, c net.Conn, logf logger.Logf) { ErrorLog: logger.StdLogger(logf), Handler: s.localhostHandler(ci), } - httpServer.Serve(netutil.NewOneConnListener(&protoSwitchConn{s: s, br: br, Conn: c})) + httpServer.Serve(netutil.NewOneConnListener(&protoSwitchConn{s: s, br: br, Conn: c}, nil)) return } diff --git a/net/netutil/netutil.go b/net/netutil/netutil.go index 49a2aec64..12d61122e 100644 --- a/net/netutil/netutil.go +++ b/net/netutil/netutil.go @@ -8,36 +8,51 @@ package netutil import ( "io" "net" + "sync" ) -// NewOneConnListener returns a net.Listener that returns c on its first -// Accept and EOF thereafter. The Listener's Addr is a dummy address. -func NewOneConnListener(c net.Conn) net.Listener { - return NewOneConnListenerFrom(c, dummyListener{}) -} - -// NewOneConnListenerFrom returns a net.Listener wrapping ln where -// its Accept returns c on the first call and io.EOF thereafter. -func NewOneConnListenerFrom(c net.Conn, ln net.Listener) net.Listener { - return &oneConnListener{c, ln} +// NewOneConnListener returns a net.Listener that returns c on its +// first Accept and EOF thereafter. +// +// The returned Listener's Addr method returns addr if non-nil. If nil, +// Addr returns a non-nil dummy address instead. +func NewOneConnListener(c net.Conn, addr net.Addr) net.Listener { + if addr == nil { + addr = dummyAddr("one-conn-listener") + } + return &oneConnListener{ + addr: addr, + conn: c, + } } type oneConnListener struct { + addr net.Addr + + mu sync.Mutex conn net.Conn - net.Listener } -func (l *oneConnListener) Accept() (c net.Conn, err error) { - c = l.conn +func (ln *oneConnListener) Accept() (c net.Conn, err error) { + ln.mu.Lock() + defer ln.mu.Unlock() + c = ln.conn if c == nil { err = io.EOF return } err = nil - l.conn = nil + ln.conn = nil return } +func (ln *oneConnListener) Addr() net.Addr { return ln.addr } + +func (ln *oneConnListener) Close() error { + ln.Accept() // guarantee future call returns io.EOF + return nil +} + type dummyListener struct{} func (dummyListener) Close() error { return nil } diff --git a/net/netutil/netutil_test.go b/net/netutil/netutil_test.go new file mode 100644 index 000000000..5deb0d94d --- /dev/null +++ b/net/netutil/netutil_test.go @@ -0,0 +1,54 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netutil + +import ( + "io" + "net" + "testing" +) + +type conn struct { + net.Conn +} + +func TestOneConnListener(t *testing.T) { + c1 := new(conn) + a1 := dummyAddr("a1") + + // Two Accepts + ln := NewOneConnListener(c1, a1) + if got := ln.Addr(); got != a1 { + t.Errorf("Addr = %#v; want %#v", got, a1) + } + c, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + if c != c1 { + t.Fatalf("didn't get c1; got %p", c) + } + c, err = ln.Accept() + if err != io.EOF { + t.Errorf("got %v; want EOF", err) + } + if c != nil { + t.Errorf("unexpected non-nil Conn") + } + + // Close before Accept + ln = NewOneConnListener(c1, a1) + ln.Close() + _, err = ln.Accept() + if err != io.EOF { + t.Fatalf("got %v; want EOF", err) + } + + // Implicit addr + ln = NewOneConnListener(c1, nil) + if ln.Addr() == nil { + t.Errorf("nil Addr") + } +}