diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index cf1726e90..4942d164d 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -440,7 +440,9 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } // Default filter blocks everything and logs nothing, until Start() is called. - b.setFilter(filter.NewAllowNone(logf, &netipx.IPSet{})) + noneFilter := filter.NewAllowNone(logf, &netipx.IPSet{}) + b.setFilter(noneFilter) + b.e.SetJailedFilter(noneFilter) b.setTCPPortsIntercepted(nil) @@ -1935,7 +1937,9 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P if !haveNetmap { b.logf("[v1] netmap packet filter: (not ready yet)") - b.setFilter(filter.NewAllowNone(b.logf, logNets)) + noneFilter := filter.NewAllowNone(b.logf, logNets) + b.setFilter(noneFilter) + b.e.SetJailedFilter(noneFilter) return } @@ -1947,6 +1951,9 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P b.logf("[v1] netmap packet filter: %v filters", len(packetFilter)) b.setFilter(filter.New(packetFilter, localNets, logNets, oldFilter, b.logf)) } + // The filter for a jailed node is the exact same as a ShieldsUp filter. + oldJailedFilter := b.e.GetJailedFilter() + b.e.SetJailedFilter(filter.NewShieldsUpFilter(localNets, logNets, oldJailedFilter, b.logf)) if b.sshServer != nil { go b.sshServer.OnPolicyChange() diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index ade2e6be7..88647b48b 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -153,6 +153,9 @@ type Wrapper struct { filter atomic.Pointer[filter.Filter] // filterFlags control the verbosity of logging packet drops/accepts. filterFlags filter.RunFlags + // jailedFilter is the packet filter for jailed nodes. + // Can be nil, which means drop all packets. + jailedFilter atomic.Pointer[filter.Filter] // PreFilterPacketInboundFromWireGuard is the inbound filter function that runs before the main filter // and therefore sees the packets that may be later dropped by it. @@ -572,6 +575,11 @@ type peerConfig struct { // masqueraded for that address family. dstMasqAddr4 netip.Addr dstMasqAddr6 netip.Addr + + // jailed is whether this peer is "jailed" (i.e. is restricted from being + // able to initiate connections to this node). This is the case for shared + // nodes. + jailed bool } func (c *peerConfigTable) String() string { @@ -597,7 +605,8 @@ func (c *peerConfig) String() string { var b strings.Builder b.WriteString("peerConfig{") fmt.Fprintf(&b, "dstMasqAddr4: %v, ", c.dstMasqAddr4) - fmt.Fprintf(&b, "dstMasqAddr6: %v}", c.dstMasqAddr6) + fmt.Fprintf(&b, "dstMasqAddr6: %v, ", c.dstMasqAddr6) + fmt.Fprintf(&b, "jailed: %v}", c.jailed) return b.String() } @@ -735,10 +744,13 @@ func peerConfigTableFromWGConfig(wcfg *wgcfg.Config) *peerConfigTable { continue } + const peerIsJailed = false // TODO: implement jailed peers + // Use the same peer configuration for each address of the peer. pc := &peerConfig{ dstMasqAddr4: addrToUse4, dstMasqAddr6: addrToUse6, + jailed: peerIsJailed, } // Insert an entry into our routing table for each allowed IP. @@ -753,6 +765,28 @@ func peerConfigTableFromWGConfig(wcfg *wgcfg.Config) *peerConfigTable { return ret } +func (pc *peerConfigTable) inboundPacketIsJailed(p *packet.Parsed) bool { + if pc == nil { + return false + } + c, ok := pc.byIP.Get(p.Src.Addr()) + if !ok { + return false + } + return c.jailed +} + +func (pc *peerConfigTable) outboundPacketIsJailed(p *packet.Parsed) bool { + if pc == nil { + return false + } + c, ok := pc.byIP.Get(p.Dst.Addr()) + if !ok { + return false + } + return c.jailed +} + // SetNetMap is called when a new NetworkMap is received. func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) { cfg := peerConfigTableFromWGConfig(wcfg) @@ -812,7 +846,14 @@ func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed, pc *peerConf } } - filt := t.filter.Load() + // If the outbound packet is to a jailed peer, use our jailed peer + // packet filter. + var filt *filter.Filter + if pc.outboundPacketIsJailed(p) { + filt = t.jailedFilter.Load() + } else { + filt = t.filter.Load() + } if filt == nil { return filter.Drop } @@ -993,7 +1034,12 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca } } - filt := t.filter.Load() + var filt *filter.Filter + if pc.inboundPacketIsJailed(p) { + filt = t.jailedFilter.Load() + } else { + filt = t.filter.Load() + } if filt == nil { return filter.Drop } @@ -1098,6 +1144,14 @@ func (t *Wrapper) SetFilter(filt *filter.Filter) { t.filter.Store(filt) } +func (t *Wrapper) GetJailedFilter() *filter.Filter { + return t.jailedFilter.Load() +} + +func (t *Wrapper) SetJailedFilter(filt *filter.Filter) { + t.jailedFilter.Store(filt) +} + // InjectInboundPacketBuffer makes the Wrapper device behave as if a packet // with the given contents was received from the network. // It takes ownership of one reference count on the packet. The injected diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 4da7f2b5c..f1f0a08ba 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -1037,6 +1037,14 @@ func (e *userspaceEngine) SetFilter(filt *filter.Filter) { e.tundev.SetFilter(filt) } +func (e *userspaceEngine) GetJailedFilter() *filter.Filter { + return e.tundev.GetJailedFilter() +} + +func (e *userspaceEngine) SetJailedFilter(filt *filter.Filter) { + e.tundev.SetJailedFilter(filt) +} + func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) { e.mu.Lock() defer e.mu.Unlock() diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go index 9ff342bcd..232591f5e 100644 --- a/wgengine/watchdog.go +++ b/wgengine/watchdog.go @@ -129,6 +129,12 @@ func (e *watchdogEngine) GetFilter() *filter.Filter { func (e *watchdogEngine) SetFilter(filt *filter.Filter) { e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) }) } +func (e *watchdogEngine) GetJailedFilter() *filter.Filter { + return e.wrap.GetJailedFilter() +} +func (e *watchdogEngine) SetJailedFilter(filt *filter.Filter) { + e.watchdog("SetJailedFilter", func() { e.wrap.SetJailedFilter(filt) }) +} func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) { e.watchdog("SetStatusCallback", func() { e.wrap.SetStatusCallback(cb) }) } diff --git a/wgengine/wgengine.go b/wgengine/wgengine.go index a27db96d9..3bc575794 100644 --- a/wgengine/wgengine.go +++ b/wgengine/wgengine.go @@ -78,6 +78,13 @@ type Engine interface { // SetFilter updates the packet filter. SetFilter(*filter.Filter) + // GetJailedFilter returns the current packet filter for jailed nodes, + // if any. + GetJailedFilter() *filter.Filter + + // SetJailedFilter updates the packet filter for jailed nodes. + SetJailedFilter(*filter.Filter) + // SetStatusCallback sets the function to call when the // WireGuard status changes. SetStatusCallback(StatusCallback)