// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause //go:build linux package xdp import ( "bytes" "errors" "fmt" "net/netip" "testing" "github.com/cilium/ebpf" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checksum" "gvisor.dev/gvisor/pkg/tcpip/header" "tailscale.com/net/stun" ) type xdpAction uint32 func (x xdpAction) String() string { switch x { case xdpActionAborted: return "XDP_ABORTED" case xdpActionDrop: return "XDP_DROP" case xdpActionPass: return "XDP_PASS" case xdpActionTX: return "XDP_TX" case xdpActionRedirect: return "XDP_REDIRECT" default: return fmt.Sprintf("unknown(%d)", x) } } const ( xdpActionAborted xdpAction = iota xdpActionDrop xdpActionPass xdpActionTX xdpActionRedirect ) const ( ethHLen = 14 udpHLen = 8 ipv4HLen = 20 ipv6HLen = 40 ) const ( defaultSTUNPort = 3478 defaultTTL = 64 reqSrcPort = uint16(1025) ) var ( reqEthSrc = tcpip.LinkAddress([]byte{0x00, 0x00, 0x5e, 0x00, 0x53, 0x01}) reqEthDst = tcpip.LinkAddress([]byte{0x00, 0x00, 0x5e, 0x00, 0x53, 0x02}) reqIPv4Src = netip.MustParseAddr("192.0.2.1") reqIPv4Dst = netip.MustParseAddr("192.0.2.2") reqIPv6Src = netip.MustParseAddr("2001:db8::1") reqIPv6Dst = netip.MustParseAddr("2001:db8::2") ) var testTXID = stun.TxID([12]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b}) type ipv4Mutations struct { ipHeaderFn func(header.IPv4) udpHeaderFn func(header.UDP) stunReqFn func([]byte) } func getIPv4STUNBindingReq(mutations *ipv4Mutations) []byte { req := stun.Request(testTXID) if mutations != nil && mutations.stunReqFn != nil { mutations.stunReqFn(req) } payloadLen := len(req) totalLen := ipv4HLen + udpHLen + payloadLen b := make([]byte, ethHLen+totalLen) ipv4H := header.IPv4(b[ethHLen:]) ethH := header.Ethernet(b) ethFields := header.EthernetFields{ SrcAddr: reqEthSrc, DstAddr: reqEthDst, Type: unix.ETH_P_IP, } ethH.Encode(ðFields) ipFields := header.IPv4Fields{ SrcAddr: tcpip.AddrFrom4(reqIPv4Src.As4()), DstAddr: tcpip.AddrFrom4(reqIPv4Dst.As4()), Protocol: unix.IPPROTO_UDP, TTL: defaultTTL, TotalLength: uint16(totalLen), } ipv4H.Encode(&ipFields) ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) if mutations != nil && mutations.ipHeaderFn != nil { mutations.ipHeaderFn(ipv4H) } udpH := header.UDP(b[ethHLen+ipv4HLen:]) udpFields := header.UDPFields{ SrcPort: reqSrcPort, DstPort: defaultSTUNPort, Length: uint16(udpHLen + payloadLen), Checksum: 0, } udpH.Encode(&udpFields) copy(b[ethHLen+ipv4HLen+udpHLen:], req) cs := header.PseudoHeaderChecksum( unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udpHLen+payloadLen), ) cs = checksum.Checksum(req, cs) udpH.SetChecksum(^udpH.CalculateChecksum(cs)) if mutations != nil && mutations.udpHeaderFn != nil { mutations.udpHeaderFn(udpH) } return b } type ipv6Mutations struct { ipHeaderFn func(header.IPv6) udpHeaderFn func(header.UDP) stunReqFn func([]byte) } func getIPv6STUNBindingReq(mutations *ipv6Mutations) []byte { req := stun.Request(testTXID) if mutations != nil && mutations.stunReqFn != nil { mutations.stunReqFn(req) } payloadLen := len(req) src := netip.MustParseAddr("2001:db8::1") dst := netip.MustParseAddr("2001:db8::2") b := make([]byte, ethHLen+ipv6HLen+udpHLen+payloadLen) ipv6H := header.IPv6(b[ethHLen:]) ethH := header.Ethernet(b) ethFields := header.EthernetFields{ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x5e, 0x00, 0x53, 0x01}), DstAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x5e, 0x00, 0x53, 0x02}), Type: unix.ETH_P_IPV6, } ethH.Encode(ðFields) ipFields := header.IPv6Fields{ SrcAddr: tcpip.AddrFrom16(src.As16()), DstAddr: tcpip.AddrFrom16(dst.As16()), TransportProtocol: unix.IPPROTO_UDP, HopLimit: 64, PayloadLength: uint16(udpHLen + payloadLen), } ipv6H.Encode(&ipFields) if mutations != nil && mutations.ipHeaderFn != nil { mutations.ipHeaderFn(ipv6H) } udpH := header.UDP(b[ethHLen+ipv6HLen:]) udpFields := header.UDPFields{ SrcPort: 1025, DstPort: defaultSTUNPort, Length: uint16(udpHLen + payloadLen), Checksum: 0, } udpH.Encode(&udpFields) copy(b[ethHLen+ipv6HLen+udpHLen:], req) cs := header.PseudoHeaderChecksum( unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udpHLen+payloadLen), ) cs = checksum.Checksum(req, cs) udpH.SetChecksum(^udpH.CalculateChecksum(cs)) if mutations != nil && mutations.udpHeaderFn != nil { mutations.udpHeaderFn(udpH) } return b } func getIPv4STUNBindingResp() []byte { addrPort := netip.AddrPortFrom(reqIPv4Src, reqSrcPort) resp := stun.Response(testTXID, addrPort) payloadLen := len(resp) totalLen := ipv4HLen + udpHLen + payloadLen b := make([]byte, ethHLen+totalLen) ipv4H := header.IPv4(b[ethHLen:]) ethH := header.Ethernet(b) ethFields := header.EthernetFields{ SrcAddr: reqEthDst, DstAddr: reqEthSrc, Type: unix.ETH_P_IP, } ethH.Encode(ðFields) ipFields := header.IPv4Fields{ SrcAddr: tcpip.AddrFrom4(reqIPv4Dst.As4()), DstAddr: tcpip.AddrFrom4(reqIPv4Src.As4()), Protocol: unix.IPPROTO_UDP, TTL: defaultTTL, TotalLength: uint16(totalLen), } ipv4H.Encode(&ipFields) ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) udpH := header.UDP(b[ethHLen+ipv4HLen:]) udpFields := header.UDPFields{ SrcPort: defaultSTUNPort, DstPort: reqSrcPort, Length: uint16(udpHLen + payloadLen), Checksum: 0, } udpH.Encode(&udpFields) copy(b[ethHLen+ipv4HLen+udpHLen:], resp) cs := header.PseudoHeaderChecksum( unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udpHLen+payloadLen), ) cs = checksum.Checksum(resp, cs) udpH.SetChecksum(^udpH.CalculateChecksum(cs)) return b } func getIPv6STUNBindingResp() []byte { addrPort := netip.AddrPortFrom(reqIPv6Src, reqSrcPort) resp := stun.Response(testTXID, addrPort) payloadLen := len(resp) totalLen := ipv6HLen + udpHLen + payloadLen b := make([]byte, ethHLen+totalLen) ipv6H := header.IPv6(b[ethHLen:]) ethH := header.Ethernet(b) ethFields := header.EthernetFields{ SrcAddr: reqEthDst, DstAddr: reqEthSrc, Type: unix.ETH_P_IPV6, } ethH.Encode(ðFields) ipFields := header.IPv6Fields{ SrcAddr: tcpip.AddrFrom16(reqIPv6Dst.As16()), DstAddr: tcpip.AddrFrom16(reqIPv6Src.As16()), TransportProtocol: unix.IPPROTO_UDP, HopLimit: defaultTTL, PayloadLength: uint16(udpHLen + payloadLen), } ipv6H.Encode(&ipFields) udpH := header.UDP(b[ethHLen+ipv6HLen:]) udpFields := header.UDPFields{ SrcPort: defaultSTUNPort, DstPort: reqSrcPort, Length: uint16(udpHLen + payloadLen), Checksum: 0, } udpH.Encode(&udpFields) copy(b[ethHLen+ipv6HLen+udpHLen:], resp) cs := header.PseudoHeaderChecksum( unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udpHLen+payloadLen), ) cs = checksum.Checksum(resp, cs) udpH.SetChecksum(^udpH.CalculateChecksum(cs)) return b } func TestXDP(t *testing.T) { ipv4STUNBindingReqTX := getIPv4STUNBindingReq(nil) ipv6STUNBindingReqTX := getIPv6STUNBindingReq(nil) ipv4STUNBindingReqIPCsumPass := getIPv4STUNBindingReq(&ipv4Mutations{ ipHeaderFn: func(ipv4H header.IPv4) { oldCS := ipv4H.Checksum() newCS := oldCS for newCS == 0 || newCS == oldCS { newCS++ } ipv4H.SetChecksum(newCS) }, }) ipv4STUNBindingReqIHLPass := getIPv4STUNBindingReq(&ipv4Mutations{ ipHeaderFn: func(ipv4H header.IPv4) { ipv4H[0] &= 0xF0 }, }) ipv4STUNBindingReqIPVerPass := getIPv4STUNBindingReq(&ipv4Mutations{ ipHeaderFn: func(ipv4H header.IPv4) { ipv4H[0] &= 0x0F }, }) ipv4STUNBindingReqIPProtoPass := getIPv4STUNBindingReq(&ipv4Mutations{ ipHeaderFn: func(ipv4H header.IPv4) { ipv4H[9] = unix.IPPROTO_TCP }, }) ipv4STUNBindingReqFragOffsetPass := getIPv4STUNBindingReq(&ipv4Mutations{ ipHeaderFn: func(ipv4H header.IPv4) { ipv4H.SetFlagsFragmentOffset(ipv4H.Flags(), 8) }, }) ipv4STUNBindingReqFlagsMFPass := getIPv4STUNBindingReq(&ipv4Mutations{ ipHeaderFn: func(ipv4H header.IPv4) { ipv4H.SetFlagsFragmentOffset(header.IPv4FlagMoreFragments, 0) }, }) ipv4STUNBindingReqTotLenPass := getIPv4STUNBindingReq(&ipv4Mutations{ ipHeaderFn: func(ipv4H header.IPv4) { ipv4H.SetTotalLength(ipv4H.TotalLength() + 1) ipv4H.SetChecksum(0) ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) }, }) ipv6STUNBindingReqIPVerPass := getIPv6STUNBindingReq(&ipv6Mutations{ ipHeaderFn: func(ipv6H header.IPv6) { ipv6H[0] &= 0x0F }, udpHeaderFn: func(udp header.UDP) {}, }) ipv6STUNBindingReqNextHdrPass := getIPv6STUNBindingReq(&ipv6Mutations{ ipHeaderFn: func(ipv6H header.IPv6) { ipv6H.SetNextHeader(unix.IPPROTO_TCP) }, udpHeaderFn: func(udp header.UDP) {}, }) ipv6STUNBindingReqPayloadLenPass := getIPv6STUNBindingReq(&ipv6Mutations{ ipHeaderFn: func(ipv6H header.IPv6) { ipv6H.SetPayloadLength(ipv6H.PayloadLength() + 1) }, udpHeaderFn: func(udp header.UDP) {}, }) ipv4STUNBindingReqUDPCsumPass := getIPv4STUNBindingReq(&ipv4Mutations{ udpHeaderFn: func(udpH header.UDP) { oldCS := udpH.Checksum() newCS := oldCS for newCS == 0 || newCS == oldCS { newCS++ } udpH.SetChecksum(newCS) }, }) ipv6STUNBindingReqUDPCsumPass := getIPv6STUNBindingReq(&ipv6Mutations{ udpHeaderFn: func(udpH header.UDP) { oldCS := udpH.Checksum() newCS := oldCS for newCS == 0 || newCS == oldCS { newCS++ } udpH.SetChecksum(newCS) }, }) ipv4STUNBindingReqSTUNTypePass := getIPv4STUNBindingReq(&ipv4Mutations{ stunReqFn: func(req []byte) { req[1] = ^req[1] }, }) ipv6STUNBindingReqSTUNTypePass := getIPv6STUNBindingReq(&ipv6Mutations{ stunReqFn: func(req []byte) { req[1] = ^req[1] }, }) ipv4STUNBindingReqSTUNMagicPass := getIPv4STUNBindingReq(&ipv4Mutations{ stunReqFn: func(req []byte) { req[4] = ^req[4] }, }) ipv6STUNBindingReqSTUNMagicPass := getIPv6STUNBindingReq(&ipv6Mutations{ stunReqFn: func(req []byte) { req[4] = ^req[4] }, }) ipv4STUNBindingReqSTUNAttrsLenPass := getIPv4STUNBindingReq(&ipv4Mutations{ stunReqFn: func(req []byte) { req[2] = ^req[2] }, }) ipv6STUNBindingReqSTUNAttrsLenPass := getIPv6STUNBindingReq(&ipv6Mutations{ stunReqFn: func(req []byte) { req[2] = ^req[2] }, }) ipv4STUNBindingReqSTUNSWValPass := getIPv4STUNBindingReq(&ipv4Mutations{ stunReqFn: func(req []byte) { req[24] = ^req[24] }, }) ipv6STUNBindingReqSTUNSWValPass := getIPv6STUNBindingReq(&ipv6Mutations{ stunReqFn: func(req []byte) { req[24] = ^req[24] }, }) ipv4STUNBindingReqSTUNFirstAttrPass := getIPv4STUNBindingReq(&ipv4Mutations{ stunReqFn: func(req []byte) { req[21] = ^req[21] }, }) ipv6STUNBindingReqSTUNFirstAttrPass := getIPv6STUNBindingReq(&ipv6Mutations{ stunReqFn: func(req []byte) { req[21] = ^req[21] }, }) cases := []struct { name string packetIn []byte wantCode xdpAction wantPacketOut []byte wantMetrics map[bpfCountersKey]uint64 }{ { name: "ipv4 STUN Binding Request TX", packetIn: ipv4STUNBindingReqTX, wantCode: xdpActionTX, wantPacketOut: getIPv4STUNBindingResp(), wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_TX_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_TX_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(getIPv4STUNBindingResp())), }, }, { name: "ipv6 STUN Binding Request TX", packetIn: ipv6STUNBindingReqTX, wantCode: xdpActionTX, wantPacketOut: getIPv6STUNBindingResp(), wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_TX_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_TX_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(getIPv6STUNBindingResp())), }, }, { name: "ipv4 STUN Binding Request invalid ip csum PASS", packetIn: ipv4STUNBindingReqIPCsumPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqIPCsumPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_IP_CSUM), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_IP_CSUM), }: uint64(len(ipv4STUNBindingReqIPCsumPass)), }, }, { name: "ipv4 STUN Binding Request ihl PASS", packetIn: ipv4STUNBindingReqIHLPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqIHLPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv4STUNBindingReqIHLPass)), }, }, { name: "ipv4 STUN Binding Request ip version PASS", packetIn: ipv4STUNBindingReqIPVerPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqIPVerPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv4STUNBindingReqIPVerPass)), }, }, { name: "ipv4 STUN Binding Request ip proto PASS", packetIn: ipv4STUNBindingReqIPProtoPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqIPProtoPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv4STUNBindingReqIPProtoPass)), }, }, { name: "ipv4 STUN Binding Request frag offset PASS", packetIn: ipv4STUNBindingReqFragOffsetPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqFragOffsetPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv4STUNBindingReqFragOffsetPass)), }, }, { name: "ipv4 STUN Binding Request flags mf PASS", packetIn: ipv4STUNBindingReqFlagsMFPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqFlagsMFPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv4STUNBindingReqFlagsMFPass)), }, }, { name: "ipv4 STUN Binding Request tot len PASS", packetIn: ipv4STUNBindingReqTotLenPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqTotLenPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv4STUNBindingReqTotLenPass)), }, }, { name: "ipv6 STUN Binding Request ip version PASS", packetIn: ipv6STUNBindingReqIPVerPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqIPVerPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv6STUNBindingReqIPVerPass)), }, }, { name: "ipv6 STUN Binding Request next hdr PASS", packetIn: ipv6STUNBindingReqNextHdrPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqNextHdrPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv6STUNBindingReqNextHdrPass)), }, }, { name: "ipv6 STUN Binding Request payload len PASS", packetIn: ipv6STUNBindingReqPayloadLenPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqPayloadLenPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv6STUNBindingReqPayloadLenPass)), }, }, { name: "ipv4 STUN Binding Request UDP csum PASS", packetIn: ipv4STUNBindingReqUDPCsumPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqUDPCsumPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_UDP_CSUM), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_UDP_CSUM), }: uint64(len(ipv4STUNBindingReqUDPCsumPass)), }, }, { name: "ipv6 STUN Binding Request UDP csum PASS", packetIn: ipv6STUNBindingReqUDPCsumPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqUDPCsumPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_UDP_CSUM), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_UDP_CSUM), }: uint64(len(ipv6STUNBindingReqUDPCsumPass)), }, }, { name: "ipv4 STUN Binding Request STUN type PASS", packetIn: ipv4STUNBindingReqSTUNTypePass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqSTUNTypePass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv4STUNBindingReqSTUNTypePass)), }, }, { name: "ipv6 STUN Binding Request STUN type PASS", packetIn: ipv6STUNBindingReqSTUNTypePass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqSTUNTypePass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv6STUNBindingReqSTUNTypePass)), }, }, { name: "ipv4 STUN Binding Request STUN magic PASS", packetIn: ipv4STUNBindingReqSTUNMagicPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqSTUNMagicPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv4STUNBindingReqSTUNMagicPass)), }, }, { name: "ipv6 STUN Binding Request STUN magic PASS", packetIn: ipv6STUNBindingReqSTUNMagicPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqSTUNMagicPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv6STUNBindingReqSTUNMagicPass)), }, }, { name: "ipv4 STUN Binding Request STUN attrs len PASS", packetIn: ipv4STUNBindingReqSTUNAttrsLenPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqSTUNAttrsLenPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv4STUNBindingReqSTUNAttrsLenPass)), }, }, { name: "ipv6 STUN Binding Request STUN attrs len PASS", packetIn: ipv6STUNBindingReqSTUNAttrsLenPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqSTUNAttrsLenPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED), }: uint64(len(ipv6STUNBindingReqSTUNAttrsLenPass)), }, }, { name: "ipv4 STUN Binding Request STUN SW val PASS", packetIn: ipv4STUNBindingReqSTUNSWValPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqSTUNSWValPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_SW_ATTR_VAL), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_SW_ATTR_VAL), }: uint64(len(ipv4STUNBindingReqSTUNSWValPass)), }, }, { name: "ipv6 STUN Binding Request STUN SW val PASS", packetIn: ipv6STUNBindingReqSTUNSWValPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqSTUNSWValPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_SW_ATTR_VAL), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_SW_ATTR_VAL), }: uint64(len(ipv6STUNBindingReqSTUNSWValPass)), }, }, { name: "ipv4 STUN Binding Request STUN first attr PASS", packetIn: ipv4STUNBindingReqSTUNFirstAttrPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqSTUNFirstAttrPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNEXPECTED_FIRST_STUN_ATTR), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNEXPECTED_FIRST_STUN_ATTR), }: uint64(len(ipv4STUNBindingReqSTUNFirstAttrPass)), }, }, { name: "ipv6 STUN Binding Request STUN first attr PASS", packetIn: ipv6STUNBindingReqSTUNFirstAttrPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqSTUNFirstAttrPass, wantMetrics: map[bpfCountersKey]uint64{ { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNEXPECTED_FIRST_STUN_ATTR), }: 1, { Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6), Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL), ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNEXPECTED_FIRST_STUN_ATTR), }: uint64(len(ipv6STUNBindingReqSTUNFirstAttrPass)), }, }, } server, err := NewSTUNServer(&STUNServerConfig{DeviceName: "fake", DstPort: defaultSTUNPort}, &noAttachOption{}) if err != nil { if errors.Is(err, unix.EPERM) { // TODO(jwhited): get this running t.Skip("skipping due to EPERM error; test requires elevated privileges") } t.Fatalf("error constructing STUN server: %v", err) } defer server.Close() clearCounters := func() error { server.metrics.last = make(map[bpfCountersKey]uint64) var cur, next bpfCountersKey keys := make([]bpfCountersKey, 0) for err = server.objs.CountersMap.NextKey(nil, &next); ; err = server.objs.CountersMap.NextKey(cur, &next) { if err != nil { if errors.Is(err, ebpf.ErrKeyNotExist) { break } return err } keys = append(keys, next) cur = next } for _, key := range keys { err = server.objs.CountersMap.Delete(&key) if err != nil { return err } } err = server.objs.CountersMap.NextKey(nil, &next) if !errors.Is(err, ebpf.ErrKeyNotExist) { return errors.New("counters map is not empty") } return nil } for _, c := range cases { t.Run(c.name, func(t *testing.T) { err = clearCounters() if err != nil { t.Fatalf("error clearing counters: %v", err) } opts := ebpf.RunOptions{ Data: c.packetIn, DataOut: make([]byte, 1514), } got, err := server.objs.XdpProgFunc.Run(&opts) if err != nil { t.Fatalf("error running program: %v", err) } if xdpAction(got) != c.wantCode { t.Fatalf("got code: %s != %s", xdpAction(got), c.wantCode) } if !bytes.Equal(opts.DataOut, c.wantPacketOut) { t.Fatal("packets not equal") } err = server.updateMetrics() if err != nil { t.Fatalf("error updating metrics: %v", err) } if c.wantMetrics != nil { for k, v := range c.wantMetrics { gotCounter, ok := server.metrics.last[k] if !ok { t.Errorf("expected counter at key %+v not found", k) } if gotCounter != v { t.Errorf("key: %+v gotCounter: %d != %d", k, gotCounter, v) } } for k := range server.metrics.last { _, ok := c.wantMetrics[k] if !ok { t.Errorf("counter at key: %+v incremented unexpectedly", k) } } } }) } } func TestCountersMapKey(t *testing.T) { if bpfCounterKeyAfCOUNTER_KEY_AF_LEN > 256 { t.Error("COUNTER_KEY_AF_LEN no longer fits within uint8") } if bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_BYTES_ACTION_LEN > 256 { t.Error("COUNTER_KEY_PACKETS_BYTES_ACTION no longer fits within uint8") } if bpfCounterKeyProgEndCOUNTER_KEY_END_LEN > 256 { t.Error("COUNTER_KEY_END_LEN no longer fits within uint8") } if len(pbaToOutcomeLV) != int(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_BYTES_ACTION_LEN) { t.Error("pbaToOutcomeLV is not in sync with xdp.c") } if len(progEndLV) != int(bpfCounterKeyProgEndCOUNTER_KEY_END_LEN) { t.Error("progEndLV is not in sync with xdp.c") } if len(packetCounterKeys)+len(bytesCounterKeys) != int(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_BYTES_ACTION_LEN) { t.Error("packetCounterKeys and/or bytesCounterKeys is not in sync with xdp.c") } if len(pbaToOutcomeLV) != int(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_BYTES_ACTION_LEN) { t.Error("pbaToOutcomeLV is not in sync with xdp.c") } }