net/stun: convert to use net/netip.AddrPort

Convert ParseResponse and Response to use netip.AddrPort instead of
net.IP and separate port.

Fixes #5281

Signed-off-by: Kris Brandow <kris.brandow@gmail.com>
pull/5379/head
Kris Brandow 2 years ago
parent c3270af52b
commit 8f38afbf8e

@ -19,6 +19,7 @@ import (
"math"
"net"
"net/http"
"net/netip"
"os"
"path/filepath"
"regexp"
@ -356,7 +357,8 @@ func serverSTUNListener(ctx context.Context, pc *net.UDPConn) {
} else {
stunIPv6.Add(1)
}
res := stun.Response(txid, ua.IP, uint16(ua.Port))
addr, _ := netip.AddrFromSlice(ua.IP)
res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(ua.Port)))
_, err = pc.WriteTo(res, ua)
if err != nil {
stunWriteError.Add(1)

@ -360,7 +360,7 @@ func probeUDP(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode) (la
time.Sleep(100 * time.Millisecond)
continue
}
txBack, _, _, err := stun.ParseResponse(buf[:n])
txBack, _, err := stun.ParseResponse(buf[:n])
if err != nil {
return 0, fmt.Errorf("parsing STUN response from %v: %v", ip, err)
}

@ -261,7 +261,7 @@ func (c *Client) ReceiveSTUNPacket(pkt []byte, src netip.AddrPort) {
return
}
tx, addr, port, err := stun.ParseResponse(pkt)
tx, addrPort, err := stun.ParseResponse(pkt)
if err != nil {
if _, err := stun.ParseBindingRequest(pkt); err == nil {
// This was probably our own netcheck hairpin
@ -279,10 +279,7 @@ func (c *Client) ReceiveSTUNPacket(pkt []byte, src netip.AddrPort) {
}
rs.mu.Unlock()
if ok {
ta := net.TCPAddr{IP: addr, Port: int(port)}
if ipp := netaddr.Unmap(ta.AddrPort()); ipp.IsValid() {
onDone(ipp)
}
onDone(addrPort)
}
}

@ -11,6 +11,7 @@ import (
"errors"
"hash/crc32"
"net"
"net/netip"
)
const (
@ -151,20 +152,18 @@ func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error {
}
// Response generates a binding response.
func Response(txID TxID, ip net.IP, port uint16) []byte {
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
}
func Response(txID TxID, addrPort netip.AddrPort) []byte {
addr := addrPort.Addr()
var fam byte
switch len(ip) {
case net.IPv4len:
if addr.Is4() {
fam = 1
case net.IPv6len:
} else if addr.Is6() {
fam = 2
default:
} else {
return nil
}
attrsLen := 8 + len(ip)
attrsLen := 8 + addr.BitLen()/8
b := make([]byte, 0, headerLen+attrsLen)
// Header
@ -175,12 +174,13 @@ func Response(txID TxID, ip net.IP, port uint16) []byte {
// Attributes (well, one)
b = appendU16(b, attrXorMappedAddress)
b = appendU16(b, uint16(4+len(ip)))
b = appendU16(b, uint16(4+addr.BitLen()/8))
b = append(b,
0, // unused byte
fam)
b = appendU16(b, port^0x2112) // first half of magicCookie
for i, o := range []byte(ip) {
b = appendU16(b, addrPort.Port()^0x2112) // first half of magicCookie
ipa := addr.As16()
for i, o := range ipa[16-addr.BitLen()/8:] {
if i < 4 {
b = append(b, o^magicCookie[i])
} else {
@ -192,25 +192,23 @@ func Response(txID TxID, ip net.IP, port uint16) []byte {
// ParseResponse parses a successful binding response STUN packet.
// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute.
// The returned addr slice is owned by the caller and does not alias b.
func ParseResponse(b []byte) (tID TxID, addr []byte, port uint16, err error) {
func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) {
if !Is(b) {
return tID, nil, 0, ErrNotSTUN
return tID, netip.AddrPort{}, ErrNotSTUN
}
copy(tID[:], b[8:8+len(tID)])
if b[0] != 0x01 || b[1] != 0x01 {
return tID, nil, 0, ErrNotSuccessResponse
return tID, netip.AddrPort{}, ErrNotSuccessResponse
}
attrsLen := int(binary.BigEndian.Uint16(b[2:4]))
b = b[headerLen:] // remove STUN header
if attrsLen > len(b) {
return tID, nil, 0, ErrMalformedAttrs
return tID, netip.AddrPort{}, ErrMalformedAttrs
} else if len(b) > attrsLen {
b = b[:attrsLen] // trim trailing packet bytes
}
var addr6, fallbackAddr, fallbackAddr6 []byte
var port6, fallbackPort, fallbackPort6 uint16
var addr6, fallbackAddr, fallbackAddr6 netip.AddrPort
// Read through the attributes.
// The the addr+port reported by XOR-MAPPED-ADDRESS
@ -225,9 +223,9 @@ func ParseResponse(b []byte) (tID TxID, addr []byte, port uint16, err error) {
return err
}
if len(a) == 16 {
addr6, port6 = a, p
addr6 = netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)([]byte(a))), p)
} else {
addr, port = a, p
addr = netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)([]byte(a))), p)
}
case attrMappedAddress:
a, p, err := mappedAddress(attr)
@ -235,30 +233,30 @@ func ParseResponse(b []byte) (tID TxID, addr []byte, port uint16, err error) {
return ErrMalformedAttrs
}
if len(a) == 16 {
fallbackAddr6, fallbackPort6 = a, p
fallbackAddr6 = netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)([]byte(a))), p)
} else {
fallbackAddr, fallbackPort = a, p
fallbackAddr = netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)([]byte(a))), p)
}
}
return nil
}); err != nil {
return TxID{}, nil, 0, err
return TxID{}, netip.AddrPort{}, err
}
if addr != nil {
return tID, addr, port, nil
if addr.IsValid() {
return tID, addr, nil
}
if fallbackAddr != nil {
return tID, append([]byte{}, fallbackAddr...), fallbackPort, nil
if fallbackAddr.IsValid() {
return tID, fallbackAddr, nil
}
if addr6 != nil {
return tID, addr6, port6, nil
if addr6.IsValid() {
return tID, addr6, nil
}
if fallbackAddr6 != nil {
return tID, append([]byte{}, fallbackAddr6...), fallbackPort6, nil
if fallbackAddr6.IsValid() {
return tID, fallbackAddr6, nil
}
return tID, nil, 0, ErrMalformedAttrs
return tID, netip.AddrPort{}, ErrMalformedAttrs
}
func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) {

@ -7,7 +7,7 @@
package stun
func FuzzStunParser(data []byte) int {
_, _, _, _ = ParseResponse(data)
_, _, _ = ParseResponse(data)
_, _ = ParseBindingRequest(data)
return 1

@ -7,7 +7,7 @@ package stun_test
import (
"bytes"
"fmt"
"net"
"net/netip"
"testing"
"tailscale.com/net/stun"
@ -25,7 +25,7 @@ var responseTests = []struct {
name string
data []byte
wantTID []byte
wantAddr []byte
wantAddr netip.Addr
wantPort uint16
}{
{
@ -40,7 +40,7 @@ var responseTests = []struct {
0x23, 0x60, 0xb1, 0x1e, 0x3e, 0xc6, 0x8f, 0xfa,
0x93, 0xe0, 0x80, 0x07,
},
wantAddr: []byte{72, 69, 33, 45},
wantAddr: netip.AddrFrom4([4]byte{72, 69, 33, 45}),
wantPort: 59028,
},
{
@ -55,7 +55,7 @@ var responseTests = []struct {
0xf9, 0xf1, 0x21, 0xcb, 0xde, 0x7d, 0x7c, 0x75,
0x92, 0x3c, 0xe2, 0x71,
},
wantAddr: []byte{72, 69, 33, 45},
wantAddr: netip.AddrFrom4([4]byte{72, 69, 33, 45}),
wantPort: 59029,
},
{
@ -77,7 +77,7 @@ var responseTests = []struct {
0x48, 0x2e, 0xb6, 0x47, 0x15, 0xe8, 0xb2, 0x8e,
0xae, 0xad, 0x64, 0x44,
},
wantAddr: []byte{72, 69, 33, 45},
wantAddr: netip.AddrFrom4([4]byte{72, 69, 33, 45}),
wantPort: 58539,
},
{
@ -95,7 +95,7 @@ var responseTests = []struct {
0x7e, 0x57, 0x96, 0x68, 0x29, 0xf4, 0x44, 0x60,
0x9d, 0x1d, 0xea, 0xa6,
},
wantAddr: []byte{72, 69, 33, 45},
wantAddr: netip.AddrFrom4([4]byte{72, 69, 33, 45}),
wantPort: 59859,
},
{
@ -114,7 +114,7 @@ var responseTests = []struct {
0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
0x4f, 0x3e, 0x30, 0x8e,
},
wantAddr: []byte{127, 0, 0, 1},
wantAddr: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
wantPort: 61300,
},
{
@ -137,7 +137,7 @@ var responseTests = []struct {
6, 245, 102, 133, 210, 138, 243, 230, 156, 227,
65, 226,
},
wantAddr: net.ParseIP("2602:d1:b4cf:c100:38b2:31ff:feef:96f6"),
wantAddr: netip.MustParseAddr("2602:d1:b4cf:c100:38b2:31ff:feef:96f6"),
wantPort: 37070,
},
@ -156,7 +156,7 @@ var responseTests = []struct {
0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
0x4f, 0x3e, 0x30, 0x8e,
},
wantAddr: []byte{127, 0, 0, 1},
wantAddr: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
wantPort: 61300,
},
{
@ -172,7 +172,7 @@ var responseTests = []struct {
0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c,
0x4f, 0x3e, 0x30, 0x8e,
},
wantAddr: []byte{127, 0, 0, 1},
wantAddr: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
wantPort: 61300,
},
}
@ -180,7 +180,7 @@ var responseTests = []struct {
func TestParseResponse(t *testing.T) {
subtest := func(t *testing.T, i int) {
test := responseTests[i]
tID, addr, port, err := stun.ParseResponse(test.data)
tID, addrPort, err := stun.ParseResponse(test.data)
if err != nil {
t.Fatal(err)
}
@ -188,11 +188,11 @@ func TestParseResponse(t *testing.T) {
if !bytes.Equal(tID[:], test.wantTID) {
t.Errorf("tid=%v, want %v", tID[:], test.wantTID)
}
if !bytes.Equal(addr, test.wantAddr) {
t.Errorf("addr=%v (%v), want %v", addr, net.IP(addr), test.wantAddr)
if addrPort.Addr().Compare(test.wantAddr) != 0 {
t.Errorf("addr=%v, want %v", addrPort.Addr(), test.wantAddr)
}
if port != test.wantPort {
t.Errorf("port=%d, want %d", port, test.wantPort)
if addrPort.Port() != test.wantPort {
t.Errorf("port=%d, want %d", addrPort.Port(), test.wantPort)
}
}
for i, test := range responseTests {
@ -249,17 +249,17 @@ func TestResponse(t *testing.T) {
}
tests := []struct {
tx stun.TxID
ip net.IP
addr netip.Addr
port uint16
}{
{tx: txN(1), ip: net.ParseIP("1.2.3.4").To4(), port: 254},
{tx: txN(2), ip: net.ParseIP("1.2.3.4").To4(), port: 257},
{tx: txN(3), ip: net.ParseIP("1::4"), port: 254},
{tx: txN(4), ip: net.ParseIP("1::4"), port: 257},
{tx: txN(1), addr: netip.MustParseAddr("1.2.3.4"), port: 254},
{tx: txN(2), addr: netip.MustParseAddr("1.2.3.4"), port: 257},
{tx: txN(3), addr: netip.MustParseAddr("1::4"), port: 254},
{tx: txN(4), addr: netip.MustParseAddr("1::4"), port: 257},
}
for _, tt := range tests {
res := stun.Response(tt.tx, tt.ip, tt.port)
tx2, ip2, port2, err := stun.ParseResponse(res)
res := stun.Response(tt.tx, netip.AddrPortFrom(tt.addr, tt.port))
tx2, addr2, err := stun.ParseResponse(res)
if err != nil {
t.Errorf("TX %x: error: %v", tt.tx, err)
continue
@ -267,11 +267,11 @@ func TestResponse(t *testing.T) {
if tt.tx != tx2 {
t.Errorf("TX %x: got TxID = %v", tt.tx, tx2)
}
if !bytes.Equal([]byte(tt.ip), ip2) {
t.Errorf("TX %x: ip = %v (%v); want %v (%v)", tt.tx, ip2, net.IP(ip2), []byte(tt.ip), tt.ip)
if tt.addr.Compare(addr2.Addr()) != 0 {
t.Errorf("TX %x: addr = %v; want %v", tt.tx, addr2.Addr(), tt.addr)
}
if tt.port != port2 {
t.Errorf("TX %x: port = %v; want %v", tt.tx, port2, tt.port)
if tt.port != addr2.Port() {
t.Errorf("TX %x: port = %v; want %v", tt.tx, addr2.Port(), tt.port)
}
}
}

@ -84,7 +84,8 @@ func runSTUN(t testing.TB, pc net.PacketConn, stats *stunStats, done chan<- stru
}
stats.mu.Unlock()
res := stun.Response(txid, ua.IP, uint16(ua.Port))
nia, _ := netip.AddrFromSlice(ua.IP)
res := stun.Response(txid, netip.AddrPortFrom(nia, uint16(ua.Port)))
if _, err := pc.WriteTo(res, addr); err != nil {
t.Logf("STUN server write failed: %v", err)
}

Loading…
Cancel
Save