wgengine/magicsock: actually use AF_PACKET socket for raw disco

Previously, despite what the commit said, we were using a raw IP socket
that was *not* an AF_PACKET socket, and thus was subject to the host
firewall rules. Switch to using a real AF_PACKET socket to actually get
the functionality we want.

Updates #13140

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: If657daeeda9ab8d967e75a4f049c66e2bca54b78
pull/13351/head
Andrew Dunham 3 months ago
parent eb2fa16fcc
commit 1c972bc7cb

@ -171,7 +171,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
L 💣 github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ L 💣 github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+
L github.com/mdlayher/netlink/nltest from github.com/google/nftables L github.com/mdlayher/netlink/nltest from github.com/google/nftables
L github.com/mdlayher/sdnotify from tailscale.com/util/systemd L github.com/mdlayher/sdnotify from tailscale.com/util/systemd
L 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink L 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink+
github.com/miekg/dns from tailscale.com/net/dns/recursive github.com/miekg/dns from tailscale.com/net/dns/recursive
💣 github.com/mitchellh/go-ps from tailscale.com/safesocket 💣 github.com/mitchellh/go-ps from tailscale.com/safesocket
github.com/modern-go/concurrent from github.com/json-iterator/go github.com/modern-go/concurrent from github.com/json-iterator/go

@ -139,7 +139,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
L 💣 github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ L 💣 github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+
L github.com/mdlayher/netlink/nltest from github.com/google/nftables L github.com/mdlayher/netlink/nltest from github.com/google/nftables
L github.com/mdlayher/sdnotify from tailscale.com/util/systemd L github.com/mdlayher/sdnotify from tailscale.com/util/systemd
L 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink L 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink+
github.com/miekg/dns from tailscale.com/net/dns/recursive github.com/miekg/dns from tailscale.com/net/dns/recursive
💣 github.com/mitchellh/go-ps from tailscale.com/safesocket 💣 github.com/mitchellh/go-ps from tailscale.com/safesocket
L github.com/pierrec/lz4/v4 from github.com/u-root/uio/uio L github.com/pierrec/lz4/v4 from github.com/u-root/uio/uio

@ -393,6 +393,11 @@ func (q *Parsed) Buffer() []byte {
// Payload returns the payload of the IP subprotocol section. // Payload returns the payload of the IP subprotocol section.
// This is a read-only view; that is, q retains the ownership of the buffer. // This is a read-only view; that is, q retains the ownership of the buffer.
func (q *Parsed) Payload() []byte { func (q *Parsed) Payload() []byte {
// If the packet is truncated, return nothing instead of crashing.
if q.length > len(q.b) || q.dataofs > len(q.b) {
return nil
}
return q.b[q.dataofs:q.length] return q.b[q.dataofs:q.length]
} }

@ -5,28 +5,37 @@ package magicsock
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip" "net/netip"
"strings"
"syscall" "syscall"
"time" "time"
"unsafe"
"github.com/mdlayher/socket"
"golang.org/x/net/bpf" "golang.org/x/net/bpf"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/cpu"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"tailscale.com/disco"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/net/netns" "tailscale.com/net/netns"
"tailscale.com/types/ipproto"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/nettype" "tailscale.com/types/nettype"
) )
const ( const (
udpHeaderSize = 8 udpHeaderSize = 8
ipv6FragmentHeaderSize = 8
// discoMinHeaderSize is the minimum size of the disco header in bytes.
discoMinHeaderSize = len(disco.Magic) + 32 /* key length */ + disco.NonceLen
) )
// Enable/disable using raw sockets to receive disco traffic. // Enable/disable using raw sockets to receive disco traffic.
@ -38,8 +47,17 @@ var debugRawDiscoReads = envknob.RegisterBool("TS_DEBUG_RAW_DISCO")
// These are our BPF filters that we use for testing packets. // These are our BPF filters that we use for testing packets.
var ( var (
magicsockFilterV4 = []bpf.Instruction{ magicsockFilterV4 = []bpf.Instruction{
// For raw UDPv4 sockets, BPF receives the entire IP packet to // For raw sockets (with ETH_P_IP set), the BPF program
// inspect. // receives the entire IPv4 packet, but not the Ethernet
// header.
// Double-check that this is a UDP packet; we shouldn't be
// seeing anything else given how we create our AF_PACKET
// socket, but an extra check here is cheap, and matches the
// check that we do in the IPv6 path.
bpf.LoadAbsolute{Off: 9, Size: 1},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(ipproto.UDP), SkipTrue: 1, SkipFalse: 0},
bpf.RetConstant{Val: 0x0},
// Disco packets are so small they should never get // Disco packets are so small they should never get
// fragmented, and we don't want to handle reassembly. // fragmented, and we don't want to handle reassembly.
@ -53,6 +71,25 @@ var (
// Load IP header length into X register. // Load IP header length into X register.
bpf.LoadMemShift{Off: 0}, bpf.LoadMemShift{Off: 0},
// Verify that we have a packet that's big enough to (possibly)
// contain a disco packet.
//
// The length of an IPv4 disco packet is composed of:
// - 8 bytes for the UDP header
// - N bytes for the disco packet header
//
// bpf will implicitly return 0 ("skip") if attempting an
// out-of-bounds load, so we can check the length of the packet
// loading a byte from that offset here. We subtract 1 byte
// from the offset to ensure that we accept a packet that's
// exactly the minimum size.
//
// We use LoadIndirect; since we loaded the start of the packet's
// payload into the X register, above, we don't need to add
// ipv4.HeaderLen to the offset (and this properly handles IPv4
// extensions).
bpf.LoadIndirect{Off: uint32(udpHeaderSize + discoMinHeaderSize - 1), Size: 1},
// Get the first 4 bytes of the UDP packet, compare with our magic number // Get the first 4 bytes of the UDP packet, compare with our magic number
bpf.LoadIndirect{Off: udpHeaderSize, Size: 4}, bpf.LoadIndirect{Off: udpHeaderSize, Size: 4},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic1, SkipTrue: 0, SkipFalse: 3}, bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic1, SkipTrue: 0, SkipFalse: 3},
@ -82,25 +119,24 @@ var (
// and thus we'd rather be conservative here and possibly not receive // and thus we'd rather be conservative here and possibly not receive
// disco packets rather than slow down the system. // disco packets rather than slow down the system.
magicsockFilterV6 = []bpf.Instruction{ magicsockFilterV6 = []bpf.Instruction{
// For raw UDPv6 sockets, BPF receives _only_ the UDP header onwards, not an entire IP packet. // Do a bounds check to ensure we have enough space for a disco
// // packet; see the comment in the IPv4 BPF program for more
// https://stackoverflow.com/questions/24514333/using-bpf-with-sock-dgram-on-linux-machine // details.
// https://blog.cloudflare.com/epbf_sockets_hop_distance/ bpf.LoadAbsolute{Off: uint32(ipv6.HeaderLen + udpHeaderSize + discoMinHeaderSize - 1), Size: 1},
//
// This is especially confusing because this *isn't* true for // Verify that the 'next header' value of the IPv6 packet is
// IPv4; see the following code from the 'ping' utility that // UDP, which is what we're expecting; if it's anything else
// corroborates this: // (including extension headers), we skip the packet.
// bpf.LoadAbsolute{Off: 6, Size: 1},
// https://github.com/iputils/iputils/blob/1ab5fa/ping/ping.c#L1667-L1676 bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(ipproto.UDP), SkipTrue: 0, SkipFalse: 5},
// https://github.com/iputils/iputils/blob/1ab5fa/ping/ping6_common.c#L933-L941
// Compare with our magic number. Start by loading and // Compare with our magic number. Start by loading and
// comparing the first 4 bytes of the UDP payload. // comparing the first 4 bytes of the UDP payload.
bpf.LoadAbsolute{Off: udpHeaderSize, Size: 4}, bpf.LoadAbsolute{Off: ipv6.HeaderLen + udpHeaderSize, Size: 4},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic1, SkipTrue: 0, SkipFalse: 3}, bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic1, SkipTrue: 0, SkipFalse: 3},
// Compare the next 2 bytes // Compare the next 2 bytes
bpf.LoadAbsolute{Off: udpHeaderSize + 4, Size: 2}, bpf.LoadAbsolute{Off: ipv6.HeaderLen + udpHeaderSize + 4, Size: 2},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic2, SkipTrue: 0, SkipFalse: 1}, bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic2, SkipTrue: 0, SkipFalse: 1},
// Accept the whole packet // Accept the whole packet
@ -140,21 +176,24 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
} }
var ( var (
network string udpnet string
addr string addr string
testAddr string proto int
testAddr netip.AddrPort
prog []bpf.Instruction prog []bpf.Instruction
) )
switch family { switch family {
case "ip4": case "ip4":
network = "ip4:17" udpnet = "udp4"
addr = "0.0.0.0" addr = "0.0.0.0"
testAddr = "127.0.0.1:1" proto = ethernetProtoIPv4()
testAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 1)
prog = magicsockFilterV4 prog = magicsockFilterV4
case "ip6": case "ip6":
network = "ip6:17" udpnet = "udp6"
addr = "::" addr = "::"
testAddr = "[::1]:1" proto = ethernetProtoIPv6()
testAddr = netip.AddrPortFrom(netip.IPv6Loopback(), 1)
prog = magicsockFilterV6 prog = magicsockFilterV6
default: default:
return nil, fmt.Errorf("unsupported address family %q", family) return nil, fmt.Errorf("unsupported address family %q", family)
@ -165,72 +204,214 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
return nil, fmt.Errorf("assembling filter: %w", err) return nil, fmt.Errorf("assembling filter: %w", err)
} }
pc, err := net.ListenPacket(network, addr) sock, err := socket.Socket(
unix.AF_PACKET,
unix.SOCK_DGRAM,
proto,
"afpacket",
nil, // no config
)
if err != nil { if err != nil {
return nil, fmt.Errorf("creating packet conn: %w", err) return nil, fmt.Errorf("creating AF_PACKET socket: %w", err)
} }
if err := setBPF(pc, asm); err != nil { if err := sock.SetBPF(asm); err != nil {
pc.Close() sock.Close()
return nil, fmt.Errorf("installing BPF filter: %w", err) return nil, fmt.Errorf("installing BPF filter: %w", err)
} }
// If all the above succeeds, we should be ready to receive. Just // If all the above succeeds, we should be ready to receive. Just
// out of paranoia, check that we do receive a well-formed disco // out of paranoia, check that we do receive a well-formed disco
// packet. // packet.
tc, err := net.ListenPacket("udp", net.JoinHostPort(addr, "0")) tc, err := net.ListenPacket(udpnet, net.JoinHostPort(addr, "0"))
if err != nil { if err != nil {
pc.Close() sock.Close()
return nil, fmt.Errorf("creating disco test socket: %w", err) return nil, fmt.Errorf("creating disco test socket: %w", err)
} }
defer tc.Close() defer tc.Close()
if _, err := tc.(*net.UDPConn).WriteToUDPAddrPort(testDiscoPacket, netip.MustParseAddrPort(testAddr)); err != nil { if _, err := tc.(*net.UDPConn).WriteToUDPAddrPort(testDiscoPacket, testAddr); err != nil {
pc.Close() sock.Close()
return nil, fmt.Errorf("writing disco test packet: %w", err) return nil, fmt.Errorf("writing disco test packet: %w", err)
} }
pc.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
var buf [1500]byte const selfTestTimeout = 100 * time.Millisecond
if err := sock.SetReadDeadline(time.Now().Add(selfTestTimeout)); err != nil {
sock.Close()
return nil, fmt.Errorf("setting socket timeout: %w", err)
}
var (
ctx = context.Background()
buf [1500]byte
)
for { for {
n, _, err := pc.ReadFrom(buf[:]) n, _, err := sock.Recvfrom(ctx, buf[:], 0)
if err != nil { if err != nil {
pc.Close() sock.Close()
return nil, fmt.Errorf("reading during raw disco self-test: %w", err) return nil, fmt.Errorf("reading during raw disco self-test: %w", err)
} }
if n < udpHeaderSize {
_ /* src */, _ /* dst */, payload := parseUDPPacket(buf[:n], family == "ip6")
if payload == nil {
continue continue
} }
if !bytes.Equal(buf[udpHeaderSize:n], testDiscoPacket) { if !bytes.Equal(payload, testDiscoPacket) {
c.discoLogf("listenRawDisco: self-test: received mismatched UDP packet of %d bytes", len(payload))
continue continue
} }
c.logf("[v1] listenRawDisco: self-test passed for %s", family)
break break
} }
pc.SetReadDeadline(time.Time{}) sock.SetReadDeadline(time.Time{})
go c.receiveDisco(pc, family == "ip6") go c.receiveDisco(sock, family == "ip6")
return pc, nil return sock, nil
} }
func (c *Conn) receiveDisco(pc net.PacketConn, isIPV6 bool) { // parseUDPPacket is a basic parser for UDP packets that returns the source and
// destination addresses, and the payload. The returned payload is a sub-slice
// of the input buffer.
//
// It expects to be called with a buffer that contains the entire UDP packet,
// including the IP header, and one that has been filtered with the BPF
// programs above.
//
// If an error occurs, it will return the zero values for all return values.
func parseUDPPacket(buf []byte, isIPv6 bool) (src, dst netip.AddrPort, payload []byte) {
// First, parse the IPv4 or IPv6 header to get to the UDP header. Since
// we assume this was filtered with BPF, we know that there will be no
// IPv6 extension headers.
var (
srcIP, dstIP netip.Addr
udp []byte
)
if isIPv6 {
// Basic length check to ensure that we don't panic
if len(buf) < ipv6.HeaderLen+udpHeaderSize {
return
}
// Extract the source and destination addresses from the IPv6
// header.
srcIP, _ = netip.AddrFromSlice(buf[8:24])
dstIP, _ = netip.AddrFromSlice(buf[24:40])
// We know that the UDP packet starts immediately after the IPv6
// packet.
udp = buf[ipv6.HeaderLen:]
} else {
// This is an IPv4 packet; read the length field from the header.
if len(buf) < ipv4.HeaderLen {
return
}
udpOffset := int((buf[0] & 0x0F) << 2)
if udpOffset+udpHeaderSize > len(buf) {
return
}
// Parse the source and destination IPs.
srcIP, _ = netip.AddrFromSlice(buf[12:16])
dstIP, _ = netip.AddrFromSlice(buf[16:20])
udp = buf[udpOffset:]
}
// Parse the ports
srcPort := binary.BigEndian.Uint16(udp[0:2])
dstPort := binary.BigEndian.Uint16(udp[2:4])
// The payload starts after the UDP header.
payload = udp[8:]
return netip.AddrPortFrom(srcIP, srcPort), netip.AddrPortFrom(dstIP, dstPort), payload
}
// ethernetProtoIPv4 returns the constant unix.ETH_P_IP, in network byte order.
// packet(7) sockets require that the 'protocol' argument be in network byte
// order; see:
//
// https://man7.org/linux/man-pages/man7/packet.7.html
//
// Instead of using htons at runtime, we can just hardcode the value here...
// but we also have a test that verifies that this is correct.
func ethernetProtoIPv4() int {
if cpu.IsBigEndian {
return 0x0800
} else {
return 0x0008
}
}
// ethernetProtoIPv6 returns the constant unix.ETH_P_IPV6, and is otherwise the
// same as ethernetProtoIPv4.
func ethernetProtoIPv6() int {
if cpu.IsBigEndian {
return 0x86dd
} else {
return 0xdd86
}
}
func (c *Conn) discoLogf(format string, args ...any) {
// Enable debug logging if we're debugging raw disco reads or if the
// magicsock component logs are on.
if debugRawDiscoReads() {
c.logf(format, args...)
} else {
c.dlogf(format, args...)
}
}
func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
// Given that we're parsing raw packets, be extra careful and recover
// from any panics in this function.
//
// If we didn't have a recover() here and panic'd, we'd take down the
// entire process since this function is the top of a goroutine, and Go
// will kill the process if a goroutine panics and it unwinds past the
// top-level function.
defer func() {
if err := recover(); err != nil {
c.logf("[unexpected] recovered from panic in receiveDisco(isIPv6=%v): %v", isIPV6, err)
}
}()
ctx := context.Background()
// Set up our loggers
var family string
if isIPV6 {
family = "ip6"
} else {
family = "ip4"
}
var (
prefix string = "disco raw " + family + ": "
logf logger.Logf = logger.WithPrefix(c.logf, prefix)
dlogf logger.Logf = logger.WithPrefix(c.discoLogf, prefix)
)
var buf [1500]byte var buf [1500]byte
for { for {
n, src, err := pc.ReadFrom(buf[:]) n, src, err := pc.Recvfrom(ctx, buf[:], 0)
if debugRawDiscoReads() { if debugRawDiscoReads() {
c.logf("raw disco read from %v = (%v, %v)", src, n, err) logf("read from %s = (%v, %v)", printSockaddr(src), n, err)
} }
if errors.Is(err, net.ErrClosed) { if err != nil && (errors.Is(err, net.ErrClosed) || err.Error() == "use of closed file") {
// EOF; no need to print an error
return return
} else if err != nil { } else if err != nil {
c.logf("disco raw reader failed: %v", err) logf("reader failed: %v", err)
return return
} }
if n < udpHeaderSize {
// Too small to be a valid UDP datagram, drop. srcAddr, dstAddr, payload := parseUDPPacket(buf[:n], family == "ip6")
if payload == nil {
// callee logged
continue continue
} }
dstPort := binary.BigEndian.Uint16(buf[2:4]) dstPort := dstAddr.Port()
if dstPort == 0 { if dstPort == 0 {
c.logf("[unexpected] disco raw: received packet for port 0") logf("[unexpected] received packet for port 0")
} }
var acceptPort uint16 var acceptPort uint16
@ -242,59 +423,58 @@ func (c *Conn) receiveDisco(pc net.PacketConn, isIPV6 bool) {
if acceptPort == 0 { if acceptPort == 0 {
// This should only typically happen if the receiving address family // This should only typically happen if the receiving address family
// was recently disabled. // was recently disabled.
c.dlogf("[v1] disco raw: dropping packet for port %d as acceptPort=0", dstPort) dlogf("[v1] dropping packet for port %d as acceptPort=0", dstPort)
continue continue
} }
// If the packet isn't destined for our local port, then we
// should drop it since it might be for another Tailscale
// process on the same machine, or NATed to a different machine
// if this is a router, etc.
//
// We get the local port to compare against inside the receive
// loop; we can't cache this beforehand because it can change
// if/when we rebind.
if dstPort != acceptPort { if dstPort != acceptPort {
c.dlogf("[v1] disco raw: dropping packet for port %d", dstPort) dlogf("[v1] dropping packet for port %d that isn't our local port", dstPort)
continue
}
srcIP, ok := netip.AddrFromSlice(src.(*net.IPAddr).IP)
if !ok {
c.logf("[unexpected] PacketConn.ReadFrom returned not-an-IP %v in from", src)
continue continue
} }
srcPort := binary.BigEndian.Uint16(buf[:2])
if srcIP.Is4() { if isIPV6 {
metricRecvDiscoPacketIPv4.Add(1)
} else {
metricRecvDiscoPacketIPv6.Add(1) metricRecvDiscoPacketIPv6.Add(1)
} else {
metricRecvDiscoPacketIPv4.Add(1)
} }
c.handleDiscoMessage(buf[udpHeaderSize:n], netip.AddrPortFrom(srcIP, srcPort), key.NodePublic{}, discoRXPathRawSocket) c.handleDiscoMessage(payload, srcAddr, key.NodePublic{}, discoRXPathRawSocket)
} }
} }
// setBPF installs filter as the BPF filter on conn. // printSockaddr is a helper function to pretty-print various sockaddr types.
// Ideally we would just use SetBPF as implemented in x/net/ipv4, func printSockaddr(sa unix.Sockaddr) string {
// but x/net/ipv6 doesn't implement it. And once you've written switch sa := sa.(type) {
// this code once, it turns out to be address family agnostic, so case *unix.SockaddrInet4:
// we might as well use it on both and get to use a net.PacketConn addr := netip.AddrFrom4(sa.Addr)
// directly for both families instead of being stuck with return netip.AddrPortFrom(addr, uint16(sa.Port)).String()
// different types. case *unix.SockaddrInet6:
func setBPF(conn net.PacketConn, filter []bpf.RawInstruction) error { addr := netip.AddrFrom16(sa.Addr)
sc, err := conn.(*net.IPConn).SyscallConn() return netip.AddrPortFrom(addr, uint16(sa.Port)).String()
if err != nil { case *unix.SockaddrLinklayer:
return err hwaddr := sa.Addr[:sa.Halen]
}
prog := &unix.SockFprog{ var buf strings.Builder
Len: uint16(len(filter)), fmt.Fprintf(&buf, "link(ty=0x%04x,if=%d):[", sa.Protocol, sa.Ifindex)
Filter: (*unix.SockFilter)(unsafe.Pointer(&filter[0])), for i, b := range hwaddr {
} if i > 0 {
var setErr error buf.WriteByte(':')
err = sc.Control(func(fd uintptr) { }
setErr = unix.SetsockoptSockFprog(int(fd), unix.SOL_SOCKET, unix.SO_ATTACH_FILTER, prog) fmt.Fprintf(&buf, "%02x", b)
}) }
if err != nil { buf.WriteByte(']')
return err return buf.String()
} default:
if setErr != nil { return fmt.Sprintf("unknown(%T)", sa)
return err
} }
return nil
} }
// trySetSocketBuffer attempts to set SO_SNDBUFFORCE and SO_RECVBUFFORCE which // trySetSocketBuffer attempts to set SO_SNDBUFFORCE and SO_RECVBUFFORCE which

@ -0,0 +1,148 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"bytes"
"encoding/binary"
"net/netip"
"testing"
"golang.org/x/sys/cpu"
"golang.org/x/sys/unix"
"tailscale.com/disco"
)
func TestParseUDPPacket(t *testing.T) {
src4 := netip.MustParseAddrPort("127.0.0.1:12345")
dst4 := netip.MustParseAddrPort("127.0.0.2:54321")
src6 := netip.MustParseAddrPort("[::1]:12345")
dst6 := netip.MustParseAddrPort("[::2]:54321")
udp4Packet := []byte{
// IPv4 header
0x45, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0x00,
0x40, 0x11, 0x00, 0x00,
0x7f, 0x00, 0x00, 0x01, // source ip
0x7f, 0x00, 0x00, 0x02, // dest ip
// UDP header
0x30, 0x39, // src port
0xd4, 0x31, // dest port
0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes
0x00, 0x00, // checksum; unused
// Payload: disco magic plus 4 bytes
0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03,
}
udp6Packet := []byte{
// IPv6 header
0x60, 0x00, 0x00, 0x00,
0x00, 0x12, // payload length
0x11, // next header: UDP
0x00, // hop limit; unused
// Source IP
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
// Dest IP
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
// UDP header
0x30, 0x39, // src port
0xd4, 0x31, // dest port
0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes
0x00, 0x00, // checksum; unused
// Payload: disco magic plus 4 bytes
0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03,
}
// Verify that parsing the UDP packet works correctly.
t.Run("IPv4", func(t *testing.T) {
src, dst, payload := parseUDPPacket(udp4Packet, false)
if src != src4 {
t.Errorf("src = %v; want %v", src, src4)
}
if dst != dst4 {
t.Errorf("dst = %v; want %v", dst, dst4)
}
if !bytes.HasPrefix(payload, []byte(disco.Magic)) {
t.Errorf("payload = %x; must start with %x", payload, disco.Magic)
}
})
t.Run("IPv6", func(t *testing.T) {
src, dst, payload := parseUDPPacket(udp6Packet, true)
if src != src6 {
t.Errorf("src = %v; want %v", src, src6)
}
if dst != dst6 {
t.Errorf("dst = %v; want %v", dst, dst6)
}
if !bytes.HasPrefix(payload, []byte(disco.Magic)) {
t.Errorf("payload = %x; must start with %x", payload, disco.Magic)
}
})
t.Run("Truncated", func(t *testing.T) {
truncateBy := func(b []byte, n int) []byte {
if n >= len(b) {
return nil
}
return b[:len(b)-n]
}
src, dst, payload := parseUDPPacket(truncateBy(udp4Packet, 11), false)
if payload != nil {
t.Errorf("payload = %x; want nil", payload)
}
if src.IsValid() || dst.IsValid() {
t.Errorf("src = %v, dst = %v; want invalid", src, dst)
}
src, dst, payload = parseUDPPacket(truncateBy(udp6Packet, 11), true)
if payload != nil {
t.Errorf("payload = %x; want nil", payload)
}
if src.IsValid() || dst.IsValid() {
t.Errorf("src = %v, dst = %v; want invalid", src, dst)
}
})
}
func TestEthernetProto(t *testing.T) {
htons := func(x uint16) int {
// Network byte order is big-endian; write the value as
// big-endian to a byte slice and read it back in the native
// endian-ness. This is a no-op on a big-endian platform and a
// byte swap on a little-endian platform.
var b [2]byte
binary.BigEndian.PutUint16(b[:], x)
return int(binary.NativeEndian.Uint16(b[:]))
}
if v4 := ethernetProtoIPv4(); v4 != htons(unix.ETH_P_IP) {
t.Errorf("ethernetProtoIPv4 = 0x%04x; want 0x%04x", v4, htons(unix.ETH_P_IP))
}
if v6 := ethernetProtoIPv6(); v6 != htons(unix.ETH_P_IPV6) {
t.Errorf("ethernetProtoIPv6 = 0x%04x; want 0x%04x", v6, htons(unix.ETH_P_IPV6))
}
// As a way to verify that the htons function is working correctly,
// assert that the ETH_P_IP value returned from our function matches
// the value defined in the unix package based on whether the host is
// big-endian (network byte order) or little-endian.
if cpu.IsBigEndian {
if v4 := ethernetProtoIPv4(); v4 != unix.ETH_P_IP {
t.Errorf("ethernetProtoIPv4 = 0x%04x; want 0x%04x", v4, unix.ETH_P_IP)
}
} else {
if v4 := ethernetProtoIPv4(); v4 == unix.ETH_P_IP {
t.Errorf("ethernetProtoIPv4 = 0x%04x; want 0x%04x", v4, htons(unix.ETH_P_IP))
} else {
t.Logf("ethernetProtoIPv4 = 0x%04x, correctly different from 0x%04x", v4, unix.ETH_P_IP)
}
}
}
Loading…
Cancel
Save