tstest/natlab: refactor, expose a Packet type.

HandlePacket and Inject now receive/take Packets. This is a handy
container for the packet, and the attached Trace method can be used
to print traces from custom packet handlers that integrate nicely
with natlab's internal traces.

Signed-off-by: David Anderson <danderson@tailscale.com>
reviewable/pr546/r1
David Anderson 4 years ago
parent 5eedbcedd1
commit b3d65ba943

@ -43,7 +43,7 @@ func (f *Firewall) timeNow() time.Time {
return time.Now()
}
func (f *Firewall) HandlePacket(p []byte, inIf *Interface, dst, src netaddr.IPPort) PacketVerdict {
func (f *Firewall) HandlePacket(p *Packet, inIf *Interface) PacketVerdict {
f.mu.Lock()
defer f.mu.Unlock()
if f.seen == nil {
@ -52,25 +52,25 @@ func (f *Firewall) HandlePacket(p []byte, inIf *Interface, dst, src netaddr.IPPo
if inIf == f.TrustedInterface {
sess := session{
src: src,
dst: dst,
src: p.Src,
dst: p.Dst,
}
f.seen[sess] = f.timeNow().Add(f.SessionTimeout)
trace(p, "mach=%s iface=%s src=%s dst=%s firewall out ok", inIf.Machine().Name, inIf.name, src, dst)
p.Trace("firewall out ok")
return Continue
} else {
// reverse src and dst because the session table is from the
// POV of outbound packets.
sess := session{
src: dst,
dst: src,
src: p.Dst,
dst: p.Src,
}
now := f.timeNow()
if now.After(f.seen[sess]) {
trace(p, "mach=%s iface=%s src=%s dst=%s firewall drop", inIf.Machine().Name, inIf.name, src, dst)
p.Trace("firewall drop")
return Drop
}
trace(p, "mach=%s iface=%s src=%s dst=%s firewall in ok", inIf.Machine().Name, inIf.name, src, dst)
p.Trace("firewall in ok")
return Continue
}
}

@ -30,21 +30,49 @@ import (
var traceOn, _ = strconv.ParseBool(os.Getenv("NATLAB_TRACE"))
func trace(p []byte, msg string, args ...interface{}) {
// Packet represents a UDP packet flowing through the virtual network.
type Packet struct {
Src, Dst netaddr.IPPort
Payload []byte
// Prefix set by various internal methods of natlab, to locate
// where in the network a trace occured.
locator string
}
// Clone returns a copy of p that shares nothing with p.
func (p *Packet) Clone() *Packet {
return &Packet{
Src: p.Src,
Dst: p.Dst,
Payload: append([]byte(nil), p.Payload...),
locator: p.locator,
}
}
// short returns a short identifier for a packet payload,
// suitable for printing trace information.
func (p *Packet) short() string {
s := sha256.Sum256(p.Payload)
payload := base64.RawStdEncoding.EncodeToString(s[:])[:2]
s = sha256.Sum256([]byte(p.Src.String() + "_" + p.Dst.String()))
tuple := base64.RawStdEncoding.EncodeToString(s[:])[:2]
return fmt.Sprintf("%s/%s", payload, tuple)
}
func (p *Packet) Trace(msg string, args ...interface{}) {
if !traceOn {
return
}
id := packetShort(p)
as := []interface{}{id}
as = append(as, args...)
fmt.Fprintf(os.Stderr, "[%s] "+msg+"\n", as...)
allArgs := []interface{}{p.short(), p.locator, p.Src, p.Dst}
allArgs = append(allArgs, args...)
fmt.Fprintf(os.Stderr, "[%s]%s src=%s dst=%s "+msg+"\n", allArgs...)
}
// packetShort returns a short identifier for a packet payload,
// suitable for pritning trace information.
func packetShort(p []byte) string {
s := sha256.Sum256(p)
return base64.RawStdEncoding.EncodeToString(s[:])[:4]
func (p *Packet) setLocator(msg string, args ...interface{}) {
p.locator = fmt.Sprintf(" "+msg, args...)
}
func mustPrefix(s string) netaddr.IPPrefix {
@ -79,6 +107,9 @@ type Network struct {
func (n *Network) SetDefaultGateway(gwIf *Interface) {
n.mu.Lock()
defer n.mu.Unlock()
if gwIf.net != n {
panic(fmt.Sprintf("can't set if=%s as net=%s's default gw, if not connected to net", gwIf.name, gwIf.net.Name))
}
n.defaultGW = gwIf
}
@ -139,24 +170,25 @@ func addOne(a *[16]byte, index int) {
}
}
func (n *Network) write(p []byte, dst, src netaddr.IPPort) (num int, err error) {
func (n *Network) write(p *Packet) (num int, err error) {
p.setLocator("net=%s", n.Name)
n.mu.Lock()
defer n.mu.Unlock()
iface, ok := n.machine[dst.IP]
iface, ok := n.machine[p.Dst.IP]
if !ok {
if n.defaultGW == nil {
trace(p, "net=%s dropped, no route to %v", n.Name, dst.IP)
return len(p), nil
p.Trace("no route to %v", p.Dst.IP)
return len(p.Payload), nil
}
iface = n.defaultGW
}
// Pretend it went across the network. Make a copy so nobody
// can later mess with caller's memory.
trace(p, "net=%s src=%v dst=%v -> mach=%s iface=%s", n.Name, src, dst, iface.machine.Name, iface.name)
pcopy := append([]byte(nil), p...)
go iface.machine.deliverIncomingPacket(pcopy, iface, dst, src)
return len(p), nil
p.Trace("-> mach=%s if=%s", iface.machine.Name, iface.name)
go iface.machine.deliverIncomingPacket(p, iface)
return len(p.Payload), nil
}
type Interface struct {
@ -235,7 +267,7 @@ func (v PacketVerdict) String() string {
}
// A PacketHandler is a function that can process packets.
type PacketHandler func(p []byte, inIf *Interface, dst, src netaddr.IPPort) PacketVerdict
type PacketHandler func(p *Packet, inIf *Interface) PacketVerdict
// A Machine is a representation of an operating system's network
// stack. It has a network routing table and can have multiple
@ -250,8 +282,9 @@ type Machine struct {
// every packet this Machine receives. Returns a verdict for how
// the packet should continue to be handled (or not).
//
// This can be used to implement things like stateful firewalls
// and NAT boxes.
// The packet provided to HandlePacket can safely be mutated and
// Inject()ed if desired. This can be used to implement things
// like stateful firewalls and NAT boxes.
HandlePacket PacketHandler
mu sync.Mutex
@ -264,18 +297,22 @@ type Machine struct {
// Inject transmits p from src to dst, without the need for a local socket.
// It's useful for implementing e.g. NAT boxes that need to mangle IPs.
func (m *Machine) Inject(p []byte, dst, src netaddr.IPPort) error {
trace(p, "mach=%s src=%s dst=%s packet injected", m.Name, src, dst)
_, err := m.writePacket(p, dst, src)
func (m *Machine) Inject(p *Packet) error {
p = p.Clone()
p.setLocator("mach=%s", m.Name)
p.Trace("Machine.Inject")
_, err := m.writePacket(p)
return err
}
func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src netaddr.IPPort) {
func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) {
p.setLocator("mach=%s if=%s", m.Name, iface.name)
// TODO: can't hold lock while handling packet. This is safe as
// long as you set HandlePacket before traffic starts flowing.
if m.HandlePacket != nil {
verdict := m.HandlePacket(p, iface, dst, src)
trace(p, "mach=%s src=%v dst=%v packethandler verdict=%s", m.Name, src, dst, verdict)
p.Trace("Machine.HandlePacket")
verdict := m.HandlePacket(p.Clone(), iface)
p.Trace("Machine.HandlePacket verdict=%s", verdict)
if verdict == Drop {
// Custom packet handler ate the packet, we're done.
return
@ -286,13 +323,13 @@ func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src net
defer m.mu.Unlock()
conns := m.conns4
if dst.IP.Is6() {
if p.Dst.IP.Is6() {
conns = m.conns6
}
possibleDsts := []netaddr.IPPort{
dst,
netaddr.IPPort{IP: v6unspec, Port: dst.Port},
netaddr.IPPort{IP: v4unspec, Port: dst.Port},
p.Dst,
netaddr.IPPort{IP: v6unspec, Port: p.Dst.Port},
netaddr.IPPort{IP: v4unspec, Port: p.Dst.Port},
}
for _, dest := range possibleDsts {
c, ok := conns[dest]
@ -300,15 +337,15 @@ func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src net
continue
}
select {
case c.in <- incomingPacket{src: src, p: p}:
trace(p, "mach=%s src=%v dst=%v queued to conn", m.Name, src, dst)
case c.in <- p:
p.Trace("queued to conn")
default:
trace(p, "mach=%s src=%v dst=%v dropped, queue overflow", m.Name, src, dst)
p.Trace("dropped, queue overflow")
// Queue overflow. Just drop it.
}
return
}
trace(p, "mach=%s src=%v dst=%v dropped, no listening conn", m.Name, src, dst)
p.Trace("dropped, no listening conn")
}
func unspecOf(ip netaddr.IP) netaddr.IP {
@ -378,38 +415,43 @@ var (
v6unspec = netaddr.IPv6Unspecified()
)
func (m *Machine) writePacket(p []byte, dst, src netaddr.IPPort) (n int, err error) {
iface, err := m.interfaceForIP(dst.IP)
func (m *Machine) writePacket(p *Packet) (n int, err error) {
p.setLocator("mach=%s", m.Name)
iface, err := m.interfaceForIP(p.Dst.IP)
if err != nil {
trace(p, "%v", err)
p.Trace("%v", err)
return 0, err
}
origSrcIP := src.IP
origSrcIP := p.Src.IP
switch {
case src.IP == v4unspec:
src.IP = iface.V4()
case src.IP == v6unspec:
case p.Src.IP == v4unspec:
p.Trace("assigning srcIP=%s", iface.V4())
p.Src.IP = iface.V4()
case p.Src.IP == v6unspec:
// v6unspec in Go means "any src, but match address families"
if dst.IP.Is6() {
src.IP = iface.V6()
} else if dst.IP.Is4() {
src.IP = iface.V4()
if p.Dst.IP.Is6() {
p.Trace("assigning srcIP=%s", iface.V6())
p.Src.IP = iface.V6()
} else if p.Dst.IP.Is4() {
p.Trace("assigning srcIP=%s", iface.V4())
p.Src.IP = iface.V4()
}
default:
if !iface.Contains(src.IP) {
err := fmt.Errorf("can't send to %v with src %v on interface %v", dst.IP, src.IP, iface)
trace(p, "%v", err)
if !iface.Contains(p.Src.IP) {
err := fmt.Errorf("can't send to %v with src %v on interface %v", p.Dst.IP, p.Src.IP, iface)
p.Trace("%v", err)
return 0, err
}
}
if src.IP.IsZero() {
if p.Src.IP.IsZero() {
err := fmt.Errorf("no matching address for address family for %v", origSrcIP)
trace(p, "%v", err)
p.Trace("%v", err)
return 0, err
}
trace(p, "mach=%s src=%s dst=%s -> net=%s", m.Name, src, dst, iface.net.Name)
return iface.net.write(p, dst, src)
p.Trace("-> net=%s if=%s", iface.net.Name, iface)
return iface.net.write(p)
}
func (m *Machine) interfaceForIP(ip netaddr.IP) (*Interface, error) {
@ -552,7 +594,7 @@ func (m *Machine) ListenPacket(ctx context.Context, network, address string) (ne
m: m,
fam: fam,
ipp: ipp,
in: make(chan incomingPacket, 100), // arbitrary
in: make(chan *Packet, 100), // arbitrary
}
switch c.fam {
case 0:
@ -585,12 +627,7 @@ type conn struct {
closed bool
readDeadline time.Time
activeReads map[*activeRead]bool
in chan incomingPacket
}
type incomingPacket struct {
p []byte
src netaddr.IPPort
in chan *Packet
}
type activeRead struct {
@ -669,9 +706,9 @@ func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case pkt := <-c.in:
n = copy(p, pkt.p)
trace(pkt.p, "mach=%s src=%s PacketConn.ReadFrom", c.m.Name, pkt.src)
return n, pkt.src.UDPAddr(), nil
n = copy(p, pkt.Payload)
pkt.Trace("PacketConn.ReadFrom")
return n, pkt.Src.UDPAddr(), nil
case <-ctx.Done():
return 0, nil, context.DeadlineExceeded
}
@ -682,7 +719,14 @@ func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if err != nil {
return 0, fmt.Errorf("bogus addr %T %q", addr, addr.String())
}
return c.m.writePacket(p, ipp, c.ipp)
pkt := &Packet{
Src: c.ipp,
Dst: ipp,
Payload: append([]byte(nil), p...),
}
pkt.setLocator("mach=%s", c.m.Name)
pkt.Trace("PacketConn.WriteTo")
return c.m.writePacket(pkt)
}
func (c *conn) SetDeadline(t time.Time) error {

@ -175,15 +175,17 @@ func TestPacketHandler(t *testing.T) {
// port remappings or any other things that NATs usually to. But
// it works as a demonstrator for a single client behind the NAT,
// where the NAT box itself doesn't also make PacketConns.
nat.HandlePacket = func(p []byte, iface *Interface, dst, src netaddr.IPPort) PacketVerdict {
nat.HandlePacket = func(p *Packet, iface *Interface) PacketVerdict {
switch {
case dst.IP.Is6():
case p.Dst.IP.Is6():
return Continue // no NAT for ipv6
case iface == ifNATLAN && src.IP == ifClient.V4():
nat.Inject(p, dst, netaddr.IPPort{IP: ifNATWAN.V4(), Port: src.Port})
case iface == ifNATLAN && p.Src.IP == ifClient.V4():
p.Src.IP = ifNATWAN.V4()
nat.Inject(p)
return Drop
case iface == ifNATWAN && dst.IP == ifNATWAN.V4():
nat.Inject(p, netaddr.IPPort{IP: ifClient.V4(), Port: dst.Port}, src)
case iface == ifNATWAN && p.Dst.IP == ifNATWAN.V4():
p.Dst.IP = ifClient.V4()
nat.Inject(p)
return Drop
default:
return Continue
@ -257,7 +259,12 @@ func TestFirewall(t *testing.T) {
for _, test := range tests {
clock.Advance(time.Second)
got := f.HandlePacket(nil, test.iface, test.dst, test.src)
p := &Packet{
Src: test.src,
Dst: test.dst,
Payload: []byte{},
}
got := f.HandlePacket(p, test.iface)
if got != test.want {
t.Errorf("iface=%s src=%s dst=%s got %v, want %v", test.iface.name, test.src, test.dst, got, test.want)
}

Loading…
Cancel
Save