diff --git a/net/portmapper/pcp.go b/net/portmapper/pcp.go index 0066ba225..d0752734e 100644 --- a/net/portmapper/pcp.go +++ b/net/portmapper/pcp.go @@ -54,8 +54,7 @@ type pcpMapping struct { renewAfter time.Time goodUntil time.Time - // TODO should this also contain an epoch? - // Doesn't seem to be used elsewhere, but can use it for validation at some point. + epoch uint32 } func (p *pcpMapping) MappingType() string { return "pcp" } @@ -140,6 +139,7 @@ func parsePCPMapResponse(resp []byte) (*pcpMapping, error) { external: external, renewAfter: now.Add(lifetime / 2), goodUntil: now.Add(lifetime), + epoch: res.Epoch, } return mapping, nil diff --git a/net/portmapper/portmapper.go b/net/portmapper/portmapper.go index f6e73dda3..54557287d 100644 --- a/net/portmapper/portmapper.go +++ b/net/portmapper/portmapper.go @@ -90,11 +90,14 @@ type Client struct { lastProbe time.Time + // The following PMP fields are populated during Probe pmpPubIP netip.Addr // non-zero if known pmpPubIPTime time.Time // time pmpPubIP last verified pmpLastEpoch uint32 - pcpSawTime time.Time // time we last saw PCP was available + // The following PCP fields are populated during Probe + pcpSawTime time.Time // time we last saw PCP was available + pcpLastEpoch uint32 uPnPSawTime time.Time // time we last saw UPnP was available uPnPMetas []uPnPDiscoResponse // UPnP UDP discovery responses @@ -324,9 +327,14 @@ func (c *Client) invalidateMappingsLocked(releaseOld bool) { } c.mapping = nil } + c.pmpPubIP = netip.Addr{} c.pmpPubIPTime = time.Time{} + c.pmpLastEpoch = 0 + c.pcpSawTime = time.Time{} + c.pcpLastEpoch = 0 + c.uPnPSawTime = time.Time{} c.uPnPMetas = nil } @@ -988,7 +996,9 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) { if pres.OpCode == pcpOpReply|pcpOpAnnounce { pcpHeard = true c.mu.Lock() + c.maybeInvalidatePCPMappingLocked(pres.Epoch) // must be before we write to c.pcp* c.pcpSawTime = time.Now() + c.pcpLastEpoch = pres.Epoch c.mu.Unlock() switch pres.ResultCode { case pcpCodeOK: @@ -1026,6 +1036,7 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) { c.logf("[v1] Got PMP response; IP: %v, epoch: %v", pres.PublicAddr, pres.SecondsSinceEpoch) res.PMP = true c.mu.Lock() + c.maybeInvalidatePMPMappingLocked(pres.SecondsSinceEpoch) // must be before we write to c.pmp* c.pmpPubIP = pres.PublicAddr c.pmpPubIPTime = time.Now() c.pmpLastEpoch = pres.SecondsSinceEpoch @@ -1051,6 +1062,57 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) { } } +func (c *Client) maybeInvalidatePMPMappingLocked(epoch uint32) { + if epoch == 0 || c.mapping == nil { + return + } + m, ok := c.mapping.(*pmpMapping) + if !ok { + return + } + + if epoch >= m.epoch { + // Epoch increased, which is fine. + // + // TODO: we should more closely follow RFC6887 § 8.5 which also + // requires us to check the current time and the time that this + // epoch was received at. + return + } + + // Epoch decreased, so invalidate the mapping and clear PMP fields. + c.logf("invalidating PMP mappings since returned epoch %d < stored epoch %d", epoch, m.epoch) + c.mapping = nil + c.pmpPubIP = netip.Addr{} + c.pmpPubIPTime = time.Time{} + c.pmpLastEpoch = 0 +} + +func (c *Client) maybeInvalidatePCPMappingLocked(epoch uint32) { + if epoch == 0 || c.mapping == nil { + return + } + m, ok := c.mapping.(*pcpMapping) + if !ok { + return + } + + if epoch >= m.epoch { + // Epoch increased, which is fine. + // + // TODO: we should more closely follow RFC6887 § 8.5 which also + // requires us to check the current time and the time that this + // epoch was received at. + return + } + + // Epoch decreased, so invalidate the mapping and clear PCP fields. + c.logf("invalidating PCP mappings since returned epoch %d < stored epoch %d", epoch, m.epoch) + c.mapping = nil + c.pcpSawTime = time.Time{} + c.pcpLastEpoch = 0 +} + var pmpReqExternalAddrPacket = []byte{pmpVersion, pmpOpMapPublicAddr} // 0, 0 const (