diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index efb632328..2945dbe1b 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -137,7 +137,7 @@ func maybeHexdump(flag RunFlags, b []byte) string { var acceptBucket = rate.NewLimiter(rate.Every(10*time.Second), 3) var dropBucket = rate.NewLimiter(rate.Every(5*time.Second), 10) -func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.QDecode, r Response, why string) { +func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.ParsedPacket, r Response, why string) { var verdict string if r == Drop && (runflags&LogDrops) != 0 && dropBucket.Allow() { @@ -161,7 +161,7 @@ func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.QDecode, r } } -func (f *Filter) RunIn(b []byte, q *packet.QDecode, rf RunFlags) Response { +func (f *Filter) RunIn(b []byte, q *packet.ParsedPacket, rf RunFlags) Response { r := f.pre(b, q, rf) if r == Accept || r == Drop { // already logged @@ -173,7 +173,7 @@ func (f *Filter) RunIn(b []byte, q *packet.QDecode, rf RunFlags) Response { return r } -func (f *Filter) RunOut(b []byte, q *packet.QDecode, rf RunFlags) Response { +func (f *Filter) RunOut(b []byte, q *packet.ParsedPacket, rf RunFlags) Response { r := f.pre(b, q, rf) if r == Drop || r == Accept { // already logged @@ -184,7 +184,7 @@ func (f *Filter) RunOut(b []byte, q *packet.QDecode, rf RunFlags) Response { return r } -func (f *Filter) runIn(q *packet.QDecode) (r Response, why string) { +func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) { // A compromised peer could try to send us packets for // destinations we didn't explicitly advertise. This check is to // prevent that. @@ -239,7 +239,7 @@ func (f *Filter) runIn(q *packet.QDecode) (r Response, why string) { return Drop, "no rules matched" } -func (f *Filter) runOut(q *packet.QDecode) (r Response, why string) { +func (f *Filter) runOut(q *packet.ParsedPacket) (r Response, why string) { if q.IPProto == packet.UDP { t := tuple{q.DstIP, q.SrcIP, q.DstPort, q.SrcPort} var ti interface{} = t // allocate once, rather than twice inside mutex @@ -251,7 +251,7 @@ func (f *Filter) runOut(q *packet.QDecode) (r Response, why string) { return Accept, "ok out" } -func (f *Filter) pre(b []byte, q *packet.QDecode, rf RunFlags) Response { +func (f *Filter) pre(b []byte, q *packet.ParsedPacket, rf RunFlags) Response { if len(b) == 0 { // wireguard keepalive packet, always permit. return Accept @@ -262,13 +262,17 @@ func (f *Filter) pre(b []byte, q *packet.QDecode, rf RunFlags) Response { } q.Decode(b) - if q.IPProto == packet.Junk { - // Junk packets are dangerous; always drop them. - f.logRateLimit(rf, b, q, Drop, "junk") + switch q.IPProto { + case packet.Unknown: + // Unknown packets are dangerous; always drop them. + f.logRateLimit(rf, b, q, Drop, "unknown") + return Drop + case packet.IPv6: + f.logRateLimit(rf, b, q, Drop, "ipv6") return Drop - } else if q.IPProto == packet.Fragment { + case packet.Fragment: // Fragments after the first always need to be passed through. - // Very small fragments are considered Junk by QDecode. + // Very small fragments are considered Junk by ParsedPacket. f.logRateLimit(rf, b, q, Accept, "fragment") return Accept } diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index 3397721df..fc0fef231 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -14,10 +14,10 @@ import ( ) // Type aliases only in test code: (but ideally nowhere) -type QDecode = packet.QDecode +type ParsedPacket = packet.ParsedPacket type IP = packet.IP -var Junk = packet.Junk +var Unknown = packet.Unknown var ICMP = packet.ICMP var TCP = packet.TCP var UDP = packet.UDP @@ -84,34 +84,34 @@ func TestFilter(t *testing.T) { type InOut struct { want Response - p QDecode + p ParsedPacket } tests := []InOut{ // Basic - {Accept, qdecode(TCP, 0x08010101, 0x01020304, 999, 22)}, - {Accept, qdecode(UDP, 0x08010101, 0x01020304, 999, 22)}, - {Accept, qdecode(ICMP, 0x08010101, 0x01020304, 0, 0)}, - {Drop, qdecode(TCP, 0x08010101, 0x01020304, 0, 0)}, - {Accept, qdecode(TCP, 0x08010101, 0x01020304, 0, 22)}, - {Drop, qdecode(TCP, 0x08010101, 0x01020304, 0, 21)}, - {Accept, qdecode(TCP, 0x11223344, 0x08012233, 0, 443)}, - {Drop, qdecode(TCP, 0x11223344, 0x08012233, 0, 444)}, - {Accept, qdecode(TCP, 0x11223344, 0x647a6232, 0, 999)}, - {Accept, qdecode(TCP, 0x11223344, 0x647a6232, 0, 0)}, + {Accept, parsed(TCP, 0x08010101, 0x01020304, 999, 22)}, + {Accept, parsed(UDP, 0x08010101, 0x01020304, 999, 22)}, + {Accept, parsed(ICMP, 0x08010101, 0x01020304, 0, 0)}, + {Drop, parsed(TCP, 0x08010101, 0x01020304, 0, 0)}, + {Accept, parsed(TCP, 0x08010101, 0x01020304, 0, 22)}, + {Drop, parsed(TCP, 0x08010101, 0x01020304, 0, 21)}, + {Accept, parsed(TCP, 0x11223344, 0x08012233, 0, 443)}, + {Drop, parsed(TCP, 0x11223344, 0x08012233, 0, 444)}, + {Accept, parsed(TCP, 0x11223344, 0x647a6232, 0, 999)}, + {Accept, parsed(TCP, 0x11223344, 0x647a6232, 0, 0)}, // localNets prefilter - accepted by policy filter, but // unexpected dst IP. - {Drop, qdecode(TCP, 0x08010101, 0x10203040, 0, 443)}, + {Drop, parsed(TCP, 0x08010101, 0x10203040, 0, 443)}, // Stateful UDP. Note each packet is run through the input // filter, then the output filter (which sets conntrack // state). // Initially empty cache - {Drop, qdecode(UDP, 0x77777777, 0x66666666, 4242, 4343)}, + {Drop, parsed(UDP, 0x77777777, 0x66666666, 4242, 4343)}, // Return packet from previous attempt is allowed - {Accept, qdecode(UDP, 0x66666666, 0x77777777, 4343, 4242)}, + {Accept, parsed(UDP, 0x66666666, 0x77777777, 4343, 4242)}, // Because of the return above, initial attempt is allowed now - {Accept, qdecode(UDP, 0x77777777, 0x66666666, 4242, 4343)}, + {Accept, parsed(UDP, 0x77777777, 0x66666666, 4242, 4343)}, } for i, test := range tests { if got, _ := acl.runIn(&test.p); test.want != got { @@ -144,7 +144,7 @@ func TestNoAllocs(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { got := int(testing.AllocsPerRun(1000, func() { - var q QDecode + var q ParsedPacket if test.in { acl.RunIn(test.packet, &q, 0) } else { @@ -187,7 +187,7 @@ func BenchmarkFilter(b *testing.B) { for _, bench := range benches { b.Run(bench.name, func(b *testing.B) { for i := 0; i < b.N; i++ { - var q QDecode + var q ParsedPacket // This branch seems to have no measurable impact on performance. if bench.in { acl.RunIn(bench.packet, &q, 0) @@ -207,7 +207,7 @@ func TestPreFilter(t *testing.T) { }{ {"empty", Accept, []byte{}}, {"short", Drop, []byte("short")}, - {"junk", Drop, rawdefault(Junk, 10)}, + {"junk", Drop, rawdefault(Unknown, 10)}, {"fragment", Accept, rawdefault(Fragment, 40)}, {"tcp", noVerdict, rawdefault(TCP, 200)}, {"udp", noVerdict, rawdefault(UDP, 200)}, @@ -215,15 +215,15 @@ func TestPreFilter(t *testing.T) { } f := NewAllowNone(t.Logf) for _, testPacket := range packets { - got := f.pre([]byte(testPacket.b), &QDecode{}, LogDrops|LogAccepts) + got := f.pre([]byte(testPacket.b), &ParsedPacket{}, LogDrops|LogAccepts) if got != testPacket.want { t.Errorf("%q got=%v want=%v packet:\n%s", testPacket.desc, got, testPacket.want, packet.Hexdump(testPacket.b)) } } } -func qdecode(proto packet.IPProto, src, dst packet.IP, sport, dport uint16) QDecode { - return QDecode{ +func parsed(proto packet.IPProto, src, dst packet.IP, sport, dport uint16) ParsedPacket { + return ParsedPacket{ IPProto: proto, SrcIP: src, DstIP: dst, @@ -277,7 +277,7 @@ func rawpacket(proto packet.IPProto, src, dst packet.IP, sport, dport uint16, tr hdr[9] = 6 // flags + fragOff bin.PutUint16(hdr[6:8], (1<<13)|1234) - case Junk: + case Unknown: default: panic("unknown protocol") } diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index 665b1df82..2632405c8 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -133,7 +133,7 @@ func ipInList(ip packet.IP, netlist []Net) bool { return false } -func matchIPPorts(mm Matches, q *packet.QDecode) bool { +func matchIPPorts(mm Matches, q *packet.ParsedPacket) bool { for _, acl := range mm { for _, dst := range acl.Dsts { if !dst.Net.Includes(q.DstIP) { @@ -153,7 +153,7 @@ func matchIPPorts(mm Matches, q *packet.QDecode) bool { return false } -func matchIPWithoutPorts(mm Matches, q *packet.QDecode) bool { +func matchIPWithoutPorts(mm Matches, q *packet.ParsedPacket) bool { for _, acl := range mm { for _, dst := range acl.Dsts { if !dst.Net.Includes(q.DstIP) { diff --git a/wgengine/packet/header.go b/wgengine/packet/header.go new file mode 100644 index 000000000..54f67704e --- /dev/null +++ b/wgengine/packet/header.go @@ -0,0 +1,48 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package packet + +import ( + "errors" + "math" +) + +const tcpHeaderLength = 20 + +// maxPacketLength is the largest length that all headers support. +// IPv4 headers using uint16 for this forces an upper bound of 64KB. +const maxPacketLength = math.MaxUint16 + +var ( + errSmallBuffer = errors.New("buffer too small") + errLargePacket = errors.New("packet too large") +) + +// Header is a packet header capable of marshaling itself into a byte buffer. +type Header interface { + // Len returns the length of the header after marshaling. + Len() int + // Marshal serializes the header into buf in wire format. + // It clobbers the header region, which is the first h.Length() bytes of buf. + // It explicitly initializes every byte of the header region, + // so pre-zeroing it on reuse is not required. It does not allocate memory. + // It fails if and only if len(buf) < Length(). + Marshal(buf []byte) error + // ToResponse transforms the header into one for a response packet. + // For instance, this swaps the source and destination IPs. + ToResponse() +} + +// Generate generates a new packet with the given header and payload. +// Unlike Header.Marshal, this does allocate memory. +func Generate(h Header, payload []byte) []byte { + hlen := h.Len() + buf := make([]byte, hlen+len(payload)) + + copy(buf[hlen:], payload) + h.Marshal(buf) + + return buf +} diff --git a/wgengine/packet/icmp.go b/wgengine/packet/icmp.go new file mode 100644 index 000000000..c4cb7b149 --- /dev/null +++ b/wgengine/packet/icmp.go @@ -0,0 +1,78 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package packet + +type ICMPType uint8 + +const ( + ICMPEchoReply ICMPType = 0x00 + ICMPEchoRequest ICMPType = 0x08 + ICMPUnreachable ICMPType = 0x03 + ICMPTimeExceeded ICMPType = 0x0b +) + +func (t ICMPType) String() string { + switch t { + case ICMPEchoReply: + return "EchoReply" + case ICMPEchoRequest: + return "EchoRequest" + case ICMPUnreachable: + return "Unreachable" + case ICMPTimeExceeded: + return "TimeExceeded" + default: + return "Unknown" + } +} + +type ICMPCode uint8 + +const ( + ICMPNoCode ICMPCode = 0 +) + +// ICMPHeader represents an ICMP packet header. +type ICMPHeader struct { + IPHeader + Type ICMPType + Code ICMPCode +} + +const ( + icmpHeaderLength = 4 + // icmpTotalHeaderLength is the length of all headers in a ICMP packet. + icmpAllHeadersLength = ipHeaderLength + icmpHeaderLength +) + +func (ICMPHeader) Len() int { + return icmpAllHeadersLength +} + +func (h ICMPHeader) Marshal(buf []byte) error { + if len(buf) < icmpAllHeadersLength { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + // The caller does not need to set this. + h.IPProto = ICMP + + buf[20] = uint8(h.Type) + buf[21] = uint8(h.Code) + + h.IPHeader.Marshal(buf) + + put16(buf[22:24], ipChecksum(buf)) + + return nil +} + +func (h *ICMPHeader) ToResponse() { + h.Type = ICMPEchoReply + h.Code = ICMPNoCode + h.IPHeader.ToResponse() +} diff --git a/wgengine/packet/ip.go b/wgengine/packet/ip.go new file mode 100644 index 000000000..71bbb3cb6 --- /dev/null +++ b/wgengine/packet/ip.go @@ -0,0 +1,127 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package packet + +import ( + "fmt" + "net" +) + +// IP is an IPv4 address. +type IP uint32 + +// NewIP converts a standard library IP address into an IP. +// It panics if b is not an IPv4 address. +func NewIP(b net.IP) IP { + b4 := b.To4() + if b4 == nil { + panic(fmt.Sprintf("To4(%v) failed", b)) + } + return IP(get32(b4)) +} + +func (ip IP) String() string { + return fmt.Sprintf("%d.%d.%d.%d", byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)) +} + +// IPProto is either a real IP protocol (ITCP, UDP, ...) or an special value like Unknown. +// If it is a real IP protocol, its value corresponds to its IP protocol number. +type IPProto uint8 + +const ( + // Unknown represents an unknown or unsupported protocol; it's deliberately the zero value. + Unknown IPProto = 0x00 + ICMP IPProto = 0x01 + TCP IPProto = 0x06 + UDP IPProto = 0x11 + // IPv6 and Fragment are special values. They're not really IPProto values + // so we're using the unassigned 0xFE and 0xFF values for them. + // TODO(dmytro): special values should be taken out of here. + IPv6 IPProto = 0xFE + Fragment IPProto = 0xFF +) + +func (p IPProto) String() string { + switch p { + case Fragment: + return "Frag" + case ICMP: + return "ICMP" + case UDP: + return "UDP" + case TCP: + return "TCP" + case IPv6: + return "IPv6" + default: + return "Unknown" + } +} + +// IPHeader represents an IP packet header. +type IPHeader struct { + IPProto IPProto + IPID uint16 + SrcIP IP + DstIP IP +} + +const ipHeaderLength = 20 + +func (IPHeader) Len() int { + return ipHeaderLength +} + +func (h IPHeader) Marshal(buf []byte) error { + if len(buf) < ipHeaderLength { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + + buf[0] = 0x40 | (ipHeaderLength >> 2) // IPv4 + buf[1] = 0x00 // DHCP, ECN + put16(buf[2:4], uint16(len(buf))) + put16(buf[4:6], h.IPID) + put16(buf[6:8], 0) // flags, offset + buf[8] = 64 // TTL + buf[9] = uint8(h.IPProto) + put16(buf[10:12], 0) // blank IP header checksum + put32(buf[12:16], uint32(h.SrcIP)) + put32(buf[16:20], uint32(h.DstIP)) + + put16(buf[10:12], ipChecksum(buf[0:20])) + + return nil +} + +// MarshalPseudo serializes the header into buf in pseudo format. +// It clobbers the header region, which is the first h.Length() bytes of buf. +// It explicitly initializes every byte of the header region, +// so pre-zeroing it on reuse is not required. It does not allocate memory. +func (h IPHeader) MarshalPseudo(buf []byte) error { + if len(buf) < ipHeaderLength { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + + length := len(buf) - ipHeaderLength + put32(buf[8:12], uint32(h.SrcIP)) + put32(buf[12:16], uint32(h.DstIP)) + buf[16] = 0x0 + buf[17] = uint8(h.IPProto) + put16(buf[18:20], uint16(length)) + + return nil +} + +func (h *IPHeader) ToResponse() { + h.SrcIP, h.DstIP = h.DstIP, h.SrcIP + // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. + h.IPID = ^h.IPID +} diff --git a/wgengine/packet/packet.go b/wgengine/packet/packet.go index 55513c553..92bc1eda8 100644 --- a/wgengine/packet/packet.go +++ b/wgengine/packet/packet.go @@ -7,75 +7,39 @@ package packet import ( "encoding/binary" "fmt" - "log" - "net" "strings" "tailscale.com/types/strbuilder" ) -type IPProto int - -const ( - Junk IPProto = iota - Fragment - ICMP - UDP - TCP -) - // RFC1858: prevent overlapping fragment attacks. const minFrag = 60 + 20 // max IPv4 header + basic TCP header -func (p IPProto) String() string { - switch p { - case Fragment: - return "Frag" - case ICMP: - return "ICMP" - case UDP: - return "UDP" - case TCP: - return "TCP" - default: - return "Junk" - } -} - -// IP is an IPv4 address. -type IP uint32 - -// NewIP converts a standard library IP address into an IP. -// It panics if b is not an IPv4 address. -func NewIP(b net.IP) IP { - b4 := b.To4() - if b4 == nil { - panic(fmt.Sprintf("To4(%v) failed", b)) - } - return IP(binary.BigEndian.Uint32(b4)) -} - -func (ip IP) String() string { - return fmt.Sprintf("%d.%d.%d.%d", byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)) -} - -// ICMP types. -const ( - ICMPEchoReply = 0x00 - ICMPEchoRequest = 0x08 - ICMPUnreachable = 0x03 - ICMPTimeExceeded = 0x0b -) - const ( TCPSyn = 0x02 TCPAck = 0x10 TCPSynAck = TCPSyn | TCPAck ) -type QDecode struct { - b []byte // Packet buffer that this decodes - subofs int // byte offset of IP subprotocol +var ( + get16 = binary.BigEndian.Uint16 + get32 = binary.BigEndian.Uint32 + + put16 = binary.BigEndian.PutUint16 + put32 = binary.BigEndian.PutUint32 +) + +// ParsedPacket is a minimal decoding of a packet suitable for use in filters. +type ParsedPacket struct { + // b is the byte buffer that this decodes. + b []byte + // subofs is the offset of IP subprotocol. + subofs int + // dataofs is the offset of IP subprotocol payload. + dataofs int + // length is the total length of the packet. + // This is not the same as len(b) because b can have trailing zeros. + length int IPProto IPProto // IP subprotocol (UDP, TCP, etc) SrcIP IP // IP source address @@ -85,9 +49,12 @@ type QDecode struct { TCPFlags uint8 // TCP flags (SYN, ACK, etc) } -func (q *QDecode) String() string { - if q.IPProto == Junk { - return "Junk{}" +func (q *ParsedPacket) String() string { + switch q.IPProto { + case IPv6: + return "IPv6{???}" + case Unknown: + return "Unknown{???}" } sb := strbuilder.Get() sb.WriteString(q.IPProto.String()) @@ -117,7 +84,7 @@ func ipChecksum(b []byte) uint16 { i := 0 n := len(b) for n >= 2 { - ac += uint32(binary.BigEndian.Uint16(b[i : i+2])) + ac += uint32(get16(b[i : i+2])) n -= 2 i += 2 } @@ -130,71 +97,44 @@ func ipChecksum(b []byte) uint16 { return uint16(^ac) } -var put16 = binary.BigEndian.PutUint16 -var put32 = binary.BigEndian.PutUint32 - -// GenICMP returns the bytes of an ICMP packet. -// If payload is too short or too long, it returns nil. -func GenICMP(srcIP, dstIP IP, ipid uint16, icmpType, icmpCode uint8, payload []byte) []byte { - if len(payload) < 4 { - return nil - } - if len(payload) > 65535-24 { - return nil - } - - sz := 24 + len(payload) - out := make([]byte, 24+len(payload)) - out[0] = 0x45 // IPv4, 20-byte header - out[1] = 0x00 // DHCP, ECN - put16(out[2:4], uint16(sz)) - put16(out[4:6], ipid) - put16(out[6:8], 0) // flags, offset - out[8] = 64 // TTL - out[9] = 0x01 // ICMPv4 - // out[10:12] = 0x00 // blank IP header checksum - put32(out[12:16], uint32(srcIP)) - put32(out[16:20], uint32(dstIP)) - - out[20] = icmpType - out[21] = icmpCode - //out[22:24] = 0x00 // blank ICMP checksum - copy(out[24:], payload) - - put16(out[10:12], ipChecksum(out[0:20])) - put16(out[22:24], ipChecksum(out)) - return out -} - -// An extremely simple packet decoder for basic IPv4 packet types. +// Decode extracts data from the packet in b into q. +// It performs extremely simple packet decoding for basic IPv4 packet types. // It extracts only the subprotocol id, IP addresses, and (if any) ports, // and shouldn't need any memory allocation. -func (q *QDecode) Decode(b []byte) { +func (q *ParsedPacket) Decode(b []byte) { q.b = nil - if len(b) < 20 { - q.IPProto = Junk + if len(b) < ipHeaderLength { + q.IPProto = Unknown return } + // Check that it's IPv4. // TODO(apenwarr): consider IPv6 support - if ((b[0] & 0xF0) >> 4) != 4 { - q.IPProto = Junk + switch (b[0] & 0xF0) >> 4 { + case 4: + q.IPProto = IPProto(b[9]) + // continue + case 6: + q.IPProto = IPv6 + return + default: + q.IPProto = Unknown return } - n := int(binary.BigEndian.Uint16(b[2:4])) - if len(b) < n { + q.length = int(get16(b[2:4])) + if len(b) < q.length { // Packet was cut off before full IPv4 length. - q.IPProto = Junk + q.IPProto = Unknown return } // If it's valid IPv4, then the IP addresses are valid - q.SrcIP = IP(binary.BigEndian.Uint32(b[12:16])) - q.DstIP = IP(binary.BigEndian.Uint32(b[16:20])) + q.SrcIP = IP(get32(b[12:16])) + q.DstIP = IP(get32(b[16:20])) - q.subofs = int((b[0] & 0x0F) * 4) + q.subofs = int((b[0] & 0x0F) << 2) sub := b[q.subofs:] // We don't care much about IP fragmentation, except insofar as it's @@ -207,57 +147,56 @@ func (q *QDecode) Decode(b []byte) { // A "perfectly correct" implementation would have to reassemble // fragments before deciding what to do. But the truth is there's // zero reason to send such a short first fragment, so we can treat - // it as Junk. We can also treat any subsequent fragment that starts - // at such a low offset as Junk. - fragFlags := binary.BigEndian.Uint16(b[6:8]) + // it as Unknown. We can also treat any subsequent fragment that starts + // at such a low offset as Unknown. + fragFlags := get16(b[6:8]) moreFrags := (fragFlags & 0x20) != 0 fragOfs := fragFlags & 0x1FFF if fragOfs == 0 { // This is the first fragment if moreFrags && len(sub) < minFrag { // Suspiciously short first fragment, dump it. - log.Printf("junk1!\n") - q.IPProto = Junk + q.IPProto = Unknown return } // otherwise, this is either non-fragmented (the usual case) // or a big enough initial fragment that we can read the // whole subprotocol header. - proto := b[9] - switch proto { - case 1: // ICMPv4 - if len(sub) < 8 { - q.IPProto = Junk + switch q.IPProto { + case ICMP: + if len(sub) < icmpHeaderLength { + q.IPProto = Unknown return } - q.IPProto = ICMP q.SrcPort = 0 q.DstPort = 0 q.b = b + q.dataofs = q.subofs + icmpHeaderLength return - case 6: // TCP - if len(sub) < 20 { - q.IPProto = Junk + case TCP: + if len(sub) < tcpHeaderLength { + q.IPProto = Unknown return } - q.IPProto = TCP - q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) - q.DstPort = binary.BigEndian.Uint16(sub[2:4]) + q.SrcPort = get16(sub[0:2]) + q.DstPort = get16(sub[2:4]) q.TCPFlags = sub[13] & 0x3F q.b = b + headerLength := (sub[12] & 0xF0) >> 2 + q.dataofs = q.subofs + int(headerLength) return - case 17: // UDP - if len(sub) < 8 { - q.IPProto = Junk + case UDP: + if len(sub) < udpHeaderLength { + q.IPProto = Unknown return } - q.IPProto = UDP - q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) - q.DstPort = binary.BigEndian.Uint16(sub[2:4]) + q.SrcPort = get16(sub[0:2]) + q.DstPort = get16(sub[2:4]) q.b = b + q.dataofs = q.subofs + udpHeaderLength return default: - q.IPProto = Junk + q.IPProto = Unknown return } } else { @@ -265,7 +204,7 @@ func (q *QDecode) Decode(b []byte) { if fragOfs < minFrag { // First frag was suspiciously short, so we can't // trust the followup either. - q.IPProto = Junk + q.IPProto = Unknown return } // otherwise, we have to permit the fragment to slide through. @@ -279,29 +218,59 @@ func (q *QDecode) Decode(b []byte) { } } -// Returns a subset of the IP subprotocol section. -func (q *QDecode) Sub(begin, n int) []byte { +func (q *ParsedPacket) IPHeader() IPHeader { + ipid := get16(q.b[4:6]) + return IPHeader{ + IPID: ipid, + IPProto: q.IPProto, + SrcIP: q.SrcIP, + DstIP: q.DstIP, + } +} + +func (q *ParsedPacket) ICMPHeader() ICMPHeader { + return ICMPHeader{ + IPHeader: q.IPHeader(), + Type: ICMPType(q.b[q.subofs+0]), + Code: ICMPCode(q.b[q.subofs+1]), + } +} + +func (q *ParsedPacket) UDPHeader() UDPHeader { + return UDPHeader{ + IPHeader: q.IPHeader(), + SrcPort: q.SrcPort, + DstPort: q.DstPort, + } +} + +// Sub returns the IP subprotocol section. +func (q *ParsedPacket) Sub(begin, n int) []byte { return q.b[q.subofs+begin : q.subofs+begin+n] } +// Payload returns the payload of the IP subprotocol section. +func (q *ParsedPacket) Payload() []byte { + return q.b[q.dataofs:q.length] +} + // Trim trims the buffer to its IPv4 length. // Sometimes packets arrive from an interface with extra bytes on the end. // This removes them. -func (q *QDecode) Trim() []byte { - n := binary.BigEndian.Uint16(q.b[2:4]) - return q.b[:n] +func (q *ParsedPacket) Trim() []byte { + return q.b[:q.length] } -// IsTCPSyn reports whether q is a TCP SYN packet (i.e. the -// first packet in a new connection). -func (q *QDecode) IsTCPSyn() bool { +// IsTCPSyn reports whether q is a TCP SYN packet +// (i.e. the first packet in a new connection). +func (q *ParsedPacket) IsTCPSyn() bool { return (q.TCPFlags & TCPSynAck) == TCPSyn } // IsError reports whether q is an IPv4 ICMP "Error" packet. -func (q *QDecode) IsError() bool { +func (q *ParsedPacket) IsError() bool { if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { - switch q.b[q.subofs] { + switch ICMPType(q.b[q.subofs]) { case ICMPUnreachable, ICMPTimeExceeded: return true } @@ -310,28 +279,23 @@ func (q *QDecode) IsError() bool { } // IsEchoRequest reports whether q is an IPv4 ICMP Echo Request. -func (q *QDecode) IsEchoRequest() bool { +func (q *ParsedPacket) IsEchoRequest() bool { if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { - return q.b[q.subofs] == ICMPEchoRequest && q.b[q.subofs+1] == 0 + return ICMPType(q.b[q.subofs]) == ICMPEchoRequest && + ICMPCode(q.b[q.subofs+1]) == ICMPNoCode } return false } // IsEchoRequest reports whether q is an IPv4 ICMP Echo Response. -func (q *QDecode) IsEchoResponse() bool { +func (q *ParsedPacket) IsEchoResponse() bool { if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { - return q.b[q.subofs] == ICMPEchoReply && q.b[q.subofs+1] == 0 + return ICMPType(q.b[q.subofs]) == ICMPEchoReply && + ICMPCode(q.b[q.subofs+1]) == ICMPNoCode } return false } -// EchoResponse returns an IPv4 ICMP echo reply to the request in q. -func (q *QDecode) EchoRespond() []byte { - icmpid := binary.BigEndian.Uint16(q.Sub(4, 2)) - b := q.Trim() - return GenICMP(q.DstIP, q.SrcIP, icmpid, ICMPEchoReply, 0, b[q.subofs+4:]) -} - func Hexdump(b []byte) string { out := new(strings.Builder) for i := 0; i < len(b); i += 16 { diff --git a/wgengine/packet/packet_test.go b/wgengine/packet/packet_test.go index 15b8ab2e9..11f75e1e3 100644 --- a/wgengine/packet/packet_test.go +++ b/wgengine/packet/packet_test.go @@ -5,7 +5,9 @@ package packet import ( + "bytes" "net" + "reflect" "testing" ) @@ -26,24 +28,312 @@ func TestIPString(t *testing.T) { } } -func TestQDecodeString(t *testing.T) { - q := QDecode{ - IPProto: TCP, - SrcIP: NewIP(net.ParseIP("1.2.3.4")), - SrcPort: 123, - DstIP: NewIP(net.ParseIP("5.6.7.8")), - DstPort: 567, +var icmpRequestBuffer = []byte{ + // IP header up to checksum + 0x45, 0x00, 0x00, 0x27, 0xde, 0xad, 0x00, 0x00, 0x40, 0x01, 0x8c, 0x15, + // source ip + 0x01, 0x02, 0x03, 0x04, + // destination ip + 0x05, 0x06, 0x07, 0x08, + // ICMP header + 0x08, 0x00, 0x7d, 0x22, + // "request_payload" + 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, +} + +var icmpRequestDecode = ParsedPacket{ + b: icmpRequestBuffer, + subofs: 20, + dataofs: 24, + length: len(icmpRequestBuffer), + + IPProto: ICMP, + SrcIP: NewIP(net.ParseIP("1.2.3.4")), + DstIP: NewIP(net.ParseIP("5.6.7.8")), + SrcPort: 0, + DstPort: 0, +} + +var icmpReplyBuffer = []byte{ + 0x45, 0x00, 0x00, 0x25, 0x21, 0x52, 0x00, 0x00, 0x40, 0x01, 0x49, 0x73, + // source ip + 0x05, 0x06, 0x07, 0x08, + // destination ip + 0x01, 0x02, 0x03, 0x04, + // ICMP header + 0x00, 0x00, 0xe6, 0x9e, + // "reply_payload" + 0x72, 0x65, 0x70, 0x6c, 0x79, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, +} + +var icmpReplyDecode = ParsedPacket{ + b: icmpReplyBuffer, + subofs: 20, + dataofs: 24, + length: len(icmpReplyBuffer), + + IPProto: ICMP, + SrcIP: NewIP(net.ParseIP("1.2.3.4")), + DstIP: NewIP(net.ParseIP("5.6.7.8")), + SrcPort: 0, + DstPort: 0, +} + +// IPv6 Router Solicitation +var ipv6PacketBuffer = []byte{ + 0x60, 0x00, 0x00, 0x00, 0x00, 0x08, 0x3a, 0xff, + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xfb, 0x57, 0x1d, 0xea, 0x9c, 0x39, 0x8f, 0xb7, + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x85, 0x00, 0x38, 0x04, 0x00, 0x00, 0x00, 0x00, +} + +var ipv6PacketDecode = ParsedPacket{ + IPProto: IPv6, +} + +// This is a malformed IPv4 packet. +// Namely, the string "tcp_payload" follows the first byte of the IPv4 header. +var unknownPacketBuffer = []byte{ + 0x45, 0x74, 0x63, 0x70, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, +} + +var unknownPacketDecode = ParsedPacket{ + IPProto: Unknown, +} + +var tcpPacketBuffer = []byte{ + // IP header up to checksum + 0x45, 0x00, 0x00, 0x37, 0xde, 0xad, 0x00, 0x00, 0x40, 0x06, 0x49, 0x5f, + // source ip + 0x01, 0x02, 0x03, 0x04, + // destination ip + 0x05, 0x06, 0x07, 0x08, + // TCP header with SYN, ACK set + 0x00, 0x7b, 0x02, 0x37, 0x00, 0x00, 0x12, 0x34, 0x00, 0x00, 0x00, 0x00, + 0x50, 0x12, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + // "request_payload" + 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, +} + +var tcpPacketDecode = ParsedPacket{ + b: tcpPacketBuffer, + subofs: 20, + dataofs: 40, + length: len(tcpPacketBuffer), + + IPProto: TCP, + SrcIP: NewIP(net.ParseIP("1.2.3.4")), + DstIP: NewIP(net.ParseIP("5.6.7.8")), + SrcPort: 123, + DstPort: 567, + TCPFlags: TCPSynAck, +} + +var udpRequestBuffer = []byte{ + // IP header up to checksum + 0x45, 0x00, 0x00, 0x2b, 0xde, 0xad, 0x00, 0x00, 0x40, 0x11, 0x8c, 0x01, + // source ip + 0x01, 0x02, 0x03, 0x04, + // destination ip + 0x05, 0x06, 0x07, 0x08, + // UDP header + 0x00, 0x7b, 0x02, 0x37, 0x00, 0x17, 0x72, 0x1d, + // "request_payload" + 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, +} + +var udpRequestDecode = ParsedPacket{ + b: udpRequestBuffer, + subofs: 20, + dataofs: 28, + length: len(udpRequestBuffer), + + IPProto: UDP, + SrcIP: NewIP(net.ParseIP("1.2.3.4")), + DstIP: NewIP(net.ParseIP("5.6.7.8")), + SrcPort: 123, + DstPort: 567, +} + +var udpReplyBuffer = []byte{ + // IP header up to checksum + 0x45, 0x00, 0x00, 0x29, 0x21, 0x52, 0x00, 0x00, 0x40, 0x11, 0x49, 0x5f, + // source ip + 0x05, 0x06, 0x07, 0x08, + // destination ip + 0x01, 0x02, 0x03, 0x04, + // UDP header + 0x02, 0x37, 0x00, 0x7b, 0x00, 0x15, 0xd3, 0x9d, + // "reply_payload" + 0x72, 0x65, 0x70, 0x6c, 0x79, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, +} + +var udpReplyDecode = ParsedPacket{ + b: udpReplyBuffer, + subofs: 20, + dataofs: 28, + length: len(udpReplyBuffer), + + IPProto: UDP, + SrcIP: NewIP(net.ParseIP("1.2.3.4")), + DstIP: NewIP(net.ParseIP("5.6.7.8")), + SrcPort: 567, + DstPort: 123, +} + +func TestParsedPacket(t *testing.T) { + tests := []struct { + name string + qdecode ParsedPacket + want string + }{ + {"tcp", tcpPacketDecode, "TCP{1.2.3.4:123 > 5.6.7.8:567}"}, + {"icmp", icmpRequestDecode, "ICMP{1.2.3.4:0 > 5.6.7.8:0}"}, + {"unknown", unknownPacketDecode, "Unknown{???}"}, + {"ipv6", ipv6PacketDecode, "IPv6{???}"}, } - got := q.String() - want := "TCP{1.2.3.4:123 > 5.6.7.8:567}" - if got != want { - t.Errorf("got %q; want %q", got, want) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.qdecode.String() + if got != tt.want { + t.Errorf("got %q; want %q", got, tt.want) + } + }) } allocs := testing.AllocsPerRun(1000, func() { - got = q.String() + tests[0].qdecode.String() }) if allocs != 1 { t.Errorf("allocs = %v; want 1", allocs) } } + +func TestDecode(t *testing.T) { + tests := []struct { + name string + buf []byte + want ParsedPacket + }{ + {"icmp", icmpRequestBuffer, icmpRequestDecode}, + {"ipv6", ipv6PacketBuffer, ipv6PacketDecode}, + {"unknown", unknownPacketBuffer, unknownPacketDecode}, + {"tcp", tcpPacketBuffer, tcpPacketDecode}, + {"udp", udpRequestBuffer, udpRequestDecode}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got ParsedPacket + got.Decode(tt.buf) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } + + allocs := testing.AllocsPerRun(1000, func() { + var got ParsedPacket + got.Decode(tests[0].buf) + }) + if allocs != 0 { + t.Errorf("allocs = %v; want 0", allocs) + } +} + +func BenchmarkDecode(b *testing.B) { + benches := []struct { + name string + buf []byte + }{ + {"icmp", icmpRequestBuffer}, + {"unknown", unknownPacketBuffer}, + {"tcp", tcpPacketBuffer}, + } + + for _, bench := range benches { + b.Run(bench.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + var p ParsedPacket + p.Decode(bench.buf) + } + }) + } +} + +func TestMarshalRequest(t *testing.T) { + // Too small to hold our packets, but only barely. + var small [20]byte + var large [64]byte + + icmpHeader := icmpRequestDecode.ICMPHeader() + udpHeader := udpRequestDecode.UDPHeader() + tests := []struct { + name string + header Header + want []byte + }{ + {"icmp", &icmpHeader, icmpRequestBuffer}, + {"udp", &udpHeader, udpRequestBuffer}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.header.Marshal(small[:]) + if err != errSmallBuffer { + t.Errorf("got err: nil; want: %s", errSmallBuffer) + } + + dataOffset := tt.header.Len() + dataLength := copy(large[dataOffset:], []byte("request_payload")) + end := dataOffset + dataLength + err = tt.header.Marshal(large[:end]) + + if err != nil { + t.Errorf("got err: %s; want nil", err) + } + + if !bytes.Equal(large[:end], tt.want) { + t.Errorf("got %x; want %x", large[:end], tt.want) + } + }) + } +} + +func TestMarshalResponse(t *testing.T) { + var buf [64]byte + + icmpHeader := icmpRequestDecode.ICMPHeader() + udpHeader := udpRequestDecode.UDPHeader() + + tests := []struct { + name string + header Header + want []byte + }{ + {"icmp", &icmpHeader, icmpReplyBuffer}, + {"udp", &udpHeader, udpReplyBuffer}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.header.ToResponse() + + dataOffset := tt.header.Len() + dataLength := copy(buf[dataOffset:], []byte("reply_payload")) + end := dataOffset + dataLength + err := tt.header.Marshal(buf[:end]) + + if err != nil { + t.Errorf("got err: %s; want nil", err) + } + + if !bytes.Equal(buf[:end], tt.want) { + t.Errorf("got %x; want %x", buf[:end], tt.want) + } + }) + } +} diff --git a/wgengine/packet/udp.go b/wgengine/packet/udp.go new file mode 100644 index 000000000..76cc9c922 --- /dev/null +++ b/wgengine/packet/udp.go @@ -0,0 +1,53 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package packet + +// UDPHeader represents an UDP packet header. +type UDPHeader struct { + IPHeader + SrcPort uint16 + DstPort uint16 +} + +const ( + udpHeaderLength = 8 + // udpTotalHeaderLength is the length of all headers in a UDP packet. + udpTotalHeaderLength = ipHeaderLength + udpHeaderLength +) + +func (UDPHeader) Len() int { + return udpTotalHeaderLength +} + +func (h UDPHeader) Marshal(buf []byte) error { + if len(buf) < udpTotalHeaderLength { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + // The caller does not need to set this. + h.IPProto = UDP + + length := len(buf) - h.IPHeader.Len() + put16(buf[20:22], h.SrcPort) + put16(buf[22:24], h.DstPort) + put16(buf[24:26], uint16(length)) + put16(buf[26:28], 0) // blank checksum + + h.IPHeader.MarshalPseudo(buf) + + // UDP checksum with IP pseudo header. + put16(buf[26:28], ipChecksum(buf[8:])) + + h.IPHeader.Marshal(buf) + + return nil +} + +func (h *UDPHeader) ToResponse() { + h.SrcPort, h.DstPort = h.DstPort, h.SrcPort + h.IPHeader.ToResponse() +} diff --git a/wgengine/tstun/tun.go b/wgengine/tstun/tun.go index c798fc37b..a4ed8dd08 100644 --- a/wgengine/tstun/tun.go +++ b/wgengine/tstun/tun.go @@ -176,10 +176,11 @@ func (t *TUN) filterOut(buf []byte) filter.Response { return filter.Drop } - var q packet.QDecode - if filt.RunOut(buf, &q, t.filterFlags) == filter.Accept { + var p packet.ParsedPacket + if filt.RunOut(buf, &p, t.filterFlags) == filter.Accept { return filter.Accept } + return filter.Drop } @@ -218,13 +219,15 @@ func (t *TUN) filterIn(buf []byte) filter.Response { return filter.Drop } - var q packet.QDecode - if filt.RunIn(buf, &q, t.filterFlags) == filter.Accept { + var p packet.ParsedPacket + if filt.RunIn(buf, &p, t.filterFlags) == filter.Accept { // Only in fake mode, answer any incoming pings. - if q.IsEchoRequest() { + if p.IsEchoRequest() { ft, ok := t.tdev.(*fakeTUN) if ok { - packet := q.EchoRespond() + header := p.ICMPHeader() + header.ToResponse() + packet := packet.Generate(&header, p.Payload()) ft.Write(packet, 0) // We already handled it, stop. return filter.Drop @@ -232,6 +235,7 @@ func (t *TUN) filterIn(buf []byte) filter.Response { } return filter.Accept } + return filter.Drop } diff --git a/wgengine/userspace.go b/wgengine/userspace.go index e5f2abc4a..6e9bbd52b 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -284,6 +284,14 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src close(p.done) }() + header := packet.ICMPHeader{ + IPHeader: packet.IPHeader{ + SrcIP: srcIP, + }, + Type: packet.ICMPEchoRequest, + Code: packet.ICMPNoCode, + } + // sendFreq is slightly longer than sprayFreq in magicsock to ensure // that if these ping packets are the only source of early packets // sent to the peer, that each one will be sprayed. @@ -298,7 +306,7 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src payload := []byte("magicsock_spray") // no meaning - ipid := uint16(1) + header.IPID = 1 t := time.NewTicker(sendFreq) defer t.Stop() for { @@ -311,12 +319,13 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src return } for _, dstIP := range dstIPs { - b := packet.GenICMP(srcIP, dstIP, ipid, packet.ICMPEchoRequest, 0, payload) + header.DstIP = dstIP + // InjectOutbound take ownership of the packet, so we allocate. + b := packet.Generate(&header, payload) p.e.tundev.InjectOutbound(b) } - ipid++ + header.IPID++ } - } // pinger sends ping packets for a few seconds.