diff --git a/cmd/tailscale/cli/netcheck.go b/cmd/tailscale/cli/netcheck.go index 90a5815e8..106e085e0 100644 --- a/cmd/tailscale/cli/netcheck.go +++ b/cmd/tailscale/cli/netcheck.go @@ -50,7 +50,7 @@ var netcheckArgs struct { func runNetcheck(ctx context.Context, args []string) error { c := &netcheck.Client{ UDPBindAddr: os.Getenv("TS_DEBUG_NETCHECK_UDP_BIND"), - PortMapper: portmapper.NewClient(logger.WithPrefix(log.Printf, "portmap: ")), + PortMapper: portmapper.NewClient(logger.WithPrefix(log.Printf, "portmap: "), nil), } if netcheckArgs.verbose { c.Logf = logger.WithPrefix(log.Printf, "netcheck: ") diff --git a/net/portmapper/portmapper.go b/net/portmapper/portmapper.go index 3424b075f..0d7d80bac 100644 --- a/net/portmapper/portmapper.go +++ b/net/portmapper/portmapper.go @@ -44,9 +44,15 @@ const trustServiceStillAvailableDuration = 10 * time.Minute type Client struct { logf logger.Logf ipAndGateway func() (gw, ip netaddr.IP, ok bool) + onChange func() // or nil mu sync.Mutex // guards following, and all fields thereof + // runningCreate is whether we're currently working on creating + // a port mapping (whether GetCachedMappingOrStartCreatingOne kicked + // off a createMapping goroutine). + runningCreate bool + lastMyIP netaddr.IP lastGW netaddr.IP closed bool @@ -68,18 +74,19 @@ type Client struct { func (c *Client) HaveMapping() bool { c.mu.Lock() defer c.mu.Unlock() - return c.pmpMapping != nil && c.pmpMapping.useUntil.After(time.Now()) + return c.pmpMapping != nil && c.pmpMapping.goodUntil.After(time.Now()) } // pmpMapping is an already-created PMP mapping. // // All fields are immutable once created. type pmpMapping struct { - gw netaddr.IP - external netaddr.IPPort - internal netaddr.IPPort - useUntil time.Time // the mapping's lifetime minus renewal interval - epoch uint32 + gw netaddr.IP + external netaddr.IPPort + internal netaddr.IPPort + renewAfter time.Time // the time at which we want to renew the mapping + goodUntil time.Time // the mapping's total lifetime + epoch uint32 } // externalValid reports whether m.external is valid, with both its IP and Port populated. @@ -99,10 +106,15 @@ func (m *pmpMapping) release() { } // NewClient returns a new portmapping client. -func NewClient(logf logger.Logf) *Client { +// +// The optional onChange argument specifies a func to run in a new +// goroutine whenever the port mapping status has changed. If nil, +// it doesn't make a callback. +func NewClient(logf logger.Logf, onChange func()) *Client { return &Client{ logf: logf, ipAndGateway: interfaces.LikelyHomeRouterIP, + onChange: onChange, } } @@ -221,8 +233,7 @@ func closeCloserOnContextDone(ctx context.Context, c io.Closer) (stop func()) { return func() { close(stopWaitDone) } } -// NoMappingError is returned by CreateOrGetMapping when no NAT -// mapping could be returned. +// NoMappingError is returned when no NAT mapping could be done. type NoMappingError struct { err error } @@ -241,12 +252,62 @@ var ( ErrGatewayNotFound = errors.New("failed to look up gateway address") ) -// CreateOrGetMapping either creates a new mapping or returns a cached +// GetCachedMappingOrStartCreatingOne quickly returns with our current cached portmapping, if any. +// If there's not one, it starts up a background goroutine to create one. +// If the background goroutine ends up creating one, the onChange hook registered with the +// NewClient constructor (if any) will fire. +func (c *Client) GetCachedMappingOrStartCreatingOne() (external netaddr.IPPort, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + + // Do we have an existing mapping that's valid? + now := time.Now() + if m := c.pmpMapping; m != nil { + if now.Before(m.goodUntil) { + if now.After(m.renewAfter) { + c.maybeStartMappingLocked() + } + return m.external, true + } + } + + c.maybeStartMappingLocked() + return netaddr.IPPort{}, false +} + +// maybeStartMappingLocked starts a createMapping goroutine up, if one isn't already running. +// +// c.mu must be held. +func (c *Client) maybeStartMappingLocked() { + if !c.runningCreate { + c.runningCreate = true + go c.createMapping() + } +} + +func (c *Client) createMapping() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + c.runningCreate = false + }() + + if _, err := c.createOrGetMapping(ctx); err == nil && c.onChange != nil { + go c.onChange() + } else if err != nil && !IsNoMappingError(err) { + c.logf("createOrGetMapping: %v", err) + } +} + +// createOrGetMapping either creates a new mapping or returns a cached // valid one. // // If no mapping is available, the error will be of type // NoMappingError; see IsNoMappingError. -func (c *Client) CreateOrGetMapping(ctx context.Context) (external netaddr.IPPort, err error) { +func (c *Client) createOrGetMapping(ctx context.Context) (external netaddr.IPPort, err error) { gw, myIP, ok := c.gatewayAndSelfIP() if !ok { return netaddr.IPPort{}, NoMappingError{ErrGatewayNotFound} @@ -266,7 +327,7 @@ func (c *Client) CreateOrGetMapping(ctx context.Context) (external netaddr.IPPor // Do we have an existing mapping that's valid? now := time.Now() if m := c.pmpMapping; m != nil { - if now.Before(m.useUntil) { + if now.Before(m.renewAfter) { defer c.mu.Unlock() return m.external, nil } @@ -342,8 +403,9 @@ func (c *Client) CreateOrGetMapping(ctx context.Context) (external netaddr.IPPor if pres.OpCode == pmpOpReply|pmpOpMapUDP { m.external = m.external.WithPort(pres.ExternalPort) d := time.Duration(pres.MappingValidSeconds) * time.Second - d /= 2 // renew in half the time - m.useUntil = time.Now().Add(d) + now := time.Now() + m.goodUntil = now.Add(d) + m.renewAfter = now.Add(d / 2) // renew in half the time m.epoch = pres.SecondsSinceEpoch } } diff --git a/net/portmapper/portmapper_test.go b/net/portmapper/portmapper_test.go index 13673dec4..837a16e8f 100644 --- a/net/portmapper/portmapper_test.go +++ b/net/portmapper/portmapper_test.go @@ -16,13 +16,13 @@ func TestCreateOrGetMapping(t *testing.T) { if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { t.Skip("skipping test without HIT_NETWORK=1") } - c := NewClient(t.Logf) + c := NewClient(t.Logf, nil) c.SetLocalPort(1234) for i := 0; i < 2; i++ { if i > 0 { time.Sleep(100 * time.Millisecond) } - ext, err := c.CreateOrGetMapping(context.Background()) + ext, err := c.createOrGetMapping(context.Background()) t.Logf("Got: %v, %v", ext, err) } } @@ -31,7 +31,7 @@ func TestClientProbe(t *testing.T) { if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { t.Skip("skipping test without HIT_NETWORK=1") } - c := NewClient(t.Logf) + c := NewClient(t.Logf, nil) for i := 0; i < 2; i++ { if i > 0 { time.Sleep(100 * time.Millisecond) @@ -45,10 +45,10 @@ func TestClientProbeThenMap(t *testing.T) { if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { t.Skip("skipping test without HIT_NETWORK=1") } - c := NewClient(t.Logf) + c := NewClient(t.Logf, nil) c.SetLocalPort(1234) res, err := c.Probe(context.Background()) t.Logf("Probe: %+v, %v", res, err) - ext, err := c.CreateOrGetMapping(context.Background()) - t.Logf("CreateOrGetMapping: %v, %v", ext, err) + ext, err := c.createOrGetMapping(context.Background()) + t.Logf("createOrGetMapping: %v, %v", ext, err) } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index a2c200be2..2117f725c 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -486,7 +486,7 @@ func NewConn(opts Options) (*Conn, error) { c.noteRecvActivity = opts.NoteRecvActivity c.simulatedNetwork = opts.SimulatedNetwork c.disableLegacy = opts.DisableLegacyNetworking - c.portMapper = portmapper.NewClient(logger.WithPrefix(c.logf, "portmapper: ")) + c.portMapper = portmapper.NewClient(logger.WithPrefix(c.logf, "portmapper: "), c.onPortMapChanged) if opts.LinkMonitor != nil { c.portMapper.SetGatewayLookupFunc(opts.LinkMonitor.GatewayAndSelfIP) } @@ -979,6 +979,8 @@ func (c *Conn) goDerpConnect(node int) { // // c.mu must NOT be held. func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, error) { + portmapExt, havePortmap := c.portMapper.GetCachedMappingOrStartCreatingOne() + nr, err := c.updateNetInfo(ctx) if err != nil { c.logf("magicsock.Conn.determineEndpoints: updateNetInfo: %v", err) @@ -1002,11 +1004,13 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro } } - if ext, err := c.portMapper.CreateOrGetMapping(ctx); err == nil { - addAddr(ext, tailcfg.EndpointPortmapped) + // If we didn't have a portmap earlier, maybe it's done by now. + if !havePortmap { + portmapExt, havePortmap = c.portMapper.GetCachedMappingOrStartCreatingOne() + } + if havePortmap { + addAddr(portmapExt, tailcfg.EndpointPortmapped) c.setNetInfoHavePortMap() - } else if !portmapper.IsNoMappingError(err) { - c.logf("portmapper: %v", err) } if nr.GlobalV4 != "" { @@ -2563,6 +2567,8 @@ func (c *Conn) shouldDoPeriodicReSTUNLocked() bool { return true } +func (c *Conn) onPortMapChanged() { c.ReSTUN("portmap-changed") } + // ReSTUN triggers an address discovery. // The provided why string is for debug logging only. func (c *Conn) ReSTUN(why string) {