tsnet: avoid deadlock on close

tsnet.Server.Close was calling listener.Close with the server mutex
held, but the listener close method tries to grab that mutex, resulting
in a deadlock.

Co-authored-by: David Crawshaw <crawshaw@tailscale.com>
Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/7562/head
Maisem Ali 2 years ago committed by Maisem Ali
parent 2b892ad6e7
commit b4d3e2928b

@ -118,6 +118,7 @@ type Server struct {
mu sync.Mutex mu sync.Mutex
listeners map[listenKey]*listener listeners map[listenKey]*listener
dialer *tsdial.Dialer dialer *tsdial.Dialer
closed bool
} }
// Dial connects to the address on the tailnet. // Dial connects to the address on the tailnet.
@ -303,6 +304,11 @@ func (s *Server) Up(ctx context.Context) (*ipnstate.Status, error) {
// //
// It must not be called before or concurrently with Start. // It must not be called before or concurrently with Start.
func (s *Server) Close() error { func (s *Server) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
return fmt.Errorf("tsnet: %w", net.ErrClosed)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
var wg sync.WaitGroup var wg sync.WaitGroup
@ -350,14 +356,12 @@ func (s *Server) Close() error {
s.loopbackListener.Close() s.loopbackListener.Close()
} }
s.mu.Lock()
defer s.mu.Unlock()
for _, ln := range s.listeners { for _, ln := range s.listeners {
ln.Close() ln.closeLocked()
} }
s.listeners = nil
wg.Wait() wg.Wait()
s.closed = true
return nil return nil
} }
@ -1017,10 +1021,11 @@ type listenKey struct {
} }
type listener struct { type listener struct {
s *Server s *Server
keys []listenKey keys []listenKey
addr string addr string
conn chan net.Conn conn chan net.Conn
closed bool // guarded by s.mu
} }
func (ln *listener) Accept() (net.Conn, error) { func (ln *listener) Accept() (net.Conn, error) {
@ -1032,15 +1037,26 @@ func (ln *listener) Accept() (net.Conn, error) {
} }
func (ln *listener) Addr() net.Addr { return addr{ln} } func (ln *listener) Addr() net.Addr { return addr{ln} }
func (ln *listener) Close() error { func (ln *listener) Close() error {
ln.s.mu.Lock() ln.s.mu.Lock()
defer ln.s.mu.Unlock() defer ln.s.mu.Unlock()
return ln.closeLocked()
}
// closeLocked closes the listener.
// It must be called with ln.s.mu held.
func (ln *listener) closeLocked() error {
if ln.closed {
return fmt.Errorf("tsnet: %w", net.ErrClosed)
}
for _, key := range ln.keys { for _, key := range ln.keys {
if v, ok := ln.s.listeners[key]; ok && v == ln { if v, ok := ln.s.listeners[key]; ok && v == ln {
delete(ln.s.listeners, key) delete(ln.s.listeners, key)
} }
} }
close(ln.conn) close(ln.conn)
ln.closed = true
return nil return nil
} }

@ -9,6 +9,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip" "net/netip"
@ -344,3 +345,26 @@ func TestTailscaleIPs(t *testing.T) {
sIp4, upIp4, sIp6, upIp6) sIp4, upIp4, sIp6, upIp6)
} }
} }
// TestListenerCleanup is a regression test to verify that s.Close doesn't
// deadlock if a listener is still open.
func TestListenerCleanup(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
controlURL := startControl(t)
s1, _ := startServer(t, ctx, controlURL, "s1")
ln, err := s1.Listen("tcp", ":8081")
if err != nil {
t.Fatal(err)
}
if err := s1.Close(); err != nil {
t.Fatal(err)
}
if err := ln.Close(); !errors.Is(err, net.ErrClosed) {
t.Fatalf("second ln.Close error: %v, want net.ErrClosed", err)
}
}

Loading…
Cancel
Save