diff --git a/net/packet/packet.go b/net/packet/packet.go index 4c83dd59a..48622eff4 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -22,12 +22,16 @@ const minFrag = 60 + 20 // max IPv4 header + basic TCP header type TCPFlag uint8 const ( - TCPFin TCPFlag = 0x01 - TCPSyn TCPFlag = 0x02 - TCPRst TCPFlag = 0x04 - TCPPsh TCPFlag = 0x08 - TCPAck TCPFlag = 0x10 - TCPSynAck TCPFlag = TCPSyn | TCPAck + TCPFin TCPFlag = 0x01 + TCPSyn TCPFlag = 0x02 + TCPRst TCPFlag = 0x04 + TCPPsh TCPFlag = 0x08 + TCPAck TCPFlag = 0x10 + TCPUrg TCPFlag = 0x20 + TCPECNEcho TCPFlag = 0x40 + TCPCWR TCPFlag = 0x80 + TCPSynAck TCPFlag = TCPSyn | TCPAck + TCPECNBits TCPFlag = TCPECNEcho | TCPCWR ) // Parsed is a minimal decoding of a packet suitable for use in filters. @@ -180,7 +184,7 @@ func (q *Parsed) decode4(b []byte) { } q.Src = q.Src.WithPort(binary.BigEndian.Uint16(sub[0:2])) q.Dst = q.Dst.WithPort(binary.BigEndian.Uint16(sub[2:4])) - q.TCPFlags = TCPFlag(sub[13]) & 0x3F + q.TCPFlags = TCPFlag(sub[13]) headerLength := (sub[12] & 0xF0) >> 2 q.dataofs = q.subofs + int(headerLength) return @@ -282,7 +286,7 @@ func (q *Parsed) decode6(b []byte) { } q.Src = q.Src.WithPort(binary.BigEndian.Uint16(sub[0:2])) q.Dst = q.Dst.WithPort(binary.BigEndian.Uint16(sub[2:4])) - q.TCPFlags = TCPFlag(sub[13]) & 0x3F + q.TCPFlags = TCPFlag(sub[13]) headerLength := (sub[12] & 0xF0) >> 2 q.dataofs = q.subofs + int(headerLength) return @@ -374,8 +378,14 @@ func (q *Parsed) Payload() []byte { return q.b[q.dataofs:q.length] } -// IsTCPSyn reports whether q is a TCP SYN packet -// (i.e. the first packet in a new connection). +// Transport returns the transport header and payload (IP subprotocol, such as TCP or UDP). +// This is a read-only view; that is, p retains the ownership of the buffer. +func (p *Parsed) Transport() []byte { + return p.b[p.subofs:] +} + +// IsTCPSyn reports whether q is a TCP SYN packet, +// without ACK set. (i.e. the first packet in a new connection) func (q *Parsed) IsTCPSyn() bool { return (q.TCPFlags & TCPSynAck) == TCPSyn } @@ -424,6 +434,40 @@ func (q *Parsed) IsEchoResponse() bool { } } +// RemoveECNBits modifies p and its underlying memory buffer to remove +// ECN bits, if any. It reports whether it did so. +// +// It currently only does the TCP flags. +func (p *Parsed) RemoveECNBits() bool { + if p.IPVersion == 0 { + return false + } + if p.IPProto != ipproto.TCP { + // TODO(bradfitz): handle non-TCP too? for now only trying to + // fix the Issue 2642 problem. + return false + } + if p.TCPFlags&TCPECNBits == 0 { + // Nothing to do. + return false + } + + // Clear flags. + + // First in the parsed output. + p.TCPFlags = p.TCPFlags & ^TCPECNBits + + // Then in the underlying memory. + tcp := p.Transport() + old := binary.BigEndian.Uint16(tcp[12:14]) + tcp[13] = byte(p.TCPFlags) + new := binary.BigEndian.Uint16(tcp[12:14]) + oldSum := binary.BigEndian.Uint16(tcp[16:18]) + newSum := ^checksumUpdate2ByteAlignedUint16(^oldSum, old, new) + binary.BigEndian.PutUint16(tcp[16:18], newSum) + return true +} + func Hexdump(b []byte) string { out := new(strings.Builder) for i := 0; i < len(b); i += 16 { @@ -455,3 +499,26 @@ func Hexdump(b []byte) string { } return out.String() } + +// From gVisor's unexported API: + +// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated +// checksum. +// +// The value MUST begin at a 2-byte boundary in the original buffer. +func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 { + // As per RFC 1071 page 4, + //(4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + return checksumCombine(xsum, checksumCombine(new, ^old)) +} diff --git a/net/packet/packet_test.go b/net/packet/packet_test.go index c39abc3fb..4af18b304 100644 --- a/net/packet/packet_test.go +++ b/net/packet/packet_test.go @@ -6,7 +6,9 @@ package packet import ( "bytes" + "encoding/hex" "reflect" + "regexp" "testing" "inet.af/netaddr" @@ -561,3 +563,57 @@ func BenchmarkString(b *testing.B) { }) } } + +func TestRemoveECNBits(t *testing.T) { + // withECNHex is a TCP SYN packet with ECN bits set in the TCP + // header as captured by Wireshark on macOS against the + // Tailscale interface. In this packet (because it's a SYN + // control packet), the ECN bits are not set in the IP header. + const withECNHex = `45 00 00 40 00 00 40 00 + 40 06 0c 66 64 7b 65 28 64 7f 00 30 f1 ab 00 16 + 5a 7a 63 e8 00 00 00 00 b0 c2 ff ff 97 76 00 00 + 02 04 04 d8 01 03 03 06 01 01 08 0a 03 e1 bd 49 + 00 00 00 00 04 02 00 00` + + // Generated by hand-editing a pcap file in hexl-mode to set + // the TCP flags to just SYN (0x02), then loading that pcap + // file in wireshark to get the expected checksum value, then + // putting that checksum value (0x9836) in the file. + const wantStrippedHex = `45 00 00 40 00 00 40 00 + 40 06 0c 66 64 7b 65 28 64 7f 00 30 f1 ab 00 16 + 5a 7a 63 e8 00 00 00 00 b0 02 ff ff 98 36 00 00 + 02 04 04 d8 01 03 03 06 01 01 08 0a 03 e1 bd 49 + 00 00 00 00 04 02 00 00` + + var p Parsed + pktBuf := bytesOfHex(withECNHex) + p.Decode(pktBuf) + if want := TCPCWR | TCPECNEcho | TCPSyn; p.TCPFlags != want { + t.Fatalf("pre flags = %v; want %v", p.TCPFlags, want) + } + + if !p.RemoveECNBits() { + t.Fatal("didn't remove bits") + } + if want := TCPSyn; p.TCPFlags != want { + t.Fatalf("post flags = %v; want %v", p.TCPFlags, want) + } + wantPkt := bytesOfHex(wantStrippedHex) + if !bytes.Equal(pktBuf, wantPkt) { + t.Fatalf("wrong result.\n got: % 2x\nwant: % 2x\n", pktBuf, wantPkt) + } + + if p.RemoveECNBits() { + t.Fatal("unexpected true return value on second call") + } +} + +var nonHex = regexp.MustCompile(`[^0-9a-fA-F]+`) + +func bytesOfHex(s string) []byte { + b, err := hex.DecodeString(nonHex.ReplaceAllString(s, "")) + if err != nil { + panic(err) + } + return b +} diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index dbd3e4694..f92007059 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -512,6 +512,7 @@ func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Respons case 6: pn = header.IPv6ProtocolNumber } + p.RemoveECNBits() // Issue 2642 if debugPackets { ns.logf("[v2] packet in (from %v): % x", p.Src, p.Buffer()) }