From 73c40c77b0ede8744f3a1bcaa9b2d3b9d74b6c09 Mon Sep 17 00:00:00 2001 From: Dmytro Shynkevych Date: Tue, 2 Jun 2020 08:09:20 -0400 Subject: [PATCH] filter: prevent escape of QDecode to the heap (#417) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Performance impact: name old time/op new time/op delta Filter/tcp_in-4 70.7ns ± 1% 30.9ns ± 1% -56.30% (p=0.008 n=5+5) Filter/tcp_out-4 58.6ns ± 0% 19.4ns ± 0% -66.87% (p=0.000 n=5+4) Filter/udp_in-4 96.8ns ± 2% 55.5ns ± 0% -42.64% (p=0.016 n=5+4) Filter/udp_out-4 120ns ± 1% 79ns ± 1% -33.87% (p=0.008 n=5+5) Signed-off-by: Dmytro Shynkevych --- wgengine/filter/filter.go | 18 +++- wgengine/filter/filter_test.go | 172 ++++++++++++++++++++++++++------- 2 files changed, 153 insertions(+), 37 deletions(-) diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index d775d2763..efb632328 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -138,16 +138,26 @@ var acceptBucket = rate.NewLimiter(rate.Every(10*time.Second), 3) var dropBucket = rate.NewLimiter(rate.Every(5*time.Second), 10) func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.QDecode, r Response, why string) { + var verdict string + if r == Drop && (runflags&LogDrops) != 0 && dropBucket.Allow() { + verdict = "Drop" + runflags &= HexdumpDrops + } else if r == Accept && (runflags&LogAccepts) != 0 && acceptBucket.Allow() { + verdict = "Accept" + runflags &= HexdumpAccepts + } + + // Note: it is crucial that q.String() be called only if {accept,drop}Bucket.Allow() passes, + // since it causes an allocation. + if verdict != "" { var qs string if q == nil { qs = fmt.Sprintf("(%d bytes)", len(b)) } else { qs = q.String() } - f.logf("Drop: %v %v %s\n%s", qs, len(b), why, maybeHexdump(runflags&HexdumpDrops, b)) - } else if r == Accept && (runflags&LogAccepts) != 0 && acceptBucket.Allow() { - f.logf("Accept: %v %v %s\n%s", q, len(b), why, maybeHexdump(runflags&HexdumpAccepts, b)) + f.logf("%s: %s %d %s\n%s", verdict, qs, len(b), why, maybeHexdump(runflags, b)) } } @@ -254,7 +264,7 @@ func (f *Filter) pre(b []byte, q *packet.QDecode, rf RunFlags) Response { if q.IPProto == packet.Junk { // Junk packets are dangerous; always drop them. - f.logRateLimit(rf, b, q, Drop, "junk!") + f.logRateLimit(rf, b, q, Drop, "junk") return Drop } else if q.IPProto == packet.Fragment { // Fragments after the first always need to be passed through. diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index 72d0d5bf1..3397721df 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -7,9 +7,9 @@ package filter import ( "encoding/binary" "encoding/json" - "net" "testing" + "tailscale.com/types/logger" "tailscale.com/wgengine/packet" ) @@ -43,26 +43,29 @@ func netpr(ip IP, bits int, start, end uint16) []NetPortRange { } } -func TestFilter(t *testing.T) { - mm := Matches{ - {Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: []NetPortRange{ - NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}}, - NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}}, - }}, - {Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)}, - {Srcs: nets([]IP{0x02020202}), Dsts: ippr(0x08010101, 22, 22)}, - {Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)}, - {Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)}, - {Srcs: nets([]IP{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)}, - } +var matches = Matches{ + {Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: []NetPortRange{ + NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}}, + NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}}, + }}, + {Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)}, + {Srcs: nets([]IP{0x02020202}), Dsts: ippr(0x08010101, 22, 22)}, + {Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)}, + {Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)}, + {Srcs: nets([]IP{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)}, +} + +func newFilter(logf logger.Logf) *Filter { // Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8, // 102.102.102.102, 119.119.119.119, 8.1.0.0/16 localNets := nets([]IP{0x647a6232, 0x01020304, 0x05060708, 0x66666666, 0x77777777}) localNets = append(localNets, Net{IP(0x08010000), Netmask(16)}) - acl := New(mm, localNets, nil, t.Logf) + return New(matches, localNets, nil, logf) +} - for _, ent := range []Matches{Matches{mm[0]}, mm} { +func TestMarshal(t *testing.T) { + for _, ent := range []Matches{Matches{matches[0]}, matches} { b, err := json.Marshal(ent) if err != nil { t.Fatalf("marshal: %v", err) @@ -73,7 +76,10 @@ func TestFilter(t *testing.T) { t.Fatalf("unmarshal: %v (%v)", err, string(b)) } } +} +func TestFilter(t *testing.T) { + acl := newFilter(t.Logf) // check packet filtering based on the table type InOut struct { @@ -116,6 +122,83 @@ func TestFilter(t *testing.T) { } } +func TestNoAllocs(t *testing.T) { + acl := newFilter(t.Logf) + + tcpPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0) + udpPacket := rawpacket(UDP, 0x08010101, 0x01020304, 999, 22, 0) + + tests := []struct { + name string + in bool + want int + packet []byte + }{ + {"tcp_in", true, 0, tcpPacket}, + {"tcp_out", false, 0, tcpPacket}, + {"udp_in", true, 0, udpPacket}, + // One alloc is inevitable (an lru cache update) + {"udp_out", false, 1, udpPacket}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := int(testing.AllocsPerRun(1000, func() { + var q QDecode + if test.in { + acl.RunIn(test.packet, &q, 0) + } else { + acl.RunOut(test.packet, &q, 0) + } + })) + + if got > test.want { + t.Errorf("got %d allocs per run; want at most %d", got, test.want) + } + }) + } +} + +func BenchmarkFilter(b *testing.B) { + acl := newFilter(b.Logf) + + tcpPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0) + udpPacket := rawpacket(UDP, 0x08010101, 0x01020304, 999, 22, 0) + icmpPacket := rawpacket(ICMP, 0x08010101, 0x01020304, 0, 0, 0) + + tcpSynPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0) + // TCP filtering is trivial (Accept) for non-SYN packets. + tcpSynPacket[33] = packet.TCPSyn + + benches := []struct { + name string + in bool + packet []byte + }{ + // Non-SYN TCP and ICMP have similar code paths in and out. + {"icmp", true, icmpPacket}, + {"tcp", true, tcpPacket}, + {"tcp_syn_in", true, tcpSynPacket}, + {"tcp_syn_out", false, tcpSynPacket}, + {"udp_in", true, udpPacket}, + {"udp_out", false, udpPacket}, + } + + for _, bench := range benches { + b.Run(bench.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + var q QDecode + // This branch seems to have no measurable impact on performance. + if bench.in { + acl.RunIn(bench.packet, &q, 0) + } else { + acl.RunOut(bench.packet, &q, 0) + } + } + }) + } +} + func TestPreFilter(t *testing.T) { packets := []struct { desc string @@ -124,11 +207,11 @@ func TestPreFilter(t *testing.T) { }{ {"empty", Accept, []byte{}}, {"short", Drop, []byte("short")}, - {"junk", Drop, rawpacket(Junk, 10)}, - {"fragment", Accept, rawpacket(Fragment, 40)}, - {"tcp", noVerdict, rawpacket(TCP, 200)}, - {"udp", noVerdict, rawpacket(UDP, 200)}, - {"icmp", noVerdict, rawpacket(ICMP, 200)}, + {"junk", Drop, rawdefault(Junk, 10)}, + {"fragment", Accept, rawdefault(Fragment, 40)}, + {"tcp", noVerdict, rawdefault(TCP, 200)}, + {"udp", noVerdict, rawdefault(UDP, 200)}, + {"icmp", noVerdict, rawdefault(ICMP, 200)}, } f := NewAllowNone(t.Logf) for _, testPacket := range packets { @@ -150,22 +233,38 @@ func qdecode(proto packet.IPProto, src, dst packet.IP, sport, dport uint16) QDec } } -func rawpacket(proto packet.IPProto, len uint16) []byte { - bl := len - if len < 24 { - bl = 24 +// rawpacket generates a packet with given source and destination ports and IPs +// and resizes the header to trimLength if it is nonzero. +func rawpacket(proto packet.IPProto, src, dst packet.IP, sport, dport uint16, trimLength int) []byte { + var headerLength int + + switch proto { + case ICMP: + headerLength = 24 + case TCP: + headerLength = 40 + case UDP: + headerLength = 28 + default: + headerLength = 24 + } + if trimLength > headerLength { + headerLength = trimLength } + if trimLength == 0 { + trimLength = headerLength + } + bin := binary.BigEndian - hdr := make([]byte, bl) + hdr := make([]byte, headerLength) hdr[0] = 0x45 - bin.PutUint16(hdr[2:4], len) + bin.PutUint16(hdr[2:4], uint16(trimLength)) hdr[8] = 64 - ip := net.IPv4(8, 8, 8, 8).To4() - copy(hdr[12:16], ip) - copy(hdr[16:20], ip) + bin.PutUint32(hdr[12:16], uint32(src)) + bin.PutUint32(hdr[16:20], uint32(dst)) // ports - bin.PutUint16(hdr[20:22], 53) - bin.PutUint16(hdr[22:24], 53) + bin.PutUint16(hdr[20:22], sport) + bin.PutUint16(hdr[22:24], dport) switch proto { case ICMP: @@ -183,8 +282,15 @@ func rawpacket(proto packet.IPProto, len uint16) []byte { panic("unknown protocol") } - // Truncate the header if requested - hdr = hdr[:len] + // Trim the header if requested + hdr = hdr[:trimLength] return hdr } + +// rawdefault calls rawpacket with default ports and IPs. +func rawdefault(proto packet.IPProto, trimLength int) []byte { + ip := IP(0x08080808) // 8.8.8.8 + port := uint16(53) + return rawpacket(proto, ip, ip, port, port, trimLength) +}