diff --git a/net/udprelay/server.go b/net/udprelay/server.go index e7ca24960..a3df60143 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -15,6 +15,7 @@ import ( "fmt" "net" "net/netip" + "runtime" "slices" "strconv" "sync" @@ -66,10 +67,10 @@ type Server struct { bindLifetime time.Duration steadyStateLifetime time.Duration bus *eventbus.Bus - uc4 batching.Conn // always non-nil - uc4Port uint16 // always nonzero - uc6 batching.Conn // may be nil if IPv6 bind fails during initialization - uc6Port uint16 // may be zero if IPv6 bind fails during initialization + uc4 []batching.Conn // length is always nonzero + uc4Port uint16 // always nonzero + uc6 []batching.Conn // length may be zero if udp6 bind fails + uc6Port uint16 // zero if len(uc6) is zero, otherwise nonzero closeOnce sync.Once wg sync.WaitGroup closeCh chan struct{} @@ -337,37 +338,51 @@ func NewServer(logf logger.Logf, port uint16, onlyStaticAddrPorts bool) (s *Serv Logf: logger.WithPrefix(logf, "netcheck: "), SendPacket: func(b []byte, addrPort netip.AddrPort) (int, error) { if addrPort.Addr().Is4() { - return s.uc4.WriteToUDPAddrPort(b, addrPort) - } else if s.uc6 != nil { - return s.uc6.WriteToUDPAddrPort(b, addrPort) + return s.uc4[0].WriteToUDPAddrPort(b, addrPort) + } else if len(s.uc6) > 0 { + return s.uc6[0].WriteToUDPAddrPort(b, addrPort) } else { return 0, errors.New("IPv6 socket is not bound") } }, } - err = s.listenOn(port) + err = s.bindSockets(port) if err != nil { return nil, err } + s.startPacketReaders() if !s.onlyStaticAddrPorts { s.wg.Add(1) go s.addrDiscoveryLoop() } - s.wg.Add(1) - go s.packetReadLoop(s.uc4, s.uc6, true) - if s.uc6 != nil { - s.wg.Add(1) - go s.packetReadLoop(s.uc6, s.uc4, false) - } s.wg.Add(1) go s.endpointGCLoop() return s, nil } +func (s *Server) startPacketReaders() { + for i, uc := range s.uc4 { + var other batching.Conn + if len(s.uc6) > 0 { + other = s.uc6[min(len(s.uc6)-1, i)] + } + s.wg.Add(1) + go s.packetReadLoop(uc, other, true) + } + for i, uc := range s.uc6 { + var other batching.Conn + if len(s.uc4) > 0 { + other = s.uc4[min(len(s.uc4)-1, i)] + } + s.wg.Add(1) + go s.packetReadLoop(uc, other, false) + } +} + func (s *Server) addrDiscoveryLoop() { defer s.wg.Done() @@ -514,70 +529,108 @@ func trySetUDPSocketOptions(pconn nettype.PacketConn, logf logger.Logf) { } } -// listenOn binds an IPv4 and IPv6 socket to port. We consider it successful if -// we manage to bind the IPv4 socket. +// bindSockets binds udp4 and udp6 sockets to desiredPort. We consider it +// successful if we manage to bind at least one udp4 socket. Multiple sockets +// may be bound per address family, e.g. SO_REUSEPORT, depending on platform. // -// The requested port may be zero, in which case port selection is left up to -// the host networking stack. We make no attempt to bind a consistent port -// across IPv4 and IPv6 if the requested port is zero. +// desiredPort may be zero, in which case port selection is left up to the host +// networking stack. We make no attempt to bind a consistent port between udp4 +// and udp6 if the requested port is zero, but a consistent port is used +// across multiple sockets within a given address family if SO_REUSEPORT is +// supported. // // TODO: make these "re-bindable" in similar fashion to magicsock as a means to // deal with EDR software closing them. http://go/corp/30118. We could re-use // [magicsock.RebindingConn], which would also remove the need for // [singlePacketConn], as [magicsock.RebindingConn] also handles fallback to // single packet syscall operations. -func (s *Server) listenOn(port uint16) error { +func (s *Server) bindSockets(desiredPort uint16) error { + // maxSocketsPerAF is a conservative starting point, but is somewhat + // arbitrary. + const maxSocketsPerAF = 16 + listenConfig := &net.ListenConfig{ + Control: listenControl, + } for _, network := range []string{"udp4", "udp6"} { - uc, err := net.ListenUDP(network, &net.UDPAddr{Port: int(port)}) - if err != nil { + SocketsLoop: + for i := 0; i < maxSocketsPerAF && i < runtime.NumCPU(); i++ { + if i > 0 { + // Use a consistent port per address family if the user-supplied + // port was zero, and we are binding multiple sockets. + if network == "udp4" { + desiredPort = s.uc4Port + } else { + desiredPort = s.uc6Port + } + } + uc, boundPort, err := s.bindSocketTo(listenConfig, network, desiredPort) + if err != nil { + switch { + case i == 0 && network == "udp4": + // At least one udp4 socket is required. + return err + case i == 0 && network == "udp6": + // A udp6 socket is not required. + s.logf("ignoring IPv6 bind failure: %v", err) + break SocketsLoop + default: // i > 0 + // Reusable sockets are not required. + s.logf("ignoring reusable (index=%d network=%v) socket bind failure: %v", i, network, err) + break SocketsLoop + } + } + pc := batching.TryUpgradeToConn(uc, network, batching.IdealBatchSize) + bc, ok := pc.(batching.Conn) + if !ok { + bc = &singlePacketConn{uc} + } if network == "udp4" { - return err + s.uc4 = append(s.uc4, bc) + s.uc4Port = boundPort } else { - s.logf("ignoring IPv6 bind failure: %v", err) - break + s.uc6 = append(s.uc6, bc) + s.uc6Port = boundPort } - } - trySetUDPSocketOptions(uc, s.logf) - // TODO: set IP_PKTINFO sockopt - _, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String()) - if err != nil { - uc.Close() - if s.uc4 != nil { - s.uc4.Close() - } - return err - } - portUint, err := strconv.ParseUint(boundPortStr, 10, 16) - if err != nil { - uc.Close() - if s.uc4 != nil { - s.uc4.Close() + if !isReusableSocket(uc) { + break } - return err - } - pc := batching.TryUpgradeToConn(uc, network, batching.IdealBatchSize) - bc, ok := pc.(batching.Conn) - if !ok { - bc = &singlePacketConn{uc} - } - if network == "udp4" { - s.uc4 = bc - s.uc4Port = uint16(portUint) - } else { - s.uc6 = bc - s.uc6Port = uint16(portUint) } - s.logf("listening on %s:%d", network, portUint) + } + s.logf("listening on udp4:%d sockets=%d", s.uc4Port, len(s.uc4)) + if len(s.uc6) > 0 { + s.logf("listening on udp6:%d sockets=%d", s.uc6Port, len(s.uc6)) } return nil } +func (s *Server) bindSocketTo(listenConfig *net.ListenConfig, network string, port uint16) (*net.UDPConn, uint16, error) { + lis, err := listenConfig.ListenPacket(context.Background(), network, fmt.Sprintf(":%d", port)) + if err != nil { + return nil, 0, err + } + uc := lis.(*net.UDPConn) + trySetUDPSocketOptions(uc, s.logf) + _, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String()) + if err != nil { + uc.Close() + return nil, 0, err + } + portUint, err := strconv.ParseUint(boundPortStr, 10, 16) + if err != nil { + uc.Close() + return nil, 0, err + } + return uc, uint16(portUint), nil +} + // Close closes the server. func (s *Server) Close() error { s.closeOnce.Do(func() { - s.uc4.Close() - if s.uc6 != nil { - s.uc6.Close() + for _, uc4 := range s.uc4 { + uc4.Close() + } + for _, uc6 := range s.uc6 { + uc6.Close() } close(s.closeCh) s.wg.Wait() diff --git a/net/udprelay/server_linux.go b/net/udprelay/server_linux.go new file mode 100644 index 000000000..009ec8cc8 --- /dev/null +++ b/net/udprelay/server_linux.go @@ -0,0 +1,35 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package udprelay + +import ( + "net" + "syscall" + + "golang.org/x/sys/unix" +) + +func listenControl(_ string, _ string, c syscall.RawConn) error { + c.Control(func(fd uintptr) { + unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + }) + return nil +} + +func isReusableSocket(uc *net.UDPConn) bool { + rc, err := uc.SyscallConn() + if err != nil { + return false + } + var reusable bool + rc.Control(func(fd uintptr) { + val, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT) + if err == nil && val == 1 { + reusable = true + } + }) + return reusable +} diff --git a/net/udprelay/server_notlinux.go b/net/udprelay/server_notlinux.go new file mode 100644 index 000000000..042a6dd68 --- /dev/null +++ b/net/udprelay/server_notlinux.go @@ -0,0 +1,19 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package udprelay + +import ( + "net" + "syscall" +) + +func listenControl(_ string, _ string, _ syscall.RawConn) error { + return nil +} + +func isReusableSocket(*net.UDPConn) bool { + return false +}