net/stun: convert to use net/netip.AddrPort

Convert ParseResponse and Response to use netip.AddrPort instead of
net.IP and separate port.

Fixes #5281

Signed-off-by: Kris Brandow <kris.brandow@gmail.com>
pull/5379/head
Kris Brandow 2 years ago
parent c3270af52b
commit 8f38afbf8e

@ -19,6 +19,7 @@ import (
"math" "math"
"net" "net"
"net/http" "net/http"
"net/netip"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
@ -356,7 +357,8 @@ func serverSTUNListener(ctx context.Context, pc *net.UDPConn) {
} else { } else {
stunIPv6.Add(1) stunIPv6.Add(1)
} }
res := stun.Response(txid, ua.IP, uint16(ua.Port)) addr, _ := netip.AddrFromSlice(ua.IP)
res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(ua.Port)))
_, err = pc.WriteTo(res, ua) _, err = pc.WriteTo(res, ua)
if err != nil { if err != nil {
stunWriteError.Add(1) stunWriteError.Add(1)

@ -360,7 +360,7 @@ func probeUDP(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode) (la
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
continue continue
} }
txBack, _, _, err := stun.ParseResponse(buf[:n]) txBack, _, err := stun.ParseResponse(buf[:n])
if err != nil { if err != nil {
return 0, fmt.Errorf("parsing STUN response from %v: %v", ip, err) return 0, fmt.Errorf("parsing STUN response from %v: %v", ip, err)
} }

@ -261,7 +261,7 @@ func (c *Client) ReceiveSTUNPacket(pkt []byte, src netip.AddrPort) {
return return
} }
tx, addr, port, err := stun.ParseResponse(pkt) tx, addrPort, err := stun.ParseResponse(pkt)
if err != nil { if err != nil {
if _, err := stun.ParseBindingRequest(pkt); err == nil { if _, err := stun.ParseBindingRequest(pkt); err == nil {
// This was probably our own netcheck hairpin // This was probably our own netcheck hairpin
@ -279,10 +279,7 @@ func (c *Client) ReceiveSTUNPacket(pkt []byte, src netip.AddrPort) {
} }
rs.mu.Unlock() rs.mu.Unlock()
if ok { if ok {
ta := net.TCPAddr{IP: addr, Port: int(port)} onDone(addrPort)
if ipp := netaddr.Unmap(ta.AddrPort()); ipp.IsValid() {
onDone(ipp)
}
} }
} }

@ -11,6 +11,7 @@ import (
"errors" "errors"
"hash/crc32" "hash/crc32"
"net" "net"
"net/netip"
) )
const ( const (
@ -151,20 +152,18 @@ func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error {
} }
// Response generates a binding response. // Response generates a binding response.
func Response(txID TxID, ip net.IP, port uint16) []byte { func Response(txID TxID, addrPort netip.AddrPort) []byte {
if ip4 := ip.To4(); ip4 != nil { addr := addrPort.Addr()
ip = ip4
}
var fam byte var fam byte
switch len(ip) { if addr.Is4() {
case net.IPv4len:
fam = 1 fam = 1
case net.IPv6len: } else if addr.Is6() {
fam = 2 fam = 2
default: } else {
return nil return nil
} }
attrsLen := 8 + len(ip) attrsLen := 8 + addr.BitLen()/8
b := make([]byte, 0, headerLen+attrsLen) b := make([]byte, 0, headerLen+attrsLen)
// Header // Header
@ -175,12 +174,13 @@ func Response(txID TxID, ip net.IP, port uint16) []byte {
// Attributes (well, one) // Attributes (well, one)
b = appendU16(b, attrXorMappedAddress) b = appendU16(b, attrXorMappedAddress)
b = appendU16(b, uint16(4+len(ip))) b = appendU16(b, uint16(4+addr.BitLen()/8))
b = append(b, b = append(b,
0, // unused byte 0, // unused byte
fam) fam)
b = appendU16(b, port^0x2112) // first half of magicCookie b = appendU16(b, addrPort.Port()^0x2112) // first half of magicCookie
for i, o := range []byte(ip) { ipa := addr.As16()
for i, o := range ipa[16-addr.BitLen()/8:] {
if i < 4 { if i < 4 {
b = append(b, o^magicCookie[i]) b = append(b, o^magicCookie[i])
} else { } else {
@ -192,25 +192,23 @@ func Response(txID TxID, ip net.IP, port uint16) []byte {
// ParseResponse parses a successful binding response STUN packet. // ParseResponse parses a successful binding response STUN packet.
// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute. // The IP address is extracted from the XOR-MAPPED-ADDRESS attribute.
// The returned addr slice is owned by the caller and does not alias b. func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) {
func ParseResponse(b []byte) (tID TxID, addr []byte, port uint16, err error) {
if !Is(b) { if !Is(b) {
return tID, nil, 0, ErrNotSTUN return tID, netip.AddrPort{}, ErrNotSTUN
} }
copy(tID[:], b[8:8+len(tID)]) copy(tID[:], b[8:8+len(tID)])
if b[0] != 0x01 || b[1] != 0x01 { if b[0] != 0x01 || b[1] != 0x01 {
return tID, nil, 0, ErrNotSuccessResponse return tID, netip.AddrPort{}, ErrNotSuccessResponse
} }
attrsLen := int(binary.BigEndian.Uint16(b[2:4])) attrsLen := int(binary.BigEndian.Uint16(b[2:4]))
b = b[headerLen:] // remove STUN header b = b[headerLen:] // remove STUN header
if attrsLen > len(b) { if attrsLen > len(b) {
return tID, nil, 0, ErrMalformedAttrs return tID, netip.AddrPort{}, ErrMalformedAttrs
} else if len(b) > attrsLen { } else if len(b) > attrsLen {
b = b[:attrsLen] // trim trailing packet bytes b = b[:attrsLen] // trim trailing packet bytes
} }
var addr6, fallbackAddr, fallbackAddr6 []byte var addr6, fallbackAddr, fallbackAddr6 netip.AddrPort
var port6, fallbackPort, fallbackPort6 uint16
// Read through the attributes. // Read through the attributes.
// The the addr+port reported by XOR-MAPPED-ADDRESS // The the addr+port reported by XOR-MAPPED-ADDRESS
@ -225,9 +223,9 @@ func ParseResponse(b []byte) (tID TxID, addr []byte, port uint16, err error) {
return err return err
} }
if len(a) == 16 { if len(a) == 16 {
addr6, port6 = a, p addr6 = netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)([]byte(a))), p)
} else { } else {
addr, port = a, p addr = netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)([]byte(a))), p)
} }
case attrMappedAddress: case attrMappedAddress:
a, p, err := mappedAddress(attr) a, p, err := mappedAddress(attr)
@ -235,30 +233,30 @@ func ParseResponse(b []byte) (tID TxID, addr []byte, port uint16, err error) {
return ErrMalformedAttrs return ErrMalformedAttrs
} }
if len(a) == 16 { if len(a) == 16 {
fallbackAddr6, fallbackPort6 = a, p fallbackAddr6 = netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)([]byte(a))), p)
} else { } else {
fallbackAddr, fallbackPort = a, p fallbackAddr = netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)([]byte(a))), p)
} }
} }
return nil return nil
}); err != nil { }); err != nil {
return TxID{}, nil, 0, err return TxID{}, netip.AddrPort{}, err
} }
if addr != nil { if addr.IsValid() {
return tID, addr, port, nil return tID, addr, nil
} }
if fallbackAddr != nil { if fallbackAddr.IsValid() {
return tID, append([]byte{}, fallbackAddr...), fallbackPort, nil return tID, fallbackAddr, nil
} }
if addr6 != nil { if addr6.IsValid() {
return tID, addr6, port6, nil return tID, addr6, nil
} }
if fallbackAddr6 != nil { if fallbackAddr6.IsValid() {
return tID, append([]byte{}, fallbackAddr6...), fallbackPort6, nil return tID, fallbackAddr6, nil
} }
return tID, nil, 0, ErrMalformedAttrs return tID, netip.AddrPort{}, ErrMalformedAttrs
} }
func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) { func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) {

@ -7,7 +7,7 @@
package stun package stun
func FuzzStunParser(data []byte) int { func FuzzStunParser(data []byte) int {
_, _, _, _ = ParseResponse(data) _, _, _ = ParseResponse(data)
_, _ = ParseBindingRequest(data) _, _ = ParseBindingRequest(data)
return 1 return 1

@ -7,7 +7,7 @@ package stun_test
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net" "net/netip"
"testing" "testing"
"tailscale.com/net/stun" "tailscale.com/net/stun"
@ -25,7 +25,7 @@ var responseTests = []struct {
name string name string
data []byte data []byte
wantTID []byte wantTID []byte
wantAddr []byte wantAddr netip.Addr
wantPort uint16 wantPort uint16
}{ }{
{ {
@ -40,7 +40,7 @@ var responseTests = []struct {
0x23, 0x60, 0xb1, 0x1e, 0x3e, 0xc6, 0x8f, 0xfa, 0x23, 0x60, 0xb1, 0x1e, 0x3e, 0xc6, 0x8f, 0xfa,
0x93, 0xe0, 0x80, 0x07, 0x93, 0xe0, 0x80, 0x07,
}, },
wantAddr: []byte{72, 69, 33, 45}, wantAddr: netip.AddrFrom4([4]byte{72, 69, 33, 45}),
wantPort: 59028, wantPort: 59028,
}, },
{ {
@ -55,7 +55,7 @@ var responseTests = []struct {
0xf9, 0xf1, 0x21, 0xcb, 0xde, 0x7d, 0x7c, 0x75, 0xf9, 0xf1, 0x21, 0xcb, 0xde, 0x7d, 0x7c, 0x75,
0x92, 0x3c, 0xe2, 0x71, 0x92, 0x3c, 0xe2, 0x71,
}, },
wantAddr: []byte{72, 69, 33, 45}, wantAddr: netip.AddrFrom4([4]byte{72, 69, 33, 45}),
wantPort: 59029, wantPort: 59029,
}, },
{ {
@ -77,7 +77,7 @@ var responseTests = []struct {
0x48, 0x2e, 0xb6, 0x47, 0x15, 0xe8, 0xb2, 0x8e, 0x48, 0x2e, 0xb6, 0x47, 0x15, 0xe8, 0xb2, 0x8e,
0xae, 0xad, 0x64, 0x44, 0xae, 0xad, 0x64, 0x44,
}, },
wantAddr: []byte{72, 69, 33, 45}, wantAddr: netip.AddrFrom4([4]byte{72, 69, 33, 45}),
wantPort: 58539, wantPort: 58539,
}, },
{ {
@ -95,7 +95,7 @@ var responseTests = []struct {
0x7e, 0x57, 0x96, 0x68, 0x29, 0xf4, 0x44, 0x60, 0x7e, 0x57, 0x96, 0x68, 0x29, 0xf4, 0x44, 0x60,
0x9d, 0x1d, 0xea, 0xa6, 0x9d, 0x1d, 0xea, 0xa6,
}, },
wantAddr: []byte{72, 69, 33, 45}, wantAddr: netip.AddrFrom4([4]byte{72, 69, 33, 45}),
wantPort: 59859, wantPort: 59859,
}, },
{ {
@ -114,7 +114,7 @@ var responseTests = []struct {
0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c, 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
0x4f, 0x3e, 0x30, 0x8e, 0x4f, 0x3e, 0x30, 0x8e,
}, },
wantAddr: []byte{127, 0, 0, 1}, wantAddr: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
wantPort: 61300, wantPort: 61300,
}, },
{ {
@ -137,7 +137,7 @@ var responseTests = []struct {
6, 245, 102, 133, 210, 138, 243, 230, 156, 227, 6, 245, 102, 133, 210, 138, 243, 230, 156, 227,
65, 226, 65, 226,
}, },
wantAddr: net.ParseIP("2602:d1:b4cf:c100:38b2:31ff:feef:96f6"), wantAddr: netip.MustParseAddr("2602:d1:b4cf:c100:38b2:31ff:feef:96f6"),
wantPort: 37070, wantPort: 37070,
}, },
@ -156,7 +156,7 @@ var responseTests = []struct {
0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c, 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
0x4f, 0x3e, 0x30, 0x8e, 0x4f, 0x3e, 0x30, 0x8e,
}, },
wantAddr: []byte{127, 0, 0, 1}, wantAddr: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
wantPort: 61300, wantPort: 61300,
}, },
{ {
@ -172,7 +172,7 @@ var responseTests = []struct {
0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c, 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
0x4f, 0x3e, 0x30, 0x8e, 0x4f, 0x3e, 0x30, 0x8e,
}, },
wantAddr: []byte{127, 0, 0, 1}, wantAddr: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
wantPort: 61300, wantPort: 61300,
}, },
} }
@ -180,7 +180,7 @@ var responseTests = []struct {
func TestParseResponse(t *testing.T) { func TestParseResponse(t *testing.T) {
subtest := func(t *testing.T, i int) { subtest := func(t *testing.T, i int) {
test := responseTests[i] test := responseTests[i]
tID, addr, port, err := stun.ParseResponse(test.data) tID, addrPort, err := stun.ParseResponse(test.data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -188,11 +188,11 @@ func TestParseResponse(t *testing.T) {
if !bytes.Equal(tID[:], test.wantTID) { if !bytes.Equal(tID[:], test.wantTID) {
t.Errorf("tid=%v, want %v", tID[:], test.wantTID) t.Errorf("tid=%v, want %v", tID[:], test.wantTID)
} }
if !bytes.Equal(addr, test.wantAddr) { if addrPort.Addr().Compare(test.wantAddr) != 0 {
t.Errorf("addr=%v (%v), want %v", addr, net.IP(addr), test.wantAddr) t.Errorf("addr=%v, want %v", addrPort.Addr(), test.wantAddr)
} }
if port != test.wantPort { if addrPort.Port() != test.wantPort {
t.Errorf("port=%d, want %d", port, test.wantPort) t.Errorf("port=%d, want %d", addrPort.Port(), test.wantPort)
} }
} }
for i, test := range responseTests { for i, test := range responseTests {
@ -249,17 +249,17 @@ func TestResponse(t *testing.T) {
} }
tests := []struct { tests := []struct {
tx stun.TxID tx stun.TxID
ip net.IP addr netip.Addr
port uint16 port uint16
}{ }{
{tx: txN(1), ip: net.ParseIP("1.2.3.4").To4(), port: 254}, {tx: txN(1), addr: netip.MustParseAddr("1.2.3.4"), port: 254},
{tx: txN(2), ip: net.ParseIP("1.2.3.4").To4(), port: 257}, {tx: txN(2), addr: netip.MustParseAddr("1.2.3.4"), port: 257},
{tx: txN(3), ip: net.ParseIP("1::4"), port: 254}, {tx: txN(3), addr: netip.MustParseAddr("1::4"), port: 254},
{tx: txN(4), ip: net.ParseIP("1::4"), port: 257}, {tx: txN(4), addr: netip.MustParseAddr("1::4"), port: 257},
} }
for _, tt := range tests { for _, tt := range tests {
res := stun.Response(tt.tx, tt.ip, tt.port) res := stun.Response(tt.tx, netip.AddrPortFrom(tt.addr, tt.port))
tx2, ip2, port2, err := stun.ParseResponse(res) tx2, addr2, err := stun.ParseResponse(res)
if err != nil { if err != nil {
t.Errorf("TX %x: error: %v", tt.tx, err) t.Errorf("TX %x: error: %v", tt.tx, err)
continue continue
@ -267,11 +267,11 @@ func TestResponse(t *testing.T) {
if tt.tx != tx2 { if tt.tx != tx2 {
t.Errorf("TX %x: got TxID = %v", tt.tx, tx2) t.Errorf("TX %x: got TxID = %v", tt.tx, tx2)
} }
if !bytes.Equal([]byte(tt.ip), ip2) { if tt.addr.Compare(addr2.Addr()) != 0 {
t.Errorf("TX %x: ip = %v (%v); want %v (%v)", tt.tx, ip2, net.IP(ip2), []byte(tt.ip), tt.ip) t.Errorf("TX %x: addr = %v; want %v", tt.tx, addr2.Addr(), tt.addr)
} }
if tt.port != port2 { if tt.port != addr2.Port() {
t.Errorf("TX %x: port = %v; want %v", tt.tx, port2, tt.port) t.Errorf("TX %x: port = %v; want %v", tt.tx, addr2.Port(), tt.port)
} }
} }
} }

@ -84,7 +84,8 @@ func runSTUN(t testing.TB, pc net.PacketConn, stats *stunStats, done chan<- stru
} }
stats.mu.Unlock() stats.mu.Unlock()
res := stun.Response(txid, ua.IP, uint16(ua.Port)) nia, _ := netip.AddrFromSlice(ua.IP)
res := stun.Response(txid, netip.AddrPortFrom(nia, uint16(ua.Port)))
if _, err := pc.WriteTo(res, addr); err != nil { if _, err := pc.WriteTo(res, addr); err != nil {
t.Logf("STUN server write failed: %v", err) t.Logf("STUN server write failed: %v", err)
} }

Loading…
Cancel
Save