diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index afeeb8323..45651699f 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -1581,15 +1581,10 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, error) { return 0, nil, syscall.EAFNOSUPPORT } for { - n, pAddr, err := c.pconn6.ReadFrom(b) + n, ipp, err := c.pconn6.ReadFromNetaddr(b) if err != nil { return 0, nil, err } - udpAddr := pAddr.(*net.UDPAddr) - ipp, ok := netaddr.FromStdAddr(udpAddr.IP, udpAddr.Port, udpAddr.Zone) - if !ok { - continue - } if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint6); ok { return n, ep, nil } @@ -1604,16 +1599,13 @@ func (c *Conn) derpPacketArrived() bool { // In Tailscale's case, that packet might also arrive via DERP. A DERP packet arrival // aborts the pconn4 read deadline to make it fail. func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) { - var addr net.Addr - var pAddr *net.UDPAddr var ipp netaddr.IPPort - var ippOK bool for { // Drain DERP queues before reading new UDP packets. if c.derpPacketArrived() { goto ReadDERP } - n, addr, err = c.pconn4.ReadFrom(b) + n, ipp, err = c.pconn4.ReadFromNetaddr(b) if err != nil { // If the pconn4 read failed, the likely reason is a DERP reader received // a packet and interrupted us. @@ -1625,11 +1617,6 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) { } return 0, nil, err } - pAddr, _ = addr.(*net.UDPAddr) - ipp, ippOK = netaddr.FromStdAddr(pAddr.IP, pAddr.Port, pAddr.Zone) - if !ippOK { - continue - } if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint4); ok { return n, ep, nil } else { @@ -2743,6 +2730,8 @@ func (c *RebindingUDPConn) Reset(pconn net.PacketConn) { } } +// ReadFromNetaddr reads a packet from c into b. +// It returns the number of bytes copied and the source address. func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { for { c.mu.Lock() @@ -2763,6 +2752,57 @@ func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { } } +// ReadFromNetaddr reads a packet from c into b. +// It returns the number of bytes copied and the return address. +// It is identical to c.ReadFrom, except that it returns a netaddr.IPPort instead of a net.Addr. +// ReadFromNetaddr is designed to work with specific underlying connection types. +// If c's underlying connection returns a non-*net.UPDAddr return address, ReadFromNetaddr will return an error. +// ReadFromNetaddr exists because it removes an allocation per read, +// when c's underlying connection is a net.UDPConn. +func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netaddr.IPPort, err error) { + for { + c.mu.Lock() + pconn := c.pconn + c.mu.Unlock() + + // Optimization: Treat *net.UDPConn specially. + // ReadFromUDP gets partially inlined, avoiding allocating a *net.UDPAddr, + // as long as pAddr itself doesn't escape. + // The non-*net.UDPConn case works, but it allocates. + var pAddr *net.UDPAddr + if udpConn, ok := pconn.(*net.UDPConn); ok { + n, pAddr, err = udpConn.ReadFromUDP(b) + } else { + var addr net.Addr + n, addr, err = pconn.ReadFrom(b) + var ok2 bool + pAddr, ok2 = addr.(*net.UDPAddr) + if !ok2 { + return 0, netaddr.IPPort{}, fmt.Errorf("RebindingUDPConn.ReadFromNetaddr: underlying connection returned address of type %T, want *netaddr.UDPAddr", addr) + } + } + + if err != nil { + c.mu.Lock() + pconn2 := c.pconn + c.mu.Unlock() + + if pconn != pconn2 { + continue + } + } else { + // Convert pAddr to a netaddr.IPPort. + // This prevents pAddr from escaping. + var ok bool + ipp, ok = netaddr.FromStdAddr(pAddr.IP, pAddr.Port, pAddr.Zone) + if !ok { + return 0, netaddr.IPPort{}, errors.New("netaddr.FromStdAddr failed") + } + } + return n, ipp, err + } +} + func (c *RebindingUDPConn) LocalAddr() *net.UDPAddr { c.mu.Lock() defer c.mu.Unlock() diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index ea00fe22e..4890ea70a 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -1512,16 +1512,16 @@ func addTestEndpoint(conn *Conn, sendConn net.PacketConn) (tailcfg.NodeKey, tail return nodeKey, discoKey } -func BenchmarkReceiveFrom(b *testing.B) { - conn := newNonLegacyTestConn(b) - defer conn.Close() +func setUpReceiveFrom(tb testing.TB) (roundTrip func()) { + conn := newNonLegacyTestConn(tb) + tb.Cleanup(func() { conn.Close() }) conn.logf = logger.Discard sendConn, err := net.ListenPacket("udp4", "127.0.0.1:0") if err != nil { - b.Fatal(err) + tb.Fatal(err) } - defer sendConn.Close() + tb.Cleanup(func() { sendConn.Close() }) addTestEndpoint(conn, sendConn) @@ -1530,21 +1530,89 @@ func BenchmarkReceiveFrom(b *testing.B) { for i := range sendBuf { sendBuf[i] = 'x' } - buf := make([]byte, 2<<10) - for i := 0; i < b.N; i++ { + return func() { if _, err := sendConn.WriteTo(sendBuf, dstAddr); err != nil { - b.Fatalf("WriteTo: %v", err) + tb.Fatalf("WriteTo: %v", err) } n, ep, err := conn.ReceiveIPv4(buf) if err != nil { - b.Fatal(err) + tb.Fatal(err) } _ = n _ = ep } } +// goMajorVersion reports the major Go version and whether it is a Tailscale fork. +// If parsing fails, goMajorVersion returns 0, false. +func goMajorVersion(s string) (version int, isTS bool) { + if !strings.HasPrefix(s, "go1.") { + return 0, false + } + mm := s[len("go1."):] + var major, rest string + for _, sep := range []string{".", "rc", "beta"} { + i := strings.Index(mm, sep) + if i > 0 { + major, rest = mm[:i], mm[i:] + break + } + } + if major == "" { + major = mm + } + n, err := strconv.Atoi(major) + if err != nil { + return 0, false + } + return n, strings.Contains(rest, "ts") +} + +func TestGoMajorVersion(t *testing.T) { + tests := []struct { + version string + wantN int + wantTS bool + }{ + {"go1.15.8", 15, false}, + {"go1.16rc1", 16, false}, + {"go1.16rc1", 16, false}, + {"go1.15.5-ts3bd89195a3", 15, true}, + {"go1.15", 15, false}, + } + + for _, tt := range tests { + n, ts := goMajorVersion(tt.version) + if tt.wantN != n || tt.wantTS != ts { + t.Errorf("goMajorVersion(%s) = %v, %v, want %v, %v", tt.version, n, ts, tt.wantN, tt.wantTS) + } + } +} + +func TestReceiveFromAllocs(t *testing.T) { + // Go 1.16 and before: allow 3 allocs. + // Go Tailscale fork, Go 1.17+: only allow 2 allocs. + major, ts := goMajorVersion(runtime.Version()) + maxAllocs := 3 + if major >= 17 || ts { + maxAllocs = 2 + } + t.Logf("allowing %d allocs for Go version %q", maxAllocs, runtime.Version()) + roundTrip := setUpReceiveFrom(t) + avg := int(testing.AllocsPerRun(100, roundTrip)) + if avg > maxAllocs { + t.Fatalf("expected %d allocs in ReceiveFrom, got %v", maxAllocs, avg) + } +} + +func BenchmarkReceiveFrom(b *testing.B) { + roundTrip := setUpReceiveFrom(b) + for i := 0; i < b.N; i++ { + roundTrip() + } +} + func BenchmarkReceiveFrom_Native(b *testing.B) { recvConn, err := net.ListenPacket("udp4", "127.0.0.1:0") if err != nil {