wgengine/magicsock: start handling disco message, use netaddr.IPPort more

Updates #483
pull/514/head^2
Brad Fitzpatrick 4 years ago
parent 790ef2bc5f
commit e96f22e560

@ -35,6 +35,7 @@ import (
"tailscale.com/control/controlclient" "tailscale.com/control/controlclient"
"tailscale.com/derp" "tailscale.com/derp"
"tailscale.com/derp/derphttp" "tailscale.com/derp/derphttp"
"tailscale.com/disco"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/net/dnscache" "tailscale.com/net/dnscache"
"tailscale.com/net/interfaces" "tailscale.com/net/interfaces"
@ -553,7 +554,7 @@ func (c *Conn) goDerpConnect(node int) {
if node == 0 { if node == 0 {
return return
} }
go c.derpWriteChanOfAddr(&net.UDPAddr{IP: derpMagicIP, Port: node}, key.Public{}) go c.derpWriteChanOfAddr(netaddr.IPPort{IP: derpMagicIPAddr, Port: uint16(node)}, key.Public{})
} }
// determineEndpoints returns the machine's endpoint addresses. It // determineEndpoints returns the machine's endpoint addresses. It
@ -659,17 +660,25 @@ var logPacketDests, _ = strconv.ParseBool(os.Getenv("DEBUG_LOG_PACKET_DESTS"))
const sprayPeriod = 3 * time.Second const sprayPeriod = 3 * time.Second
// appendDests appends to dsts the destinations that b should be // appendDests appends to dsts the destinations that b should be
// written to in order to reach as. Some of the returned UDPAddrs may // written to in order to reach as. Some of the returned IPPorts may
// be fake addrs representing DERP servers. // be fake addrs representing DERP servers.
// //
// It also returns as's current roamAddr, if any. // It also returns as's current roamAddr, if any.
func (as *AddrSet) appendDests(dsts []*net.UDPAddr, b []byte) (_ []*net.UDPAddr, roamAddr *net.UDPAddr) { func (as *AddrSet) appendDests(dsts []netaddr.IPPort, b []byte) (_ []netaddr.IPPort, roamAddr netaddr.IPPort) {
spray := shouldSprayPacket(b) // true for handshakes spray := shouldSprayPacket(b) // true for handshakes
now := as.timeNow() now := as.timeNow()
as.mu.Lock() as.mu.Lock()
defer as.mu.Unlock() defer as.mu.Unlock()
// Some internal invariant checks.
if len(as.addrs) != len(as.ipPorts) {
panic(fmt.Sprintf("lena %d != leni %d", len(as.addrs), len(as.ipPorts)))
}
if n1, n2 := as.roamAddr != nil, as.roamAddrStd != nil; n1 != n2 {
panic(fmt.Sprintf("roamnil %v != roamstdnil %v", n1, n2))
}
// Spray logic. // Spray logic.
// //
// After exchanging a handshake with a peer, we send some outbound // After exchanging a handshake with a peer, we send some outbound
@ -702,17 +711,17 @@ func (as *AddrSet) appendDests(dsts []*net.UDPAddr, b []byte) (_ []*net.UDPAddr,
switch { switch {
case spray: case spray:
// This packet is being sprayed to all addresses. // This packet is being sprayed to all addresses.
for i := range as.addrs { for i := range as.ipPorts {
dsts = append(dsts, &as.addrs[i]) dsts = append(dsts, as.ipPorts[i])
} }
if as.roamAddr != nil { if as.roamAddr != nil {
dsts = append(dsts, as.roamAddr) dsts = append(dsts, *as.roamAddr)
} }
case as.roamAddr != nil: case as.roamAddr != nil:
// We have a roaming address, prefer it over other addrs. // We have a roaming address, prefer it over other addrs.
// TODO(danderson): this is not correct, there's no reason // TODO(danderson): this is not correct, there's no reason
// roamAddr should be special like this. // roamAddr should be special like this.
dsts = append(dsts, as.roamAddr) dsts = append(dsts, *as.roamAddr)
case as.curAddr != -1: case as.curAddr != -1:
if as.curAddr >= len(as.addrs) { if as.curAddr >= len(as.addrs) {
as.Logf("[unexpected] magicsock bug: as.curAddr >= len(as.addrs): %d >= %d", as.curAddr, len(as.addrs)) as.Logf("[unexpected] magicsock bug: as.curAddr >= len(as.addrs): %d >= %d", as.curAddr, len(as.addrs))
@ -720,20 +729,23 @@ func (as *AddrSet) appendDests(dsts []*net.UDPAddr, b []byte) (_ []*net.UDPAddr,
} }
// No roaming addr, but we've seen packets from a known peer // No roaming addr, but we've seen packets from a known peer
// addr, so keep using that one. // addr, so keep using that one.
dsts = append(dsts, &as.addrs[as.curAddr]) dsts = append(dsts, as.ipPorts[as.curAddr])
default: default:
// We know nothing about how to reach this peer, and we're not // We know nothing about how to reach this peer, and we're not
// spraying. Use the first address in the array, which will // spraying. Use the first address in the array, which will
// usually be a DERP address that guarantees connectivity. // usually be a DERP address that guarantees connectivity.
if len(as.addrs) > 0 { if len(as.ipPorts) > 0 {
dsts = append(dsts, &as.addrs[0]) dsts = append(dsts, as.ipPorts[0])
} }
} }
if logPacketDests { if logPacketDests {
as.Logf("spray=%v; roam=%v; dests=%v", spray, as.roamAddr, dsts) as.Logf("spray=%v; roam=%v; dests=%v", spray, as.roamAddr, dsts)
} }
return dsts, as.roamAddr if as.roamAddr != nil {
roamAddr = *as.roamAddr
}
return dsts, roamAddr
} }
var errNoDestinations = errors.New("magicsock: no destinations") var errNoDestinations = errors.New("magicsock: no destinations")
@ -751,12 +763,12 @@ func (c *Conn) Send(b []byte, ep conn.Endpoint) error {
c.logf("magicsock: [unexpected] DERP BUG: attempting to send packet to DERP address %v", addr) c.logf("magicsock: [unexpected] DERP BUG: attempting to send packet to DERP address %v", addr)
return nil return nil
} }
return c.sendUDP(addr, b) return c.sendUDPStd(addr, b)
case *AddrSet: case *AddrSet:
as = v as = v
} }
var addrBuf [8]*net.UDPAddr var addrBuf [8]netaddr.IPPort
dsts, roamAddr := as.appendDests(addrBuf[:0], b) dsts, roamAddr := as.appendDests(addrBuf[:0], b)
if len(dsts) == 0 { if len(dsts) == 0 {
@ -788,30 +800,39 @@ var errConnClosed = errors.New("Conn closed")
var errDropDerpPacket = errors.New("too many DERP packets queued; dropping") var errDropDerpPacket = errors.New("too many DERP packets queued; dropping")
// sendUDP sends UDP packet b to addr. // sendUDP sends UDP packet b to ipp.
func (c *Conn) sendUDP(addr *net.UDPAddr, b []byte) error { func (c *Conn) sendUDP(ipp netaddr.IPPort, b []byte) error {
if addr.IP.To4() != nil { addr := ipp.UDPAddr() // TOOD(bradfitz): add alloc-free netaddr.WriteTo helper
_, err := c.pconn4.WriteTo(b, addr) return c.sendUDPStd(addr, b)
}
func (c *Conn) sendUDPStd(addr *net.UDPAddr, b []byte) (err error) {
switch {
case addr.IP.To4() != nil:
_, err = c.pconn4.WriteTo(b, addr)
if err != nil && c.noV4.Get() { if err != nil && c.noV4.Get() {
return nil return nil
} }
return err case len(addr.IP) == net.IPv6len:
if c.pconn6 == nil {
// ignore IPv6 dest if we don't have an IPv6 address.
return nil
} }
if c.pconn6 != nil { _, err = c.pconn6.WriteTo(b, addr)
_, err := c.pconn6.WriteTo(b, addr)
if err != nil && c.noV6.Get() { if err != nil && c.noV6.Get() {
return nil return nil
} }
return err default:
return errors.New("bogus sendUDPStd addr type")
} }
return nil // ignore IPv6 dest if we don't have an IPv6 address. return err
} }
// sendAddr sends packet b to addr, which is either a real UDP address // sendAddr sends packet b to addr, which is either a real UDP address
// or a fake UDP address representing a DERP server (see derpmap.go). // or a fake UDP address representing a DERP server (see derpmap.go).
// The provided public key identifies the recipient. // The provided public key identifies the recipient.
func (c *Conn) sendAddr(addr *net.UDPAddr, pubKey key.Public, b []byte) error { func (c *Conn) sendAddr(addr netaddr.IPPort, pubKey key.Public, b []byte) error {
if !addr.IP.Equal(derpMagicIP) { if addr.IP != derpMagicIPAddr {
return c.sendUDP(addr, b) return c.sendUDP(addr, b)
} }
@ -857,11 +878,11 @@ var debugUseDerpRoute, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_ENABLE_DERP_ROU
// //
// If peer is non-zero, it can be used to find an active reverse // If peer is non-zero, it can be used to find an active reverse
// path, without using addr. // path, without using addr.
func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- derpWriteRequest { func (c *Conn) derpWriteChanOfAddr(addr netaddr.IPPort, peer key.Public) chan<- derpWriteRequest {
if !addr.IP.Equal(derpMagicIP) { if addr.IP != derpMagicIPAddr {
return nil return nil
} }
regionID := addr.Port regionID := int(addr.Port)
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -964,7 +985,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de
} }
go c.runDerpReader(ctx, addr, dc, wg, startGate) go c.runDerpReader(ctx, addr, dc, wg, startGate)
go c.runDerpWriter(ctx, addr, dc, ch, wg, startGate) go c.runDerpWriter(ctx, dc, ch, wg, startGate)
return ad.writeCh return ad.writeCh
} }
@ -1012,7 +1033,7 @@ func (c *Conn) setPeerLastDerpLocked(peer key.Public, regionID, homeID int) {
// get at the packet contents they need to call copyBuf to copy it // get at the packet contents they need to call copyBuf to copy it
// out, which also releases the buffer. // out, which also releases the buffer.
type derpReadResult struct { type derpReadResult struct {
derpAddr *net.UDPAddr regionID int
n int // length of data received n int // length of data received
src key.Public // may be zero until server deployment if v2+ src key.Public // may be zero until server deployment if v2+
// copyBuf is called to copy the data to dst. It returns how // copyBuf is called to copy the data to dst. It returns how
@ -1025,7 +1046,7 @@ var logDerpVerbose, _ = strconv.ParseBool(os.Getenv("DEBUG_DERP_VERBOSE"))
// runDerpReader runs in a goroutine for the life of a DERP // runDerpReader runs in a goroutine for the life of a DERP
// connection, handling received packets. // connection, handling received packets.
func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc *derphttp.Client, wg *syncs.WaitGroupChan, startGate <-chan struct{}) { func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr netaddr.IPPort, dc *derphttp.Client, wg *syncs.WaitGroupChan, startGate <-chan struct{}) {
defer wg.Decr() defer wg.Decr()
defer dc.Close() defer dc.Close()
@ -1036,8 +1057,8 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
} }
didCopy := make(chan struct{}, 1) didCopy := make(chan struct{}, 1)
regionID := int(derpFakeAddr.Port)
res := derpReadResult{derpAddr: derpFakeAddr} res := derpReadResult{regionID: regionID}
var pkt derp.ReceivedPacket var pkt derp.ReceivedPacket
res.copyBuf = func(dst []byte) int { res.copyBuf = func(dst []byte) int {
n := copy(dst, pkt.Data) n := copy(dst, pkt.Data)
@ -1058,7 +1079,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
// Forget that all these peers have routes. // Forget that all these peers have routes.
for peer := range peerPresent { for peer := range peerPresent {
delete(peerPresent, peer) delete(peerPresent, peer)
c.removeDerpPeerRoute(peer, derpFakeAddr.Port, dc) c.removeDerpPeerRoute(peer, regionID, dc)
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -1066,7 +1087,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
default: default:
} }
c.ReSTUN("derp-close") c.ReSTUN("derp-close")
c.logf("magicsock: [%p] derp.Recv(derp-%d): %v", dc, derpFakeAddr.Port, err) c.logf("magicsock: [%p] derp.Recv(derp-%d): %v", dc, regionID, err)
time.Sleep(250 * time.Millisecond) time.Sleep(250 * time.Millisecond)
continue continue
} }
@ -1076,13 +1097,13 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
res.n = len(m.Data) res.n = len(m.Data)
res.src = m.Source res.src = m.Source
if logDerpVerbose { if logDerpVerbose {
c.logf("magicsock: got derp-%v packet: %q", derpFakeAddr, m.Data) c.logf("magicsock: got derp-%v packet: %q", regionID, m.Data)
} }
// If this is a new sender we hadn't seen before, remember it and // If this is a new sender we hadn't seen before, remember it and
// register a route for this peer. // register a route for this peer.
if _, ok := peerPresent[m.Source]; !ok { if _, ok := peerPresent[m.Source]; !ok {
peerPresent[m.Source] = true peerPresent[m.Source] = true
c.addDerpPeerRoute(m.Source, derpFakeAddr.Port, dc) c.addDerpPeerRoute(m.Source, regionID, dc)
} }
default: default:
// Ignore. // Ignore.
@ -1099,14 +1120,14 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
} }
type derpWriteRequest struct { type derpWriteRequest struct {
addr *net.UDPAddr addr netaddr.IPPort
pubKey key.Public pubKey key.Public
b []byte // copied; ownership passed to receiver b []byte // copied; ownership passed to receiver
} }
// runDerpWriter runs in a goroutine for the life of a DERP // runDerpWriter runs in a goroutine for the life of a DERP
// connection, handling received packets. // connection, handling received packets.
func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc *derphttp.Client, ch <-chan derpWriteRequest, wg *syncs.WaitGroupChan, startGate <-chan struct{}) { func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan derpWriteRequest, wg *syncs.WaitGroupChan, startGate <-chan struct{}) {
defer wg.Decr() defer wg.Decr()
select { select {
case <-startGate: case <-startGate:
@ -1185,7 +1206,6 @@ func (c *Conn) awaitUDP4(b []byte) {
continue continue
} }
addr.IP = addr.IP.To4()
select { select {
case c.udpRecvCh <- udpReadResult{n: n, addr: addr}: case c.udpRecvCh <- udpReadResult{n: n, addr: addr}:
case <-c.donec(): case <-c.donec():
@ -1200,12 +1220,13 @@ func (c *Conn) awaitUDP4(b []byte) {
// per peer. // per peer.
func wgRecvAddr(e conn.Endpoint, addr *net.UDPAddr) *net.UDPAddr { func wgRecvAddr(e conn.Endpoint, addr *net.UDPAddr) *net.UDPAddr {
if de, ok := e.(*discoEndpoint); ok { if de, ok := e.(*discoEndpoint); ok {
return de.fakeWGAddr return de.fakeWGAddrStd
} }
return addr return addr
} }
func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr, err error) { func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr, err error) {
Top:
// First, process any buffered packet from earlier. // First, process any buffered packet from earlier.
if addr := c.bufferedIPv4From; addr != nil { if addr := c.bufferedIPv4From; addr != nil {
c.bufferedIPv4From = nil c.bufferedIPv4From = nil
@ -1245,7 +1266,8 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr
case <-c.donec(): case <-c.donec():
return 0, nil, nil, errors.New("Conn closed") return 0, nil, nil, errors.New("Conn closed")
} }
n, addr = dm.n, dm.derpAddr var regionID int
n, regionID = dm.n, dm.regionID
ncopy := dm.copyBuf(b) ncopy := dm.copyBuf(b)
if ncopy != n { if ncopy != n {
err = fmt.Errorf("received DERP packet of length %d that's too big for WireGuard ReceiveIPv4 buf size %d", n, ncopy) err = fmt.Errorf("received DERP packet of length %d that's too big for WireGuard ReceiveIPv4 buf size %d", n, ncopy)
@ -1253,6 +1275,11 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr
return 0, nil, nil, err return 0, nil, nil, err
} }
addr := netaddr.IPPort{IP: derpMagicIPAddr, Port: uint16(regionID)}
if c.handleDiscoMessage(b[:n], addr.UDPAddr()) {
goto Top
}
c.mu.Lock() c.mu.Lock()
if dk, ok := c.discoOfNode[tailcfg.NodeKey(dm.src)]; ok { if dk, ok := c.discoOfNode[tailcfg.NodeKey(dm.src)]; ok {
discoEp = c.endpointOfDisco[dk] discoEp = c.endpointOfDisco[dk]
@ -1311,6 +1338,10 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, *net.UDPAddr, error) {
c.stunReceiveFunc.Load().(func([]byte, *net.UDPAddr))(b[:n], addr) c.stunReceiveFunc.Load().(func([]byte, *net.UDPAddr))(b[:n], addr)
continue continue
} }
if c.handleDiscoMessage(b[:n], addr) {
continue
}
ep := c.findEndpoint(addr) ep := c.findEndpoint(addr)
return n, ep, wgRecvAddr(ep, addr), nil return n, ep, wgRecvAddr(ep, addr), nil
} }
@ -1324,8 +1355,11 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, *net.UDPAddr, error) {
// * magic [6]byte // * magic [6]byte
// * senderDiscoPubKey [32]byte // * senderDiscoPubKey [32]byte
// * nonce [24]byte // * nonce [24]byte
// * naclbox of payload // * naclbox of payload (see tailscale.com/disco package for inner payload format)
func (c *Conn) handleDiscoMessage(msg []byte, addr *net.UDPAddr) bool { //
// For messages received over DERP, the addr will be derpMagicIP (with
// port being the region)
func (c *Conn) handleDiscoMessage(msg []byte, src *net.UDPAddr) bool {
const magic = "TS💬" const magic = "TS💬"
const nonceLen = 24 const nonceLen = 24
const headerLen = len(magic) + len(tailcfg.DiscoKey{}) + nonceLen const headerLen = len(magic) + len(tailcfg.DiscoKey{}) + nonceLen
@ -1335,6 +1369,11 @@ func (c *Conn) handleDiscoMessage(msg []byte, addr *net.UDPAddr) bool {
var sender tailcfg.DiscoKey var sender tailcfg.DiscoKey
copy(sender[:], msg[len(magic):]) copy(sender[:], msg[len(magic):])
srca, ok := netaddr.FromStdAddr(src.IP, src.Port, src.Zone)
if !ok {
return false
}
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -1361,10 +1400,57 @@ func (c *Conn) handleDiscoMessage(msg []byte, addr *net.UDPAddr) bool {
return false return false
} }
c.logf("magicsock: got disco message from %s: %x (%q)", senderNode.Key.ShortString(), payload, payload) dm, err := disco.Parse(payload)
if err != nil {
// Couldn't parse it, but it was inside a correctly
// signed box, so just ignore it, assuming it's from a
// newer version of Tailscale that we don't
// understand. Not even worth logging about, lest it
// be too spammy for old clients.
return true
}
switch dm := dm.(type) {
case *disco.Ping:
c.handlePingLocked(dm, senderNode, sender, srca)
case *disco.Pong:
c.handlePongLocked(dm, senderNode, sender, srca)
case disco.CallMeMaybe:
if srca.IP != derpMagicIPAddr {
// CallMeMaybe messages should only come via DERP.
return false
}
c.handleCallMeMaybeLocked(senderNode, sender)
}
return true return true
} }
func (c *Conn) handlePongLocked(m *disco.Pong, n *tailcfg.Node, dk tailcfg.DiscoKey, from netaddr.IPPort) {
c.logf("magicsock: disco: got pong from %s, tx=%x, disco=%x, src=%v (they saw %v)", n.Key.ShortString(), m.TxID, dk[:8], from, m.Src)
// TODO: implement
}
func (c *Conn) handlePingLocked(m *disco.Ping, n *tailcfg.Node, dk tailcfg.DiscoKey, from netaddr.IPPort) {
c.logf("magicsock: disco: got ping from %s, tx=%x, disco=%x, src=%v", n.Key.ShortString(), m.TxID, dk[:8], from)
// TODO: implement
reply := &disco.Pong{
TxID: m.TxID,
Src: from,
}
go c.sendAddr(from, key.Public(n.Key), reply.AppendMarshal(nil))
}
// handleCallMeMaybeLocked is called when a discovery message arrives
// via DERP for us to send to a peer. The contract for use of this
// message is that the peer has already sent to us via UDP, so their
// stateful firewall should be open. Now we can Ping back and make it
// through.
func (c *Conn) handleCallMeMaybeLocked(n *tailcfg.Node, dk tailcfg.DiscoKey) {
c.logf("magicsock: disco: got call-me-maybe packet from %s (disco=%x)", n.Key.ShortString, dk[:8])
// TODO: implement
}
func (c *Conn) sharedDiscoKeyLocked(k tailcfg.DiscoKey) *[32]byte { func (c *Conn) sharedDiscoKeyLocked(k tailcfg.DiscoKey) *[32]byte {
if v, ok := c.sharedDiscoKey[k]; ok { if v, ok := c.sharedDiscoKey[k]; ok {
return v return v
@ -1865,7 +1951,8 @@ type AddrSet struct {
// this should hopefully never be used (or at least used // this should hopefully never be used (or at least used
// rarely) in the case that all the components of Tailscale // rarely) in the case that all the components of Tailscale
// are correctly learning/sharing the network map details. // are correctly learning/sharing the network map details.
roamAddr *net.UDPAddr roamAddr *netaddr.IPPort
roamAddrStd *net.UDPAddr
// curAddr is an index into addrs of the highest-priority // curAddr is an index into addrs of the highest-priority
// address a valid packet has been received from so far. // address a valid packet has been received from so far.
@ -1903,17 +1990,14 @@ func (as *AddrSet) timeNow() time.Time {
return time.Now() return time.Now()
} }
var noAddr = &net.UDPAddr{ var noAddr, _ = netaddr.FromStdAddr(net.ParseIP("127.127.127.127"), 127, "")
IP: net.ParseIP("127.127.127.127"),
Port: 127,
}
func (a *AddrSet) dst() *net.UDPAddr { func (a *AddrSet) dst() netaddr.IPPort {
a.mu.Lock() a.mu.Lock()
defer a.mu.Unlock() defer a.mu.Unlock()
if a.roamAddr != nil { if a.roamAddr != nil {
return a.roamAddr return *a.roamAddr
} }
if len(a.addrs) == 0 { if len(a.addrs) == 0 {
return noAddr return noAddr
@ -1922,7 +2006,7 @@ func (a *AddrSet) dst() *net.UDPAddr {
if i == -1 { if i == -1 {
i = 0 i = 0
} }
return &a.addrs[i] return a.ipPorts[i]
} }
// packUDPAddr packs a UDPAddr in the form wanted by WireGuard. // packUDPAddr packs a UDPAddr in the form wanted by WireGuard.
@ -1938,15 +2022,30 @@ func packUDPAddr(ua *net.UDPAddr) []byte {
return b return b
} }
// packIPPort packs an IPPort into the form wanted by WireGuard.
func packIPPort(ua netaddr.IPPort) []byte {
ip := ua.IP.Unmap()
a := ip.As16()
ipb := a[:]
if ip.Is4() {
ipb = ipb[12:]
}
b := make([]byte, 0, len(ipb)+2)
b = append(b, ipb...)
b = append(b, byte(ua.Port))
b = append(b, byte(ua.Port>>8))
return b
}
func (a *AddrSet) DstToBytes() []byte { func (a *AddrSet) DstToBytes() []byte {
return packUDPAddr(a.dst()) return packIPPort(a.dst())
} }
func (a *AddrSet) DstToString() string { func (a *AddrSet) DstToString() string {
dst := a.dst() dst := a.dst()
return dst.String() return dst.String()
} }
func (a *AddrSet) DstIP() net.IP { func (a *AddrSet) DstIP() net.IP {
return a.dst().IP return a.dst().IP.IPAddr().IP // TODO: add netaddr accessor to cut an alloc here?
} }
func (a *AddrSet) SrcIP() net.IP { return nil } func (a *AddrSet) SrcIP() net.IP { return nil }
func (a *AddrSet) SrcToString() string { return "" } func (a *AddrSet) SrcToString() string { return "" }
@ -1964,7 +2063,7 @@ func (a *AddrSet) UpdateDst(new *net.UDPAddr) error {
a.mu.Lock() a.mu.Lock()
defer a.mu.Unlock() defer a.mu.Unlock()
if a.roamAddr != nil && equalUDPAddr(new, a.roamAddr) { if a.roamAddrStd != nil && equalUDPAddr(new, a.roamAddrStd) {
// Packet from the current roaming address, no logging. // Packet from the current roaming address, no logging.
// This is a hot path for established connections. // This is a hot path for established connections.
return nil return nil
@ -1975,6 +2074,11 @@ func (a *AddrSet) UpdateDst(new *net.UDPAddr) error {
return nil return nil
} }
newa, ok := netaddr.FromStdAddr(new.IP, new.Port, new.Zone)
if !ok {
return nil
}
index := -1 index := -1
for i := range a.addrs { for i := range a.addrs {
if equalUDPAddr(new, &a.addrs[i]) { if equalUDPAddr(new, &a.addrs[i]) {
@ -1997,11 +2101,13 @@ func (a *AddrSet) UpdateDst(new *net.UDPAddr) error {
} else { } else {
a.Logf("magicsock: rx %s from roaming address %s, replaces roaming address %s", pk, new, a.roamAddr) a.Logf("magicsock: rx %s from roaming address %s, replaces roaming address %s", pk, new, a.roamAddr)
} }
a.roamAddr = new a.roamAddr = &newa
a.roamAddrStd = new
case a.roamAddr != nil: case a.roamAddr != nil:
a.Logf("magicsock: rx %s from known %s (%d), replaces roaming address %s", pk, new, index, a.roamAddr) a.Logf("magicsock: rx %s from known %s (%d), replaces roaming address %s", pk, new, index, a.roamAddr)
a.roamAddr = nil a.roamAddr = nil
a.roamAddrStd = nil
a.curAddr = index a.curAddr = index
a.loggedLogPriMask = 0 a.loggedLogPriMask = 0
@ -2037,7 +2143,7 @@ func (a *AddrSet) String() string {
buf.WriteByte('[') buf.WriteByte('[')
if a.roamAddr != nil { if a.roamAddr != nil {
buf.WriteString("roam:") buf.WriteString("roam:")
sbPrintAddr(buf, *a.roamAddr) sbPrintAddr(buf, *a.roamAddrStd)
} }
for i, addr := range a.addrs { for i, addr := range a.addrs {
if i > 0 || a.roamAddr != nil { if i > 0 || a.roamAddr != nil {
@ -2325,7 +2431,7 @@ func (c *Conn) UpdateStatus(sb *ipnstate.StatusBuilder) {
} }
} }
if as.roamAddr != nil { if as.roamAddr != nil {
ps.CurAddr = udpAddrDebugString(*as.roamAddr) ps.CurAddr = udpAddrDebugString(*as.roamAddrStd)
} }
sb.AddPeer(k, ps) sb.AddPeer(k, ps)
} }
@ -2349,11 +2455,12 @@ type discoEndpoint struct {
c *Conn c *Conn
publicKey key.Public // peer public key (for WireGuard + DERP) publicKey key.Public // peer public key (for WireGuard + DERP)
discoKey tailcfg.DiscoKey // for discovery mesages discoKey tailcfg.DiscoKey // for discovery mesages
fakeWGAddr *net.UDPAddr // the UDPAddr we tell wireguard-go we're using fakeWGAddr netaddr.IPPort // the UDP address we tell wireguard-go we're using
fakeWGAddrStd *net.UDPAddr // the *net.UDPAddr form of fakeWGAddr
wgEndpointHostPort string // string from CreateEndpoint: "<hex-discovery-key>.disco.tailscale:12345" wgEndpointHostPort string // string from CreateEndpoint: "<hex-discovery-key>.disco.tailscale:12345"
mu sync.Mutex // Lock ordering: Conn.mu, then discoEndpoint.mu mu sync.Mutex // Lock ordering: Conn.mu, then discoEndpoint.mu
derpAddr *net.UDPAddr derpAddr netaddr.IPPort
} }
// initFakeUDPAddr populates fakeWGAddr with a globally unique fake UDPAddr. // initFakeUDPAddr populates fakeWGAddr with a globally unique fake UDPAddr.
@ -2364,11 +2471,11 @@ func (de *discoEndpoint) initFakeUDPAddr() {
addr[0] = 0xfd addr[0] = 0xfd
addr[1] = 0x00 addr[1] = 0x00
binary.BigEndian.PutUint64(addr[2:], uint64(reflect.ValueOf(de).Pointer())) binary.BigEndian.PutUint64(addr[2:], uint64(reflect.ValueOf(de).Pointer()))
ipp := netaddr.IPPort{ de.fakeWGAddr = netaddr.IPPort{
IP: netaddr.IPFrom16(addr), IP: netaddr.IPFrom16(addr),
Port: 12345, Port: 12345,
} }
de.fakeWGAddr = ipp.UDPAddr() de.fakeWGAddrStd = de.fakeWGAddr.UDPAddr()
} }
func (de *discoEndpoint) Addrs() []wgcfg.Endpoint { func (de *discoEndpoint) Addrs() []wgcfg.Endpoint {
@ -2391,7 +2498,7 @@ func (de *discoEndpoint) SrcToString() string { panic("unused") } // unused by w
func (de *discoEndpoint) SrcIP() net.IP { panic("unused") } // unused by wireguard-go func (de *discoEndpoint) SrcIP() net.IP { panic("unused") } // unused by wireguard-go
func (de *discoEndpoint) DstToString() string { return de.wgEndpointHostPort } func (de *discoEndpoint) DstToString() string { return de.wgEndpointHostPort }
func (de *discoEndpoint) DstIP() net.IP { panic("unused") } func (de *discoEndpoint) DstIP() net.IP { panic("unused") }
func (de *discoEndpoint) DstToBytes() []byte { return de.fakeWGAddr.IP[:] } func (de *discoEndpoint) DstToBytes() []byte { return packIPPort(de.fakeWGAddr) }
func (de *discoEndpoint) UpdateDst(addr *net.UDPAddr) error { func (de *discoEndpoint) UpdateDst(addr *net.UDPAddr) error {
// This is called ~per packet (and requiring a mutex acquisition inside wireguard-go). // This is called ~per packet (and requiring a mutex acquisition inside wireguard-go).
// TODO(bradfitz): make that cheaper and/or remove it. We don't need it. // TODO(bradfitz): make that cheaper and/or remove it. We don't need it.
@ -2406,7 +2513,7 @@ func (de *discoEndpoint) send(b []byte) error {
derpAddr := de.derpAddr derpAddr := de.derpAddr
de.mu.Unlock() de.mu.Unlock()
if derpAddr == nil { if derpAddr.Port == 0 {
return errors.New("no DERP addr") return errors.New("no DERP addr")
} }
return de.c.sendAddr(derpAddr, de.publicKey, b) return de.c.sendAddr(derpAddr, de.publicKey, b)
@ -2421,12 +2528,9 @@ func (de *discoEndpoint) updateFromNode(n *tailcfg.Node) {
defer de.mu.Unlock() defer de.mu.Unlock()
if n.DERP == "" { if n.DERP == "" {
de.derpAddr = nil de.derpAddr = netaddr.IPPort{}
} else { } else {
// TODO: add ParseIPPort to netaddr package; only safe to use ResolveUDPAddr de.derpAddr, _ = netaddr.ParseIPPort(n.DERP)
// here because we know no DNS lookups are involved
ua, _ := net.ResolveUDPAddr("udp", n.DERP)
de.derpAddr = ua
} }
// TODO: parse all the endpoints, not just DERP // TODO: parse all the endpoints, not just DERP

@ -24,6 +24,7 @@ import (
"github.com/tailscale/wireguard-go/tun/tuntest" "github.com/tailscale/wireguard-go/tun/tuntest"
"github.com/tailscale/wireguard-go/wgcfg" "github.com/tailscale/wireguard-go/wgcfg"
"golang.org/x/crypto/nacl/box" "golang.org/x/crypto/nacl/box"
"inet.af/netaddr"
"tailscale.com/derp" "tailscale.com/derp"
"tailscale.com/derp/derphttp" "tailscale.com/derp/derphttp"
"tailscale.com/derp/derpmap" "tailscale.com/derp/derpmap"
@ -662,17 +663,16 @@ func TestAddrSet(t *testing.T) {
rc := tstest.NewResourceCheck() rc := tstest.NewResourceCheck()
defer rc.Assert(t) defer rc.Assert(t)
// This gets reassigned inside every test, so that the connections mustIPPortPtr := func(s string) *netaddr.IPPort {
// all log using the "current" t.Logf function. Sigh.
logf, setT := makeNestable(t)
mustUDPAddr := func(s string) *net.UDPAddr {
t.Helper() t.Helper()
ua, err := net.ResolveUDPAddr("udp", s) ipp, err := netaddr.ParseIPPort(s)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return ua return &ipp
}
mustUDPAddr := func(s string) *net.UDPAddr {
return mustIPPortPtr(s).UDPAddr()
} }
udpAddrs := func(ss ...string) (ret []net.UDPAddr) { udpAddrs := func(ss ...string) (ret []net.UDPAddr) {
t.Helper() t.Helper()
@ -681,7 +681,7 @@ func TestAddrSet(t *testing.T) {
} }
return ret return ret
} }
joinUDPs := func(in []*net.UDPAddr) string { joinUDPs := func(in []netaddr.IPPort) string {
var sb strings.Builder var sb strings.Builder
for i, ua := range in { for i, ua := range in {
if i > 0 { if i > 0 {
@ -747,7 +747,7 @@ func TestAddrSet(t *testing.T) {
as: &AddrSet{ as: &AddrSet{
addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
curAddr: 2, // should be ignored curAddr: 2, // should be ignored
roamAddr: mustUDPAddr("5.6.7.8:123"), roamAddr: mustIPPortPtr("5.6.7.8:123"),
}, },
steps: []step{ steps: []step{
{b: regPacket, want: "5.6.7.8:123"}, {b: regPacket, want: "5.6.7.8:123"},
@ -776,7 +776,7 @@ func TestAddrSet(t *testing.T) {
as: &AddrSet{ as: &AddrSet{
addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"),
curAddr: 2, // should be ignored curAddr: 2, // should be ignored
roamAddr: mustUDPAddr("5.6.7.8:123"), roamAddr: mustIPPortPtr("5.6.7.8:123"),
}, },
steps: []step{ steps: []step{
{b: sprayPacket, want: "127.3.3.40:1,123.45.67.89:123,10.0.0.1:123,5.6.7.8:123"}, {b: sprayPacket, want: "127.3.3.40:1,123.45.67.89:123,10.0.0.1:123,5.6.7.8:123"},
@ -804,18 +804,17 @@ func TestAddrSet(t *testing.T) {
}, },
}, },
} }
outerT := t
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
setT(t)
defer setT(outerT)
faket := time.Unix(0, 0) faket := time.Unix(0, 0)
var logBuf bytes.Buffer var logBuf bytes.Buffer
tt.as.Logf = func(format string, args ...interface{}) { tt.as.Logf = func(format string, args ...interface{}) {
fmt.Fprintf(&logBuf, format, args...) fmt.Fprintf(&logBuf, format, args...)
logf(format, args...) t.Logf(format, args...)
} }
tt.as.clock = func() time.Time { return faket } tt.as.clock = func() time.Time { return faket }
initAddrSet(tt.as)
for i, st := range tt.steps { for i, st := range tt.steps {
faket = faket.Add(st.advance) faket = faket.Add(st.advance)
@ -837,6 +836,23 @@ func TestAddrSet(t *testing.T) {
} }
} }
// initAddrSet initializes fields in the provided incomplete AddrSet
// to satisfying invariants within magicsock.
func initAddrSet(as *AddrSet) {
if as.roamAddr != nil && as.roamAddrStd == nil {
as.roamAddrStd = as.roamAddr.UDPAddr()
}
if len(as.ipPorts) == 0 {
for _, ua := range as.addrs {
ipp, ok := netaddr.FromStdAddr(ua.IP, ua.Port, ua.Zone)
if !ok {
panic(fmt.Sprintf("bogus UDPAddr %+v", ua))
}
as.ipPorts = append(as.ipPorts, ipp)
}
}
}
func TestDiscoMessage(t *testing.T) { func TestDiscoMessage(t *testing.T) {
peer1Priv := key.NewPrivate() peer1Priv := key.NewPrivate()
peer1Pub := peer1Priv.Public() peer1Pub := peer1Priv.Public()
@ -857,7 +873,7 @@ func TestDiscoMessage(t *testing.T) {
pkt = append(pkt, nonce[:]...) pkt = append(pkt, nonce[:]...)
pkt = box.Seal(pkt, []byte(payload), &nonce, c.discoPrivate.Public().B32(), peer1Priv.B32()) pkt = box.Seal(pkt, []byte(payload), &nonce, c.discoPrivate.Public().B32(), peer1Priv.B32())
got := c.handleDiscoMessage(pkt, &net.UDPAddr{}) got := c.handleDiscoMessage(pkt, &net.UDPAddr{IP: net.ParseIP("1.2.3.4")})
if !got { if !got {
t.Error("failed to open it") t.Error("failed to open it")
} }

Loading…
Cancel
Save