diff --git a/wgengine/packet/packet.go b/wgengine/packet/packet.go index e2cd08fc7..55513c553 100644 --- a/wgengine/packet/packet.go +++ b/wgengine/packet/packet.go @@ -10,6 +10,8 @@ import ( "log" "net" "strings" + + "tailscale.com/types/strbuilder" ) type IPProto int @@ -23,7 +25,7 @@ const ( ) // RFC1858: prevent overlapping fragment attacks. -const MIN_FRAG = 60 + 20 // max IPv4 header + basic TCP header +const minFrag = 60 + 20 // max IPv4 header + basic TCP header func (p IPProto) String() string { switch p { @@ -40,8 +42,11 @@ func (p IPProto) String() string { } } +// IP is an IPv4 address. type IP uint32 +// NewIP converts a standard library IP address into an IP. +// It panics if b is not an IPv4 address. func NewIP(b net.IP) IP { b4 := b.To4() if b4 == nil { @@ -51,22 +56,21 @@ func NewIP(b net.IP) IP { } func (ip IP) String() string { - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, uint32(ip)) - return fmt.Sprintf("%d.%d.%d.%d", b[0], b[1], b[2], b[3]) + return fmt.Sprintf("%d.%d.%d.%d", byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)) } +// ICMP types. const ( - EchoReply uint8 = 0x00 - EchoRequest uint8 = 0x08 - Unreachable uint8 = 0x03 - TimeExceeded uint8 = 0x0B + ICMPEchoReply = 0x00 + ICMPEchoRequest = 0x08 + ICMPUnreachable = 0x03 + ICMPTimeExceeded = 0x0b ) const ( - TCPSyn uint8 = 0x02 - TCPAck uint8 = 0x10 - TCPSynAck uint8 = TCPSyn | TCPAck + TCPSyn = 0x02 + TCPAck = 0x10 + TCPSynAck = TCPSyn | TCPAck ) type QDecode struct { @@ -81,18 +85,30 @@ type QDecode struct { TCPFlags uint8 // TCP flags (SYN, ACK, etc) } -func (q QDecode) String() string { +func (q *QDecode) String() string { if q.IPProto == Junk { return "Junk{}" } - srcip := make([]byte, 4) - dstip := make([]byte, 4) - binary.BigEndian.PutUint32(srcip, uint32(q.SrcIP)) - binary.BigEndian.PutUint32(dstip, uint32(q.DstIP)) - return fmt.Sprintf("%v{%d.%d.%d.%d:%d > %d.%d.%d.%d:%d}", - q.IPProto, - srcip[0], srcip[1], srcip[2], srcip[3], q.SrcPort, - dstip[0], dstip[1], dstip[2], dstip[3], q.DstPort) + sb := strbuilder.Get() + sb.WriteString(q.IPProto.String()) + sb.WriteByte('{') + writeIPPort(sb, q.SrcIP, q.SrcPort) + sb.WriteString(" > ") + writeIPPort(sb, q.DstIP, q.DstPort) + sb.WriteByte('}') + return sb.String() +} + +func writeIPPort(sb *strbuilder.Builder, ip IP, port uint16) { + sb.WriteUint(uint64(byte(ip >> 24))) + sb.WriteByte('.') + sb.WriteUint(uint64(byte(ip >> 16))) + sb.WriteByte('.') + sb.WriteUint(uint64(byte(ip >> 8))) + sb.WriteByte('.') + sb.WriteUint(uint64(byte(ip))) + sb.WriteByte(':') + sb.WriteUint(uint64(port)) } // based on https://tools.ietf.org/html/rfc1071 @@ -114,7 +130,12 @@ func ipChecksum(b []byte) uint16 { return uint16(^ac) } -func GenICMP(srcIP, dstIP IP, ipid uint16, icmpType uint8, icmpCode uint8, payload []byte) []byte { +var put16 = binary.BigEndian.PutUint16 +var put32 = binary.BigEndian.PutUint32 + +// GenICMP returns the bytes of an ICMP packet. +// If payload is too short or too long, it returns nil. +func GenICMP(srcIP, dstIP IP, ipid uint16, icmpType, icmpCode uint8, payload []byte) []byte { if len(payload) < 4 { return nil } @@ -126,22 +147,22 @@ func GenICMP(srcIP, dstIP IP, ipid uint16, icmpType uint8, icmpCode uint8, paylo out := make([]byte, 24+len(payload)) out[0] = 0x45 // IPv4, 20-byte header out[1] = 0x00 // DHCP, ECN - binary.BigEndian.PutUint16(out[2:4], uint16(sz)) - binary.BigEndian.PutUint16(out[4:6], ipid) - binary.BigEndian.PutUint16(out[6:8], 0) // flags, offset - out[8] = 64 // TTL - out[9] = 0x01 // ICMPv4 + put16(out[2:4], uint16(sz)) + put16(out[4:6], ipid) + put16(out[6:8], 0) // flags, offset + out[8] = 64 // TTL + out[9] = 0x01 // ICMPv4 // out[10:12] = 0x00 // blank IP header checksum - binary.BigEndian.PutUint32(out[12:16], uint32(srcIP)) - binary.BigEndian.PutUint32(out[16:20], uint32(dstIP)) + put32(out[12:16], uint32(srcIP)) + put32(out[16:20], uint32(dstIP)) out[20] = icmpType out[21] = icmpCode //out[22:24] = 0x00 // blank ICMP checksum copy(out[24:], payload) - binary.BigEndian.PutUint16(out[10:12], ipChecksum(out[0:20])) - binary.BigEndian.PutUint16(out[22:24], ipChecksum(out)) + put16(out[10:12], ipChecksum(out[0:20])) + put16(out[22:24], ipChecksum(out)) return out } @@ -193,7 +214,7 @@ func (q *QDecode) Decode(b []byte) { fragOfs := fragFlags & 0x1FFF if fragOfs == 0 { // This is the first fragment - if moreFrags && len(sub) < MIN_FRAG { + if moreFrags && len(sub) < minFrag { // Suspiciously short first fragment, dump it. log.Printf("junk1!\n") q.IPProto = Junk @@ -241,7 +262,7 @@ func (q *QDecode) Decode(b []byte) { } } else { // This is a fragment other than the first one. - if fragOfs < MIN_FRAG { + if fragOfs < minFrag { // First frag was suspiciously short, so we can't // trust the followup either. q.IPProto = Junk @@ -263,57 +284,52 @@ func (q *QDecode) Sub(begin, n int) []byte { return q.b[q.subofs+begin : q.subofs+begin+n] } -// For a packet that is known to be IPv4, trim the buffer to its IPv4 length. +// Trim trims the buffer to its IPv4 length. // Sometimes packets arrive from an interface with extra bytes on the end. // This removes them. func (q *QDecode) Trim() []byte { n := binary.BigEndian.Uint16(q.b[2:4]) - return q.b[0:n] + return q.b[:n] } -// For a decoded TCP packet, return true if it's a TCP SYN packet (ie. the +// IsTCPSyn reports whether q is a TCP SYN packet (i.e. the // first packet in a new connection). func (q *QDecode) IsTCPSyn() bool { - const Syn = 0x02 - const Ack = 0x10 - const SynAck = Syn | Ack - return (q.TCPFlags & SynAck) == Syn + return (q.TCPFlags & TCPSynAck) == TCPSyn } -// For a packet that has already been decoded, check if it's an IPv4 ICMP -// "Error" packet. +// IsError reports whether q is an IPv4 ICMP "Error" packet. func (q *QDecode) IsError() bool { if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { switch q.b[q.subofs] { - case Unreachable, TimeExceeded: + case ICMPUnreachable, ICMPTimeExceeded: return true } } return false } -// For a packet that has already been decoded, check if it's an IPv4 ICMP -// Echo Request. +// IsEchoRequest reports whether q is an IPv4 ICMP Echo Request. func (q *QDecode) IsEchoRequest() bool { if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { - return q.b[q.subofs] == EchoRequest && q.b[q.subofs+1] == 0 + return q.b[q.subofs] == ICMPEchoRequest && q.b[q.subofs+1] == 0 } return false } -// For a packet that has already been decoded, check if it's an IPv4 ICMP -// Echo Response. +// IsEchoRequest reports whether q is an IPv4 ICMP Echo Response. func (q *QDecode) IsEchoResponse() bool { if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { - return q.b[q.subofs] == EchoReply && q.b[q.subofs+1] == 0 + return q.b[q.subofs] == ICMPEchoReply && q.b[q.subofs+1] == 0 } return false } +// EchoResponse returns an IPv4 ICMP echo reply to the request in q. func (q *QDecode) EchoRespond() []byte { icmpid := binary.BigEndian.Uint16(q.Sub(4, 2)) b := q.Trim() - return GenICMP(q.DstIP, q.SrcIP, icmpid, EchoReply, 0, b[q.subofs+4:]) + return GenICMP(q.DstIP, q.SrcIP, icmpid, ICMPEchoReply, 0, b[q.subofs+4:]) } func Hexdump(b []byte) string { diff --git a/wgengine/packet/packet_test.go b/wgengine/packet/packet_test.go new file mode 100644 index 000000000..15b8ab2e9 --- /dev/null +++ b/wgengine/packet/packet_test.go @@ -0,0 +1,49 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package packet + +import ( + "net" + "testing" +) + +func TestIPString(t *testing.T) { + const str = "1.2.3.4" + ip := NewIP(net.ParseIP(str)) + + var got string + allocs := testing.AllocsPerRun(1000, func() { + got = ip.String() + }) + + if got != str { + t.Errorf("got %q; want %q", got, str) + } + if allocs != 1 { + t.Errorf("allocs = %v; want 1", allocs) + } +} + +func TestQDecodeString(t *testing.T) { + q := QDecode{ + IPProto: TCP, + SrcIP: NewIP(net.ParseIP("1.2.3.4")), + SrcPort: 123, + DstIP: NewIP(net.ParseIP("5.6.7.8")), + DstPort: 567, + } + got := q.String() + want := "TCP{1.2.3.4:123 > 5.6.7.8:567}" + if got != want { + t.Errorf("got %q; want %q", got, want) + } + + allocs := testing.AllocsPerRun(1000, func() { + got = q.String() + }) + if allocs != 1 { + t.Errorf("allocs = %v; want 1", allocs) + } +} diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 9a960fde6..c47fb0fa7 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -322,7 +322,7 @@ func (e *userspaceEngine) pinger(peerKey wgcfg.Key, ips []wgcfg.IP) { return } for _, dstIP := range dstIPs { - b := packet.GenICMP(srcIP, dstIP, ipid, packet.EchoRequest, 0, payload) + b := packet.GenICMP(srcIP, dstIP, ipid, packet.ICMPEchoRequest, 0, payload) e.tundev.InjectOutbound(b) } ipid++