// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package packet import ( "encoding/binary" "fmt" "strings" "tailscale.com/types/strbuilder" ) // RFC1858: prevent overlapping fragment attacks. const minFrag = 60 + 20 // max IPv4 header + basic TCP header const ( TCPSyn = 0x02 TCPAck = 0x10 TCPSynAck = TCPSyn | TCPAck ) var ( get16 = binary.BigEndian.Uint16 get32 = binary.BigEndian.Uint32 put16 = binary.BigEndian.PutUint16 put32 = binary.BigEndian.PutUint32 ) // ParsedPacket is a minimal decoding of a packet suitable for use in filters. // // In general, it only supports IPv4. The IPv6 parsing is very minimal. type ParsedPacket struct { // b is the byte buffer that this decodes. b []byte // subofs is the offset of IP subprotocol. subofs int // dataofs is the offset of IP subprotocol payload. dataofs int // length is the total length of the packet. // This is not the same as len(b) because b can have trailing zeros. length int IPVersion uint8 // 4, 6, or 0 IPProto IPProto // IP subprotocol (UDP, TCP, etc); the NextHeader field for IPv6 SrcIP IP // IP source address (not used for IPv6) DstIP IP // IP destination address (not used for IPv6) SrcPort uint16 // TCP/UDP source port DstPort uint16 // TCP/UDP destination port TCPFlags uint8 // TCP flags (SYN, ACK, etc) } // NextHeader type NextHeader uint8 func (p *ParsedPacket) String() string { if p.IPVersion == 6 { return fmt.Sprintf("IPv6{Proto=%d}", p.IPProto) } switch p.IPProto { case Unknown: return "Unknown{???}" } sb := strbuilder.Get() sb.WriteString(p.IPProto.String()) sb.WriteByte('{') writeIPPort(sb, p.SrcIP, p.SrcPort) sb.WriteString(" > ") writeIPPort(sb, p.DstIP, p.DstPort) sb.WriteByte('}') return sb.String() } func writeIPPort(sb *strbuilder.Builder, ip IP, port uint16) { sb.WriteUint(uint64(byte(ip >> 24))) sb.WriteByte('.') sb.WriteUint(uint64(byte(ip >> 16))) sb.WriteByte('.') sb.WriteUint(uint64(byte(ip >> 8))) sb.WriteByte('.') sb.WriteUint(uint64(byte(ip))) sb.WriteByte(':') sb.WriteUint(uint64(port)) } // based on https://tools.ietf.org/html/rfc1071 func ipChecksum(b []byte) uint16 { var ac uint32 i := 0 n := len(b) for n >= 2 { ac += uint32(get16(b[i : i+2])) n -= 2 i += 2 } if n == 1 { ac += uint32(b[i]) << 8 } for (ac >> 16) > 0 { ac = (ac >> 16) + (ac & 0xffff) } return uint16(^ac) } // Decode extracts data from the packet in b into q. // It performs extremely simple packet decoding for basic IPv4 packet types. // It extracts only the subprotocol id, IP addresses, and (if any) ports, // and shouldn't need any memory allocation. func (q *ParsedPacket) Decode(b []byte) { q.b = b if len(b) < ipHeaderLength { q.IPVersion = 0 q.IPProto = Unknown return } // Check that it's IPv4. // TODO(apenwarr): consider IPv6 support q.IPVersion = (b[0] & 0xF0) >> 4 switch q.IPVersion { case 4: q.IPProto = IPProto(b[9]) case 6: q.IPProto = IPProto(b[6]) // "Next Header" field return default: q.IPVersion = 0 q.IPProto = Unknown return } q.length = int(get16(b[2:4])) if len(b) < q.length { // Packet was cut off before full IPv4 length. q.IPProto = Unknown return } // If it's valid IPv4, then the IP addresses are valid q.SrcIP = IP(get32(b[12:16])) q.DstIP = IP(get32(b[16:20])) q.subofs = int((b[0] & 0x0F) << 2) sub := b[q.subofs:] // We don't care much about IP fragmentation, except insofar as it's // used for firewall bypass attacks. The trick is make the first // fragment of a TCP or UDP packet so short that it doesn't fit // the TCP or UDP header, so we can't read the port, in hope that // it'll sneak past. Then subsequent fragments fill it in, but we're // missing the first part of the header, so we can't read that either. // // A "perfectly correct" implementation would have to reassemble // fragments before deciding what to do. But the truth is there's // zero reason to send such a short first fragment, so we can treat // it as Unknown. We can also treat any subsequent fragment that starts // at such a low offset as Unknown. fragFlags := get16(b[6:8]) moreFrags := (fragFlags & 0x20) != 0 fragOfs := fragFlags & 0x1FFF if fragOfs == 0 { // This is the first fragment if moreFrags && len(sub) < minFrag { // Suspiciously short first fragment, dump it. q.IPProto = Unknown return } // otherwise, this is either non-fragmented (the usual case) // or a big enough initial fragment that we can read the // whole subprotocol header. switch q.IPProto { case ICMP: if len(sub) < icmpHeaderLength { q.IPProto = Unknown return } q.SrcPort = 0 q.DstPort = 0 q.dataofs = q.subofs + icmpHeaderLength return case TCP: if len(sub) < tcpHeaderLength { q.IPProto = Unknown return } q.SrcPort = get16(sub[0:2]) q.DstPort = get16(sub[2:4]) q.TCPFlags = sub[13] & 0x3F headerLength := (sub[12] & 0xF0) >> 2 q.dataofs = q.subofs + int(headerLength) return case UDP: if len(sub) < udpHeaderLength { q.IPProto = Unknown return } q.SrcPort = get16(sub[0:2]) q.DstPort = get16(sub[2:4]) q.dataofs = q.subofs + udpHeaderLength return default: q.IPProto = Unknown return } } else { // This is a fragment other than the first one. if fragOfs < minFrag { // First frag was suspiciously short, so we can't // trust the followup either. q.IPProto = Unknown return } // otherwise, we have to permit the fragment to slide through. // Second and later fragments don't have sub-headers. // Ideally, we would drop fragments that we can't identify, // but that would require statefulness. Anyway, receivers' // kernels know to drop fragments where the initial fragment // doesn't arrive. q.IPProto = Fragment return } } func (q *ParsedPacket) IPHeader() IPHeader { ipid := get16(q.b[4:6]) return IPHeader{ IPID: ipid, IPProto: q.IPProto, SrcIP: q.SrcIP, DstIP: q.DstIP, } } func (q *ParsedPacket) ICMPHeader() ICMPHeader { return ICMPHeader{ IPHeader: q.IPHeader(), Type: ICMPType(q.b[q.subofs+0]), Code: ICMPCode(q.b[q.subofs+1]), } } func (q *ParsedPacket) UDPHeader() UDPHeader { return UDPHeader{ IPHeader: q.IPHeader(), SrcPort: q.SrcPort, DstPort: q.DstPort, } } // Buffer returns the entire packet buffer. // This is a read-only view; that is, q retains the ownership of the buffer. func (q *ParsedPacket) Buffer() []byte { return q.b } // Sub returns the IP subprotocol section. // This is a read-only view; that is, q retains the ownership of the buffer. func (q *ParsedPacket) Sub(begin, n int) []byte { return q.b[q.subofs+begin : q.subofs+begin+n] } // Payload returns the payload of the IP subprotocol section. // This is a read-only view; that is, q retains the ownership of the buffer. func (q *ParsedPacket) Payload() []byte { return q.b[q.dataofs:q.length] } // Trim trims the buffer to its IPv4 length. // Sometimes packets arrive from an interface with extra bytes on the end. // This removes them. func (q *ParsedPacket) Trim() []byte { return q.b[:q.length] } // IsTCPSyn reports whether q is a TCP SYN packet // (i.e. the first packet in a new connection). func (q *ParsedPacket) IsTCPSyn() bool { return (q.TCPFlags & TCPSynAck) == TCPSyn } // IsError reports whether q is an IPv4 ICMP "Error" packet. func (q *ParsedPacket) IsError() bool { if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { switch ICMPType(q.b[q.subofs]) { case ICMPUnreachable, ICMPTimeExceeded: return true } } return false } // IsEchoRequest reports whether q is an IPv4 ICMP Echo Request. func (q *ParsedPacket) IsEchoRequest() bool { if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { return ICMPType(q.b[q.subofs]) == ICMPEchoRequest && ICMPCode(q.b[q.subofs+1]) == ICMPNoCode } return false } // IsEchoRequest reports whether q is an IPv4 ICMP Echo Response. func (q *ParsedPacket) IsEchoResponse() bool { if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { return ICMPType(q.b[q.subofs]) == ICMPEchoReply && ICMPCode(q.b[q.subofs+1]) == ICMPNoCode } return false } func Hexdump(b []byte) string { out := new(strings.Builder) for i := 0; i < len(b); i += 16 { if i > 0 { fmt.Fprintf(out, "\n") } fmt.Fprintf(out, " %04x ", i) j := 0 for ; j < 16 && i+j < len(b); j++ { if j == 8 { fmt.Fprintf(out, " ") } fmt.Fprintf(out, "%02x ", b[i+j]) } for ; j < 16; j++ { if j == 8 { fmt.Fprintf(out, " ") } fmt.Fprintf(out, " ") } fmt.Fprintf(out, " ") for j = 0; j < 16 && i+j < len(b); j++ { if b[i+j] >= 32 && b[i+j] < 128 { fmt.Fprintf(out, "%c", b[i+j]) } else { fmt.Fprintf(out, ".") } } } return out.String() }