net/netcheck, wgengine/magicsock: replace more UDPAddr with netaddr.IPPort

pull/514/head^2
Brad Fitzpatrick 4 years ago
parent 9070aacdee
commit 2d6e84e19e

@ -127,7 +127,7 @@ func (c *Client) vlogf(format string, a ...interface{}) {
// handleHairSTUN reports whether pkt (from src) was our magic hairpin // handleHairSTUN reports whether pkt (from src) was our magic hairpin
// probe packet that we sent to ourselves. // probe packet that we sent to ourselves.
func (c *Client) handleHairSTUNLocked(pkt []byte, src *net.UDPAddr) bool { func (c *Client) handleHairSTUNLocked(pkt []byte, src netaddr.IPPort) bool {
rs := c.curState rs := c.curState
if rs == nil { if rs == nil {
return false return false
@ -150,11 +150,7 @@ func (c *Client) MakeNextReportFull() {
c.mu.Unlock() c.mu.Unlock()
} }
func (c *Client) ReceiveSTUNPacket(pkt []byte, src *net.UDPAddr) { func (c *Client) ReceiveSTUNPacket(pkt []byte, src netaddr.IPPort) {
if src == nil || src.IP == nil {
panic("bogus src")
}
c.mu.Lock() c.mu.Lock()
if c.handleHairSTUNLocked(pkt, src) { if c.handleHairSTUNLocked(pkt, src) {
c.mu.Unlock() c.mu.Unlock()
@ -421,7 +417,9 @@ func (c *Client) readPackets(ctx context.Context, pc net.PacketConn) {
if !stun.Is(pkt) { if !stun.Is(pkt) {
continue continue
} }
c.ReceiveSTUNPacket(pkt, ua) if ipp, ok := netaddr.FromStdAddr(ua.IP, ua.Port, ua.Zone); ok {
c.ReceiveSTUNPacket(pkt, ipp)
}
} }
} }
@ -429,7 +427,7 @@ func (c *Client) readPackets(ctx context.Context, pc net.PacketConn) {
type reportState struct { type reportState struct {
c *Client c *Client
hairTX stun.TxID hairTX stun.TxID
gotHairSTUN chan *net.UDPAddr gotHairSTUN chan netaddr.IPPort
hairTimeout chan struct{} // closed on timeout hairTimeout chan struct{} // closed on timeout
pc4 STUNConn pc4 STUNConn
pc6 STUNConn pc6 STUNConn
@ -638,7 +636,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (*Report, e
report: newReport(), report: newReport(),
inFlight: map[stun.TxID]func(netaddr.IPPort){}, inFlight: map[stun.TxID]func(netaddr.IPPort){},
hairTX: stun.NewTxID(), // random payload hairTX: stun.NewTxID(), // random payload
gotHairSTUN: make(chan *net.UDPAddr, 1), gotHairSTUN: make(chan netaddr.IPPort, 1),
hairTimeout: make(chan struct{}), hairTimeout: make(chan struct{}),
stopProbeCh: make(chan struct{}, 1), stopProbeCh: make(chan struct{}, 1),
} }

@ -16,6 +16,7 @@ import (
"testing" "testing"
"time" "time"
"inet.af/netaddr"
"tailscale.com/net/interfaces" "tailscale.com/net/interfaces"
"tailscale.com/net/stun" "tailscale.com/net/stun"
"tailscale.com/net/stun/stuntest" "tailscale.com/net/stun/stuntest"
@ -27,14 +28,14 @@ func TestHairpinSTUN(t *testing.T) {
c := &Client{ c := &Client{
curState: &reportState{ curState: &reportState{
hairTX: tx, hairTX: tx,
gotHairSTUN: make(chan *net.UDPAddr, 1), gotHairSTUN: make(chan netaddr.IPPort, 1),
}, },
} }
req := stun.Request(tx) req := stun.Request(tx)
if !stun.Is(req) { if !stun.Is(req) {
t.Fatal("expected STUN message") t.Fatal("expected STUN message")
} }
if !c.handleHairSTUNLocked(req, nil) { if !c.handleHairSTUNLocked(req, netaddr.IPPort{}) {
t.Fatal("expected true") t.Fatal("expected true")
} }
select { select {

@ -310,7 +310,7 @@ func (c *Conn) donec() <-chan struct{} { return c.connCtx.Done() }
// ignoreSTUNPackets sets a STUN packet processing func that does nothing. // ignoreSTUNPackets sets a STUN packet processing func that does nothing.
func (c *Conn) ignoreSTUNPackets() { func (c *Conn) ignoreSTUNPackets() {
c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {}) c.stunReceiveFunc.Store(func([]byte, netaddr.IPPort) {})
} }
// c.mu must NOT be held. // c.mu must NOT be held.
@ -1198,11 +1198,15 @@ func (c *Conn) awaitUDP4(b []byte) {
return return
} }
addr := pAddr.(*net.UDPAddr) addr := pAddr.(*net.UDPAddr)
ipp, ok := netaddr.FromStdAddr(addr.IP, addr.Port, addr.Zone)
if !ok {
continue
}
if stun.Is(b[:n]) { if stun.Is(b[:n]) {
c.stunReceiveFunc.Load().(func([]byte, *net.UDPAddr))(b[:n], addr) c.stunReceiveFunc.Load().(func([]byte, netaddr.IPPort))(b[:n], ipp)
continue continue
} }
if c.handleDiscoMessage(b[:n], addr) { if c.handleDiscoMessage(b[:n], ipp) {
continue continue
} }
@ -1276,7 +1280,7 @@ Top:
} }
addr := netaddr.IPPort{IP: derpMagicIPAddr, Port: uint16(regionID)} addr := netaddr.IPPort{IP: derpMagicIPAddr, Port: uint16(regionID)}
if c.handleDiscoMessage(b[:n], addr.UDPAddr()) { if c.handleDiscoMessage(b[:n], addr) {
goto Top goto Top
} }
@ -1334,11 +1338,15 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, *net.UDPAddr, error) {
return 0, nil, nil, err return 0, nil, nil, err
} }
addr := pAddr.(*net.UDPAddr) addr := pAddr.(*net.UDPAddr)
ipp, ok := netaddr.FromStdAddr(addr.IP, addr.Port, addr.Zone)
if !ok {
continue
}
if stun.Is(b[:n]) { if stun.Is(b[:n]) {
c.stunReceiveFunc.Load().(func([]byte, *net.UDPAddr))(b[:n], addr) c.stunReceiveFunc.Load().(func([]byte, netaddr.IPPort))(b[:n], ipp)
continue continue
} }
if c.handleDiscoMessage(b[:n], addr) { if c.handleDiscoMessage(b[:n], ipp) {
continue continue
} }
@ -1359,7 +1367,7 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, *net.UDPAddr, error) {
// //
// For messages received over DERP, the addr will be derpMagicIP (with // For messages received over DERP, the addr will be derpMagicIP (with
// port being the region) // port being the region)
func (c *Conn) handleDiscoMessage(msg []byte, src *net.UDPAddr) bool { func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
const magic = "TS💬" const magic = "TS💬"
const nonceLen = 24 const nonceLen = 24
const headerLen = len(magic) + len(tailcfg.DiscoKey{}) + nonceLen const headerLen = len(magic) + len(tailcfg.DiscoKey{}) + nonceLen
@ -1369,11 +1377,6 @@ func (c *Conn) handleDiscoMessage(msg []byte, src *net.UDPAddr) bool {
var sender tailcfg.DiscoKey var sender tailcfg.DiscoKey
copy(sender[:], msg[len(magic):]) copy(sender[:], msg[len(magic):])
srca, ok := netaddr.FromStdAddr(src.IP, src.Port, src.Zone)
if !ok {
return false
}
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -1421,11 +1424,11 @@ func (c *Conn) handleDiscoMessage(msg []byte, src *net.UDPAddr) bool {
switch dm := dm.(type) { switch dm := dm.(type) {
case *disco.Ping: case *disco.Ping:
c.handlePingLocked(dm, senderNode, sender, srca) c.handlePingLocked(dm, senderNode, sender, src)
case *disco.Pong: case *disco.Pong:
c.handlePongLocked(dm, senderNode, sender, srca) c.handlePongLocked(dm, senderNode, sender, src)
case disco.CallMeMaybe: case disco.CallMeMaybe:
if srca.IP != derpMagicIPAddr { if src.IP != derpMagicIPAddr {
// CallMeMaybe messages should only come via DERP. // CallMeMaybe messages should only come via DERP.
c.logf("[unexpected] CallMeMaybe packets should only come via DERP") c.logf("[unexpected] CallMeMaybe packets should only come via DERP")
return true return true

@ -873,7 +873,7 @@ func TestDiscoMessage(t *testing.T) {
pkt = append(pkt, nonce[:]...) pkt = append(pkt, nonce[:]...)
pkt = box.Seal(pkt, []byte(payload), &nonce, c.discoPrivate.Public().B32(), peer1Priv.B32()) pkt = box.Seal(pkt, []byte(payload), &nonce, c.discoPrivate.Public().B32(), peer1Priv.B32())
got := c.handleDiscoMessage(pkt, &net.UDPAddr{IP: net.ParseIP("1.2.3.4")}) got := c.handleDiscoMessage(pkt, netaddr.IPPort{})
if !got { if !got {
t.Error("failed to open it") t.Error("failed to open it")
} }

Loading…
Cancel
Save