diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 14d60cc15..72651bd41 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -28,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" "inet.af/netaddr" + "tailscale.com/control/controlclient" "tailscale.com/net/packet" "tailscale.com/types/logger" "tailscale.com/wgengine" @@ -62,7 +63,46 @@ func Impl(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock. log.Fatal(err) } - ipstack.AddAddress(nicID, ipv4.ProtocolNumber, tcpip.Address(net.ParseIP("100.96.188.101").To4())) + e.AddNetworkMapCallback(func(nm *controlclient.NetworkMap) { + oldIPs := make(map[tcpip.Address]bool) + for _, ip := range ipstack.AllAddresses()[nicID] { + oldIPs[ip.AddressWithPrefix.Address] = true + } + newIPs := make(map[tcpip.Address]bool) + for _, ip := range nm.Addresses { + newIPs[tcpip.Address(ip.IPNet().IP)] = true + } + + ipsToBeAdded := make(map[tcpip.Address]bool) + for ip := range newIPs { + if !oldIPs[ip] { + ipsToBeAdded[ip] = true + } + } + ipsToBeRemoved := make(map[tcpip.Address]bool) + for ip := range oldIPs { + if !newIPs[ip] { + ipsToBeRemoved[ip] = true + } + } + + for ip := range ipsToBeRemoved { + err := ipstack.RemoveAddress(nicID, ip) + if err != nil { + logf("netstack: could not deregister IP %s: %v", ip, err) + } else { + logf("netstack: deregistered IP %s", ip) + } + } + for ip := range ipsToBeAdded { + err := ipstack.AddAddress(nicID, ipv4.ProtocolNumber, ip) + if err != nil { + logf("netstack: could not register IP %s: %v", ip, err) + } else { + logf("netstack: registered IP %s", ip) + } + } + }) // Add 0.0.0.0/0 default route. subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4))) diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 2b77b4570..d15579080 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -111,15 +111,16 @@ type userspaceEngine struct { sentActivityAt map[netaddr.IP]*int64 // value is atomic int64 of unixtime destIPActivityFuncs map[netaddr.IP]func() - mu sync.Mutex // guards following; see lock order comment below - closing bool // Close was called (even if we're still closing) - statusCallback StatusCallback - linkChangeCallback func(major bool, newState *interfaces.State) - peerSequence []wgkey.Key - endpoints []string - pingers map[wgkey.Key]*pinger // legacy pingers for pre-discovery peers - linkState *interfaces.State - pendOpen map[flowtrack.Tuple]*pendingOpenFlow // see pendopen.go + mu sync.Mutex // guards following; see lock order comment below + closing bool // Close was called (even if we're still closing) + statusCallback StatusCallback + linkChangeCallback func(major bool, newState *interfaces.State) + peerSequence []wgkey.Key + endpoints []string + pingers map[wgkey.Key]*pinger // legacy pingers for pre-discovery peers + linkState *interfaces.State + pendOpen map[flowtrack.Tuple]*pendingOpenFlow // see pendopen.go + networkMapCallbacks map[*someHandle]NetworkMapCallback // Lock ordering: magicsock.Conn.mu, wgLock, then mu. } @@ -1290,6 +1291,21 @@ func (e *userspaceEngine) SetLinkChangeCallback(cb func(major bool, newState *in } } +func (e *userspaceEngine) AddNetworkMapCallback(cb NetworkMapCallback) func() { + e.mu.Lock() + defer e.mu.Unlock() + if e.networkMapCallbacks == nil { + e.networkMapCallbacks = make(map[*someHandle]NetworkMapCallback) + } + h := new(someHandle) + e.networkMapCallbacks[h] = cb + return func() { + e.mu.Lock() + defer e.mu.Unlock() + delete(e.networkMapCallbacks, h) + } +} + func getLinkState() (*interfaces.State, error) { s, err := interfaces.GetState() if s != nil { @@ -1308,6 +1324,15 @@ func (e *userspaceEngine) SetDERPMap(dm *tailcfg.DERPMap) { func (e *userspaceEngine) SetNetworkMap(nm *controlclient.NetworkMap) { e.magicConn.SetNetworkMap(nm) + e.mu.Lock() + callbacks := make([]NetworkMapCallback, 0, 4) + for _, fn := range e.networkMapCallbacks { + callbacks = append(callbacks, fn) + } + e.mu.Unlock() + for _, fn := range callbacks { + fn(nm) + } } func (e *userspaceEngine) DiscoPublicKey() tailcfg.DiscoKey { diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go index 6a8f2b698..ee0fc3045 100644 --- a/wgengine/watchdog.go +++ b/wgengine/watchdog.go @@ -110,6 +110,11 @@ func (e *watchdogEngine) SetDERPMap(m *tailcfg.DERPMap) { func (e *watchdogEngine) SetNetworkMap(nm *controlclient.NetworkMap) { e.watchdog("SetNetworkMap", func() { e.wrap.SetNetworkMap(nm) }) } +func (e *watchdogEngine) AddNetworkMapCallback(callback NetworkMapCallback) func() { + var fn func() + e.watchdog("AddNetworkMapCallback", func() { fn = e.wrap.AddNetworkMapCallback(callback) }) + return func() { e.watchdog("RemoveNetworkMapCallback", fn) } +} func (e *watchdogEngine) DiscoPublicKey() (k tailcfg.DiscoKey) { e.watchdog("DiscoPublicKey", func() { k = e.wrap.DiscoPublicKey() }) return k diff --git a/wgengine/wgengine.go b/wgengine/wgengine.go index 57c98c6ea..3139dc1f2 100644 --- a/wgengine/wgengine.go +++ b/wgengine/wgengine.go @@ -49,6 +49,15 @@ type StatusCallback func(*Status, error) // NetInfoCallback is the type used by Engine.SetNetInfoCallback. type NetInfoCallback func(*tailcfg.NetInfo) +// NetworkMapCallback is the type used by callbacks that hook +// into network map updates. +type NetworkMapCallback func(*controlclient.NetworkMap) + +// someHandle is allocated so its pointer address acts as a unique +// map key handle. (It needs to have non-zero size for Go to guarantee +// the pointer is unique.) +type someHandle struct{ _ byte } + // ErrNoChanges is returned by Engine.Reconfig if no changes were made. var ErrNoChanges = errors.New("no changes made to Engine config") @@ -114,6 +123,12 @@ type Engine interface { // The network map should only be read from. SetNetworkMap(*controlclient.NetworkMap) + // AddNetworkMapCallback adds a function to a list of callbacks + // that are called when the network map updates. It returns a + // function that when called would remove the function from the + // list of callbacks. + AddNetworkMapCallback(NetworkMapCallback) (removeCallback func()) + // SetNetInfoCallback sets the function to call when a // new NetInfo summary is available. SetNetInfoCallback(NetInfoCallback)