diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 34f933e37..121647a20 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -171,7 +171,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ L 💣 github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ L github.com/mdlayher/netlink/nltest from github.com/google/nftables L github.com/mdlayher/sdnotify from tailscale.com/util/systemd - L 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink + L 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink+ github.com/miekg/dns from tailscale.com/net/dns/recursive 💣 github.com/mitchellh/go-ps from tailscale.com/safesocket github.com/modern-go/concurrent from github.com/json-iterator/go diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 4e5da410a..eed37c7d4 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -139,7 +139,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L 💣 github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ L github.com/mdlayher/netlink/nltest from github.com/google/nftables L github.com/mdlayher/sdnotify from tailscale.com/util/systemd - L 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink + L 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink+ github.com/miekg/dns from tailscale.com/net/dns/recursive 💣 github.com/mitchellh/go-ps from tailscale.com/safesocket L github.com/pierrec/lz4/v4 from github.com/u-root/uio/uio diff --git a/net/packet/packet.go b/net/packet/packet.go index dc870414a..c9521ad46 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -393,6 +393,11 @@ func (q *Parsed) Buffer() []byte { // Payload returns the payload of the IP subprotocol section. // This is a read-only view; that is, q retains the ownership of the buffer. func (q *Parsed) Payload() []byte { + // If the packet is truncated, return nothing instead of crashing. + if q.length > len(q.b) || q.dataofs > len(q.b) { + return nil + } + return q.b[q.dataofs:q.length] } diff --git a/wgengine/magicsock/magicsock_linux.go b/wgengine/magicsock/magicsock_linux.go index a647c90d2..f658c016b 100644 --- a/wgengine/magicsock/magicsock_linux.go +++ b/wgengine/magicsock/magicsock_linux.go @@ -5,28 +5,37 @@ package magicsock import ( "bytes" + "context" "encoding/binary" "errors" "fmt" "io" "net" "net/netip" + "strings" "syscall" "time" - "unsafe" + "github.com/mdlayher/socket" "golang.org/x/net/bpf" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.org/x/sys/cpu" "golang.org/x/sys/unix" + "tailscale.com/disco" "tailscale.com/envknob" "tailscale.com/net/netns" + "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/nettype" ) const ( - udpHeaderSize = 8 - ipv6FragmentHeaderSize = 8 + udpHeaderSize = 8 + + // discoMinHeaderSize is the minimum size of the disco header in bytes. + discoMinHeaderSize = len(disco.Magic) + 32 /* key length */ + disco.NonceLen ) // Enable/disable using raw sockets to receive disco traffic. @@ -38,8 +47,17 @@ var debugRawDiscoReads = envknob.RegisterBool("TS_DEBUG_RAW_DISCO") // These are our BPF filters that we use for testing packets. var ( magicsockFilterV4 = []bpf.Instruction{ - // For raw UDPv4 sockets, BPF receives the entire IP packet to - // inspect. + // For raw sockets (with ETH_P_IP set), the BPF program + // receives the entire IPv4 packet, but not the Ethernet + // header. + + // Double-check that this is a UDP packet; we shouldn't be + // seeing anything else given how we create our AF_PACKET + // socket, but an extra check here is cheap, and matches the + // check that we do in the IPv6 path. + bpf.LoadAbsolute{Off: 9, Size: 1}, + bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(ipproto.UDP), SkipTrue: 1, SkipFalse: 0}, + bpf.RetConstant{Val: 0x0}, // Disco packets are so small they should never get // fragmented, and we don't want to handle reassembly. @@ -53,6 +71,25 @@ var ( // Load IP header length into X register. bpf.LoadMemShift{Off: 0}, + // Verify that we have a packet that's big enough to (possibly) + // contain a disco packet. + // + // The length of an IPv4 disco packet is composed of: + // - 8 bytes for the UDP header + // - N bytes for the disco packet header + // + // bpf will implicitly return 0 ("skip") if attempting an + // out-of-bounds load, so we can check the length of the packet + // loading a byte from that offset here. We subtract 1 byte + // from the offset to ensure that we accept a packet that's + // exactly the minimum size. + // + // We use LoadIndirect; since we loaded the start of the packet's + // payload into the X register, above, we don't need to add + // ipv4.HeaderLen to the offset (and this properly handles IPv4 + // extensions). + bpf.LoadIndirect{Off: uint32(udpHeaderSize + discoMinHeaderSize - 1), Size: 1}, + // Get the first 4 bytes of the UDP packet, compare with our magic number bpf.LoadIndirect{Off: udpHeaderSize, Size: 4}, bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic1, SkipTrue: 0, SkipFalse: 3}, @@ -82,25 +119,24 @@ var ( // and thus we'd rather be conservative here and possibly not receive // disco packets rather than slow down the system. magicsockFilterV6 = []bpf.Instruction{ - // For raw UDPv6 sockets, BPF receives _only_ the UDP header onwards, not an entire IP packet. - // - // https://stackoverflow.com/questions/24514333/using-bpf-with-sock-dgram-on-linux-machine - // https://blog.cloudflare.com/epbf_sockets_hop_distance/ - // - // This is especially confusing because this *isn't* true for - // IPv4; see the following code from the 'ping' utility that - // corroborates this: - // - // https://github.com/iputils/iputils/blob/1ab5fa/ping/ping.c#L1667-L1676 - // https://github.com/iputils/iputils/blob/1ab5fa/ping/ping6_common.c#L933-L941 + // Do a bounds check to ensure we have enough space for a disco + // packet; see the comment in the IPv4 BPF program for more + // details. + bpf.LoadAbsolute{Off: uint32(ipv6.HeaderLen + udpHeaderSize + discoMinHeaderSize - 1), Size: 1}, + + // Verify that the 'next header' value of the IPv6 packet is + // UDP, which is what we're expecting; if it's anything else + // (including extension headers), we skip the packet. + bpf.LoadAbsolute{Off: 6, Size: 1}, + bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(ipproto.UDP), SkipTrue: 0, SkipFalse: 5}, // Compare with our magic number. Start by loading and // comparing the first 4 bytes of the UDP payload. - bpf.LoadAbsolute{Off: udpHeaderSize, Size: 4}, + bpf.LoadAbsolute{Off: ipv6.HeaderLen + udpHeaderSize, Size: 4}, bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic1, SkipTrue: 0, SkipFalse: 3}, // Compare the next 2 bytes - bpf.LoadAbsolute{Off: udpHeaderSize + 4, Size: 2}, + bpf.LoadAbsolute{Off: ipv6.HeaderLen + udpHeaderSize + 4, Size: 2}, bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic2, SkipTrue: 0, SkipFalse: 1}, // Accept the whole packet @@ -140,21 +176,24 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) { } var ( - network string + udpnet string addr string - testAddr string + proto int + testAddr netip.AddrPort prog []bpf.Instruction ) switch family { case "ip4": - network = "ip4:17" + udpnet = "udp4" addr = "0.0.0.0" - testAddr = "127.0.0.1:1" + proto = ethernetProtoIPv4() + testAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 1) prog = magicsockFilterV4 case "ip6": - network = "ip6:17" + udpnet = "udp6" addr = "::" - testAddr = "[::1]:1" + proto = ethernetProtoIPv6() + testAddr = netip.AddrPortFrom(netip.IPv6Loopback(), 1) prog = magicsockFilterV6 default: return nil, fmt.Errorf("unsupported address family %q", family) @@ -165,72 +204,214 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) { return nil, fmt.Errorf("assembling filter: %w", err) } - pc, err := net.ListenPacket(network, addr) + sock, err := socket.Socket( + unix.AF_PACKET, + unix.SOCK_DGRAM, + proto, + "afpacket", + nil, // no config + ) if err != nil { - return nil, fmt.Errorf("creating packet conn: %w", err) + return nil, fmt.Errorf("creating AF_PACKET socket: %w", err) } - if err := setBPF(pc, asm); err != nil { - pc.Close() + if err := sock.SetBPF(asm); err != nil { + sock.Close() return nil, fmt.Errorf("installing BPF filter: %w", err) } // If all the above succeeds, we should be ready to receive. Just // out of paranoia, check that we do receive a well-formed disco // packet. - tc, err := net.ListenPacket("udp", net.JoinHostPort(addr, "0")) + tc, err := net.ListenPacket(udpnet, net.JoinHostPort(addr, "0")) if err != nil { - pc.Close() + sock.Close() return nil, fmt.Errorf("creating disco test socket: %w", err) } defer tc.Close() - if _, err := tc.(*net.UDPConn).WriteToUDPAddrPort(testDiscoPacket, netip.MustParseAddrPort(testAddr)); err != nil { - pc.Close() + if _, err := tc.(*net.UDPConn).WriteToUDPAddrPort(testDiscoPacket, testAddr); err != nil { + sock.Close() return nil, fmt.Errorf("writing disco test packet: %w", err) } - pc.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - var buf [1500]byte + + const selfTestTimeout = 100 * time.Millisecond + if err := sock.SetReadDeadline(time.Now().Add(selfTestTimeout)); err != nil { + sock.Close() + return nil, fmt.Errorf("setting socket timeout: %w", err) + } + + var ( + ctx = context.Background() + buf [1500]byte + ) for { - n, _, err := pc.ReadFrom(buf[:]) + n, _, err := sock.Recvfrom(ctx, buf[:], 0) if err != nil { - pc.Close() + sock.Close() return nil, fmt.Errorf("reading during raw disco self-test: %w", err) } - if n < udpHeaderSize { + + _ /* src */, _ /* dst */, payload := parseUDPPacket(buf[:n], family == "ip6") + if payload == nil { continue } - if !bytes.Equal(buf[udpHeaderSize:n], testDiscoPacket) { + if !bytes.Equal(payload, testDiscoPacket) { + c.discoLogf("listenRawDisco: self-test: received mismatched UDP packet of %d bytes", len(payload)) continue } + c.logf("[v1] listenRawDisco: self-test passed for %s", family) break } - pc.SetReadDeadline(time.Time{}) + sock.SetReadDeadline(time.Time{}) - go c.receiveDisco(pc, family == "ip6") - return pc, nil + go c.receiveDisco(sock, family == "ip6") + return sock, nil } -func (c *Conn) receiveDisco(pc net.PacketConn, isIPV6 bool) { +// parseUDPPacket is a basic parser for UDP packets that returns the source and +// destination addresses, and the payload. The returned payload is a sub-slice +// of the input buffer. +// +// It expects to be called with a buffer that contains the entire UDP packet, +// including the IP header, and one that has been filtered with the BPF +// programs above. +// +// If an error occurs, it will return the zero values for all return values. +func parseUDPPacket(buf []byte, isIPv6 bool) (src, dst netip.AddrPort, payload []byte) { + // First, parse the IPv4 or IPv6 header to get to the UDP header. Since + // we assume this was filtered with BPF, we know that there will be no + // IPv6 extension headers. + var ( + srcIP, dstIP netip.Addr + udp []byte + ) + if isIPv6 { + // Basic length check to ensure that we don't panic + if len(buf) < ipv6.HeaderLen+udpHeaderSize { + return + } + + // Extract the source and destination addresses from the IPv6 + // header. + srcIP, _ = netip.AddrFromSlice(buf[8:24]) + dstIP, _ = netip.AddrFromSlice(buf[24:40]) + + // We know that the UDP packet starts immediately after the IPv6 + // packet. + udp = buf[ipv6.HeaderLen:] + } else { + // This is an IPv4 packet; read the length field from the header. + if len(buf) < ipv4.HeaderLen { + return + } + udpOffset := int((buf[0] & 0x0F) << 2) + if udpOffset+udpHeaderSize > len(buf) { + return + } + + // Parse the source and destination IPs. + srcIP, _ = netip.AddrFromSlice(buf[12:16]) + dstIP, _ = netip.AddrFromSlice(buf[16:20]) + udp = buf[udpOffset:] + } + + // Parse the ports + srcPort := binary.BigEndian.Uint16(udp[0:2]) + dstPort := binary.BigEndian.Uint16(udp[2:4]) + + // The payload starts after the UDP header. + payload = udp[8:] + return netip.AddrPortFrom(srcIP, srcPort), netip.AddrPortFrom(dstIP, dstPort), payload +} + +// ethernetProtoIPv4 returns the constant unix.ETH_P_IP, in network byte order. +// packet(7) sockets require that the 'protocol' argument be in network byte +// order; see: +// +// https://man7.org/linux/man-pages/man7/packet.7.html +// +// Instead of using htons at runtime, we can just hardcode the value here... +// but we also have a test that verifies that this is correct. +func ethernetProtoIPv4() int { + if cpu.IsBigEndian { + return 0x0800 + } else { + return 0x0008 + } +} + +// ethernetProtoIPv6 returns the constant unix.ETH_P_IPV6, and is otherwise the +// same as ethernetProtoIPv4. +func ethernetProtoIPv6() int { + if cpu.IsBigEndian { + return 0x86dd + } else { + return 0xdd86 + } +} + +func (c *Conn) discoLogf(format string, args ...any) { + // Enable debug logging if we're debugging raw disco reads or if the + // magicsock component logs are on. + if debugRawDiscoReads() { + c.logf(format, args...) + } else { + c.dlogf(format, args...) + } +} + +func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) { + // Given that we're parsing raw packets, be extra careful and recover + // from any panics in this function. + // + // If we didn't have a recover() here and panic'd, we'd take down the + // entire process since this function is the top of a goroutine, and Go + // will kill the process if a goroutine panics and it unwinds past the + // top-level function. + defer func() { + if err := recover(); err != nil { + c.logf("[unexpected] recovered from panic in receiveDisco(isIPv6=%v): %v", isIPV6, err) + } + }() + + ctx := context.Background() + + // Set up our loggers + var family string + if isIPV6 { + family = "ip6" + } else { + family = "ip4" + } + var ( + prefix string = "disco raw " + family + ": " + logf logger.Logf = logger.WithPrefix(c.logf, prefix) + dlogf logger.Logf = logger.WithPrefix(c.discoLogf, prefix) + ) + var buf [1500]byte for { - n, src, err := pc.ReadFrom(buf[:]) + n, src, err := pc.Recvfrom(ctx, buf[:], 0) if debugRawDiscoReads() { - c.logf("raw disco read from %v = (%v, %v)", src, n, err) + logf("read from %s = (%v, %v)", printSockaddr(src), n, err) } - if errors.Is(err, net.ErrClosed) { + if err != nil && (errors.Is(err, net.ErrClosed) || err.Error() == "use of closed file") { + // EOF; no need to print an error return } else if err != nil { - c.logf("disco raw reader failed: %v", err) + logf("reader failed: %v", err) return } - if n < udpHeaderSize { - // Too small to be a valid UDP datagram, drop. + + srcAddr, dstAddr, payload := parseUDPPacket(buf[:n], family == "ip6") + if payload == nil { + // callee logged continue } - dstPort := binary.BigEndian.Uint16(buf[2:4]) + dstPort := dstAddr.Port() if dstPort == 0 { - c.logf("[unexpected] disco raw: received packet for port 0") + logf("[unexpected] received packet for port 0") } var acceptPort uint16 @@ -242,59 +423,58 @@ func (c *Conn) receiveDisco(pc net.PacketConn, isIPV6 bool) { if acceptPort == 0 { // This should only typically happen if the receiving address family // was recently disabled. - c.dlogf("[v1] disco raw: dropping packet for port %d as acceptPort=0", dstPort) + dlogf("[v1] dropping packet for port %d as acceptPort=0", dstPort) continue } + // If the packet isn't destined for our local port, then we + // should drop it since it might be for another Tailscale + // process on the same machine, or NATed to a different machine + // if this is a router, etc. + // + // We get the local port to compare against inside the receive + // loop; we can't cache this beforehand because it can change + // if/when we rebind. if dstPort != acceptPort { - c.dlogf("[v1] disco raw: dropping packet for port %d", dstPort) - continue - } - - srcIP, ok := netip.AddrFromSlice(src.(*net.IPAddr).IP) - if !ok { - c.logf("[unexpected] PacketConn.ReadFrom returned not-an-IP %v in from", src) + dlogf("[v1] dropping packet for port %d that isn't our local port", dstPort) continue } - srcPort := binary.BigEndian.Uint16(buf[:2]) - if srcIP.Is4() { - metricRecvDiscoPacketIPv4.Add(1) - } else { + if isIPV6 { metricRecvDiscoPacketIPv6.Add(1) + } else { + metricRecvDiscoPacketIPv4.Add(1) } - c.handleDiscoMessage(buf[udpHeaderSize:n], netip.AddrPortFrom(srcIP, srcPort), key.NodePublic{}, discoRXPathRawSocket) + c.handleDiscoMessage(payload, srcAddr, key.NodePublic{}, discoRXPathRawSocket) } } -// setBPF installs filter as the BPF filter on conn. -// Ideally we would just use SetBPF as implemented in x/net/ipv4, -// but x/net/ipv6 doesn't implement it. And once you've written -// this code once, it turns out to be address family agnostic, so -// we might as well use it on both and get to use a net.PacketConn -// directly for both families instead of being stuck with -// different types. -func setBPF(conn net.PacketConn, filter []bpf.RawInstruction) error { - sc, err := conn.(*net.IPConn).SyscallConn() - if err != nil { - return err - } - prog := &unix.SockFprog{ - Len: uint16(len(filter)), - Filter: (*unix.SockFilter)(unsafe.Pointer(&filter[0])), - } - var setErr error - err = sc.Control(func(fd uintptr) { - setErr = unix.SetsockoptSockFprog(int(fd), unix.SOL_SOCKET, unix.SO_ATTACH_FILTER, prog) - }) - if err != nil { - return err - } - if setErr != nil { - return err +// printSockaddr is a helper function to pretty-print various sockaddr types. +func printSockaddr(sa unix.Sockaddr) string { + switch sa := sa.(type) { + case *unix.SockaddrInet4: + addr := netip.AddrFrom4(sa.Addr) + return netip.AddrPortFrom(addr, uint16(sa.Port)).String() + case *unix.SockaddrInet6: + addr := netip.AddrFrom16(sa.Addr) + return netip.AddrPortFrom(addr, uint16(sa.Port)).String() + case *unix.SockaddrLinklayer: + hwaddr := sa.Addr[:sa.Halen] + + var buf strings.Builder + fmt.Fprintf(&buf, "link(ty=0x%04x,if=%d):[", sa.Protocol, sa.Ifindex) + for i, b := range hwaddr { + if i > 0 { + buf.WriteByte(':') + } + fmt.Fprintf(&buf, "%02x", b) + } + buf.WriteByte(']') + return buf.String() + default: + return fmt.Sprintf("unknown(%T)", sa) } - return nil } // trySetSocketBuffer attempts to set SO_SNDBUFFORCE and SO_RECVBUFFORCE which diff --git a/wgengine/magicsock/magicsock_linux_test.go b/wgengine/magicsock/magicsock_linux_test.go new file mode 100644 index 000000000..6b86b04f2 --- /dev/null +++ b/wgengine/magicsock/magicsock_linux_test.go @@ -0,0 +1,148 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "bytes" + "encoding/binary" + "net/netip" + "testing" + + "golang.org/x/sys/cpu" + "golang.org/x/sys/unix" + "tailscale.com/disco" +) + +func TestParseUDPPacket(t *testing.T) { + src4 := netip.MustParseAddrPort("127.0.0.1:12345") + dst4 := netip.MustParseAddrPort("127.0.0.2:54321") + + src6 := netip.MustParseAddrPort("[::1]:12345") + dst6 := netip.MustParseAddrPort("[::2]:54321") + + udp4Packet := []byte{ + // IPv4 header + 0x45, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0x00, + 0x40, 0x11, 0x00, 0x00, + 0x7f, 0x00, 0x00, 0x01, // source ip + 0x7f, 0x00, 0x00, 0x02, // dest ip + + // UDP header + 0x30, 0x39, // src port + 0xd4, 0x31, // dest port + 0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes + 0x00, 0x00, // checksum; unused + + // Payload: disco magic plus 4 bytes + 0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03, + } + udp6Packet := []byte{ + // IPv6 header + 0x60, 0x00, 0x00, 0x00, + 0x00, 0x12, // payload length + 0x11, // next header: UDP + 0x00, // hop limit; unused + + // Source IP + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // Dest IP + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + + // UDP header + 0x30, 0x39, // src port + 0xd4, 0x31, // dest port + 0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes + 0x00, 0x00, // checksum; unused + + // Payload: disco magic plus 4 bytes + 0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03, + } + + // Verify that parsing the UDP packet works correctly. + t.Run("IPv4", func(t *testing.T) { + src, dst, payload := parseUDPPacket(udp4Packet, false) + if src != src4 { + t.Errorf("src = %v; want %v", src, src4) + } + if dst != dst4 { + t.Errorf("dst = %v; want %v", dst, dst4) + } + if !bytes.HasPrefix(payload, []byte(disco.Magic)) { + t.Errorf("payload = %x; must start with %x", payload, disco.Magic) + } + }) + t.Run("IPv6", func(t *testing.T) { + src, dst, payload := parseUDPPacket(udp6Packet, true) + if src != src6 { + t.Errorf("src = %v; want %v", src, src6) + } + if dst != dst6 { + t.Errorf("dst = %v; want %v", dst, dst6) + } + if !bytes.HasPrefix(payload, []byte(disco.Magic)) { + t.Errorf("payload = %x; must start with %x", payload, disco.Magic) + } + }) + t.Run("Truncated", func(t *testing.T) { + truncateBy := func(b []byte, n int) []byte { + if n >= len(b) { + return nil + } + return b[:len(b)-n] + } + + src, dst, payload := parseUDPPacket(truncateBy(udp4Packet, 11), false) + if payload != nil { + t.Errorf("payload = %x; want nil", payload) + } + if src.IsValid() || dst.IsValid() { + t.Errorf("src = %v, dst = %v; want invalid", src, dst) + } + + src, dst, payload = parseUDPPacket(truncateBy(udp6Packet, 11), true) + if payload != nil { + t.Errorf("payload = %x; want nil", payload) + } + if src.IsValid() || dst.IsValid() { + t.Errorf("src = %v, dst = %v; want invalid", src, dst) + } + }) +} + +func TestEthernetProto(t *testing.T) { + htons := func(x uint16) int { + // Network byte order is big-endian; write the value as + // big-endian to a byte slice and read it back in the native + // endian-ness. This is a no-op on a big-endian platform and a + // byte swap on a little-endian platform. + var b [2]byte + binary.BigEndian.PutUint16(b[:], x) + return int(binary.NativeEndian.Uint16(b[:])) + } + + if v4 := ethernetProtoIPv4(); v4 != htons(unix.ETH_P_IP) { + t.Errorf("ethernetProtoIPv4 = 0x%04x; want 0x%04x", v4, htons(unix.ETH_P_IP)) + } + if v6 := ethernetProtoIPv6(); v6 != htons(unix.ETH_P_IPV6) { + t.Errorf("ethernetProtoIPv6 = 0x%04x; want 0x%04x", v6, htons(unix.ETH_P_IPV6)) + } + + // As a way to verify that the htons function is working correctly, + // assert that the ETH_P_IP value returned from our function matches + // the value defined in the unix package based on whether the host is + // big-endian (network byte order) or little-endian. + if cpu.IsBigEndian { + if v4 := ethernetProtoIPv4(); v4 != unix.ETH_P_IP { + t.Errorf("ethernetProtoIPv4 = 0x%04x; want 0x%04x", v4, unix.ETH_P_IP) + } + } else { + if v4 := ethernetProtoIPv4(); v4 == unix.ETH_P_IP { + t.Errorf("ethernetProtoIPv4 = 0x%04x; want 0x%04x", v4, htons(unix.ETH_P_IP)) + } else { + t.Logf("ethernetProtoIPv4 = 0x%04x, correctly different from 0x%04x", v4, unix.ETH_P_IP) + } + } +}