wgengine/magicsock: make portmapping async

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/2390/head
Brad Fitzpatrick 3 years ago committed by Brad Fitzpatrick
parent afbd35482d
commit 92077ae78c

@ -50,7 +50,7 @@ var netcheckArgs struct {
func runNetcheck(ctx context.Context, args []string) error { func runNetcheck(ctx context.Context, args []string) error {
c := &netcheck.Client{ c := &netcheck.Client{
UDPBindAddr: os.Getenv("TS_DEBUG_NETCHECK_UDP_BIND"), UDPBindAddr: os.Getenv("TS_DEBUG_NETCHECK_UDP_BIND"),
PortMapper: portmapper.NewClient(logger.WithPrefix(log.Printf, "portmap: ")), PortMapper: portmapper.NewClient(logger.WithPrefix(log.Printf, "portmap: "), nil),
} }
if netcheckArgs.verbose { if netcheckArgs.verbose {
c.Logf = logger.WithPrefix(log.Printf, "netcheck: ") c.Logf = logger.WithPrefix(log.Printf, "netcheck: ")

@ -44,9 +44,15 @@ const trustServiceStillAvailableDuration = 10 * time.Minute
type Client struct { type Client struct {
logf logger.Logf logf logger.Logf
ipAndGateway func() (gw, ip netaddr.IP, ok bool) ipAndGateway func() (gw, ip netaddr.IP, ok bool)
onChange func() // or nil
mu sync.Mutex // guards following, and all fields thereof mu sync.Mutex // guards following, and all fields thereof
// runningCreate is whether we're currently working on creating
// a port mapping (whether GetCachedMappingOrStartCreatingOne kicked
// off a createMapping goroutine).
runningCreate bool
lastMyIP netaddr.IP lastMyIP netaddr.IP
lastGW netaddr.IP lastGW netaddr.IP
closed bool closed bool
@ -68,18 +74,19 @@ type Client struct {
func (c *Client) HaveMapping() bool { func (c *Client) HaveMapping() bool {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
return c.pmpMapping != nil && c.pmpMapping.useUntil.After(time.Now()) return c.pmpMapping != nil && c.pmpMapping.goodUntil.After(time.Now())
} }
// pmpMapping is an already-created PMP mapping. // pmpMapping is an already-created PMP mapping.
// //
// All fields are immutable once created. // All fields are immutable once created.
type pmpMapping struct { type pmpMapping struct {
gw netaddr.IP gw netaddr.IP
external netaddr.IPPort external netaddr.IPPort
internal netaddr.IPPort internal netaddr.IPPort
useUntil time.Time // the mapping's lifetime minus renewal interval renewAfter time.Time // the time at which we want to renew the mapping
epoch uint32 goodUntil time.Time // the mapping's total lifetime
epoch uint32
} }
// externalValid reports whether m.external is valid, with both its IP and Port populated. // externalValid reports whether m.external is valid, with both its IP and Port populated.
@ -99,10 +106,15 @@ func (m *pmpMapping) release() {
} }
// NewClient returns a new portmapping client. // NewClient returns a new portmapping client.
func NewClient(logf logger.Logf) *Client { //
// The optional onChange argument specifies a func to run in a new
// goroutine whenever the port mapping status has changed. If nil,
// it doesn't make a callback.
func NewClient(logf logger.Logf, onChange func()) *Client {
return &Client{ return &Client{
logf: logf, logf: logf,
ipAndGateway: interfaces.LikelyHomeRouterIP, ipAndGateway: interfaces.LikelyHomeRouterIP,
onChange: onChange,
} }
} }
@ -221,8 +233,7 @@ func closeCloserOnContextDone(ctx context.Context, c io.Closer) (stop func()) {
return func() { close(stopWaitDone) } return func() { close(stopWaitDone) }
} }
// NoMappingError is returned by CreateOrGetMapping when no NAT // NoMappingError is returned when no NAT mapping could be done.
// mapping could be returned.
type NoMappingError struct { type NoMappingError struct {
err error err error
} }
@ -241,12 +252,62 @@ var (
ErrGatewayNotFound = errors.New("failed to look up gateway address") ErrGatewayNotFound = errors.New("failed to look up gateway address")
) )
// CreateOrGetMapping either creates a new mapping or returns a cached // GetCachedMappingOrStartCreatingOne quickly returns with our current cached portmapping, if any.
// If there's not one, it starts up a background goroutine to create one.
// If the background goroutine ends up creating one, the onChange hook registered with the
// NewClient constructor (if any) will fire.
func (c *Client) GetCachedMappingOrStartCreatingOne() (external netaddr.IPPort, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
// Do we have an existing mapping that's valid?
now := time.Now()
if m := c.pmpMapping; m != nil {
if now.Before(m.goodUntil) {
if now.After(m.renewAfter) {
c.maybeStartMappingLocked()
}
return m.external, true
}
}
c.maybeStartMappingLocked()
return netaddr.IPPort{}, false
}
// maybeStartMappingLocked starts a createMapping goroutine up, if one isn't already running.
//
// c.mu must be held.
func (c *Client) maybeStartMappingLocked() {
if !c.runningCreate {
c.runningCreate = true
go c.createMapping()
}
}
func (c *Client) createMapping() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
defer func() {
c.mu.Lock()
defer c.mu.Unlock()
c.runningCreate = false
}()
if _, err := c.createOrGetMapping(ctx); err == nil && c.onChange != nil {
go c.onChange()
} else if err != nil && !IsNoMappingError(err) {
c.logf("createOrGetMapping: %v", err)
}
}
// createOrGetMapping either creates a new mapping or returns a cached
// valid one. // valid one.
// //
// If no mapping is available, the error will be of type // If no mapping is available, the error will be of type
// NoMappingError; see IsNoMappingError. // NoMappingError; see IsNoMappingError.
func (c *Client) CreateOrGetMapping(ctx context.Context) (external netaddr.IPPort, err error) { func (c *Client) createOrGetMapping(ctx context.Context) (external netaddr.IPPort, err error) {
gw, myIP, ok := c.gatewayAndSelfIP() gw, myIP, ok := c.gatewayAndSelfIP()
if !ok { if !ok {
return netaddr.IPPort{}, NoMappingError{ErrGatewayNotFound} return netaddr.IPPort{}, NoMappingError{ErrGatewayNotFound}
@ -266,7 +327,7 @@ func (c *Client) CreateOrGetMapping(ctx context.Context) (external netaddr.IPPor
// Do we have an existing mapping that's valid? // Do we have an existing mapping that's valid?
now := time.Now() now := time.Now()
if m := c.pmpMapping; m != nil { if m := c.pmpMapping; m != nil {
if now.Before(m.useUntil) { if now.Before(m.renewAfter) {
defer c.mu.Unlock() defer c.mu.Unlock()
return m.external, nil return m.external, nil
} }
@ -342,8 +403,9 @@ func (c *Client) CreateOrGetMapping(ctx context.Context) (external netaddr.IPPor
if pres.OpCode == pmpOpReply|pmpOpMapUDP { if pres.OpCode == pmpOpReply|pmpOpMapUDP {
m.external = m.external.WithPort(pres.ExternalPort) m.external = m.external.WithPort(pres.ExternalPort)
d := time.Duration(pres.MappingValidSeconds) * time.Second d := time.Duration(pres.MappingValidSeconds) * time.Second
d /= 2 // renew in half the time now := time.Now()
m.useUntil = time.Now().Add(d) m.goodUntil = now.Add(d)
m.renewAfter = now.Add(d / 2) // renew in half the time
m.epoch = pres.SecondsSinceEpoch m.epoch = pres.SecondsSinceEpoch
} }
} }

@ -16,13 +16,13 @@ func TestCreateOrGetMapping(t *testing.T) {
if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v {
t.Skip("skipping test without HIT_NETWORK=1") t.Skip("skipping test without HIT_NETWORK=1")
} }
c := NewClient(t.Logf) c := NewClient(t.Logf, nil)
c.SetLocalPort(1234) c.SetLocalPort(1234)
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
if i > 0 { if i > 0 {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
ext, err := c.CreateOrGetMapping(context.Background()) ext, err := c.createOrGetMapping(context.Background())
t.Logf("Got: %v, %v", ext, err) t.Logf("Got: %v, %v", ext, err)
} }
} }
@ -31,7 +31,7 @@ func TestClientProbe(t *testing.T) {
if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v {
t.Skip("skipping test without HIT_NETWORK=1") t.Skip("skipping test without HIT_NETWORK=1")
} }
c := NewClient(t.Logf) c := NewClient(t.Logf, nil)
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
if i > 0 { if i > 0 {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@ -45,10 +45,10 @@ func TestClientProbeThenMap(t *testing.T) {
if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v {
t.Skip("skipping test without HIT_NETWORK=1") t.Skip("skipping test without HIT_NETWORK=1")
} }
c := NewClient(t.Logf) c := NewClient(t.Logf, nil)
c.SetLocalPort(1234) c.SetLocalPort(1234)
res, err := c.Probe(context.Background()) res, err := c.Probe(context.Background())
t.Logf("Probe: %+v, %v", res, err) t.Logf("Probe: %+v, %v", res, err)
ext, err := c.CreateOrGetMapping(context.Background()) ext, err := c.createOrGetMapping(context.Background())
t.Logf("CreateOrGetMapping: %v, %v", ext, err) t.Logf("createOrGetMapping: %v, %v", ext, err)
} }

@ -486,7 +486,7 @@ func NewConn(opts Options) (*Conn, error) {
c.noteRecvActivity = opts.NoteRecvActivity c.noteRecvActivity = opts.NoteRecvActivity
c.simulatedNetwork = opts.SimulatedNetwork c.simulatedNetwork = opts.SimulatedNetwork
c.disableLegacy = opts.DisableLegacyNetworking c.disableLegacy = opts.DisableLegacyNetworking
c.portMapper = portmapper.NewClient(logger.WithPrefix(c.logf, "portmapper: ")) c.portMapper = portmapper.NewClient(logger.WithPrefix(c.logf, "portmapper: "), c.onPortMapChanged)
if opts.LinkMonitor != nil { if opts.LinkMonitor != nil {
c.portMapper.SetGatewayLookupFunc(opts.LinkMonitor.GatewayAndSelfIP) c.portMapper.SetGatewayLookupFunc(opts.LinkMonitor.GatewayAndSelfIP)
} }
@ -979,6 +979,8 @@ func (c *Conn) goDerpConnect(node int) {
// //
// c.mu must NOT be held. // c.mu must NOT be held.
func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, error) { func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, error) {
portmapExt, havePortmap := c.portMapper.GetCachedMappingOrStartCreatingOne()
nr, err := c.updateNetInfo(ctx) nr, err := c.updateNetInfo(ctx)
if err != nil { if err != nil {
c.logf("magicsock.Conn.determineEndpoints: updateNetInfo: %v", err) c.logf("magicsock.Conn.determineEndpoints: updateNetInfo: %v", err)
@ -1002,11 +1004,13 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro
} }
} }
if ext, err := c.portMapper.CreateOrGetMapping(ctx); err == nil { // If we didn't have a portmap earlier, maybe it's done by now.
addAddr(ext, tailcfg.EndpointPortmapped) if !havePortmap {
portmapExt, havePortmap = c.portMapper.GetCachedMappingOrStartCreatingOne()
}
if havePortmap {
addAddr(portmapExt, tailcfg.EndpointPortmapped)
c.setNetInfoHavePortMap() c.setNetInfoHavePortMap()
} else if !portmapper.IsNoMappingError(err) {
c.logf("portmapper: %v", err)
} }
if nr.GlobalV4 != "" { if nr.GlobalV4 != "" {
@ -2563,6 +2567,8 @@ func (c *Conn) shouldDoPeriodicReSTUNLocked() bool {
return true return true
} }
func (c *Conn) onPortMapChanged() { c.ReSTUN("portmap-changed") }
// ReSTUN triggers an address discovery. // ReSTUN triggers an address discovery.
// The provided why string is for debug logging only. // The provided why string is for debug logging only.
func (c *Conn) ReSTUN(why string) { func (c *Conn) ReSTUN(why string) {

Loading…
Cancel
Save