wgengine/magicsock: unify on netaddr types in addrSet

addrSet maintained duplicate lists of netaddr.IPPorts and net.UDPAddrs.
Unify to use the netaddr type only.

This makes (*Conn).ReceiveIPvN a bit uglier,
but that'll be cleaned up in a subsequent commit.

This is preparatory work to remove an allocation from ReceiveIPv4.

Co-authored-by: Sonia Appasamy <sonia@tailscale.com>
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
pull/1329/head
Josh Bleecher Snyder 3 years ago committed by Josh Bleecher Snyder
parent 4cd9218351
commit 0c673c1344

@ -53,7 +53,6 @@ func (c *Conn) createLegacyEndpointLocked(pk key.Public, addrs string) (conn.End
return nil, fmt.Errorf("bogus address %q", ep)
}
a.ipPorts = append(a.ipPorts, ipp)
a.addrs = append(a.addrs, *ipp.UDPAddr())
}
}
@ -84,14 +83,14 @@ func (c *Conn) createLegacyEndpointLocked(pk key.Public, addrs string) (conn.End
return a, nil
}
func (c *Conn) findLegacyEndpointLocked(ipp netaddr.IPPort, addr *net.UDPAddr, packet []byte) conn.Endpoint {
func (c *Conn) findLegacyEndpointLocked(ipp netaddr.IPPort, packet []byte) conn.Endpoint {
if c.disableLegacy {
return nil
}
// Pre-disco: look up their addrSet.
if as, ok := c.addrsByUDP[ipp]; ok {
as.updateDst(addr)
as.updateDst(ipp)
return as
}
@ -100,7 +99,7 @@ func (c *Conn) findLegacyEndpointLocked(ipp netaddr.IPPort, addr *net.UDPAddr, p
// know. If this is a handshake packet, we can try to identify the
// peer in question.
if as := c.peerFromPacketLocked(packet); as != nil {
as.updateDst(addr)
as.updateDst(ipp)
return as
}
@ -268,14 +267,6 @@ func (as *addrSet) appendDests(dsts []netaddr.IPPort, b []byte) (_ []netaddr.IPP
as.lastSend = now
// Some internal invariant checks.
if len(as.addrs) != len(as.ipPorts) {
panic(fmt.Sprintf("lena %d != leni %d", len(as.addrs), len(as.ipPorts)))
}
if n1, n2 := as.roamAddr != nil, as.roamAddrStd != nil; n1 != n2 {
panic(fmt.Sprintf("roamnil %v != roamstdnil %v", n1, n2))
}
// Spray logic.
//
// After exchanging a handshake with a peer, we send some outbound
@ -320,8 +311,8 @@ func (as *addrSet) appendDests(dsts []netaddr.IPPort, b []byte) (_ []netaddr.IPP
// roamAddr should be special like this.
dsts = append(dsts, *as.roamAddr)
case as.curAddr != -1:
if as.curAddr >= len(as.addrs) {
as.Logf("[unexpected] magicsock bug: as.curAddr >= len(as.addrs): %d >= %d", as.curAddr, len(as.addrs))
if as.curAddr >= len(as.ipPorts) {
as.Logf("[unexpected] magicsock bug: as.curAddr >= len(as.ipPorts): %d >= %d", as.curAddr, len(as.ipPorts))
break
}
// No roaming addr, but we've seen packets from a known peer
@ -352,15 +343,14 @@ func (as *addrSet) appendDests(dsts []netaddr.IPPort, b []byte) (_ []netaddr.IPP
type addrSet struct {
publicKey key.Public // peer public key used for DERP communication
// addrs is an ordered priority list provided by wgengine,
// ipPorts is an ordered priority list provided by wgengine,
// sorted from expensive+slow+reliable at the begnining to
// fast+cheap at the end. More concretely, it's typically:
//
// [DERP fakeip:node, Global IP:port, LAN ip:port]
//
// But there could be multiple or none of each.
addrs []net.UDPAddr
ipPorts []netaddr.IPPort // same as addrs, in different form
ipPorts []netaddr.IPPort
// clock, if non-nil, is used in tests instead of time.Now.
clock func() time.Time
@ -376,8 +366,7 @@ type addrSet struct {
// this should hopefully never be used (or at least used
// rarely) in the case that all the components of Tailscale
// are correctly learning/sharing the network map details.
roamAddr *netaddr.IPPort
roamAddrStd *net.UDPAddr
roamAddr *netaddr.IPPort
// curAddr is an index into addrs of the highest-priority
// address a valid packet has been received from so far.
@ -400,9 +389,9 @@ type addrSet struct {
// derpID returns this addrSet's home DERP node, or 0 if none is found.
func (as *addrSet) derpID() int {
for _, ua := range as.addrs {
if ua.IP.Equal(derpMagicIP) {
return ua.Port
for _, ua := range as.ipPorts {
if ua.IP == derpMagicIPAddr {
return int(ua.Port)
}
}
return 0
@ -424,7 +413,7 @@ func (a *addrSet) dst() netaddr.IPPort {
if a.roamAddr != nil {
return *a.roamAddr
}
if len(a.addrs) == 0 {
if len(a.ipPorts) == 0 {
return noAddr
}
i := a.curAddr
@ -439,7 +428,7 @@ func (a *addrSet) DstToBytes() []byte {
}
func (a *addrSet) DstToString() string {
var addrs []string
for _, addr := range a.addrs {
for _, addr := range a.ipPorts {
addrs = append(addrs, addr.String())
}
@ -459,8 +448,8 @@ func (a *addrSet) ClearSrc() {}
// updateDst records receipt of a packet from new. This is used to
// potentially update the transmit address used for this addrSet.
func (a *addrSet) updateDst(new *net.UDPAddr) error {
if new.IP.Equal(derpMagicIP) {
func (a *addrSet) updateDst(new netaddr.IPPort) error {
if new.IP == derpMagicIPAddr {
// Never consider DERP addresses as a viable candidate for
// either curAddr or roamAddr. It's only ever a last resort
// choice, never a preferred choice.
@ -471,25 +460,20 @@ func (a *addrSet) updateDst(new *net.UDPAddr) error {
a.mu.Lock()
defer a.mu.Unlock()
if a.roamAddrStd != nil && equalUDPAddr(new, a.roamAddrStd) {
if a.roamAddr != nil && new == *a.roamAddr {
// Packet from the current roaming address, no logging.
// This is a hot path for established connections.
return nil
}
if a.roamAddr == nil && a.curAddr >= 0 && equalUDPAddr(new, &a.addrs[a.curAddr]) {
if a.roamAddr == nil && a.curAddr >= 0 && new == a.ipPorts[a.curAddr] {
// Packet from current-priority address, no logging.
// This is a hot path for established connections.
return nil
}
newa, ok := netaddr.FromStdAddr(new.IP, new.Port, new.Zone)
if !ok {
return nil
}
index := -1
for i := range a.addrs {
if equalUDPAddr(new, &a.addrs[i]) {
for i := range a.ipPorts {
if new == a.ipPorts[i] {
index = i
break
}
@ -499,7 +483,7 @@ func (a *addrSet) updateDst(new *net.UDPAddr) error {
pk := publicKey.ShortString()
old := "<none>"
if a.curAddr >= 0 {
old = a.addrs[a.curAddr].String()
old = a.ipPorts[a.curAddr].String()
}
switch {
@ -509,18 +493,16 @@ func (a *addrSet) updateDst(new *net.UDPAddr) error {
} else {
a.Logf("magicsock: rx %s from roaming address %s, replaces roaming address %s", pk, new, a.roamAddr)
}
a.roamAddr = &newa
a.roamAddrStd = new
a.roamAddr = &new
case a.roamAddr != nil:
a.Logf("magicsock: rx %s from known %s (%d), replaces roaming address %s", pk, new, index, a.roamAddr)
a.roamAddr = nil
a.roamAddrStd = nil
a.curAddr = index
a.loggedLogPriMask = 0
case a.curAddr == -1:
a.Logf("magicsock: rx %s from %s (%d/%d), set as new priority", pk, new, index, len(a.addrs))
a.Logf("magicsock: rx %s from %s (%d/%d), set as new priority", pk, new, index, len(a.ipPorts))
a.curAddr = index
a.loggedLogPriMask = 0
@ -531,7 +513,7 @@ func (a *addrSet) updateDst(new *net.UDPAddr) error {
}
default: // index > a.curAddr
a.Logf("magicsock: rx %s from %s (%d/%d), replaces old priority %s", pk, new, index, len(a.addrs), old)
a.Logf("magicsock: rx %s from %s (%d/%d), replaces old priority %s", pk, new, index, len(a.ipPorts), old)
a.curAddr = index
a.loggedLogPriMask = 0
}
@ -539,10 +521,6 @@ func (a *addrSet) updateDst(new *net.UDPAddr) error {
return nil
}
func equalUDPAddr(x, y *net.UDPAddr) bool {
return x.Port == y.Port && x.IP.Equal(y.IP)
}
func (a *addrSet) String() string {
a.mu.Lock()
defer a.mu.Unlock()
@ -551,9 +529,9 @@ func (a *addrSet) String() string {
buf.WriteByte('[')
if a.roamAddr != nil {
buf.WriteString("roam:")
sbPrintAddr(buf, *a.roamAddrStd)
sbPrintAddr(buf, *a.roamAddr)
}
for i, addr := range a.addrs {
for i, addr := range a.ipPorts {
if i > 0 || a.roamAddr != nil {
buf.WriteString(", ")
}
@ -572,8 +550,8 @@ func (as *addrSet) populatePeerStatus(ps *ipnstate.PeerStatus) {
defer as.mu.Unlock()
ps.LastWrite = as.lastSend
for i, ua := range as.addrs {
if ua.IP.Equal(derpMagicIP) {
for i, ua := range as.ipPorts {
if ua.IP == derpMagicIPAddr {
continue
}
uaStr := ua.String()
@ -583,7 +561,7 @@ func (as *addrSet) populatePeerStatus(ps *ipnstate.PeerStatus) {
}
}
if as.roamAddr != nil {
ps.CurAddr = udpAddrDebugString(*as.roamAddrStd)
ps.CurAddr = ippDebugString(*as.roamAddr)
}
}

@ -348,8 +348,7 @@ func (c *Conn) addDerpPeerRoute(peer key.Public, derpID int, dc *derphttp.Client
// Mnemonic: 3.3.40 are numbers above the keys D, E, R, P.
const DerpMagicIP = "127.3.3.40"
var derpMagicIP = net.ParseIP(DerpMagicIP).To4()
var derpMagicIPAddr = netaddr.IPv4(127, 3, 3, 40)
var derpMagicIPAddr = netaddr.MustParseIP(DerpMagicIP)
// activeDerp contains fields for an active DERP connection.
type activeDerp struct {
@ -1539,7 +1538,6 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan
// findEndpoint maps from a UDP address to a WireGuard endpoint, for
// ReceiveIPv4/ReceiveIPv6.
// The provided addr and ipp must match.
//
// TODO(bradfitz): add a fast path that returns nil here for normal
// wireguard-go transport packets; wireguard-go only uses this
@ -1547,7 +1545,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan
// Endpoint to find the UDPAddr to return to wireguard anyway, so no
// benefit unless we can, say, always return the same fake UDPAddr for
// all packets.
func (c *Conn) findEndpoint(ipp netaddr.IPPort, addr *net.UDPAddr, packet []byte) conn.Endpoint {
func (c *Conn) findEndpoint(ipp netaddr.IPPort, packet []byte) conn.Endpoint {
c.mu.Lock()
defer c.mu.Unlock()
@ -1559,10 +1557,7 @@ func (c *Conn) findEndpoint(ipp netaddr.IPPort, addr *net.UDPAddr, packet []byte
}
}
if addr == nil {
addr = ipp.UDPAddr()
}
return c.findLegacyEndpointLocked(ipp, addr, packet)
return c.findLegacyEndpointLocked(ipp, packet)
}
// aLongTimeAgo is a non-zero time, far in the past, used for
@ -1590,7 +1585,12 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, error) {
if err != nil {
return 0, nil, err
}
if ep, ok := c.receiveIP(b[:n], pAddr.(*net.UDPAddr), &c.ippEndpoint6); ok {
udpAddr := pAddr.(*net.UDPAddr)
ipp, ok := netaddr.FromStdAddr(udpAddr.IP, udpAddr.Port, udpAddr.Zone)
if !ok {
continue
}
if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint6); ok {
return n, ep, nil
}
}
@ -1604,13 +1604,16 @@ func (c *Conn) derpPacketArrived() bool {
// In Tailscale's case, that packet might also arrive via DERP. A DERP packet arrival
// aborts the pconn4 read deadline to make it fail.
func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
var pAddr net.Addr
var addr net.Addr
var pAddr *net.UDPAddr
var ipp netaddr.IPPort
var ippOK bool
for {
// Drain DERP queues before reading new UDP packets.
if c.derpPacketArrived() {
goto ReadDERP
}
n, pAddr, err = c.pconn4.ReadFrom(b)
n, addr, err = c.pconn4.ReadFrom(b)
if err != nil {
// If the pconn4 read failed, the likely reason is a DERP reader received
// a packet and interrupted us.
@ -1622,7 +1625,12 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
}
return 0, nil, err
}
if ep, ok := c.receiveIP(b[:n], pAddr.(*net.UDPAddr), &c.ippEndpoint4); ok {
pAddr, _ = addr.(*net.UDPAddr)
ipp, ippOK = netaddr.FromStdAddr(pAddr.IP, pAddr.Port, pAddr.Zone)
if !ippOK {
continue
}
if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint4); ok {
return n, ep, nil
} else {
continue
@ -1640,11 +1648,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
//
// ok is whether this read should be reported up to wireguard-go (our
// caller).
func (c *Conn) receiveIP(b []byte, ua *net.UDPAddr, cache *ippEndpointCache) (ep conn.Endpoint, ok bool) {
ipp, ok := netaddr.FromStdAddr(ua.IP, ua.Port, ua.Zone)
if !ok {
return nil, false
}
func (c *Conn) receiveIP(b []byte, ipp netaddr.IPPort, cache *ippEndpointCache) (ep conn.Endpoint, ok bool) {
if stun.Is(b) {
c.stunReceiveFunc.Load().(func([]byte, netaddr.IPPort))(b, ipp)
return nil, false
@ -1662,7 +1666,7 @@ func (c *Conn) receiveIP(b []byte, ua *net.UDPAddr, cache *ippEndpointCache) (ep
if cache.ipp == ipp && cache.de != nil && cache.gen == cache.de.numStopAndReset() {
ep = cache.de
} else {
ep = c.findEndpoint(ipp, ua, b)
ep = c.findEndpoint(ipp, b)
if ep == nil {
return nil, false
}
@ -1759,7 +1763,7 @@ func (c *Conn) receiveIPv4DERP(b []byte) (n int, ep conn.Endpoint, err error) {
} else {
key := wgkey.Key(dm.src)
c.logf("magicsock: DERP packet from unknown key: %s", key.ShortString())
ep = c.findEndpoint(ipp, nil, b[:n])
ep = c.findEndpoint(ipp, b[:n])
if ep == nil {
return 0, nil, errLoopAgain
}
@ -2833,8 +2837,8 @@ func peerShort(k key.Public) string {
return k2.ShortString()
}
func sbPrintAddr(sb *strings.Builder, a net.UDPAddr) {
is6 := a.IP.To4() == nil
func sbPrintAddr(sb *strings.Builder, a netaddr.IPPort) {
is6 := a.IP.Is6()
if is6 {
sb.WriteByte('[')
}
@ -2931,8 +2935,8 @@ func (c *Conn) UpdateStatus(sb *ipnstate.StatusBuilder) {
})
}
func udpAddrDebugString(ua net.UDPAddr) string {
if ua.IP.Equal(derpMagicIP) {
func ippDebugString(ua netaddr.IPPort) string {
if ua.IP == derpMagicIPAddr {
return fmt.Sprintf("derp-%d", ua.Port)
}
return ua.String()

@ -398,18 +398,6 @@ func pickPort(t testing.TB) uint16 {
return uint16(conn.LocalAddr().(*net.UDPAddr).Port)
}
func TestDerpIPConstant(t *testing.T) {
tstest.PanicOnLog()
tstest.ResourceCheck(t)
if DerpMagicIP != derpMagicIP.String() {
t.Errorf("str %q != IP %v", DerpMagicIP, derpMagicIP)
}
if len(derpMagicIP) != 4 {
t.Errorf("derpMagicIP is len %d; want 4", len(derpMagicIP))
}
}
func TestPickDERPFallback(t *testing.T) {
tstest.PanicOnLog()
tstest.ResourceCheck(t)
@ -452,7 +440,7 @@ func TestPickDERPFallback(t *testing.T) {
// But move if peers are elsewhere.
const otherNode = 789
c.addrsByKey = map[key.Public]*addrSet{
key.Public{1}: &addrSet{addrs: []net.UDPAddr{{IP: derpMagicIP, Port: otherNode}}},
key.Public{1}: &addrSet{ipPorts: []netaddr.IPPort{{IP: derpMagicIPAddr, Port: otherNode}}},
}
if got := c.pickDERPFallback(); got != otherNode {
t.Errorf("didn't join peers: got %v; want %v", got, someNode)
@ -1156,20 +1144,13 @@ func TestAddrSet(t *testing.T) {
tstest.ResourceCheck(t)
mustIPPortPtr := func(s string) *netaddr.IPPort {
t.Helper()
ipp, err := netaddr.ParseIPPort(s)
if err != nil {
t.Fatal(err)
}
ipp := netaddr.MustParseIPPort(s)
return &ipp
}
mustUDPAddr := func(s string) *net.UDPAddr {
return mustIPPortPtr(s).UDPAddr()
}
udpAddrs := func(ss ...string) (ret []net.UDPAddr) {
ipps := func(ss ...string) (ret []netaddr.IPPort) {
t.Helper()
for _, s := range ss {
ret = append(ret, *mustUDPAddr(s))
ret = append(ret, netaddr.MustParseIPPort(s))
}
return ret
}
@ -1201,7 +1182,7 @@ func TestAddrSet(t *testing.T) {
// updateDst, if set, does an UpdateDst call and
// b+want are ignored.
updateDst *net.UDPAddr
updateDst *netaddr.IPPort
b []byte
want string // comma-separated
@ -1215,7 +1196,7 @@ func TestAddrSet(t *testing.T) {
{
name: "reg_packet_no_curaddr",
as: &addrSet{
addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
ipPorts: ipps("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
curAddr: -1, // unknown
roamAddr: nil,
},
@ -1226,7 +1207,7 @@ func TestAddrSet(t *testing.T) {
{
name: "reg_packet_have_curaddr",
as: &addrSet{
addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
ipPorts: ipps("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
curAddr: 1, // global IP
roamAddr: nil,
},
@ -1237,36 +1218,36 @@ func TestAddrSet(t *testing.T) {
{
name: "reg_packet_have_roamaddr",
as: &addrSet{
addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
ipPorts: ipps("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
curAddr: 2, // should be ignored
roamAddr: mustIPPortPtr("5.6.7.8:123"),
},
steps: []step{
{b: regPacket, want: "5.6.7.8:123"},
{updateDst: mustUDPAddr("10.0.0.1:123")}, // no more roaming
{updateDst: mustIPPortPtr("10.0.0.1:123")}, // no more roaming
{b: regPacket, want: "10.0.0.1:123"},
},
},
{
name: "start_roaming",
as: &addrSet{
addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
ipPorts: ipps("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
curAddr: 2,
},
steps: []step{
{b: regPacket, want: "10.0.0.1:123"},
{updateDst: mustUDPAddr("4.5.6.7:123")},
{updateDst: mustIPPortPtr("4.5.6.7:123")},
{b: regPacket, want: "4.5.6.7:123"},
{updateDst: mustUDPAddr("5.6.7.8:123")},
{updateDst: mustIPPortPtr("5.6.7.8:123")},
{b: regPacket, want: "5.6.7.8:123"},
{updateDst: mustUDPAddr("123.45.67.89:123")}, // end roaming
{updateDst: mustIPPortPtr("123.45.67.89:123")}, // end roaming
{b: regPacket, want: "123.45.67.89:123"},
},
},
{
name: "spray_packet",
as: &addrSet{
addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
ipPorts: ipps("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
curAddr: 2, // should be ignored
roamAddr: mustIPPortPtr("5.6.7.8:123"),
},
@ -1275,19 +1256,19 @@ func TestAddrSet(t *testing.T) {
{advance: 300 * time.Millisecond, b: regPacket, want: "127.3.3.40:1,123.45.67.89:123,10.0.0.1:123,5.6.7.8:123"},
{advance: 300 * time.Millisecond, b: regPacket, want: "127.3.3.40:1,123.45.67.89:123,10.0.0.1:123,5.6.7.8:123"},
{advance: 3, b: regPacket, want: "5.6.7.8:123"},
{advance: 2 * time.Millisecond, updateDst: mustUDPAddr("10.0.0.1:123")},
{advance: 2 * time.Millisecond, updateDst: mustIPPortPtr("10.0.0.1:123")},
{advance: 3, b: regPacket, want: "10.0.0.1:123"},
},
},
{
name: "low_pri",
as: &addrSet{
addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
ipPorts: ipps("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
curAddr: 2,
},
steps: []step{
{updateDst: mustUDPAddr("123.45.67.89:123")},
{updateDst: mustUDPAddr("123.45.67.89:123")},
{updateDst: mustIPPortPtr("123.45.67.89:123")},
{updateDst: mustIPPortPtr("123.45.67.89:123")},
},
logCheck: func(t *testing.T, logged []byte) {
if n := bytes.Count(logged, []byte(", keeping current ")); n != 1 {
@ -1306,12 +1287,11 @@ func TestAddrSet(t *testing.T) {
t.Logf(format, args...)
}
tt.as.clock = func() time.Time { return faket }
initAddrSet(tt.as)
for i, st := range tt.steps {
faket = faket.Add(st.advance)
if st.updateDst != nil {
if err := tt.as.updateDst(st.updateDst); err != nil {
if err := tt.as.updateDst(*st.updateDst); err != nil {
t.Fatal(err)
}
continue
@ -1328,23 +1308,6 @@ func TestAddrSet(t *testing.T) {
}
}
// initAddrSet initializes fields in the provided incomplete addrSet
// to satisfying invariants within magicsock.
func initAddrSet(as *addrSet) {
if as.roamAddr != nil && as.roamAddrStd == nil {
as.roamAddrStd = as.roamAddr.UDPAddr()
}
if len(as.ipPorts) == 0 {
for _, ua := range as.addrs {
ipp, ok := netaddr.FromStdAddr(ua.IP, ua.Port, ua.Zone)
if !ok {
panic(fmt.Sprintf("bogus UDPAddr %+v", ua))
}
as.ipPorts = append(as.ipPorts, ipp)
}
}
}
func TestDiscoMessage(t *testing.T) {
c := newConn()
c.logf = t.Logf

Loading…
Cancel
Save