diff --git a/net/packet/ip6.go b/net/packet/ip6.go index cdff94093..a441646b5 100644 --- a/net/packet/ip6.go +++ b/net/packet/ip6.go @@ -5,13 +5,14 @@ package packet import ( + "encoding/binary" "fmt" "inet.af/netaddr" ) // IP6 is an IPv6 address. -type IP6 [16]byte +type IP6 [16]byte // TODO: maybe 2x uint64 would be faster for the type of ops we do? // IP6FromNetaddr converts a netaddr.IP to an IP6. Panics if !ip.Is6. func IP6FromNetaddr(ip netaddr.IP) IP6 { @@ -30,5 +31,73 @@ func (ip IP6) String() string { return ip.Netaddr().String() } +func (ip IP6) IsMulticast() bool { + return ip[0] == 0xFF +} + +func (ip IP6) IsLinkLocalUnicast() bool { + return ip[0] == 0xFE && ip[1] == 0x80 +} + // ip6HeaderLength is the length of an IPv6 header with no IP options. const ip6HeaderLength = 40 + +// IP6Header represents an IPv6 packet header. +type IP6Header struct { + IPProto IPProto + IPID uint32 // only lower 20 bits used + SrcIP IP6 + DstIP IP6 +} + +// Len implements Header. +func (h IP6Header) Len() int { + return ip6HeaderLength +} + +// Marshal implements Header. +func (h IP6Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + + binary.BigEndian.PutUint32(buf[:4], h.IPID&0x000FFFFF) + buf[0] = 0x60 + binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length + buf[6] = uint8(h.IPProto) // Inner protocol + buf[7] = 64 // TTL + copy(buf[8:24], h.SrcIP[:]) + copy(buf[24:40], h.DstIP[:]) + + return nil +} + +// ToResponse implements Header. +func (h *IP6Header) ToResponse() { + h.SrcIP, h.DstIP = h.DstIP, h.SrcIP + // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. + h.IPID = (^h.IPID) & 0x000FFFFF +} + +// marshalPseudo serializes h into buf in the "pseudo-header" form +// required when calculating UDP checksums. +func (h IP6Header) marshalPseudo(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + + copy(buf[:16], h.SrcIP[:]) + copy(buf[16:32], h.DstIP[:]) + binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len())) + buf[36] = 0 + buf[37] = 0 + buf[38] = 0 + buf[39] = 17 // NextProto + return nil +} diff --git a/net/packet/packet.go b/net/packet/packet.go index 0aa5b7351..7b2c12111 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -186,6 +186,10 @@ func (q *Parsed) decode4(b []byte) { q.DstPort = 0 q.dataofs = q.subofs + icmp4HeaderLength return + case IGMP: + // Keep IPProto, but don't parse anything else + // out. + return case TCP: if len(sub) < tcpHeaderLength { q.IPProto = Unknown diff --git a/net/packet/packet_test.go b/net/packet/packet_test.go index 951b9605a..32685323c 100644 --- a/net/packet/packet_test.go +++ b/net/packet/packet_test.go @@ -307,6 +307,29 @@ var udp4ReplyDecode = Parsed{ DstPort: 123, } +var igmpPacketBuffer = []byte{ + // IP header up to checksum + 0x46, 0xc0, 0x00, 0x20, 0x00, 0x00, 0x40, 0x00, 0x01, 0x02, 0x41, 0x22, + // source IP + 0xc0, 0xa8, 0x01, 0x52, + // destination IP + 0xe0, 0x00, 0x00, 0xfb, + // IGMP Membership Report + 0x94, 0x04, 0x00, 0x00, 0x16, 0x00, 0x09, 0x04, 0xe0, 0x00, 0x00, 0xfb, + //0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +} + +var igmpPacketDecode = Parsed{ + b: igmpPacketBuffer, + subofs: 24, + length: len(igmpPacketBuffer), + + IPVersion: 4, + IPProto: IGMP, + SrcIP4: mustIP4("192.168.1.82"), + DstIP4: mustIP4("224.0.0.251"), +} + func TestParsed(t *testing.T) { tests := []struct { name string @@ -319,6 +342,7 @@ func TestParsed(t *testing.T) { {"udp6", udp6RequestDecode, "UDP{[2001:559:bc13:5400:1749:4628:3934:e1b]:54276 > [2607:f8b0:400a:809::200e]:443}"}, {"icmp4", icmp4RequestDecode, "ICMPv4{1.2.3.4:0 > 5.6.7.8:0}"}, {"icmp6", icmp6PacketDecode, "ICMPv6{[fe80::fb57:1dea:9c39:8fb7]:0 > [ff02::2]:0}"}, + {"igmp", igmpPacketDecode, "IGMP{192.168.1.82:0 > 224.0.0.251:0}"}, {"unknown", unknownPacketDecode, "Unknown{???}"}, } @@ -353,6 +377,7 @@ func TestDecode(t *testing.T) { {"tcp6", tcp6RequestBuffer, tcp6RequestDecode}, {"udp4", udp4RequestBuffer, udp4RequestDecode}, {"udp6", udp6RequestBuffer, udp6RequestDecode}, + {"igmp", igmpPacketBuffer, igmpPacketDecode}, {"unknown", unknownPacketBuffer, unknownPacketDecode}, {"invalid4", invalid4RequestBuffer, invalid4RequestDecode}, } @@ -387,6 +412,7 @@ func BenchmarkDecode(b *testing.B) { {"udp6", udp6RequestBuffer}, {"icmp4", icmp4RequestBuffer}, {"icmp6", icmp6PacketBuffer}, + {"igmp", igmpPacketBuffer}, {"unknown", unknownPacketBuffer}, } diff --git a/net/packet/udp6.go b/net/packet/udp6.go new file mode 100644 index 000000000..0450eae9e --- /dev/null +++ b/net/packet/udp6.go @@ -0,0 +1,51 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package packet + +import "encoding/binary" + +// UDP6Header is an IPv6+UDP header. +type UDP6Header struct { + IP6Header + SrcPort uint16 + DstPort uint16 +} + +// Len implements Header. +func (h UDP6Header) Len() int { + return h.IP6Header.Len() + udpHeaderLength +} + +// Marshal implements Header. +func (h UDP6Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + // The caller does not need to set this. + h.IPProto = UDP + + length := len(buf) - h.IP6Header.Len() + binary.BigEndian.PutUint16(buf[40:42], h.SrcPort) + binary.BigEndian.PutUint16(buf[42:44], h.DstPort) + binary.BigEndian.PutUint16(buf[44:46], uint16(length)) + binary.BigEndian.PutUint16(buf[46:48], 0) // blank checksum + + // UDP checksum with IP pseudo header. + h.IP6Header.marshalPseudo(buf) + binary.BigEndian.PutUint16(buf[46:48], ip4Checksum(buf[:])) + + h.IP6Header.Marshal(buf) + + return nil +} + +// ToResponse implements Header. +func (h *UDP6Header) ToResponse() { + h.SrcPort, h.DstPort = h.DstPort, h.SrcPort + h.IP6Header.ToResponse() +} diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index 466158fb2..04fe7f70a 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -20,38 +20,50 @@ import ( // Filter is a stateful packet filter. type Filter struct { logf logger.Logf - // localNets is the list of IP prefixes that we know to be - // "local" to this node. All packets coming in over tailscale - // must have a destination within localNets, regardless of the - // policy filter below. A nil localNets rejects all incoming - // traffic. + // local4 and local6 are the lists of IP prefixes that we know + // to be "local" to this node. All packets coming in over + // tailscale must have a destination within local4 or local6, + // regardless of the policy filter below. Zero values reject + // all incoming traffic. local4 []net4 - // matches4 is a list of match->action rules applied to all - // packets arriving over tailscale tunnels. Matches are - // checked in order, and processing stops at the first - // matching rule. The default policy if no rules match is to - // drop the packet. + local6 []net6 + // matches4 and matches6 are lists of match->action rules + // applied to all packets arriving over tailscale + // tunnels. Matches are checked in order, and processing stops + // at the first matching rule. The default policy if no rules + // match is to drop the packet. matches4 matches4 + matches6 matches6 // state is the connection tracking state attached to this // filter. It is used to allow incoming traffic that is a response // to an outbound connection that this node made, even if those // incoming packets don't get accepted by matches above. - state *filterState + state4 *filterState + state6 *filterState } -// tuple is a 4-tuple of source and destination IPv4 and port. It's +// tuple4 is a 4-tuple of source and destination IPv4 and port. It's // used as a lookup key in filterState. -type tuple struct { +type tuple4 struct { SrcIP packet.IP4 DstIP packet.IP4 SrcPort uint16 DstPort uint16 } +// tuple6 is a 4-tuple of source and destination IPv6 and port. It's +// used as a lookup key in filterState. +type tuple6 struct { + SrcIP packet.IP6 + DstIP packet.IP6 + SrcPort uint16 + DstPort uint16 +} + // filterState is a state cache of past seen packets. type filterState struct { mu sync.Mutex - lru *lru.Cache // of tuple + lru *lru.Cache // of tuple4 or tuple6 } // lruMax is the size of the LRU cache in filterState. @@ -93,21 +105,36 @@ const ( // everything. Use in tests only, as it permits some kinds of spoofing // attacks to reach the OS network stack. func NewAllowAllForTest(logf logger.Logf) *Filter { - any4 := netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0} // TODO: IPv6 - m := Match{ - Srcs: []netaddr.IPPrefix{any4}, - Dsts: []NetPortRange{ - { - Net: any4, - Ports: PortRange{ - First: 0, - Last: 65535, + any4 := netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0} + any6 := netaddr.IPPrefix{IP: netaddr.IPFrom16([16]byte{}), Bits: 0} + ms := []Match{ + { + Srcs: []netaddr.IPPrefix{any4}, + Dsts: []NetPortRange{ + { + Net: any4, + Ports: PortRange{ + First: 0, + Last: 65535, + }, + }, + }, + }, + { + Srcs: []netaddr.IPPrefix{any6}, + Dsts: []NetPortRange{ + { + Net: any6, + Ports: PortRange{ + First: 0, + Last: 65535, + }, }, }, }, } - return New([]Match{m}, []netaddr.IPPrefix{any4}, nil, logf) + return New(ms, []netaddr.IPPrefix{any4, any6}, nil, logf) } // NewAllowNone returns a packet filter that rejects everything. @@ -121,19 +148,26 @@ func NewAllowNone(logf logger.Logf) *Filter { // shares state with the previous one, to enable changing rules at // runtime without breaking existing stateful flows. func New(matches []Match, localNets []netaddr.IPPrefix, shareStateWith *Filter, logf logger.Logf) *Filter { - var state *filterState + var state4, state6 *filterState if shareStateWith != nil { - state = shareStateWith.state + state4 = shareStateWith.state4 + state6 = shareStateWith.state6 } else { - state = &filterState{ + state4 = &filterState{ + lru: lru.New(lruMax), + } + state6 = &filterState{ lru: lru.New(lruMax), } } f := &Filter{ logf: logf, matches4: newMatches4(matches), + matches6: newMatches6(matches), local4: nets4FromIPPrefixes(localNets), - state: state, + local6: nets6FromIPPrefixes(localNets), + state4: state4, + state6: state6, } return f } @@ -188,11 +222,24 @@ var dummyPacket = []byte{ func (f *Filter) CheckTCP(srcIP, dstIP netaddr.IP, dstPort uint16) Response { pkt := &packet.Parsed{} pkt.Decode(dummyPacket) // initialize private fields - pkt.IPVersion = 4 + switch { + case (srcIP.Is4() && dstIP.Is6()) || (srcIP.Is6() && srcIP.Is4()): + // Mistmatched address families, no filters will + // match. + return Drop + case srcIP.Is4(): + pkt.IPVersion = 4 + pkt.SrcIP4 = packet.IP4FromNetaddr(srcIP) + pkt.DstIP4 = packet.IP4FromNetaddr(dstIP) + case srcIP.Is6(): + pkt.IPVersion = 6 + pkt.SrcIP6 = packet.IP6FromNetaddr(srcIP) + pkt.DstIP6 = packet.IP6FromNetaddr(dstIP) + default: + panic("unreachable") + } pkt.IPProto = packet.TCP pkt.TCPFlags = packet.TCPSyn - pkt.SrcIP4 = packet.IP4FromNetaddr(srcIP) // TODO: IPv6 - pkt.DstIP4 = packet.IP4FromNetaddr(dstIP) pkt.SrcPort = 0 pkt.DstPort = dstPort @@ -209,7 +256,15 @@ func (f *Filter) RunIn(q *packet.Parsed, rf RunFlags) Response { return r } - r, why := f.runIn(q) + var why string + switch q.IPVersion { + case 4: + r, why = f.runIn4(q) + case 6: + r, why = f.runIn6(q) + default: + r, why = Drop, "not-ip" + } f.logRateLimit(rf, q, dir, r, why) return r } @@ -228,8 +283,7 @@ func (f *Filter) RunOut(q *packet.Parsed, rf RunFlags) Response { return r } -// runIn runs the input-specific part of the filter logic. -func (f *Filter) runIn(q *packet.Parsed) (r Response, why string) { +func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) { // A compromised peer could try to send us packets for // destinations we didn't explicitly advertise. This check is to // prevent that. @@ -237,11 +291,6 @@ func (f *Filter) runIn(q *packet.Parsed) (r Response, why string) { return Drop, "destination not allowed" } - if q.IPVersion == 6 { - // TODO: support IPv6. - return Drop, "no rules matched" - } - switch q.IPProto { case packet.ICMPv4: if q.IsEchoResponse() || q.IsError() { @@ -271,11 +320,11 @@ func (f *Filter) runIn(q *packet.Parsed) (r Response, why string) { return Accept, "tcp ok" } case packet.UDP: - t := tuple{q.SrcIP4, q.DstIP4, q.SrcPort, q.DstPort} + t := tuple4{q.SrcIP4, q.DstIP4, q.SrcPort, q.DstPort} - f.state.mu.Lock() - _, ok := f.state.lru.Get(t) - f.state.mu.Unlock() + f.state4.mu.Lock() + _, ok := f.state4.lru.Get(t) + f.state4.mu.Unlock() if ok { return Accept, "udp cached" @@ -289,15 +338,80 @@ func (f *Filter) runIn(q *packet.Parsed) (r Response, why string) { return Drop, "no rules matched" } +func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) { + // A compromised peer could try to send us packets for + // destinations we didn't explicitly advertise. This check is to + // prevent that. + if !ip6InList(q.DstIP6, f.local6) { + return Drop, "destination not allowed" + } + + switch q.IPProto { + case packet.ICMPv6: + if q.IsEchoResponse() || q.IsError() { + // ICMP responses are allowed. + // TODO(apenwarr): consider using conntrack state. + // We could choose to reject all packets that aren't + // related to an existing ICMP-Echo, TCP, or UDP + // session. + return Accept, "icmp response ok" + } else if f.matches6.matchIPsOnly(q) { + // If any port is open to an IP, allow ICMP to it. + return Accept, "icmp ok" + } + case packet.TCP: + // For TCP, we want to allow *outgoing* connections, + // which means we want to allow return packets on those + // connections. To make this restriction work, we need to + // allow non-SYN packets (continuation of an existing session) + // to arrive. This should be okay since a new incoming session + // can't be initiated without first sending a SYN. + // It happens to also be much faster. + // TODO(apenwarr): Skip the rest of decoding in this path? + if q.IPProto == packet.TCP && !q.IsTCPSyn() { + return Accept, "tcp non-syn" + } + if f.matches6.match(q) { + return Accept, "tcp ok" + } + case packet.UDP: + t := tuple6{q.SrcIP6, q.DstIP6, q.SrcPort, q.DstPort} + + f.state6.mu.Lock() + _, ok := f.state6.lru.Get(t) + f.state6.mu.Unlock() + + if ok { + return Accept, "udp cached" + } + if f.matches6.match(q) { + return Accept, "udp ok" + } + default: + return Drop, "Unknown proto" + } + return Drop, "no rules matched" +} + // runIn runs the output-specific part of the filter logic. func (f *Filter) runOut(q *packet.Parsed) (r Response, why string) { - if q.IPProto == packet.UDP { - t := tuple{q.DstIP4, q.SrcIP4, q.DstPort, q.SrcPort} - var ti interface{} = t // allocate once, rather than twice inside mutex + if q.IPProto != packet.UDP { + return Accept, "ok out" + } - f.state.mu.Lock() - f.state.lru.Add(ti, ti) - f.state.mu.Unlock() + switch q.IPVersion { + case 4: + t := tuple4{q.DstIP4, q.SrcIP4, q.DstPort, q.SrcPort} + var ti interface{} = t // allocate once, rather than twice inside mutex + f.state4.mu.Lock() + f.state4.lru.Add(ti, ti) + f.state4.mu.Unlock() + case 6: + t := tuple6{q.DstIP6, q.SrcIP6, q.DstPort, q.SrcPort} + var ti interface{} = t // allocate once, rather than twice inside mutex + f.state6.mu.Lock() + f.state6.lru.Add(ti, ti) + f.state6.mu.Unlock() } return Accept, "ok out" } @@ -334,17 +448,25 @@ func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response { return Drop } - if q.IPVersion == 6 { - f.logRateLimit(rf, q, dir, Drop, "ipv6") - return Drop - } - if q.DstIP4.IsMulticast() { - f.logRateLimit(rf, q, dir, Drop, "multicast") - return Drop - } - if q.DstIP4.IsLinkLocalUnicast() { - f.logRateLimit(rf, q, dir, Drop, "link-local-unicast") - return Drop + switch q.IPVersion { + case 4: + if q.DstIP4.IsMulticast() { + f.logRateLimit(rf, q, dir, Drop, "multicast") + return Drop + } + if q.DstIP4.IsLinkLocalUnicast() { + f.logRateLimit(rf, q, dir, Drop, "link-local-unicast") + return Drop + } + case 6: + if q.DstIP6.IsMulticast() { + f.logRateLimit(rf, q, dir, Drop, "multicast") + return Drop + } + if q.DstIP6.IsLinkLocalUnicast() { + f.logRateLimit(rf, q, dir, Drop, "link-local-unicast") + return Drop + } } switch q.IPProto { @@ -362,61 +484,21 @@ func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response { return noVerdict } -const ( - // ipv6AllRoutersLinkLocal is ff02::2 (All link-local routers) - ipv6AllRoutersLinkLocal = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - // ipv6AllMLDv2CapableRouters is ff02::16 (All MLDv2-capable routers) - ipv6AllMLDv2CapableRouters = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x16" -) - // omitDropLogging reports whether packet p, which has already been // deemed a packet to Drop, should bypass the [rate-limited] logging. // We don't want to log scary & spammy reject warnings for packets // that are totally normal, like IPv6 route announcements. func omitDropLogging(p *packet.Parsed, dir direction) bool { - b := p.Buffer() - switch dir { - case out: - switch p.IPVersion { - case 4: - // Parsed.Decode zeros out Parsed.IPProtocol for protocols - // it doesn't know about, so parse it out ourselves if needed. - ipProto := p.IPProto - if ipProto == 0 && len(b) > 8 { - ipProto = packet.IPProto(b[9]) - } - // Omit logging about outgoing IGMP. - if ipProto == packet.IGMP { - return true - } - if p.DstIP4.IsMulticast() || p.DstIP4.IsLinkLocalUnicast() { - return true - } - case 6: - if len(b) < 40 { - return false - } - src, dst := b[8:8+16], b[24:24+16] - // Omit logging for outgoing IPv6 ICMP-v6 queries to ff02::2, - // as sent by the OS, looking for routers. - if p.IPProto == packet.ICMPv6 { - if isLinkLocalV6(src) && string(dst) == ipv6AllRoutersLinkLocal { - return true - } - } - if string(dst) == ipv6AllMLDv2CapableRouters { - return true - } - // Actually, just catch all multicast. - if dst[0] == 0xff { - return true - } - } + if dir != out { + return false } - return false -} -// isLinkLocalV6 reports whether src is in fe80::/10. -func isLinkLocalV6(src []byte) bool { - return len(src) == 16 && src[0] == 0xfe && src[1]>>6 == 0x80>>6 + switch p.IPVersion { + case 4: + return p.DstIP4.IsMulticast() || p.DstIP4.IsLinkLocalUnicast() || p.IPProto == packet.IGMP + case 6: + return p.DstIP6.IsMulticast() || p.DstIP6.IsLinkLocalUnicast() + default: + return false + } } diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index 7466cdef5..aea7c9978 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -5,9 +5,7 @@ package filter import ( - "encoding/binary" "encoding/hex" - "encoding/json" "fmt" "strconv" "strings" @@ -18,189 +16,155 @@ import ( "tailscale.com/types/logger" ) -var Unknown = packet.Unknown -var ICMPv4 = packet.ICMPv4 -var TCP = packet.TCP -var UDP = packet.UDP -var Fragment = packet.Fragment - -func mustIP4(s string) packet.IP4 { - ip, err := netaddr.ParseIP(s) - if err != nil { - panic(err) - } - return packet.IP4FromNetaddr(ip) -} - -func pfx(s string) netaddr.IPPrefix { - pfx, err := netaddr.ParseIPPrefix(s) - if err != nil { - panic(err) - } - return pfx -} - -func nets(nets ...string) (ret []netaddr.IPPrefix) { - for _, s := range nets { - if i := strings.IndexByte(s, '/'); i == -1 { - ip, err := netaddr.ParseIP(s) - if err != nil { - panic(err) - } - bits := uint8(32) - if ip.Is6() { - bits = 128 - } - ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits}) - } else { - pfx, err := netaddr.ParseIPPrefix(s) - if err != nil { - panic(err) - } - ret = append(ret, pfx) - } - } - return ret -} - -func ports(s string) PortRange { - if s == "*" { - return PortRange{First: 0, Last: 65535} - } - - var fs, ls string - i := strings.IndexByte(s, '-') - if i == -1 { - fs = s - ls = fs - } else { - fs = s[:i] - ls = s[i+1:] - } - first, err := strconv.ParseInt(fs, 10, 16) - if err != nil { - panic(fmt.Sprintf("invalid NetPortRange %q", s)) - } - last, err := strconv.ParseInt(ls, 10, 16) - if err != nil { - panic(fmt.Sprintf("invalid NetPortRange %q", s)) - } - return PortRange{uint16(first), uint16(last)} -} - -func netports(netPorts ...string) (ret []NetPortRange) { - for _, s := range netPorts { - i := strings.LastIndexByte(s, ':') - if i == -1 { - panic(fmt.Sprintf("invalid NetPortRange %q", s)) - } - - npr := NetPortRange{ - Net: nets(s[:i])[0], - Ports: ports(s[i+1:]), - } - ret = append(ret, npr) +func newFilter(logf logger.Logf) *Filter { + matches := []Match{ + {Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("1.2.3.4:22", "5.6.7.8:23-24")}, + {Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("5.6.7.8:27-28")}, + {Srcs: nets("2.2.2.2"), Dsts: netports("8.1.1.1:22")}, + {Srcs: nets("0.0.0.0/0"), Dsts: netports("100.122.98.50:*")}, + {Srcs: nets("0.0.0.0/0"), Dsts: netports("0.0.0.0/0:443")}, + {Srcs: nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), Dsts: netports("1.2.3.4:999")}, + {Srcs: nets("::1", "::2"), Dsts: netports("2001::1:22")}, + {Srcs: nets("::/0"), Dsts: netports("::/0:443")}, } - return ret -} -var matches = []Match{ - {Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("1.2.3.4:22", "5.6.7.8:23-24")}, - {Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("5.6.7.8:27-28")}, - {Srcs: nets("2.2.2.2"), Dsts: netports("8.1.1.1:22")}, - {Srcs: nets("0.0.0.0/0"), Dsts: netports("100.122.98.50:*")}, - {Srcs: nets("0.0.0.0/0"), Dsts: netports("0.0.0.0/0:443")}, - {Srcs: nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), Dsts: netports("1.2.3.4:999")}, -} - -func newFilter(logf logger.Logf) *Filter { // Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8, // 102.102.102.102, 119.119.119.119, 8.1.0.0/16 - localNets := nets("100.122.98.50", "1.2.3.4", "5.6.7.8", "102.102.102.102", "119.119.119.119", "8.1.0.0/16") + localNets := nets("100.122.98.50", "1.2.3.4", "5.6.7.8", "102.102.102.102", "119.119.119.119", "8.1.0.0/16", "2001::/16") return New(matches, localNets, nil, logf) } -func TestMarshal(t *testing.T) { - for _, ent := range [][]Match{[]Match{matches[0]}, matches} { - b, err := json.Marshal(ent) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - mm2 := []Match{} - if err := json.Unmarshal(b, &mm2); err != nil { - t.Fatalf("unmarshal: %v (%v)", err, string(b)) - } - } -} - func TestFilter(t *testing.T) { acl := newFilter(t.Logf) - // check packet filtering based on the table type InOut struct { want Response p packet.Parsed } tests := []InOut{ - // Basic - {Accept, parsed(TCP, 0x08010101, 0x01020304, 999, 22)}, - {Accept, parsed(UDP, 0x08010101, 0x01020304, 999, 22)}, - {Accept, parsed(ICMPv4, 0x08010101, 0x01020304, 0, 0)}, - {Drop, parsed(TCP, 0x08010101, 0x01020304, 0, 0)}, - {Accept, parsed(TCP, 0x08010101, 0x01020304, 0, 22)}, - {Drop, parsed(TCP, 0x08010101, 0x01020304, 0, 21)}, - {Accept, parsed(TCP, 0x11223344, 0x08012233, 0, 443)}, - {Drop, parsed(TCP, 0x11223344, 0x08012233, 0, 444)}, - {Accept, parsed(TCP, 0x11223344, 0x647a6232, 0, 999)}, - {Accept, parsed(TCP, 0x11223344, 0x647a6232, 0, 0)}, + // allow 8.1.1.1 => 1.2.3.4:22 + {Accept, parsed(packet.TCP, "8.1.1.1", "1.2.3.4", 999, 22)}, + {Accept, parsed(packet.ICMPv4, "8.1.1.1", "1.2.3.4", 0, 0)}, + {Drop, parsed(packet.TCP, "8.1.1.1", "1.2.3.4", 0, 0)}, + {Accept, parsed(packet.TCP, "8.1.1.1", "1.2.3.4", 0, 22)}, + {Drop, parsed(packet.TCP, "8.1.1.1", "1.2.3.4", 0, 21)}, + // allow 8.2.2.2. => 1.2.3.4:22 + {Accept, parsed(packet.TCP, "8.2.2.2", "1.2.3.4", 0, 22)}, + {Drop, parsed(packet.TCP, "8.2.2.2", "1.2.3.4", 0, 23)}, + {Drop, parsed(packet.TCP, "8.3.3.3", "1.2.3.4", 0, 22)}, + // allow * => *:443 + {Accept, parsed(packet.TCP, "17.34.51.68", "8.1.34.51", 0, 443)}, + {Drop, parsed(packet.TCP, "17.34.51.68", "8.1.34.51", 0, 444)}, + // allow * => 100.122.98.50:* + {Accept, parsed(packet.TCP, "17.34.51.68", "100.122.98.50", 0, 999)}, + {Accept, parsed(packet.TCP, "17.34.51.68", "100.122.98.50", 0, 0)}, + + // allow ::1, ::2 => [2001::1]:22 + {Accept, parsed(packet.TCP, "::1", "2001::1", 0, 22)}, + {Accept, parsed(packet.ICMPv6, "::1", "2001::1", 0, 0)}, + {Accept, parsed(packet.TCP, "::2", "2001::1", 0, 22)}, + {Drop, parsed(packet.TCP, "::1", "2001::1", 0, 23)}, + {Drop, parsed(packet.TCP, "::1", "2001::2", 0, 22)}, + {Drop, parsed(packet.TCP, "::3", "2001::1", 0, 22)}, + // allow * => *:443 + {Accept, parsed(packet.TCP, "::1", "2001::1", 0, 443)}, + {Drop, parsed(packet.TCP, "::1", "2001::1", 0, 444)}, // localNets prefilter - accepted by policy filter, but // unexpected dst IP. - {Drop, parsed(TCP, 0x08010101, 0x10203040, 0, 443)}, - - // Stateful UDP. Note each packet is run through the input - // filter, then the output filter (which sets conntrack - // state). - // Initially empty cache - {Drop, parsed(UDP, 0x77777777, 0x66666666, 4242, 4343)}, - // Return packet from previous attempt is allowed - {Accept, parsed(UDP, 0x66666666, 0x77777777, 4343, 4242)}, - // Because of the return above, initial attempt is allowed now - {Accept, parsed(UDP, 0x77777777, 0x66666666, 4242, 4343)}, + {Drop, parsed(packet.TCP, "8.1.1.1", "16.32.48.64", 0, 443)}, + {Drop, parsed(packet.TCP, "1::", "2602::1", 0, 443)}, } for i, test := range tests { - if got, _ := acl.runIn(&test.p); test.want != got { - t.Errorf("#%d runIn got=%v want=%v packet:%v", i, got, test.want, test.p) + aclFunc := acl.runIn4 + if test.p.IPVersion == 6 { + aclFunc = acl.runIn6 } - if test.p.IPProto == TCP { - if got := acl.CheckTCP(test.p.SrcIP4.Netaddr(), test.p.DstIP4.Netaddr(), test.p.DstPort); test.want != got { + if got, why := aclFunc(&test.p); test.want != got { + t.Errorf("#%d runIn4 got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p) + } + if test.p.IPProto == packet.TCP { + var got Response + if test.p.IPVersion == 4 { + got = acl.CheckTCP(test.p.SrcIP4.Netaddr(), test.p.DstIP4.Netaddr(), test.p.DstPort) + } else { + got = acl.CheckTCP(test.p.SrcIP6.Netaddr(), test.p.DstIP6.Netaddr(), test.p.DstPort) + } + if test.want != got { t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p) } + // TCP and UDP are treated equivalently in the filter - verify that. + test.p.IPProto = packet.UDP + if got, why := aclFunc(&test.p); test.want != got { + t.Errorf("#%d runIn4 (UDP) got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p) + } } // Update UDP state _, _ = acl.runOut(&test.p) } } +func TestUDPState(t *testing.T) { + acl := newFilter(t.Logf) + flags := LogDrops | LogAccepts + + a4 := parsed(packet.UDP, "119.119.119.119", "102.102.102.102", 4242, 4343) + b4 := parsed(packet.UDP, "102.102.102.102", "119.119.119.119", 4343, 4242) + + // Unsollicited UDP traffic gets dropped + if got := acl.RunIn(&a4, flags); got != Drop { + t.Fatalf("incoming initial packet not dropped, got=%v: %v", got, a4) + } + // We talk to that peer + if got := acl.RunOut(&b4, flags); got != Accept { + t.Fatalf("outbound packet didn't egress, got=%v: %v", got, b4) + } + // Now, the same packet as before is allowed back. + if got := acl.RunIn(&a4, flags); got != Accept { + t.Fatalf("incoming response packet not accepted, got=%v: %v", got, a4) + } + + a6 := parsed(packet.UDP, "2001::2", "2001::1", 4242, 4343) + b6 := parsed(packet.UDP, "2001::1", "2001::2", 4343, 4242) + + // Unsollicited UDP traffic gets dropped + if got := acl.RunIn(&a6, flags); got != Drop { + t.Fatalf("incoming initial packet not dropped: %v", a4) + } + // We talk to that peer + if got := acl.RunOut(&b6, flags); got != Accept { + t.Fatalf("outbound packet didn't egress: %v", b4) + } + // Now, the same packet as before is allowed back. + if got := acl.RunIn(&a6, flags); got != Accept { + t.Fatalf("incoming response packet not accepted: %v", a4) + } +} + func TestNoAllocs(t *testing.T) { acl := newFilter(t.Logf) - tcpPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0) - udpPacket := rawpacket(UDP, 0x08010101, 0x01020304, 999, 22, 0) + tcp4Packet := raw4(packet.TCP, "8.1.1.1", "1.2.3.4", 999, 22, 0) + udp4Packet := raw4(packet.UDP, "8.1.1.1", "1.2.3.4", 999, 22, 0) + tcp6Packet := raw6(packet.TCP, "2001::1", "2001::2", 999, 22, 0) + udp6Packet := raw6(packet.UDP, "2001::1", "2001::2", 999, 22, 0) tests := []struct { name string - in bool + dir direction want int packet []byte }{ - {"tcp_in", true, 0, tcpPacket}, - {"tcp_out", false, 0, tcpPacket}, - {"udp_in", true, 0, udpPacket}, + {"tcp4_in", in, 0, tcp4Packet}, + {"tcp6_in", in, 0, tcp6Packet}, + {"tcp4_out", out, 0, tcp4Packet}, + {"tcp6_out", out, 0, tcp6Packet}, + {"udp4_in", in, 0, udp4Packet}, + {"udp6_in", in, 0, udp6Packet}, // One alloc is inevitable (an lru cache update) - {"udp_out", false, 1, udpPacket}, + {"udp4_out", out, 1, udp4Packet}, + {"udp6_out", out, 1, udp6Packet}, } for _, test := range tests { @@ -208,9 +172,10 @@ func TestNoAllocs(t *testing.T) { got := int(testing.AllocsPerRun(1000, func() { q := &packet.Parsed{} q.Decode(test.packet) - if test.in { + switch test.dir { + case in: acl.RunIn(q, 0) - } else { + case out: acl.RunOut(q, 0) } })) @@ -231,11 +196,13 @@ func TestParseIP(t *testing.T) { wantErr string }{ {"8.8.8.8", 24, pfx("8.8.8.8/24"), ""}, + {"2601:1234::", 64, pfx("2601:1234::/64"), ""}, {"8.8.8.8", 33, noaddr, `invalid CIDR size 33 for host "8.8.8.8"`}, {"8.8.8.8", -1, noaddr, `invalid CIDR size -1 for host "8.8.8.8"`}, + {"2601:1234::", 129, noaddr, `invalid CIDR size 129 for host "2601:1234::"`}, {"0.0.0.0", 24, noaddr, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`}, + {"::", 64, noaddr, `ports="::": to allow all IP addresses, use *:port, not [::]:port`}, {"*", 24, pfx("0.0.0.0/0"), ""}, - {"fe80::1", 128, pfx("255.255.255.255/32"), `ports="fe80::1": invalid IPv4 address`}, } for _, tt := range tests { got, err := parseIP(tt.host, tt.bits) @@ -253,38 +220,42 @@ func TestParseIP(t *testing.T) { } func BenchmarkFilter(b *testing.B) { - acl := newFilter(b.Logf) + tcp4Packet := raw4(packet.TCP, "8.1.1.1", "1.2.3.4", 999, 22, 0) + udp4Packet := raw4(packet.UDP, "8.1.1.1", "1.2.3.4", 999, 22, 0) + icmp4Packet := raw4(packet.ICMPv4, "8.1.1.1", "1.2.3.4", 0, 0, 0) - tcpPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0) - udpPacket := rawpacket(UDP, 0x08010101, 0x01020304, 999, 22, 0) - icmpPacket := rawpacket(ICMPv4, 0x08010101, 0x01020304, 0, 0, 0) - - tcpSynPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0) - // TCP filtering is trivial (Accept) for non-SYN packets. - tcpSynPacket[33] = packet.TCPSyn + tcp6Packet := raw6(packet.TCP, "::1", "2001::1", 999, 22, 0) + udp6Packet := raw6(packet.UDP, "::1", "2001::1", 999, 22, 0) + icmp6Packet := raw6(packet.ICMPv6, "::1", "2001::1", 0, 0, 0) benches := []struct { name string - in bool + dir direction packet []byte }{ // Non-SYN TCP and ICMP have similar code paths in and out. - {"icmp", true, icmpPacket}, - {"tcp", true, tcpPacket}, - {"tcp_syn_in", true, tcpSynPacket}, - {"tcp_syn_out", false, tcpSynPacket}, - {"udp_in", true, udpPacket}, - {"udp_out", false, udpPacket}, + {"icmp4", in, icmp4Packet}, + {"tcp4_syn_in", in, tcp4Packet}, + {"tcp4_syn_out", out, tcp4Packet}, + {"udp4_in", in, udp4Packet}, + {"udp4_out", out, udp4Packet}, + {"icmp6", in, icmp6Packet}, + {"tcp6_syn_in", in, tcp6Packet}, + {"tcp6_syn_out", out, tcp6Packet}, + {"udp6_in", in, udp6Packet}, + {"udp6_out", out, udp6Packet}, } for _, bench := range benches { b.Run(bench.name, func(b *testing.B) { + acl := newFilter(b.Logf) b.ReportAllocs() + b.ResetTimer() for i := 0; i < b.N; i++ { q := &packet.Parsed{} q.Decode(bench.packet) // This branch seems to have no measurable impact on performance. - if bench.in { + if bench.dir == in { acl.RunIn(q, 0) } else { acl.RunOut(q, 0) @@ -302,11 +273,11 @@ func TestPreFilter(t *testing.T) { }{ {"empty", Accept, []byte{}}, {"short", Drop, []byte("short")}, - {"junk", Drop, rawdefault(Unknown, 10)}, - {"fragment", Accept, rawdefault(Fragment, 40)}, - {"tcp", noVerdict, rawdefault(TCP, 200)}, - {"udp", noVerdict, rawdefault(UDP, 200)}, - {"icmp", noVerdict, rawdefault(ICMPv4, 200)}, + {"junk", Drop, raw4default(packet.Unknown, 10)}, + {"fragment", Accept, raw4default(packet.Fragment, 40)}, + {"tcp", noVerdict, raw4default(packet.TCP, 0)}, + {"udp", noVerdict, raw4default(packet.UDP, 0)}, + {"icmp", noVerdict, raw4default(packet.ICMPv4, 0)}, } f := NewAllowNone(t.Logf) for _, testPacket := range packets { @@ -319,90 +290,6 @@ func TestPreFilter(t *testing.T) { } } -func parsed(proto packet.IPProto, src, dst packet.IP4, sport, dport uint16) packet.Parsed { - return packet.Parsed{ - IPProto: proto, - SrcIP4: src, - DstIP4: dst, - SrcPort: sport, - DstPort: dport, - TCPFlags: packet.TCPSyn, - } -} - -// rawpacket generates a packet with given source and destination ports and IPs -// and resizes the header to trimLength if it is nonzero. -func rawpacket(proto packet.IPProto, src, dst packet.IP4, sport, dport uint16, trimLength int) []byte { - var headerLength int - - switch proto { - case ICMPv4: - headerLength = 24 - case TCP: - headerLength = 40 - case UDP: - headerLength = 28 - default: - headerLength = 24 - } - if trimLength > headerLength { - headerLength = trimLength - } - if trimLength == 0 { - trimLength = headerLength - } - - bin := binary.BigEndian - hdr := make([]byte, headerLength) - hdr[0] = 0x45 - bin.PutUint16(hdr[2:4], uint16(trimLength)) - hdr[8] = 64 - bin.PutUint32(hdr[12:16], uint32(src)) - bin.PutUint32(hdr[16:20], uint32(dst)) - // ports - bin.PutUint16(hdr[20:22], sport) - bin.PutUint16(hdr[22:24], dport) - - switch proto { - case ICMPv4: - hdr[9] = 1 - case TCP: - hdr[9] = 6 - case UDP: - hdr[9] = 17 - case Fragment: - hdr[9] = 6 - // flags + fragOff - bin.PutUint16(hdr[6:8], (1<<13)|1234) - case Unknown: - default: - panic("unknown protocol") - } - - // Trim the header if requested - hdr = hdr[:trimLength] - - return hdr -} - -// rawdefault calls rawpacket with default ports and IPs. -func rawdefault(proto packet.IPProto, trimLength int) []byte { - ip := packet.IP4(0x08080808) // 8.8.8.8 - port := uint16(53) - return rawpacket(proto, ip, ip, port, port, trimLength) -} - -func parseHexPkt(t *testing.T, h string) *packet.Parsed { - t.Helper() - b, err := hex.DecodeString(strings.ReplaceAll(h, " ", "")) - if err != nil { - t.Fatalf("failed to read hex %q: %v", h, err) - } - p := new(packet.Parsed) - p.Decode(b) - return p -} - func TestOmitDropLogging(t *testing.T) { tests := []struct { name string @@ -469,3 +356,198 @@ func TestOmitDropLogging(t *testing.T) { }) } } + +func mustIP(s string) netaddr.IP { + ip, err := netaddr.ParseIP(s) + if err != nil { + panic(err) + } + return ip +} + +func parsed(proto packet.IPProto, src, dst string, sport, dport uint16) packet.Parsed { + sip, dip := mustIP(src), mustIP(dst) + + var ret packet.Parsed + ret.Decode(dummyPacket) + ret.IPProto = proto + ret.SrcPort = sport + ret.DstPort = dport + ret.TCPFlags = packet.TCPSyn + + if sip.Is4() { + ret.IPVersion = 4 + ret.SrcIP4 = packet.IP4FromNetaddr(sip) + ret.DstIP4 = packet.IP4FromNetaddr(dip) + } else { + ret.IPVersion = 6 + ret.SrcIP6 = packet.IP6FromNetaddr(sip) + ret.DstIP6 = packet.IP6FromNetaddr(dip) + } + + return ret +} + +func raw6(proto packet.IPProto, src, dst string, sport, dport uint16, trimLen int) []byte { + u := packet.UDP6Header{ + IP6Header: packet.IP6Header{ + SrcIP: packet.IP6FromNetaddr(mustIP(src)), + DstIP: packet.IP6FromNetaddr(mustIP(dst)), + }, + SrcPort: sport, + DstPort: dport, + } + + payload := make([]byte, 12) + // Set the right bit to look like a TCP SYN, if the packet ends up interpreted as TCP + payload[5] = packet.TCPSyn + + b := packet.Generate(&u, payload) // payload large enough to possibly be TCP + + // UDP marshaling clobbers IPProto, so override it here. + u.IP6Header.IPProto = proto + if err := u.IP6Header.Marshal(b); err != nil { + panic(err) + } + + if trimLen > 0 { + return b[:trimLen] + } else { + return b + } +} + +func raw4(proto packet.IPProto, src, dst string, sport, dport uint16, trimLength int) []byte { + u := packet.UDP4Header{ + IP4Header: packet.IP4Header{ + SrcIP: packet.IP4FromNetaddr(mustIP(src)), + DstIP: packet.IP4FromNetaddr(mustIP(dst)), + }, + SrcPort: sport, + DstPort: dport, + } + + payload := make([]byte, 12) + // Set the right bit to look like a TCP SYN, if the packet ends up interpreted as TCP + payload[5] = packet.TCPSyn + + b := packet.Generate(&u, payload) // payload large enough to possibly be TCP + + // UDP marshaling clobbers IPProto, so override it here. + switch proto { + case packet.Unknown, packet.Fragment: + default: + u.IP4Header.IPProto = proto + } + if err := u.IP4Header.Marshal(b); err != nil { + panic(err) + } + + if proto == packet.Fragment { + // Set some fragment offset. This makes the IP + // checksum wrong, but we don't validate the checksum + // when parsing. + b[7] = 255 + } + + if trimLength > 0 { + return b[:trimLength] + } else { + return b + } +} + +func raw4default(proto packet.IPProto, trimLength int) []byte { + return raw4(proto, "8.8.8.8", "8.8.8.8", 53, 53, trimLength) +} + +func parseHexPkt(t *testing.T, h string) *packet.Parsed { + t.Helper() + b, err := hex.DecodeString(strings.ReplaceAll(h, " ", "")) + if err != nil { + t.Fatalf("failed to read hex %q: %v", h, err) + } + p := new(packet.Parsed) + p.Decode(b) + return p +} + +func mustIP4(s string) packet.IP4 { + ip, err := netaddr.ParseIP(s) + if err != nil { + panic(err) + } + return packet.IP4FromNetaddr(ip) +} + +func pfx(s string) netaddr.IPPrefix { + pfx, err := netaddr.ParseIPPrefix(s) + if err != nil { + panic(err) + } + return pfx +} + +func nets(nets ...string) (ret []netaddr.IPPrefix) { + for _, s := range nets { + if i := strings.IndexByte(s, '/'); i == -1 { + ip, err := netaddr.ParseIP(s) + if err != nil { + panic(err) + } + bits := uint8(32) + if ip.Is6() { + bits = 128 + } + ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits}) + } else { + pfx, err := netaddr.ParseIPPrefix(s) + if err != nil { + panic(err) + } + ret = append(ret, pfx) + } + } + return ret +} + +func ports(s string) PortRange { + if s == "*" { + return PortRange{First: 0, Last: 65535} + } + + var fs, ls string + i := strings.IndexByte(s, '-') + if i == -1 { + fs = s + ls = fs + } else { + fs = s[:i] + ls = s[i+1:] + } + first, err := strconv.ParseInt(fs, 10, 16) + if err != nil { + panic(fmt.Sprintf("invalid NetPortRange %q", s)) + } + last, err := strconv.ParseInt(ls, 10, 16) + if err != nil { + panic(fmt.Sprintf("invalid NetPortRange %q", s)) + } + return PortRange{uint16(first), uint16(last)} +} + +func netports(netPorts ...string) (ret []NetPortRange) { + for _, s := range netPorts { + i := strings.LastIndexByte(s, ':') + if i == -1 { + panic(fmt.Sprintf("invalid NetPortRange %q", s)) + } + + npr := NetPortRange{ + Net: nets(s[:i])[0], + Ports: ports(s[i+1:]), + } + ret = append(ret, npr) + } + return ret +} diff --git a/wgengine/filter/match4.go b/wgengine/filter/match4.go index d239b497f..6e301ae80 100644 --- a/wgengine/filter/match4.go +++ b/wgengine/filter/match4.go @@ -66,8 +66,8 @@ func (npr npr4) String() string { } type match4 struct { - dsts []npr4 srcs []net4 + dsts []npr4 } type matches4 []match4 diff --git a/wgengine/filter/match6.go b/wgengine/filter/match6.go new file mode 100644 index 000000000..a5c182bf4 --- /dev/null +++ b/wgengine/filter/match6.go @@ -0,0 +1,153 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package filter + +import ( + "fmt" + "strings" + + "inet.af/netaddr" + "tailscale.com/net/packet" +) + +type net6 struct { + ip packet.IP6 + bits uint8 +} + +func net6FromIPPrefix(pfx netaddr.IPPrefix) net6 { + if !pfx.IP.Is6() { + panic("net6FromIPPrefix given non-ipv6 prefix") + } + return net6{ + ip: packet.IP6FromNetaddr(pfx.IP), + bits: pfx.Bits, + } +} + +func nets6FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net6) { + for _, pfx := range pfxs { + if pfx.IP.Is6() { + ret = append(ret, net6FromIPPrefix(pfx)) + } + } + return ret +} + +func (n net6) Contains(ip packet.IP6) bool { + // Implementation stolen from inet.af/netaddr + bits := n.bits + for i := 0; bits > 0 && i < len(n.ip); i++ { + m := uint8(255) + if bits < 8 { + zeros := 8 - bits + m = m >> zeros << zeros + } + if n.ip[i]&m != ip[i]&m { + return false + } + if bits < 8 { + break + } + bits -= 8 + } + return true +} + +func (n net6) String() string { + switch n.bits { + case 128: + return n.ip.String() + case 0: + return "*" + default: + return fmt.Sprintf("%s/%d", n.ip, n.bits) + } +} + +type npr6 struct { + net net6 + ports PortRange +} + +func (npr npr6) String() string { + return fmt.Sprintf("%s:%s", npr.net, npr.ports) +} + +type match6 struct { + srcs []net6 + dsts []npr6 +} + +type matches6 []match6 + +func (ms matches6) String() string { + var b strings.Builder + for _, m := range ms { + fmt.Fprintf(&b, "%s => %s\n", m.srcs, m.dsts) + } + return b.String() +} + +func newMatches6(ms []Match) (ret matches6) { + for _, m := range ms { + var m6 match6 + for _, src := range m.Srcs { + if src.IP.Is6() { + m6.srcs = append(m6.srcs, net6FromIPPrefix(src)) + } + } + for _, dst := range m.Dsts { + if dst.Net.IP.Is6() { + m6.dsts = append(m6.dsts, npr6{net6FromIPPrefix(dst.Net), dst.Ports}) + } + } + if len(m6.srcs) > 0 && len(m6.dsts) > 0 { + ret = append(ret, m6) + } + } + return ret +} + +func (ms matches6) match(q *packet.Parsed) bool { + for _, m := range ms { + if !ip6InList(q.SrcIP6, m.srcs) { + continue + } + for _, dst := range m.dsts { + if !dst.net.Contains(q.DstIP6) { + continue + } + if !dst.ports.contains(q.DstPort) { + continue + } + return true + } + } + return false +} + +func (ms matches6) matchIPsOnly(q *packet.Parsed) bool { + for _, m := range ms { + if !ip6InList(q.SrcIP6, m.srcs) { + continue + } + for _, dst := range m.dsts { + if dst.net.Contains(q.DstIP6) { + return true + } + } + } + return false +} + +func ip6InList(ip packet.IP6, netlist []net6) bool { + for _, net := range netlist { + if net.Contains(ip) { + return true + } + } + return false +} diff --git a/wgengine/filter/tailcfg.go b/wgengine/filter/tailcfg.go index 02261cd8a..c498e0936 100644 --- a/wgengine/filter/tailcfg.go +++ b/wgengine/filter/tailcfg.go @@ -58,6 +58,11 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) { return mm, erracc } +var ( + zeroIP4 = netaddr.IPv4(0, 0, 0, 0) + zeroIP6 = netaddr.IPFrom16([16]byte{}) +) + func parseIP(host string, defaultBits int) (netaddr.IPPrefix, error) { if host == "*" { // User explicitly requested wildcard dst ip. @@ -69,15 +74,16 @@ func parseIP(host string, defaultBits int) (netaddr.IPPrefix, error) { if err != nil { return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IP address", host) } - if ip == netaddr.IPv4(0, 0, 0, 0) { + if ip == zeroIP4 { // For clarity, reject 0.0.0.0 as an input return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) } - if !ip.Is4() { - // TODO: ipv6 - return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IPv4 address", host) + if ip == zeroIP6 { + // For clarity, reject :: as an input + return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not [::]:port", host) } - if defaultBits < 0 || defaultBits > 32 { + + if defaultBits < 0 || (ip.Is4() && defaultBits > 32) || (ip.Is6() && defaultBits > 128) { return netaddr.IPPrefix{}, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host) } return netaddr.IPPrefix{