tstest/natlab: refactor, expose a Packet type.

HandlePacket and Inject now receive/take Packets. This is a handy
container for the packet, and the attached Trace method can be used
to print traces from custom packet handlers that integrate nicely
with natlab's internal traces.

Signed-off-by: David Anderson <danderson@tailscale.com>
reviewable/pr546/r1
David Anderson 4 years ago
parent 5eedbcedd1
commit b3d65ba943

@ -43,7 +43,7 @@ func (f *Firewall) timeNow() time.Time {
return time.Now() return time.Now()
} }
func (f *Firewall) HandlePacket(p []byte, inIf *Interface, dst, src netaddr.IPPort) PacketVerdict { func (f *Firewall) HandlePacket(p *Packet, inIf *Interface) PacketVerdict {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
if f.seen == nil { if f.seen == nil {
@ -52,25 +52,25 @@ func (f *Firewall) HandlePacket(p []byte, inIf *Interface, dst, src netaddr.IPPo
if inIf == f.TrustedInterface { if inIf == f.TrustedInterface {
sess := session{ sess := session{
src: src, src: p.Src,
dst: dst, dst: p.Dst,
} }
f.seen[sess] = f.timeNow().Add(f.SessionTimeout) f.seen[sess] = f.timeNow().Add(f.SessionTimeout)
trace(p, "mach=%s iface=%s src=%s dst=%s firewall out ok", inIf.Machine().Name, inIf.name, src, dst) p.Trace("firewall out ok")
return Continue return Continue
} else { } else {
// reverse src and dst because the session table is from the // reverse src and dst because the session table is from the
// POV of outbound packets. // POV of outbound packets.
sess := session{ sess := session{
src: dst, src: p.Dst,
dst: src, dst: p.Src,
} }
now := f.timeNow() now := f.timeNow()
if now.After(f.seen[sess]) { if now.After(f.seen[sess]) {
trace(p, "mach=%s iface=%s src=%s dst=%s firewall drop", inIf.Machine().Name, inIf.name, src, dst) p.Trace("firewall drop")
return Drop return Drop
} }
trace(p, "mach=%s iface=%s src=%s dst=%s firewall in ok", inIf.Machine().Name, inIf.name, src, dst) p.Trace("firewall in ok")
return Continue return Continue
} }
} }

@ -30,21 +30,49 @@ import (
var traceOn, _ = strconv.ParseBool(os.Getenv("NATLAB_TRACE")) var traceOn, _ = strconv.ParseBool(os.Getenv("NATLAB_TRACE"))
func trace(p []byte, msg string, args ...interface{}) { // Packet represents a UDP packet flowing through the virtual network.
type Packet struct {
Src, Dst netaddr.IPPort
Payload []byte
// Prefix set by various internal methods of natlab, to locate
// where in the network a trace occured.
locator string
}
// Clone returns a copy of p that shares nothing with p.
func (p *Packet) Clone() *Packet {
return &Packet{
Src: p.Src,
Dst: p.Dst,
Payload: append([]byte(nil), p.Payload...),
locator: p.locator,
}
}
// short returns a short identifier for a packet payload,
// suitable for printing trace information.
func (p *Packet) short() string {
s := sha256.Sum256(p.Payload)
payload := base64.RawStdEncoding.EncodeToString(s[:])[:2]
s = sha256.Sum256([]byte(p.Src.String() + "_" + p.Dst.String()))
tuple := base64.RawStdEncoding.EncodeToString(s[:])[:2]
return fmt.Sprintf("%s/%s", payload, tuple)
}
func (p *Packet) Trace(msg string, args ...interface{}) {
if !traceOn { if !traceOn {
return return
} }
id := packetShort(p) allArgs := []interface{}{p.short(), p.locator, p.Src, p.Dst}
as := []interface{}{id} allArgs = append(allArgs, args...)
as = append(as, args...) fmt.Fprintf(os.Stderr, "[%s]%s src=%s dst=%s "+msg+"\n", allArgs...)
fmt.Fprintf(os.Stderr, "[%s] "+msg+"\n", as...)
} }
// packetShort returns a short identifier for a packet payload, func (p *Packet) setLocator(msg string, args ...interface{}) {
// suitable for pritning trace information. p.locator = fmt.Sprintf(" "+msg, args...)
func packetShort(p []byte) string {
s := sha256.Sum256(p)
return base64.RawStdEncoding.EncodeToString(s[:])[:4]
} }
func mustPrefix(s string) netaddr.IPPrefix { func mustPrefix(s string) netaddr.IPPrefix {
@ -79,6 +107,9 @@ type Network struct {
func (n *Network) SetDefaultGateway(gwIf *Interface) { func (n *Network) SetDefaultGateway(gwIf *Interface) {
n.mu.Lock() n.mu.Lock()
defer n.mu.Unlock() defer n.mu.Unlock()
if gwIf.net != n {
panic(fmt.Sprintf("can't set if=%s as net=%s's default gw, if not connected to net", gwIf.name, gwIf.net.Name))
}
n.defaultGW = gwIf n.defaultGW = gwIf
} }
@ -139,24 +170,25 @@ func addOne(a *[16]byte, index int) {
} }
} }
func (n *Network) write(p []byte, dst, src netaddr.IPPort) (num int, err error) { func (n *Network) write(p *Packet) (num int, err error) {
p.setLocator("net=%s", n.Name)
n.mu.Lock() n.mu.Lock()
defer n.mu.Unlock() defer n.mu.Unlock()
iface, ok := n.machine[dst.IP] iface, ok := n.machine[p.Dst.IP]
if !ok { if !ok {
if n.defaultGW == nil { if n.defaultGW == nil {
trace(p, "net=%s dropped, no route to %v", n.Name, dst.IP) p.Trace("no route to %v", p.Dst.IP)
return len(p), nil return len(p.Payload), nil
} }
iface = n.defaultGW iface = n.defaultGW
} }
// Pretend it went across the network. Make a copy so nobody // Pretend it went across the network. Make a copy so nobody
// can later mess with caller's memory. // can later mess with caller's memory.
trace(p, "net=%s src=%v dst=%v -> mach=%s iface=%s", n.Name, src, dst, iface.machine.Name, iface.name) p.Trace("-> mach=%s if=%s", iface.machine.Name, iface.name)
pcopy := append([]byte(nil), p...) go iface.machine.deliverIncomingPacket(p, iface)
go iface.machine.deliverIncomingPacket(pcopy, iface, dst, src) return len(p.Payload), nil
return len(p), nil
} }
type Interface struct { type Interface struct {
@ -235,7 +267,7 @@ func (v PacketVerdict) String() string {
} }
// A PacketHandler is a function that can process packets. // A PacketHandler is a function that can process packets.
type PacketHandler func(p []byte, inIf *Interface, dst, src netaddr.IPPort) PacketVerdict type PacketHandler func(p *Packet, inIf *Interface) PacketVerdict
// A Machine is a representation of an operating system's network // A Machine is a representation of an operating system's network
// stack. It has a network routing table and can have multiple // stack. It has a network routing table and can have multiple
@ -250,8 +282,9 @@ type Machine struct {
// every packet this Machine receives. Returns a verdict for how // every packet this Machine receives. Returns a verdict for how
// the packet should continue to be handled (or not). // the packet should continue to be handled (or not).
// //
// This can be used to implement things like stateful firewalls // The packet provided to HandlePacket can safely be mutated and
// and NAT boxes. // Inject()ed if desired. This can be used to implement things
// like stateful firewalls and NAT boxes.
HandlePacket PacketHandler HandlePacket PacketHandler
mu sync.Mutex mu sync.Mutex
@ -264,18 +297,22 @@ type Machine struct {
// Inject transmits p from src to dst, without the need for a local socket. // Inject transmits p from src to dst, without the need for a local socket.
// It's useful for implementing e.g. NAT boxes that need to mangle IPs. // It's useful for implementing e.g. NAT boxes that need to mangle IPs.
func (m *Machine) Inject(p []byte, dst, src netaddr.IPPort) error { func (m *Machine) Inject(p *Packet) error {
trace(p, "mach=%s src=%s dst=%s packet injected", m.Name, src, dst) p = p.Clone()
_, err := m.writePacket(p, dst, src) p.setLocator("mach=%s", m.Name)
p.Trace("Machine.Inject")
_, err := m.writePacket(p)
return err return err
} }
func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src netaddr.IPPort) { func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) {
p.setLocator("mach=%s if=%s", m.Name, iface.name)
// TODO: can't hold lock while handling packet. This is safe as // TODO: can't hold lock while handling packet. This is safe as
// long as you set HandlePacket before traffic starts flowing. // long as you set HandlePacket before traffic starts flowing.
if m.HandlePacket != nil { if m.HandlePacket != nil {
verdict := m.HandlePacket(p, iface, dst, src) p.Trace("Machine.HandlePacket")
trace(p, "mach=%s src=%v dst=%v packethandler verdict=%s", m.Name, src, dst, verdict) verdict := m.HandlePacket(p.Clone(), iface)
p.Trace("Machine.HandlePacket verdict=%s", verdict)
if verdict == Drop { if verdict == Drop {
// Custom packet handler ate the packet, we're done. // Custom packet handler ate the packet, we're done.
return return
@ -286,13 +323,13 @@ func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src net
defer m.mu.Unlock() defer m.mu.Unlock()
conns := m.conns4 conns := m.conns4
if dst.IP.Is6() { if p.Dst.IP.Is6() {
conns = m.conns6 conns = m.conns6
} }
possibleDsts := []netaddr.IPPort{ possibleDsts := []netaddr.IPPort{
dst, p.Dst,
netaddr.IPPort{IP: v6unspec, Port: dst.Port}, netaddr.IPPort{IP: v6unspec, Port: p.Dst.Port},
netaddr.IPPort{IP: v4unspec, Port: dst.Port}, netaddr.IPPort{IP: v4unspec, Port: p.Dst.Port},
} }
for _, dest := range possibleDsts { for _, dest := range possibleDsts {
c, ok := conns[dest] c, ok := conns[dest]
@ -300,15 +337,15 @@ func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src net
continue continue
} }
select { select {
case c.in <- incomingPacket{src: src, p: p}: case c.in <- p:
trace(p, "mach=%s src=%v dst=%v queued to conn", m.Name, src, dst) p.Trace("queued to conn")
default: default:
trace(p, "mach=%s src=%v dst=%v dropped, queue overflow", m.Name, src, dst) p.Trace("dropped, queue overflow")
// Queue overflow. Just drop it. // Queue overflow. Just drop it.
} }
return return
} }
trace(p, "mach=%s src=%v dst=%v dropped, no listening conn", m.Name, src, dst) p.Trace("dropped, no listening conn")
} }
func unspecOf(ip netaddr.IP) netaddr.IP { func unspecOf(ip netaddr.IP) netaddr.IP {
@ -378,38 +415,43 @@ var (
v6unspec = netaddr.IPv6Unspecified() v6unspec = netaddr.IPv6Unspecified()
) )
func (m *Machine) writePacket(p []byte, dst, src netaddr.IPPort) (n int, err error) { func (m *Machine) writePacket(p *Packet) (n int, err error) {
iface, err := m.interfaceForIP(dst.IP) p.setLocator("mach=%s", m.Name)
iface, err := m.interfaceForIP(p.Dst.IP)
if err != nil { if err != nil {
trace(p, "%v", err) p.Trace("%v", err)
return 0, err return 0, err
} }
origSrcIP := src.IP origSrcIP := p.Src.IP
switch { switch {
case src.IP == v4unspec: case p.Src.IP == v4unspec:
src.IP = iface.V4() p.Trace("assigning srcIP=%s", iface.V4())
case src.IP == v6unspec: p.Src.IP = iface.V4()
case p.Src.IP == v6unspec:
// v6unspec in Go means "any src, but match address families" // v6unspec in Go means "any src, but match address families"
if dst.IP.Is6() { if p.Dst.IP.Is6() {
src.IP = iface.V6() p.Trace("assigning srcIP=%s", iface.V6())
} else if dst.IP.Is4() { p.Src.IP = iface.V6()
src.IP = iface.V4() } else if p.Dst.IP.Is4() {
p.Trace("assigning srcIP=%s", iface.V4())
p.Src.IP = iface.V4()
} }
default: default:
if !iface.Contains(src.IP) { if !iface.Contains(p.Src.IP) {
err := fmt.Errorf("can't send to %v with src %v on interface %v", dst.IP, src.IP, iface) err := fmt.Errorf("can't send to %v with src %v on interface %v", p.Dst.IP, p.Src.IP, iface)
trace(p, "%v", err) p.Trace("%v", err)
return 0, err return 0, err
} }
} }
if src.IP.IsZero() { if p.Src.IP.IsZero() {
err := fmt.Errorf("no matching address for address family for %v", origSrcIP) err := fmt.Errorf("no matching address for address family for %v", origSrcIP)
trace(p, "%v", err) p.Trace("%v", err)
return 0, err return 0, err
} }
trace(p, "mach=%s src=%s dst=%s -> net=%s", m.Name, src, dst, iface.net.Name) p.Trace("-> net=%s if=%s", iface.net.Name, iface)
return iface.net.write(p, dst, src) return iface.net.write(p)
} }
func (m *Machine) interfaceForIP(ip netaddr.IP) (*Interface, error) { func (m *Machine) interfaceForIP(ip netaddr.IP) (*Interface, error) {
@ -552,7 +594,7 @@ func (m *Machine) ListenPacket(ctx context.Context, network, address string) (ne
m: m, m: m,
fam: fam, fam: fam,
ipp: ipp, ipp: ipp,
in: make(chan incomingPacket, 100), // arbitrary in: make(chan *Packet, 100), // arbitrary
} }
switch c.fam { switch c.fam {
case 0: case 0:
@ -585,12 +627,7 @@ type conn struct {
closed bool closed bool
readDeadline time.Time readDeadline time.Time
activeReads map[*activeRead]bool activeReads map[*activeRead]bool
in chan incomingPacket in chan *Packet
}
type incomingPacket struct {
p []byte
src netaddr.IPPort
} }
type activeRead struct { type activeRead struct {
@ -669,9 +706,9 @@ func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select { select {
case pkt := <-c.in: case pkt := <-c.in:
n = copy(p, pkt.p) n = copy(p, pkt.Payload)
trace(pkt.p, "mach=%s src=%s PacketConn.ReadFrom", c.m.Name, pkt.src) pkt.Trace("PacketConn.ReadFrom")
return n, pkt.src.UDPAddr(), nil return n, pkt.Src.UDPAddr(), nil
case <-ctx.Done(): case <-ctx.Done():
return 0, nil, context.DeadlineExceeded return 0, nil, context.DeadlineExceeded
} }
@ -682,7 +719,14 @@ func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if err != nil { if err != nil {
return 0, fmt.Errorf("bogus addr %T %q", addr, addr.String()) return 0, fmt.Errorf("bogus addr %T %q", addr, addr.String())
} }
return c.m.writePacket(p, ipp, c.ipp) pkt := &Packet{
Src: c.ipp,
Dst: ipp,
Payload: append([]byte(nil), p...),
}
pkt.setLocator("mach=%s", c.m.Name)
pkt.Trace("PacketConn.WriteTo")
return c.m.writePacket(pkt)
} }
func (c *conn) SetDeadline(t time.Time) error { func (c *conn) SetDeadline(t time.Time) error {

@ -175,15 +175,17 @@ func TestPacketHandler(t *testing.T) {
// port remappings or any other things that NATs usually to. But // port remappings or any other things that NATs usually to. But
// it works as a demonstrator for a single client behind the NAT, // it works as a demonstrator for a single client behind the NAT,
// where the NAT box itself doesn't also make PacketConns. // where the NAT box itself doesn't also make PacketConns.
nat.HandlePacket = func(p []byte, iface *Interface, dst, src netaddr.IPPort) PacketVerdict { nat.HandlePacket = func(p *Packet, iface *Interface) PacketVerdict {
switch { switch {
case dst.IP.Is6(): case p.Dst.IP.Is6():
return Continue // no NAT for ipv6 return Continue // no NAT for ipv6
case iface == ifNATLAN && src.IP == ifClient.V4(): case iface == ifNATLAN && p.Src.IP == ifClient.V4():
nat.Inject(p, dst, netaddr.IPPort{IP: ifNATWAN.V4(), Port: src.Port}) p.Src.IP = ifNATWAN.V4()
nat.Inject(p)
return Drop return Drop
case iface == ifNATWAN && dst.IP == ifNATWAN.V4(): case iface == ifNATWAN && p.Dst.IP == ifNATWAN.V4():
nat.Inject(p, netaddr.IPPort{IP: ifClient.V4(), Port: dst.Port}, src) p.Dst.IP = ifClient.V4()
nat.Inject(p)
return Drop return Drop
default: default:
return Continue return Continue
@ -257,7 +259,12 @@ func TestFirewall(t *testing.T) {
for _, test := range tests { for _, test := range tests {
clock.Advance(time.Second) clock.Advance(time.Second)
got := f.HandlePacket(nil, test.iface, test.dst, test.src) p := &Packet{
Src: test.src,
Dst: test.dst,
Payload: []byte{},
}
got := f.HandlePacket(p, test.iface)
if got != test.want { if got != test.want {
t.Errorf("iface=%s src=%s dst=%s got %v, want %v", test.iface.name, test.src, test.dst, got, test.want) t.Errorf("iface=%s src=%s dst=%s got %v, want %v", test.iface.name, test.src, test.dst, got, test.want)
} }

Loading…
Cancel
Save