diff --git a/net/packet/ip6.go b/net/packet/ip6.go index a441646b5..407a93216 100644 --- a/net/packet/ip6.go +++ b/net/packet/ip6.go @@ -12,19 +12,25 @@ import ( ) // IP6 is an IPv6 address. -type IP6 [16]byte // TODO: maybe 2x uint64 would be faster for the type of ops we do? +type IP6 struct { + Hi, Lo uint64 +} // IP6FromNetaddr converts a netaddr.IP to an IP6. Panics if !ip.Is6. func IP6FromNetaddr(ip netaddr.IP) IP6 { if !ip.Is6() { panic(fmt.Sprintf("IP6FromNetaddr called with non-v6 addr %q", ip)) } - return IP6(ip.As16()) + b := ip.As16() + return IP6{binary.BigEndian.Uint64(b[:8]), binary.BigEndian.Uint64(b[8:])} } // Netaddr converts ip to a netaddr.IP. func (ip IP6) Netaddr() netaddr.IP { - return netaddr.IPFrom16(ip) + var b [16]byte + binary.BigEndian.PutUint64(b[:8], ip.Hi) + binary.BigEndian.PutUint64(b[8:], ip.Lo) + return netaddr.IPFrom16(b) } func (ip IP6) String() string { @@ -32,11 +38,11 @@ func (ip IP6) String() string { } func (ip IP6) IsMulticast() bool { - return ip[0] == 0xFF + return (ip.Hi >> 56) == 0xFF } func (ip IP6) IsLinkLocalUnicast() bool { - return ip[0] == 0xFE && ip[1] == 0x80 + return (ip.Hi >> 48) == 0xFE80 } // ip6HeaderLength is the length of an IPv6 header with no IP options. @@ -69,8 +75,10 @@ func (h IP6Header) Marshal(buf []byte) error { binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length buf[6] = uint8(h.IPProto) // Inner protocol buf[7] = 64 // TTL - copy(buf[8:24], h.SrcIP[:]) - copy(buf[24:40], h.DstIP[:]) + binary.BigEndian.PutUint64(buf[8:16], h.SrcIP.Hi) + binary.BigEndian.PutUint64(buf[16:24], h.SrcIP.Lo) + binary.BigEndian.PutUint64(buf[24:32], h.DstIP.Hi) + binary.BigEndian.PutUint64(buf[32:40], h.DstIP.Lo) return nil } @@ -92,8 +100,10 @@ func (h IP6Header) marshalPseudo(buf []byte) error { return errLargePacket } - copy(buf[:16], h.SrcIP[:]) - copy(buf[16:32], h.DstIP[:]) + binary.BigEndian.PutUint64(buf[:8], h.SrcIP.Hi) + binary.BigEndian.PutUint64(buf[8:16], h.SrcIP.Lo) + binary.BigEndian.PutUint64(buf[16:24], h.DstIP.Hi) + binary.BigEndian.PutUint64(buf[24:32], h.DstIP.Lo) binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len())) buf[36] = 0 buf[37] = 0 diff --git a/net/packet/packet.go b/net/packet/packet.go index 7b2c12111..2dc42c937 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -248,8 +248,10 @@ func (q *Parsed) decode6(b []byte) { return } - copy(q.SrcIP6[:], b[8:24]) - copy(q.DstIP6[:], b[24:40]) + q.SrcIP6.Hi = binary.BigEndian.Uint64(b[8:16]) + q.SrcIP6.Lo = binary.BigEndian.Uint64(b[16:24]) + q.DstIP6.Hi = binary.BigEndian.Uint64(b[24:32]) + q.DstIP6.Lo = binary.BigEndian.Uint64(b[32:40]) // We don't support any IPv6 extension headers. Don't try to // be clever. Therefore, the IP subprotocol always starts at diff --git a/wgengine/filter/match6.go b/wgengine/filter/match6.go index a5c182bf4..f00011bed 100644 --- a/wgengine/filter/match6.go +++ b/wgengine/filter/match6.go @@ -6,6 +6,7 @@ package filter import ( "fmt" + "math/bits" "strings" "inet.af/netaddr" @@ -14,16 +15,24 @@ import ( type net6 struct { ip packet.IP6 - bits uint8 + mask packet.IP6 } func net6FromIPPrefix(pfx netaddr.IPPrefix) net6 { if !pfx.IP.Is6() { panic("net6FromIPPrefix given non-ipv6 prefix") } + var mask packet.IP6 + if pfx.Bits > 64 { + mask.Hi = ^uint64(0) + mask.Lo = (^uint64(0) << (128 - pfx.Bits)) + } else { + mask.Hi = (^uint64(0) << (64 - pfx.Bits)) + } + return net6{ ip: packet.IP6FromNetaddr(pfx.IP), - bits: pfx.Bits, + mask: mask, } } @@ -37,33 +46,22 @@ func nets6FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net6) { } func (n net6) Contains(ip packet.IP6) bool { - // Implementation stolen from inet.af/netaddr - bits := n.bits - for i := 0; bits > 0 && i < len(n.ip); i++ { - m := uint8(255) - if bits < 8 { - zeros := 8 - bits - m = m >> zeros << zeros - } - if n.ip[i]&m != ip[i]&m { - return false - } - if bits < 8 { - break - } - bits -= 8 - } - return true + return ((n.ip.Hi&n.mask.Hi) == (ip.Hi&n.mask.Hi) && + (n.ip.Lo&n.mask.Lo) == (ip.Lo&n.mask.Lo)) +} + +func (n net6) Bits() int { + return 128 - bits.TrailingZeros64(n.mask.Hi) - bits.TrailingZeros64(n.mask.Lo) } func (n net6) String() string { - switch n.bits { + switch n.Bits() { case 128: return n.ip.String() case 0: return "*" default: - return fmt.Sprintf("%s/%d", n.ip, n.bits) + return fmt.Sprintf("%s/%d", n.ip, n.Bits()) } }