diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index b5bfe9496..9916604e1 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -139,6 +139,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/util/pidowner from tailscale.com/ipn/ipnserver tailscale.com/util/racebuild from tailscale.com/logpolicy tailscale.com/util/systemd from tailscale.com/control/controlclient+ + tailscale.com/util/uniq from tailscale.com/wgengine/magicsock tailscale.com/util/winutil from tailscale.com/logpolicy+ tailscale.com/version from tailscale.com/cmd/tailscaled+ tailscale.com/version/distro from tailscale.com/control/controlclient+ diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index e1eff1b24..e6e722ef2 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -52,6 +52,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/types/nettype" "tailscale.com/types/wgkey" + "tailscale.com/util/uniq" "tailscale.com/version" "tailscale.com/wgengine/monitor" "tailscale.com/wgengine/wgcfg" @@ -2585,11 +2586,11 @@ func (c *Conn) ReSTUN(why string) { } func (c *Conn) initialBind() error { - if err := c.bind1(&c.pconn4, "udp4"); err != nil { - return err + if err := c.bindSocket(&c.pconn4, "udp4"); err != nil { + return fmt.Errorf("magicsock: initialBind IPv4 failed: %w", err) } c.portMapper.SetLocalPort(c.LocalPort()) - if err := c.bind1(&c.pconn6, "udp6"); err != nil { + if err := c.bindSocket(&c.pconn6, "udp6"); err != nil { c.logf("magicsock: ignoring IPv6 bind failure: %v", err) } return nil @@ -2605,66 +2606,82 @@ func (c *Conn) listenPacket(network, host string, port uint16) (net.PacketConn, return netns.Listener().ListenPacket(ctx, network, addr) } -func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error { +// bindSocket initializes rucPtr if necessary and binds a UDP socket to it. +// Network indicates the UDP socket type; it must be "udp4" or "udp6". +// If rucPtr had an existing UDP socket bound, it closes that socket. +// The caller is responsible for informing the portMapper of any changes. +func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string) error { host := "" if inTest() && !c.simulatedNetwork { host = "127.0.0.1" - if which == "udp6" { + if network == "udp6" { host = "::1" } } - pc, err := c.listenPacket(which, host, c.port) - if err != nil { - c.logf("magicsock: bind(%s/%v): %v", which, c.port, err) - return fmt.Errorf("magicsock: bind: %s/%d: %v", which, c.port, err) - } - if *ruc == nil { - *ruc = new(RebindingUDPConn) - } - (*ruc).Reset(pc) - return nil -} -// Rebind closes and re-binds the UDP sockets. -// It should be followed by a call to ReSTUN. -func (c *Conn) Rebind() { - host := "" - if inTest() && !c.simulatedNetwork { - host = "127.0.0.1" + if *rucPtr == nil { + *rucPtr = new(RebindingUDPConn) } + ruc := *rucPtr + + // Hold the ruc lock the entire time, so that the close+bind is atomic + // from the perspective of ruc receive functions. + ruc.mu.Lock() + defer ruc.mu.Unlock() + // Build a list of preferred ports. + // Best is the port that the user requested. + // Second best is the port that is currently in use. + // If those fail, fall back to 0. + var ports []uint16 if c.port != 0 { - c.pconn4.mu.Lock() - oldPort := c.pconn4.localAddrLocked().Port - if err := c.pconn4.pconn.Close(); err != nil { - c.logf("magicsock: link change close failed: %v", err) - } - packetConn, err := c.listenPacket("udp4", host, c.port) - if err != nil { - c.logf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.port, err) - packetConn, err = c.listenPacket("udp4", host, 0) - if err != nil { - c.logf("magicsock: link change failed to bind random port: %v", err) - c.pconn4.mu.Unlock() - return - } - newPort := packetConn.LocalAddr().(*net.UDPAddr).Port - c.logf("magicsock: link change rebound port: from %v to %v (failed to get %v)", oldPort, newPort, c.port) - } else { - c.logf("magicsock: link change rebound port from %d to %d", oldPort, c.port) + ports = append(ports, c.port) + } + if ruc.pconn != nil { + curPort := uint16(ruc.localAddrLocked().Port) + ports = append(ports, curPort) + } + ports = append(ports, 0) + // Remove duplicates. (All duplicates are consecutive.) + uniq.ModifySlice(&ports, func(i, j int) bool { return ports[i] == ports[j] }) + + var pconn net.PacketConn + for _, port := range ports { + // Close the existing conn, in case it is sitting on the port we want. + err := ruc.closeLocked() + if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, errNilPConn) { + c.logf("magicsock: bindSocket %v close failed: %v", network, err) } - c.pconn4.pconn = packetConn - c.pconn4.mu.Unlock() - } else { - c.logf("magicsock: link change, binding new port") - packetConn, err := c.listenPacket("udp4", host, 0) + // Open a new one with the desired port. + pconn, err = c.listenPacket(network, host, port) if err != nil { - c.logf("magicsock: link change failed to bind new port: %v", err) - return + c.logf("magicsock: unable to bind %v port %d: %v", network, port, err) + continue } - c.pconn4.Reset(packetConn) + // Success. + ruc.pconn = pconn + return nil + } + + // Failed to bind, including on port 0 (!). + // Set pconn to a dummy conn whose reads block until closed. + // This keeps the receive funcs alive for a future in which + // we get a link change and we can try binding again. + ruc.pconn = newBlockForeverConn() + return fmt.Errorf("failed to bind any ports (tried %v)", ports) +} + +// Rebind closes and re-binds the UDP sockets. +// It should be followed by a call to ReSTUN. +func (c *Conn) Rebind() { + if err := c.bindSocket(&c.pconn4, "udp4"); err != nil { + c.logf("magicsock: Rebind IPv4 failed: %w", err) + return } c.portMapper.SetLocalPort(c.LocalPort()) + if err := c.bindSocket(&c.pconn6, "udp6"); err != nil { + c.logf("magicsock: Rebind ignoring IPv6 bind failure: %v", err) + } c.mu.Lock() c.closeAllDerpLocked("rebind") @@ -2764,17 +2781,6 @@ func (c *RebindingUDPConn) currentConn() net.PacketConn { return c.pconn } -func (c *RebindingUDPConn) Reset(pconn net.PacketConn) { - c.mu.Lock() - old := c.pconn - c.pconn = pconn - c.mu.Unlock() - - if old != nil { - old.Close() - } -} - // ReadFrom reads a packet from c into b. // It returns the number of bytes copied and the source address. func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { @@ -2844,9 +2850,20 @@ func (c *RebindingUDPConn) localAddrLocked() *net.UDPAddr { return c.pconn.LocalAddr().(*net.UDPAddr) } +// errNilPConn is returned by RebindingUDPConn.Close when there is no current pconn. +// It is for internal use only and should not be returned to users. +var errNilPConn = errors.New("nil pconn") + func (c *RebindingUDPConn) Close() error { c.mu.Lock() defer c.mu.Unlock() + return c.closeLocked() +} + +func (c *RebindingUDPConn) closeLocked() error { + if c.pconn == nil { + return errNilPConn + } return c.pconn.Close() } @@ -2890,6 +2907,52 @@ func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { } } +func newBlockForeverConn() *blockForeverConn { + c := new(blockForeverConn) + c.cond = sync.NewCond(&c.mu) + return c +} + +// blockForeverConn is a net.PacketConn whose reads block until it is closed. +type blockForeverConn struct { + mu sync.Mutex + cond *sync.Cond + closed bool +} + +func (c *blockForeverConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + c.mu.Lock() + for !c.closed { + c.cond.Wait() + } + c.mu.Unlock() + return 0, nil, net.ErrClosed +} + +func (c *blockForeverConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + // Silently drop writes. + return len(p), nil +} + +func (c *blockForeverConn) LocalAddr() net.Addr { + // Return a *net.UDPAddr because lots of code assumes that it will. + return new(net.UDPAddr) +} + +func (c *blockForeverConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return net.ErrClosed + } + c.closed = true + return nil +} + +func (c *blockForeverConn) SetDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SetReadDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SetWriteDeadline(t time.Time) error { return errors.New("unimplemented") } + // simpleDur rounds d such that it stringifies to something short. func simpleDur(d time.Duration) time.Duration { if d < time.Second {