diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index f2573f504..bd587c7b7 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -19,6 +19,7 @@ import ( "math" "net" "net/http" + "net/netip" "os" "path/filepath" "regexp" @@ -356,7 +357,8 @@ func serverSTUNListener(ctx context.Context, pc *net.UDPConn) { } else { stunIPv6.Add(1) } - res := stun.Response(txid, ua.IP, uint16(ua.Port)) + addr, _ := netip.AddrFromSlice(ua.IP) + res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(ua.Port))) _, err = pc.WriteTo(res, ua) if err != nil { stunWriteError.Add(1) diff --git a/cmd/derpprobe/derpprobe.go b/cmd/derpprobe/derpprobe.go index 07660f8c7..9175f001a 100644 --- a/cmd/derpprobe/derpprobe.go +++ b/cmd/derpprobe/derpprobe.go @@ -360,7 +360,7 @@ func probeUDP(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode) (la time.Sleep(100 * time.Millisecond) continue } - txBack, _, _, err := stun.ParseResponse(buf[:n]) + txBack, _, err := stun.ParseResponse(buf[:n]) if err != nil { return 0, fmt.Errorf("parsing STUN response from %v: %v", ip, err) } diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index ec9919f7f..49cdcf514 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -261,7 +261,7 @@ func (c *Client) ReceiveSTUNPacket(pkt []byte, src netip.AddrPort) { return } - tx, addr, port, err := stun.ParseResponse(pkt) + tx, addrPort, err := stun.ParseResponse(pkt) if err != nil { if _, err := stun.ParseBindingRequest(pkt); err == nil { // This was probably our own netcheck hairpin @@ -279,10 +279,7 @@ func (c *Client) ReceiveSTUNPacket(pkt []byte, src netip.AddrPort) { } rs.mu.Unlock() if ok { - ta := net.TCPAddr{IP: addr, Port: int(port)} - if ipp := netaddr.Unmap(ta.AddrPort()); ipp.IsValid() { - onDone(ipp) - } + onDone(addrPort) } } diff --git a/net/stun/stun.go b/net/stun/stun.go index e6077c306..834b9e8c9 100644 --- a/net/stun/stun.go +++ b/net/stun/stun.go @@ -11,6 +11,7 @@ import ( "errors" "hash/crc32" "net" + "net/netip" ) const ( @@ -151,20 +152,18 @@ func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error { } // Response generates a binding response. -func Response(txID TxID, ip net.IP, port uint16) []byte { - if ip4 := ip.To4(); ip4 != nil { - ip = ip4 - } +func Response(txID TxID, addrPort netip.AddrPort) []byte { + addr := addrPort.Addr() + var fam byte - switch len(ip) { - case net.IPv4len: + if addr.Is4() { fam = 1 - case net.IPv6len: + } else if addr.Is6() { fam = 2 - default: + } else { return nil } - attrsLen := 8 + len(ip) + attrsLen := 8 + addr.BitLen()/8 b := make([]byte, 0, headerLen+attrsLen) // Header @@ -175,12 +174,13 @@ func Response(txID TxID, ip net.IP, port uint16) []byte { // Attributes (well, one) b = appendU16(b, attrXorMappedAddress) - b = appendU16(b, uint16(4+len(ip))) + b = appendU16(b, uint16(4+addr.BitLen()/8)) b = append(b, 0, // unused byte fam) - b = appendU16(b, port^0x2112) // first half of magicCookie - for i, o := range []byte(ip) { + b = appendU16(b, addrPort.Port()^0x2112) // first half of magicCookie + ipa := addr.As16() + for i, o := range ipa[16-addr.BitLen()/8:] { if i < 4 { b = append(b, o^magicCookie[i]) } else { @@ -192,25 +192,23 @@ func Response(txID TxID, ip net.IP, port uint16) []byte { // ParseResponse parses a successful binding response STUN packet. // The IP address is extracted from the XOR-MAPPED-ADDRESS attribute. -// The returned addr slice is owned by the caller and does not alias b. -func ParseResponse(b []byte) (tID TxID, addr []byte, port uint16, err error) { +func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) { if !Is(b) { - return tID, nil, 0, ErrNotSTUN + return tID, netip.AddrPort{}, ErrNotSTUN } copy(tID[:], b[8:8+len(tID)]) if b[0] != 0x01 || b[1] != 0x01 { - return tID, nil, 0, ErrNotSuccessResponse + return tID, netip.AddrPort{}, ErrNotSuccessResponse } attrsLen := int(binary.BigEndian.Uint16(b[2:4])) b = b[headerLen:] // remove STUN header if attrsLen > len(b) { - return tID, nil, 0, ErrMalformedAttrs + return tID, netip.AddrPort{}, ErrMalformedAttrs } else if len(b) > attrsLen { b = b[:attrsLen] // trim trailing packet bytes } - var addr6, fallbackAddr, fallbackAddr6 []byte - var port6, fallbackPort, fallbackPort6 uint16 + var addr6, fallbackAddr, fallbackAddr6 netip.AddrPort // Read through the attributes. // The the addr+port reported by XOR-MAPPED-ADDRESS @@ -225,9 +223,9 @@ func ParseResponse(b []byte) (tID TxID, addr []byte, port uint16, err error) { return err } if len(a) == 16 { - addr6, port6 = a, p + addr6 = netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)([]byte(a))), p) } else { - addr, port = a, p + addr = netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)([]byte(a))), p) } case attrMappedAddress: a, p, err := mappedAddress(attr) @@ -235,30 +233,30 @@ func ParseResponse(b []byte) (tID TxID, addr []byte, port uint16, err error) { return ErrMalformedAttrs } if len(a) == 16 { - fallbackAddr6, fallbackPort6 = a, p + fallbackAddr6 = netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)([]byte(a))), p) } else { - fallbackAddr, fallbackPort = a, p + fallbackAddr = netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)([]byte(a))), p) } } return nil }); err != nil { - return TxID{}, nil, 0, err + return TxID{}, netip.AddrPort{}, err } - if addr != nil { - return tID, addr, port, nil + if addr.IsValid() { + return tID, addr, nil } - if fallbackAddr != nil { - return tID, append([]byte{}, fallbackAddr...), fallbackPort, nil + if fallbackAddr.IsValid() { + return tID, fallbackAddr, nil } - if addr6 != nil { - return tID, addr6, port6, nil + if addr6.IsValid() { + return tID, addr6, nil } - if fallbackAddr6 != nil { - return tID, append([]byte{}, fallbackAddr6...), fallbackPort6, nil + if fallbackAddr6.IsValid() { + return tID, fallbackAddr6, nil } - return tID, nil, 0, ErrMalformedAttrs + return tID, netip.AddrPort{}, ErrMalformedAttrs } func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) { diff --git a/net/stun/stun_fuzzer.go b/net/stun/stun_fuzzer.go index a12a03ad4..42a818cd8 100644 --- a/net/stun/stun_fuzzer.go +++ b/net/stun/stun_fuzzer.go @@ -7,7 +7,7 @@ package stun func FuzzStunParser(data []byte) int { - _, _, _, _ = ParseResponse(data) + _, _, _ = ParseResponse(data) _, _ = ParseBindingRequest(data) return 1 diff --git a/net/stun/stun_test.go b/net/stun/stun_test.go index 81b5a8244..4e3f78a0b 100644 --- a/net/stun/stun_test.go +++ b/net/stun/stun_test.go @@ -7,7 +7,7 @@ package stun_test import ( "bytes" "fmt" - "net" + "net/netip" "testing" "tailscale.com/net/stun" @@ -25,7 +25,7 @@ var responseTests = []struct { name string data []byte wantTID []byte - wantAddr []byte + wantAddr netip.Addr wantPort uint16 }{ { @@ -40,7 +40,7 @@ var responseTests = []struct { 0x23, 0x60, 0xb1, 0x1e, 0x3e, 0xc6, 0x8f, 0xfa, 0x93, 0xe0, 0x80, 0x07, }, - wantAddr: []byte{72, 69, 33, 45}, + wantAddr: netip.AddrFrom4([4]byte{72, 69, 33, 45}), wantPort: 59028, }, { @@ -55,7 +55,7 @@ var responseTests = []struct { 0xf9, 0xf1, 0x21, 0xcb, 0xde, 0x7d, 0x7c, 0x75, 0x92, 0x3c, 0xe2, 0x71, }, - wantAddr: []byte{72, 69, 33, 45}, + wantAddr: netip.AddrFrom4([4]byte{72, 69, 33, 45}), wantPort: 59029, }, { @@ -77,7 +77,7 @@ var responseTests = []struct { 0x48, 0x2e, 0xb6, 0x47, 0x15, 0xe8, 0xb2, 0x8e, 0xae, 0xad, 0x64, 0x44, }, - wantAddr: []byte{72, 69, 33, 45}, + wantAddr: netip.AddrFrom4([4]byte{72, 69, 33, 45}), wantPort: 58539, }, { @@ -95,7 +95,7 @@ var responseTests = []struct { 0x7e, 0x57, 0x96, 0x68, 0x29, 0xf4, 0x44, 0x60, 0x9d, 0x1d, 0xea, 0xa6, }, - wantAddr: []byte{72, 69, 33, 45}, + wantAddr: netip.AddrFrom4([4]byte{72, 69, 33, 45}), wantPort: 59859, }, { @@ -114,7 +114,7 @@ var responseTests = []struct { 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c, 0x4f, 0x3e, 0x30, 0x8e, }, - wantAddr: []byte{127, 0, 0, 1}, + wantAddr: netip.AddrFrom4([4]byte{127, 0, 0, 1}), wantPort: 61300, }, { @@ -137,7 +137,7 @@ var responseTests = []struct { 6, 245, 102, 133, 210, 138, 243, 230, 156, 227, 65, 226, }, - wantAddr: net.ParseIP("2602:d1:b4cf:c100:38b2:31ff:feef:96f6"), + wantAddr: netip.MustParseAddr("2602:d1:b4cf:c100:38b2:31ff:feef:96f6"), wantPort: 37070, }, @@ -156,7 +156,7 @@ var responseTests = []struct { 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c, 0x4f, 0x3e, 0x30, 0x8e, }, - wantAddr: []byte{127, 0, 0, 1}, + wantAddr: netip.AddrFrom4([4]byte{127, 0, 0, 1}), wantPort: 61300, }, { @@ -172,7 +172,7 @@ var responseTests = []struct { 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c, 0x4f, 0x3e, 0x30, 0x8e, }, - wantAddr: []byte{127, 0, 0, 1}, + wantAddr: netip.AddrFrom4([4]byte{127, 0, 0, 1}), wantPort: 61300, }, } @@ -180,7 +180,7 @@ var responseTests = []struct { func TestParseResponse(t *testing.T) { subtest := func(t *testing.T, i int) { test := responseTests[i] - tID, addr, port, err := stun.ParseResponse(test.data) + tID, addrPort, err := stun.ParseResponse(test.data) if err != nil { t.Fatal(err) } @@ -188,11 +188,11 @@ func TestParseResponse(t *testing.T) { if !bytes.Equal(tID[:], test.wantTID) { t.Errorf("tid=%v, want %v", tID[:], test.wantTID) } - if !bytes.Equal(addr, test.wantAddr) { - t.Errorf("addr=%v (%v), want %v", addr, net.IP(addr), test.wantAddr) + if addrPort.Addr().Compare(test.wantAddr) != 0 { + t.Errorf("addr=%v, want %v", addrPort.Addr(), test.wantAddr) } - if port != test.wantPort { - t.Errorf("port=%d, want %d", port, test.wantPort) + if addrPort.Port() != test.wantPort { + t.Errorf("port=%d, want %d", addrPort.Port(), test.wantPort) } } for i, test := range responseTests { @@ -249,17 +249,17 @@ func TestResponse(t *testing.T) { } tests := []struct { tx stun.TxID - ip net.IP + addr netip.Addr port uint16 }{ - {tx: txN(1), ip: net.ParseIP("1.2.3.4").To4(), port: 254}, - {tx: txN(2), ip: net.ParseIP("1.2.3.4").To4(), port: 257}, - {tx: txN(3), ip: net.ParseIP("1::4"), port: 254}, - {tx: txN(4), ip: net.ParseIP("1::4"), port: 257}, + {tx: txN(1), addr: netip.MustParseAddr("1.2.3.4"), port: 254}, + {tx: txN(2), addr: netip.MustParseAddr("1.2.3.4"), port: 257}, + {tx: txN(3), addr: netip.MustParseAddr("1::4"), port: 254}, + {tx: txN(4), addr: netip.MustParseAddr("1::4"), port: 257}, } for _, tt := range tests { - res := stun.Response(tt.tx, tt.ip, tt.port) - tx2, ip2, port2, err := stun.ParseResponse(res) + res := stun.Response(tt.tx, netip.AddrPortFrom(tt.addr, tt.port)) + tx2, addr2, err := stun.ParseResponse(res) if err != nil { t.Errorf("TX %x: error: %v", tt.tx, err) continue @@ -267,11 +267,11 @@ func TestResponse(t *testing.T) { if tt.tx != tx2 { t.Errorf("TX %x: got TxID = %v", tt.tx, tx2) } - if !bytes.Equal([]byte(tt.ip), ip2) { - t.Errorf("TX %x: ip = %v (%v); want %v (%v)", tt.tx, ip2, net.IP(ip2), []byte(tt.ip), tt.ip) + if tt.addr.Compare(addr2.Addr()) != 0 { + t.Errorf("TX %x: addr = %v; want %v", tt.tx, addr2.Addr(), tt.addr) } - if tt.port != port2 { - t.Errorf("TX %x: port = %v; want %v", tt.tx, port2, tt.port) + if tt.port != addr2.Port() { + t.Errorf("TX %x: port = %v; want %v", tt.tx, addr2.Port(), tt.port) } } } diff --git a/net/stun/stuntest/stuntest.go b/net/stun/stuntest/stuntest.go index 466d0c7e3..55ad146b4 100644 --- a/net/stun/stuntest/stuntest.go +++ b/net/stun/stuntest/stuntest.go @@ -84,7 +84,8 @@ func runSTUN(t testing.TB, pc net.PacketConn, stats *stunStats, done chan<- stru } stats.mu.Unlock() - res := stun.Response(txid, ua.IP, uint16(ua.Port)) + nia, _ := netip.AddrFromSlice(ua.IP) + res := stun.Response(txid, netip.AddrPortFrom(nia, uint16(ua.Port))) if _, err := pc.WriteTo(res, addr); err != nil { t.Logf("STUN server write failed: %v", err) }