diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index b5ed82a54..082d8a0f5 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -300,9 +300,9 @@ var dummyPacket = []byte{ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, } -// CheckTCP determines whether TCP traffic from srcIP to dstIP:dstPort -// is allowed. -func (f *Filter) CheckTCP(srcIP, dstIP netip.Addr, dstPort uint16) Response { +// Check determines whether traffic from srcIP to dstIP:dstPort is allowed +// using protocol proto. +func (f *Filter) Check(srcIP, dstIP netip.Addr, dstPort uint16, proto ipproto.Proto) Response { pkt := &packet.Parsed{} pkt.Decode(dummyPacket) // initialize private fields switch { @@ -319,12 +319,20 @@ func (f *Filter) CheckTCP(srcIP, dstIP netip.Addr, dstPort uint16) Response { } pkt.Src = netip.AddrPortFrom(srcIP, 0) pkt.Dst = netip.AddrPortFrom(dstIP, dstPort) - pkt.IPProto = ipproto.TCP - pkt.TCPFlags = packet.TCPSyn + pkt.IPProto = proto + if proto == ipproto.TCP { + pkt.TCPFlags = packet.TCPSyn + } return f.RunIn(pkt, 0) } +// CheckTCP determines whether TCP traffic from srcIP to dstIP:dstPort +// is allowed. +func (f *Filter) CheckTCP(srcIP, dstIP netip.Addr, dstPort uint16) Response { + return f.Check(srcIP, dstIP, dstPort, ipproto.TCP) +} + // CapsWithValues appends to base the capabilities that srcIP has talking // to dstIP. func (f *Filter) CapsWithValues(srcIP, dstIP netip.Addr) tailcfg.PeerCapMap {