// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package packet import ( "encoding/binary" "encoding/json" "fmt" "log" "net" "strings" ) type IPProto int const ( Junk IPProto = iota Fragment ICMP UDP TCP ) // RFC1858: prevent overlapping fragment attacks. const MIN_FRAG = 60 + 20 // max IPv4 header + basic TCP header func (p IPProto) String() string { switch p { case Fragment: return "Frag" case ICMP: return "ICMP" case UDP: return "UDP" case TCP: return "TCP" default: return "Junk" } } type IP uint32 const IPAny = IP(0) func NewIP(b net.IP) IP { b4 := b.To4() if b4 == nil { panic(fmt.Sprintf("To4(%v) failed", b)) } return IP(binary.BigEndian.Uint32(b4)) } func (ip IP) String() string { if ip == 0 { return "*" } b := make([]byte, 4) binary.BigEndian.PutUint32(b, uint32(ip)) return fmt.Sprintf("%d.%d.%d.%d", b[0], b[1], b[2], b[3]) } func (ipp *IP) MarshalJSON() ([]byte, error) { s := "\"" + (*ipp).String() + "\"" return []byte(s), nil } func (ipp *IP) UnmarshalJSON(b []byte) error { var hostp *string err := json.Unmarshal(b, &hostp) if err != nil { return err } host := *hostp ip := net.ParseIP(host) if ip != nil && ip.IsUnspecified() { // For clarity, reject 0.0.0.0 as an input return fmt.Errorf("Ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) } else if ip == nil && host == "*" { // User explicitly requested wildcard dst ip *ipp = IPAny } else { if ip != nil { ip = ip.To4() } if ip == nil || len(ip) != 4 { return fmt.Errorf("Ports=%#v: invalid IPv4 address", host) } *ipp = NewIP(ip) } return nil } const ( EchoReply uint8 = 0x00 EchoRequest uint8 = 0x08 ) const ( TCPSyn uint8 = 0x02 TCPAck uint8 = 0x10 TCPSynAck uint8 = TCPSyn | TCPAck ) type QDecode struct { b []byte // Packet buffer that this decodes subofs int // byte offset of IP subprotocol IPProto IPProto // IP subprotocol (UDP, TCP, etc) SrcIP IP // IP source address DstIP IP // IP destination address SrcPort uint16 // TCP/UDP source port DstPort uint16 // TCP/UDP destination port TCPFlags uint8 // TCP flags (SYN, ACK, etc) } func (q QDecode) String() string { if q.IPProto == Junk { return "Junk{}" } srcip := make([]byte, 4) dstip := make([]byte, 4) binary.BigEndian.PutUint32(srcip, uint32(q.SrcIP)) binary.BigEndian.PutUint32(dstip, uint32(q.DstIP)) return fmt.Sprintf("%v{%d.%d.%d.%d:%d > %d.%d.%d.%d:%d}", q.IPProto, srcip[0], srcip[1], srcip[2], srcip[3], q.SrcPort, dstip[0], dstip[1], dstip[2], dstip[3], q.DstPort) } // based on https://tools.ietf.org/html/rfc1071 func ipChecksum(b []byte) uint16 { var ac uint32 i := 0 n := len(b) for n >= 2 { ac += uint32(binary.BigEndian.Uint16(b[i : i+2])) n -= 2 i += 2 } if n == 1 { ac += uint32(b[i]) << 8 } for (ac >> 16) > 0 { ac = (ac >> 16) + (ac & 0xffff) } return uint16(^ac) } func GenICMP(srcIP, dstIP IP, ipid uint16, icmpType uint8, icmpCode uint8, payload []byte) []byte { if len(payload) < 4 { return nil } if len(payload) > 65535-24 { return nil } sz := 24 + len(payload) out := make([]byte, 24+len(payload)) out[0] = 0x45 // IPv4, 20-byte header out[1] = 0x00 // DHCP, ECN binary.BigEndian.PutUint16(out[2:4], uint16(sz)) binary.BigEndian.PutUint16(out[4:6], ipid) binary.BigEndian.PutUint16(out[6:8], 0) // flags, offset out[8] = 64 // TTL out[9] = 0x01 // ICMPv4 // out[10:12] = 0x00 // blank IP header checksum binary.BigEndian.PutUint32(out[12:16], uint32(srcIP)) binary.BigEndian.PutUint32(out[16:20], uint32(dstIP)) out[20] = icmpType out[21] = icmpCode //out[22:24] = 0x00 // blank ICMP checksum copy(out[24:len(out)], payload) binary.BigEndian.PutUint16(out[10:12], ipChecksum(out[0:20])) binary.BigEndian.PutUint16(out[22:24], ipChecksum(out)) return out } // An extremely simple packet decoder for basic IPv4 packet types. // It extracts only the subprotocol id, IP addresses, and (if any) ports, // and shouldn't need any memory allocation. func (q *QDecode) Decode(b []byte) { q.b = nil if len(b) < 20 { q.IPProto = Junk return } // Check that it's IPv4. // TODO(apenwarr): consider IPv6 support if ((b[0] & 0xF0) >> 4) != 4 { q.IPProto = Junk return } n := int(binary.BigEndian.Uint16(b[2:4])) if len(b) < n { // Packet was cut off before full IPv4 length. q.IPProto = Junk return } // If it's valid IPv4, then the IP addresses are valid q.SrcIP = IP(binary.BigEndian.Uint32(b[12:16])) q.DstIP = IP(binary.BigEndian.Uint32(b[16:20])) q.subofs = int((b[0] & 0x0F) * 4) sub := b[q.subofs:] // We don't care much about IP fragmentation, except insofar as it's // used for firewall bypass attacks. The trick is make the first // fragment of a TCP or UDP packet so short that it doesn't fit // the TCP or UDP header, so we can't read the port, in hope that // it'll sneak past. Then subsequent fragments fill it in, but we're // missing the first part of the header, so we can't read that either. // // A "perfectly correct" implementation would have to reassemble // fragments before deciding what to do. But the truth is there's // zero reason to send such a short first fragment, so we can treat // it as Junk. We can also treat any subsequent fragment that starts // at such a low offset as Junk. fragFlags := binary.BigEndian.Uint16(b[6:8]) moreFrags := (fragFlags & 0x20) != 0 fragOfs := fragFlags & 0x1FFF if fragOfs == 0 { // This is the first fragment if moreFrags && len(sub) < MIN_FRAG { // Suspiciously short first fragment, dump it. log.Printf("junk1!\n") q.IPProto = Junk return } // otherwise, this is either non-fragmented (the usual case) // or a big enough initial fragment that we can read the // whole subprotocol header. proto := b[9] switch proto { case 1: // ICMPv4 if len(sub) < 8 { q.IPProto = Junk return } q.IPProto = ICMP q.SrcPort = 0 q.DstPort = 0 q.b = b return case 6: // TCP if len(sub) < 20 { q.IPProto = Junk return } q.IPProto = TCP q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) q.DstPort = binary.BigEndian.Uint16(sub[2:4]) q.TCPFlags = sub[13] & 0x3F q.b = b return case 17: // UDP if len(sub) < 8 { q.IPProto = Junk return } q.IPProto = UDP q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) q.DstPort = binary.BigEndian.Uint16(sub[2:4]) q.b = b return default: q.IPProto = Junk return } } else { // This is a fragment other than the first one. if fragOfs < MIN_FRAG { // First frag was suspiciously short, so we can't // trust the followup either. q.IPProto = Junk return } // otherwise, we have to permit the fragment to slide through. // Second and later fragments don't have sub-headers. // Ideally, we would drop fragments that we can't identify, // but that would require statefulness. Anyway, receivers' // kernels know to drop fragments where the initial fragment // doesn't arrive. q.IPProto = Fragment return } } // Returns a subset of the IP subprotocol section. func (q *QDecode) Sub(begin, n int) []byte { return q.b[q.subofs+begin : q.subofs+begin+n] } // For a packet that is known to be IPv4, trim the buffer to its IPv4 length. // Sometimes packets arrive from an interface with extra bytes on the end. // This removes them. func (q *QDecode) Trim() []byte { n := binary.BigEndian.Uint16(q.b[2:4]) return q.b[0:n] } // For a decoded TCP packet, return true if it's a TCP SYN packet (ie. the // first packet in a new connection). func (q *QDecode) IsTCPSyn() bool { const Syn = 0x02 const Ack = 0x10 const SynAck = Syn | Ack return (q.TCPFlags & SynAck) == Syn } // For a packet that has already been decoded, check if it's an IPv4 ICMP // Echo Request. func (q *QDecode) IsEchoRequest() bool { if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { return q.b[q.subofs] == EchoRequest && q.b[q.subofs+1] == 0 } return false } func (q *QDecode) EchoRespond() []byte { icmpid := binary.BigEndian.Uint16(q.Sub(4, 2)) b := q.Trim() return GenICMP(q.DstIP, q.SrcIP, icmpid, EchoReply, 0, b[q.subofs+4:]) } func Hexdump(b []byte) string { out := new(strings.Builder) for i := 0; i < len(b); i += 16 { if i > 0 { fmt.Fprintf(out, "\n") } fmt.Fprintf(out, " %04x ", i) j := 0 for ; j < 16 && i+j < len(b); j++ { if j == 8 { fmt.Fprintf(out, " ") } fmt.Fprintf(out, "%02x ", b[i+j]) } for ; j < 16; j++ { if j == 8 { fmt.Fprintf(out, " ") } fmt.Fprintf(out, " ") } fmt.Fprintf(out, " ") for j = 0; j < 16 && i+j < len(b); j++ { if b[i+j] >= 32 && b[i+j] < 128 { fmt.Fprintf(out, "%c", b[i+j]) } else { fmt.Fprintf(out, ".") } } } return out.String() }