diff --git a/tstest/natlab/firewall.go b/tstest/natlab/firewall.go index 4cb83849b..06af2663e 100644 --- a/tstest/natlab/firewall.go +++ b/tstest/natlab/firewall.go @@ -101,32 +101,57 @@ func (f *Firewall) timeNow() time.Time { return time.Now() } -// HandlePacket implements the PacketHandler type. -func (f *Firewall) HandlePacket(p *Packet, inIf *Interface) PacketVerdict { - f.mu.Lock() - defer f.mu.Unlock() +func (f *Firewall) init() { if f.seen == nil { f.seen = map[fwKey]time.Time{} } - if f.SessionTimeout == 0 { - f.SessionTimeout = 30 * time.Second +} + +func (f *Firewall) HandleOut(p *Packet, oif *Interface) *Packet { + f.mu.Lock() + defer f.mu.Unlock() + f.init() + + k := f.Type.key(p.Src, p.Dst) + f.seen[k] = f.timeNow().Add(f.sessionTimeoutLocked()) + p.Trace("firewall out ok") + return p +} + +func (f *Firewall) HandleIn(p *Packet, iif *Interface) *Packet { + f.mu.Lock() + defer f.mu.Unlock() + f.init() + + // reverse src and dst because the session table is from the POV + // of outbound packets. + k := f.Type.key(p.Dst, p.Src) + now := f.timeNow() + if now.After(f.seen[k]) { + p.Trace("firewall drop") + return nil } + p.Trace("firewall in ok") + return p +} - if inIf == f.TrustedInterface || inIf == nil { - k := f.Type.key(p.Src, p.Dst) - f.seen[k] = f.timeNow().Add(f.SessionTimeout) - p.Trace("firewall out ok") - return Continue - } else { - // reverse src and dst because the session table is from the - // POV of outbound packets. - k := f.Type.key(p.Dst, p.Src) - now := f.timeNow() - if now.After(f.seen[k]) { - p.Trace("firewall drop") - return Drop - } - p.Trace("firewall in ok") - return Continue +func (f *Firewall) HandleForward(p *Packet, iif *Interface, oif *Interface) *Packet { + if iif == f.TrustedInterface { + // Treat just like a locally originated packet + return f.HandleOut(p, oif) + } + if oif != f.TrustedInterface { + // Not a possible return packet from our trusted interface, drop. + p.Trace("firewall drop, unexpected oif") + return nil + } + // Otherwise, a session must exist, same as HandleIn. + return f.HandleIn(p, iif) +} + +func (f *Firewall) sessionTimeoutLocked() time.Duration { + if f.SessionTimeout == 0 { + return DefaultSessionTimeout } + return f.SessionTimeout } diff --git a/tstest/natlab/nat.go b/tstest/natlab/nat.go index da31a5dcb..a4fbb9ee1 100644 --- a/tstest/natlab/nat.go +++ b/tstest/natlab/nat.go @@ -99,11 +99,6 @@ type SNAT44 struct { // nil, time.Now is used. TimeNow func() time.Time - // inject, if not nil, will be invoked instead of Machine.Inject - // to inject NATed packets into the network. It is used for tests - // only. - inject func(*Packet) error - mu sync.Mutex byLAN map[natKey]*mapping // lookup by outbound packet tuple byWAN map[netaddr.IPPort]*mapping // lookup by wan ip:port only @@ -131,87 +126,105 @@ func (n *SNAT44) initLocked() { if n.ExternalInterface.Machine() != n.Machine { panic(fmt.Sprintf("NAT given interface %s that is not part of given machine %s", n.ExternalInterface, n.Machine.Name)) } - if n.inject == nil { - n.inject = n.Machine.Inject +} + +func (n *SNAT44) HandleOut(p *Packet, oif *Interface) *Packet { + // NATs don't affect locally originated packets. + if n.Firewall != nil { + return n.Firewall.HandleOut(p, oif) } + return p } -func (n *SNAT44) HandlePacket(p *Packet, inIf *Interface) PacketVerdict { +func (n *SNAT44) HandleIn(p *Packet, iif *Interface) *Packet { + if iif != n.ExternalInterface { + // NAT can't apply, defer to firewall. + if n.Firewall != nil { + return n.Firewall.HandleIn(p, iif) + } + return p + } + n.mu.Lock() defer n.mu.Unlock() n.initLocked() - if inIf == n.ExternalInterface { - return n.processInboundLocked(p, inIf) - } else { - return n.processOutboundLocked(p, inIf) - } -} - -func (n *SNAT44) processInboundLocked(p *Packet, inIf *Interface) PacketVerdict { - // TODO: packets to local addrs should fall through to local - // socket processing. now := n.timeNow() mapping := n.byWAN[p.Dst] if mapping == nil || now.After(mapping.deadline) { - p.Trace("nat drop, no mapping/expired mapping") - return Drop - } - p.Dst = mapping.lanSrc - - if n.Firewall != nil { - if verdict := n.Firewall(p.Clone(), inIf); verdict == Drop { - return Drop + // NAT didn't hit, defer to firewall or allow in for local + // socket handling. + if n.Firewall != nil { + return n.Firewall.HandleIn(p, iif) } + return p } - if err := n.inject(p); err != nil { - p.Trace("inject failed: %v", err) - } - return Drop + p.Dst = mapping.lanSrc + p.Trace("dnat to %v", p.Dst) + // Don't process firewall here. We mutated the packet such that + // it's no longer destined locally, so we'll get reinvoked as + // HandleForward and need to process the altered packet there. + return p } -func (n *SNAT44) processOutboundLocked(p *Packet, inIf *Interface) PacketVerdict { - if n.Firewall != nil { - if verdict := n.Firewall(p, inIf); verdict == Drop { - return Drop +func (n *SNAT44) HandleForward(p *Packet, iif, oif *Interface) *Packet { + switch { + case oif == n.ExternalInterface: + if p.Src.IP == oif.V4() { + // Packet already NATed and is just retraversing Forward, + // don't touch it again. + return p } - } - if inIf == nil { - // Technically, we don't need to process the outbound firewall - // for NATed packets, but our current packet processing API - // doesn't give us that granularity: we'll see both locally - // originated PacketConn traffic and NATed traffic as inIf == - // nil, and we need to apply the firewall to locally - // originated traffic. This may create some useless state - // entries in the firewall, but until we implement a much more - // elaborate packet processing pipeline that can distinguish - // local vs. forwarded traffic, this is the best we have. - return Continue - } - k := n.Type.key(p.Src, p.Dst) - now := n.timeNow() - m := n.byLAN[k] - if m == nil || now.After(m.deadline) { - pc, wanAddr := n.allocateMappedPort() - m = &mapping{ - lanSrc: p.Src, - lanDst: p.Dst, - wanSrc: wanAddr, - pc: pc, + if n.Firewall != nil { + p2 := n.Firewall.HandleForward(p, iif, oif) + if p2 == nil { + // firewall dropped, done + return nil + } + if !p.Equivalent(p2) { + // firewall mutated packet? Weird, but okay. + return p2 + } } - n.byLAN[k] = m - n.byWAN[wanAddr] = m - } - m.deadline = now.Add(n.mappingTimeout()) - p.Src = m.wanSrc - p.Trace("snat from %v", p.Src) - if err := n.inject(p); err != nil { - p.Trace("inject failed: %v", err) + n.mu.Lock() + defer n.mu.Unlock() + n.initLocked() + + k := n.Type.key(p.Src, p.Dst) + now := n.timeNow() + m := n.byLAN[k] + if m == nil || now.After(m.deadline) { + pc, wanAddr := n.allocateMappedPort() + m = &mapping{ + lanSrc: p.Src, + lanDst: p.Dst, + wanSrc: wanAddr, + pc: pc, + } + n.byLAN[k] = m + n.byWAN[wanAddr] = m + } + m.deadline = now.Add(n.mappingTimeout()) + p.Src = m.wanSrc + p.Trace("snat from %v", p.Src) + return p + case iif == n.ExternalInterface: + // Packet was already un-NAT-ed, we just need to either + // firewall it or let it through. + if n.Firewall != nil { + return n.Firewall.HandleForward(p, iif, oif) + } + return p + default: + // No NAT applies, invoke firewall or drop. + if n.Firewall != nil { + return n.Firewall.HandleForward(p, iif, oif) + } + return nil } - return Drop } func (n *SNAT44) allocateMappedPort() (net.PacketConn, netaddr.IPPort) { diff --git a/tstest/natlab/natlab.go b/tstest/natlab/natlab.go index fe9ffedfa..230039564 100644 --- a/tstest/natlab/natlab.go +++ b/tstest/natlab/natlab.go @@ -12,6 +12,7 @@ package natlab import ( + "bytes" "context" "crypto/sha256" "encoding/base64" @@ -40,6 +41,12 @@ type Packet struct { locator string } +// Equivalent returns true if Src, Dst and Payload are the same in p +// and p2. +func (p *Packet) Equivalent(p2 *Packet) bool { + return p.Src == p2.Src && p.Dst == p2.Dst && bytes.Equal(p.Payload, p2.Payload) +} + // Clone returns a copy of p that shares nothing with p. func (p *Packet) Clone() *Packet { return &Packet{ @@ -266,8 +273,41 @@ func (v PacketVerdict) String() string { } } -// A PacketHandler is a function that can process packets. -type PacketHandler func(p *Packet, inIf *Interface) PacketVerdict +// A PacketHandler can look at packets arriving at, departing, and +// transiting a Machine, and filter or mutate them. +// +// Each method is invoked with a Packet that natlab would like to keep +// processing. Handlers can return that same Packet to allow +// processing to continue; nil to drop the Packet; or a different +// Packet that should be processed instead of the original. +// +// Packets passed to handlers share no state with anything else, and +// are therefore safe to mutate. It's safe to return the original +// packet mutated in-place, or a brand new packet initialized from +// scratch. +// +// Packets mutated by a PacketHandler are processed anew by the +// associated Machine, as if the packet had always been the mutated +// one. For example, if HandleForward is invoked with a Packet, and +// the handler changes the destination IP address to one of the +// Machine's own IPs, the Machine restarts delivery, but this time +// going to a local PacketConn (which in turn will invoke HandleIn, +// since the packet is now destined for local delivery). +type PacketHandler interface { + // HandleIn processes a packet arriving on iif, whose destination + // is an IP address owned by the attached Machine. If p is + // returned unmodified, the Machine will go on to deliver the + // Packet to the appropriate listening PacketConn, if one exists. + HandleIn(p *Packet, iif *Interface) *Packet + // HandleOut processes a packet about to depart on oif from a + // local PacketConn. If p is returned unmodified, the Machine will + // transmit the Packet on oif. + HandleOut(p *Packet, oif *Interface) *Packet + // HandleForward is called when the Machine wants to forward a + // packet from iif to oif. If p is returned unmodified, the + // Machine will transmit the packet on oif. + HandleForward(p *Packet, iif, oif *Interface) *Packet +} // A Machine is a representation of an operating system's network // stack. It has a network routing table and can have multiple @@ -278,19 +318,14 @@ type Machine struct { // not be globally unique. Name string - // HandlePacket, if not nil, is a function that gets invoked for - // every packet this Machine receives, and every packet sent by a - // local PacketConn. Returns a verdict for how the packet should - // continue to be handled (or not). - // - // HandlePacket's interface parameter is the interface on which - // the packet was received, or nil for a packet sent by a local - // PacketConn or Inject call. + // PacketHandler, if not nil, is a PacketHandler implementation + // that inspects all packets arriving, departing, or transiting + // the Machine. See the definition of the PacketHandler interface + // for semantics. // - // The packet provided to HandlePacket can safely be mutated and - // Inject()ed if desired. This can be used to implement things - // like stateful firewalls and NAT boxes. - HandlePacket PacketHandler + // If PacketHandler is nil, the machine allows all inbound + // traffic, all outbound traffic, and drops forwarded packets. + PacketHandler PacketHandler mu sync.Mutex interfaces []*Interface @@ -300,26 +335,42 @@ type Machine struct { conns6 map[netaddr.IPPort]*conn // conns that want IPv6 packets } -// Inject transmits p from src to dst, without the need for a local socket. -// It's useful for implementing e.g. NAT boxes that need to mangle IPs. -func (m *Machine) Inject(p *Packet) error { - p = p.Clone() - p.setLocator("mach=%s", m.Name) - p.Trace("Machine.Inject") - _, err := m.writePacket(p) - return err +func (m *Machine) isLocalIP(ip netaddr.IP) bool { + m.mu.Lock() + defer m.mu.Unlock() + for _, intf := range m.interfaces { + for _, iip := range intf.ips { + if ip == iip { + return true + } + } + } + return false } func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) { p.setLocator("mach=%s if=%s", m.Name, iface.name) + + if m.isLocalIP(p.Dst.IP) { + m.deliverLocalPacket(p, iface) + } else { + m.forwardPacket(p, iface) + } +} + +func (m *Machine) deliverLocalPacket(p *Packet, iface *Interface) { // TODO: can't hold lock while handling packet. This is safe as // long as you set HandlePacket before traffic starts flowing. - if m.HandlePacket != nil { - p.Trace("Machine.HandlePacket") - verdict := m.HandlePacket(p.Clone(), iface) - p.Trace("Machine.HandlePacket verdict=%s", verdict) - if verdict == Drop { - // Custom packet handler ate the packet, we're done. + if m.PacketHandler != nil { + p2 := m.PacketHandler.HandleIn(p.Clone(), iface) + if p2 == nil { + // Packet dropped, nothing left to do. + return + } + if !p.Equivalent(p2) { + // Restart delivery, this packet might be a forward packet + // now. + m.deliverIncomingPacket(p2, iface) return } } @@ -353,6 +404,35 @@ func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) { p.Trace("dropped, no listening conn") } +func (m *Machine) forwardPacket(p *Packet, iif *Interface) { + oif, err := m.interfaceForIP(p.Dst.IP) + if err != nil { + p.Trace("%v", err) + return + } + + if m.PacketHandler == nil { + // Forwarding not allowed by default + p.Trace("drop, forwarding not allowed") + return + } + p2 := m.PacketHandler.HandleForward(p.Clone(), iif, oif) + if p2 == nil { + p.Trace("drop") + // Packet dropped, done. + return + } + if !p.Equivalent(p2) { + // Packet changed, restart delivery. + p2.Trace("PacketHandler mutated packet") + m.deliverIncomingPacket(p2, iif) + return + } + + p.Trace("-> net=%s oif=%s", oif.net.Name, oif) + oif.net.write(p) +} + func unspecOf(ip netaddr.IP) netaddr.IP { if ip.Is4() { return v4unspec @@ -455,13 +535,17 @@ func (m *Machine) writePacket(p *Packet) (n int, err error) { return 0, err } - if m.HandlePacket != nil { - p.Trace("Machine.HandlePacket") - verdict := m.HandlePacket(p.Clone(), nil) - p.Trace("Machine.HandlePacket verdict=%s", verdict) - if verdict == Drop { + if m.PacketHandler != nil { + p2 := m.PacketHandler.HandleOut(p.Clone(), iface) + if p2 == nil { + // Packet dropped, done. return len(p.Payload), nil } + if !p.Equivalent(p2) { + // Restart transmission, src may have changed weirdly + m.writePacket(p2) + return + } } p.Trace("-> net=%s if=%s", iface.net.Name, iface) diff --git a/tstest/natlab/natlab_test.go b/tstest/natlab/natlab_test.go index 2368fb2aa..1119e804d 100644 --- a/tstest/natlab/natlab_test.go +++ b/tstest/natlab/natlab_test.go @@ -148,6 +148,38 @@ func TestMultiNetwork(t *testing.T) { } } +type trivialNAT struct { + clientIP netaddr.IP + lanIf, wanIf *Interface +} + +func (n *trivialNAT) HandleIn(p *Packet, iface *Interface) *Packet { + if iface == n.wanIf && p.Dst.IP == n.wanIf.V4() { + p.Dst.IP = n.clientIP + } + return p +} + +func (n trivialNAT) HandleOut(p *Packet, iface *Interface) *Packet { + return p +} + +func (n *trivialNAT) HandleForward(p *Packet, iif, oif *Interface) *Packet { + // Outbound from LAN -> apply NAT, continue + if iif == n.lanIf && oif == n.wanIf { + if p.Src.IP == n.clientIP { + p.Src.IP = n.wanIf.V4() + } + return p + } + // Return traffic to LAN, allow if right dst. + if iif == n.wanIf && oif == n.lanIf && p.Dst.IP == n.clientIP { + return p + } + // Else drop. + return nil +} + func TestPacketHandler(t *testing.T) { lan := &Network{ Name: "lan", @@ -167,29 +199,10 @@ func TestPacketHandler(t *testing.T) { lan.SetDefaultGateway(ifNATLAN) - // This HandlePacket implements a basic (some might say "broken") - // 1:1 NAT, where client's IP gets replaced with the NAT's WAN IP, - // and vice versa. - // - // This NAT is not suitable for actual use, since it doesn't do - // port remappings or any other things that NATs usually to. But - // it works as a demonstrator for a single client behind the NAT, - // where the NAT box itself doesn't also make PacketConns. - nat.HandlePacket = func(p *Packet, iface *Interface) PacketVerdict { - switch { - case p.Dst.IP.Is6(): - return Continue // no NAT for ipv6 - case iface == ifNATLAN && p.Src.IP == ifClient.V4(): - p.Src.IP = ifNATWAN.V4() - nat.Inject(p) - return Drop - case iface == ifNATWAN && p.Dst.IP == ifNATWAN.V4(): - p.Dst.IP = ifClient.V4() - nat.Inject(p) - return Drop - default: - return Continue - } + nat.PacketHandler = &trivialNAT{ + clientIP: ifClient.V4(), + lanIf: ifNATLAN, + wanIf: ifNATWAN, } ctx := context.Background() @@ -246,17 +259,17 @@ func TestFirewall(t *testing.T) { } testFirewall(t, f, []fwTest{ // client -> A authorizes A -> client - {trust, client, serverA, Continue}, - {untrust, serverA, client, Continue}, - {untrust, serverA, client, Continue}, + {trust, untrust, client, serverA, true}, + {untrust, trust, serverA, client, true}, + {untrust, trust, serverA, client, true}, // B1 -> client fails until client -> B1 - {untrust, serverB1, client, Drop}, - {trust, client, serverB1, Continue}, - {untrust, serverB1, client, Continue}, + {untrust, trust, serverB1, client, false}, + {trust, untrust, client, serverB1, true}, + {untrust, trust, serverB1, client, true}, // B2 -> client still fails - {untrust, serverB2, client, Drop}, + {untrust, trust, serverB2, client, false}, }) }) t.Run("ip_dependent", func(t *testing.T) { @@ -267,17 +280,17 @@ func TestFirewall(t *testing.T) { } testFirewall(t, f, []fwTest{ // client -> A authorizes A -> client - {trust, client, serverA, Continue}, - {untrust, serverA, client, Continue}, - {untrust, serverA, client, Continue}, + {trust, untrust, client, serverA, true}, + {untrust, trust, serverA, client, true}, + {untrust, trust, serverA, client, true}, // B1 -> client fails until client -> B1 - {untrust, serverB1, client, Drop}, - {trust, client, serverB1, Continue}, - {untrust, serverB1, client, Continue}, + {untrust, trust, serverB1, client, false}, + {trust, untrust, client, serverB1, true}, + {untrust, trust, serverB1, client, true}, // B2 -> client also works now - {untrust, serverB2, client, Continue}, + {untrust, trust, serverB2, client, true}, }) }) t.Run("endpoint_independent", func(t *testing.T) { @@ -288,23 +301,23 @@ func TestFirewall(t *testing.T) { } testFirewall(t, f, []fwTest{ // client -> A authorizes A -> client - {trust, client, serverA, Continue}, - {untrust, serverA, client, Continue}, - {untrust, serverA, client, Continue}, + {trust, untrust, client, serverA, true}, + {untrust, trust, serverA, client, true}, + {untrust, trust, serverA, client, true}, // B1 -> client also works - {untrust, serverB1, client, Continue}, + {untrust, trust, serverB1, client, true}, // B2 -> client also works - {untrust, serverB2, client, Continue}, + {untrust, trust, serverB2, client, true}, }) }) } type fwTest struct { - iface *Interface + iif, oif *Interface src, dst netaddr.IPPort - want PacketVerdict + ok bool } func testFirewall(t *testing.T, f *Firewall, tests []fwTest) { @@ -318,9 +331,10 @@ func testFirewall(t *testing.T, f *Firewall, tests []fwTest) { Dst: test.dst, Payload: []byte{}, } - got := f.HandlePacket(p, test.iface) - if got != test.want { - t.Errorf("iface=%s src=%s dst=%s got %v, want %v", test.iface.name, test.src, test.dst, got, test.want) + got := f.HandleForward(p, test.iif, test.oif) + gotOK := got != nil + if gotOK != test.ok { + t.Errorf("iif=%s oif=%s src=%s dst=%s got ok=%v, want ok=%v", test.iif, test.oif, test.src, test.dst, gotOK, test.ok) } } } @@ -344,14 +358,13 @@ func TestNAT(t *testing.T) { lanIf := m.Attach("lan", lan) t.Run("endpoint_independent_mapping", func(t *testing.T) { - fw := &Firewall{ - TrustedInterface: lanIf, - } n := &SNAT44{ Machine: m, ExternalInterface: wanIf, Type: EndpointIndependentNAT, - Firewall: fw.HandlePacket, + Firewall: &Firewall{ + TrustedInterface: lanIf, + }, } testNAT(t, n, lanIf, wanIf, []natTest{ { @@ -373,14 +386,13 @@ func TestNAT(t *testing.T) { }) t.Run("address_dependent_mapping", func(t *testing.T) { - fw := &Firewall{ - TrustedInterface: lanIf, - } n := &SNAT44{ Machine: m, ExternalInterface: wanIf, Type: AddressDependentNAT, - Firewall: fw.HandlePacket, + Firewall: &Firewall{ + TrustedInterface: lanIf, + }, } testNAT(t, n, lanIf, wanIf, []natTest{ { @@ -407,14 +419,13 @@ func TestNAT(t *testing.T) { }) t.Run("address_and_port_dependent_mapping", func(t *testing.T) { - fw := &Firewall{ - TrustedInterface: lanIf, - } n := &SNAT44{ Machine: m, ExternalInterface: wanIf, Type: AddressAndPortDependentNAT, - Firewall: fw.HandlePacket, + Firewall: &Firewall{ + TrustedInterface: lanIf, + }, } testNAT(t, n, lanIf, wanIf, []natTest{ { @@ -448,16 +459,7 @@ type natTest struct { func testNAT(t *testing.T, n *SNAT44, lanIf, wanIf *Interface, tests []natTest) { clock := &tstest.Clock{} - injected := make(chan *Packet, 100) // arbitrary n.TimeNow = clock.Now - n.inject = func(p *Packet) error { - select { - case injected <- p: - default: - panic("inject overflow") - } - return nil - } mappings := map[netaddr.IPPort]bool{} for _, test := range tests { @@ -467,25 +469,18 @@ func testNAT(t *testing.T, n *SNAT44, lanIf, wanIf *Interface, tests []natTest) Dst: test.dst, Payload: []byte("foo"), } - gotVerdict := n.HandlePacket(p.Clone(), lanIf) - if gotVerdict != Drop { - t.Errorf("p.HandlePacket(%v) = %v, want Drop", p, gotVerdict) - } - - var gotPacket *Packet - - select { - default: - t.Errorf("p.HandlePacket(%v) didn't inject expected packet", p) - case gotPacket = <-injected: + gotPacket := n.HandleForward(p.Clone(), lanIf, wanIf) + if gotPacket == nil { + t.Errorf("n.HandleForward(%v) dropped packet", p) + continue } if gotPacket.Dst != p.Dst { - t.Errorf("p.HandlePacket(%v) mutated dest ip:port, got %v", p, gotPacket.Dst) + t.Errorf("n.HandleForward(%v) mutated dest ip:port, got %v", p, gotPacket.Dst) } gotNewMapping := !mappings[gotPacket.Src] if gotNewMapping != test.wantNewMapping { - t.Errorf("p.HandlePacket(%v) mapping was new=%v, want %v", p, gotNewMapping, test.wantNewMapping) + t.Errorf("n.HandleForward(%v) mapping was new=%v, want %v", p, gotNewMapping, test.wantNewMapping) } mappings[gotPacket.Src] = true @@ -497,16 +492,11 @@ func testNAT(t *testing.T, n *SNAT44, lanIf, wanIf *Interface, tests []natTest) Dst: gotPacket.Src, Payload: []byte("bar"), } - gotVerdict = n.HandlePacket(p2.Clone(), wanIf) - if gotVerdict != Drop { - t.Errorf("p.HandlePacket(%v) = %v, want Drop", p, gotVerdict) - } + gotPacket2 := n.HandleIn(p2.Clone(), wanIf) - var gotPacket2 *Packet - select { - default: - t.Errorf("p.HandlePacket(%v) didn't inject expected packet", p) - case gotPacket2 = <-injected: + if gotPacket2 == nil { + t.Errorf("return packet was dropped") + continue } if gotPacket2.Src != test.dst { diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 1f7814cbf..d8193d17d 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -371,15 +371,13 @@ func TestTwoDevicePing(t *testing.T) { t.Run("facing firewalls", func(t *testing.T) { mstun := &natlab.Machine{Name: "stun"} - f1 := &natlab.Firewall{} - f2 := &natlab.Firewall{} m1 := &natlab.Machine{ - Name: "m1", - HandlePacket: f1.HandlePacket, + Name: "m1", + PacketHandler: &natlab.Firewall{}, } m2 := &natlab.Machine{ - Name: "m2", - HandlePacket: f2.HandlePacket, + Name: "m2", + PacketHandler: &natlab.Firewall{}, } inet := natlab.NewInternet() sif := mstun.Attach("eth0", inet)