diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 7b253378e..878eaeab9 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -158,7 +158,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/util/singleflight from tailscale.com/net/dnscache tailscale.com/util/slicesx from tailscale.com/net/dnscache+ tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli - tailscale.com/util/vizerror from tailscale.com/types/ipproto + tailscale.com/util/vizerror from tailscale.com/types/ipproto+ 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+ W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate tailscale.com/version from tailscale.com/cmd/tailscale/cli+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index c0b94b113..10e417f61 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -363,7 +363,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/util/systemd from tailscale.com/control/controlclient+ tailscale.com/util/testenv from tailscale.com/ipn/ipnlocal+ tailscale.com/util/uniq from tailscale.com/wgengine/magicsock+ - tailscale.com/util/vizerror from tailscale.com/types/ipproto + tailscale.com/util/vizerror from tailscale.com/types/ipproto+ 💣 tailscale.com/util/winutil from tailscale.com/control/controlclient+ W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/util/osdiag+ W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal diff --git a/tailcfg/proto_port_range.go b/tailcfg/proto_port_range.go new file mode 100644 index 000000000..e19a3b578 --- /dev/null +++ b/tailcfg/proto_port_range.go @@ -0,0 +1,160 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailcfg + +import ( + "errors" + "fmt" + "strconv" + "strings" + + "tailscale.com/types/ipproto" + "tailscale.com/util/vizerror" +) + +// ProtoPortRange is used to encode "proto:port" format. +// The following formats are supported: +// +// "*" allows all TCP, UDP and ICMP traffic on all ports. +// "" allows all TCP, UDP and ICMP traffic on the specified ports. +// "proto:*" allows traffic of the specified proto on all ports. +// "proto:" allows traffic of the specified proto on the specified port. +// +// Ports are either a single port number or a range of ports (e.g. "80-90"). +// String named protocols support names that ipproto.Proto accepts. +type ProtoPortRange struct { + // Proto is the IP protocol number. + // If Proto is 0, it means TCP+UDP+ICMP(4+6). + Proto int + Ports PortRange +} + +func (ppr ProtoPortRange) String() string { + if ppr.Proto == 0 { + if ppr.Ports == PortRangeAny { + return "*" + } + } + var buf strings.Builder + if ppr.Proto != 0 { + // Proto.MarshalText is infallible. + text, _ := ipproto.Proto(ppr.Proto).MarshalText() + buf.Write(text) + buf.Write([]byte(":")) + } + pr := ppr.Ports + if pr.First == pr.Last { + fmt.Fprintf(&buf, "%d", pr.First) + } else if pr == PortRangeAny { + buf.WriteByte('*') + } else { + fmt.Fprintf(&buf, "%d-%d", pr.First, pr.Last) + } + return buf.String() +} + +// ParseProtoPortRanges parses a slice of IP port range fields. +func ParseProtoPortRanges(ips []string) ([]ProtoPortRange, error) { + var out []ProtoPortRange + for _, p := range ips { + ppr, err := parseProtoPortRange(p) + if err != nil { + return nil, err + } + out = append(out, *ppr) + } + return out, nil +} + +func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) { + if ipProtoPort == "" { + return nil, errors.New("empty string") + } + if ipProtoPort == "*" { + return &ProtoPortRange{Ports: PortRangeAny}, nil + } + if !strings.Contains(ipProtoPort, ":") { + ipProtoPort = "*:" + ipProtoPort + } + protoStr, portRange, err := parseHostPortRange(ipProtoPort) + if err != nil { + return nil, err + } + if protoStr == "" { + return nil, errors.New("empty protocol") + } + + ppr := &ProtoPortRange{ + Ports: portRange, + } + if protoStr == "*" { + return ppr, nil + } + var ipProto ipproto.Proto + if err := ipProto.UnmarshalText([]byte(protoStr)); err != nil { + return nil, err + } + ppr.Proto = int(ipProto) + return ppr, nil +} + +// parseHostPortRange parses hostport as HOST:PORTS where HOST is +// returned unchanged and PORTS is is either "*" or PORTLOW-PORTHIGH ranges. +func parseHostPortRange(hostport string) (host string, ports PortRange, err error) { + hostport = strings.ToLower(hostport) + colon := strings.LastIndexByte(hostport, ':') + if colon < 0 { + return "", ports, vizerror.New("hostport must contain a colon (\":\")") + } + host = hostport[:colon] + portlist := hostport[colon+1:] + + if strings.Contains(host, ",") { + return "", ports, vizerror.New("host cannot contain a comma (\",\")") + } + + if portlist == "*" { + // Special case: permit hostname:* as a port wildcard. + return host, PortRangeAny, nil + } + + if len(portlist) == 0 { + return "", ports, vizerror.Errorf("invalid port list: %#v", portlist) + } + + if strings.Count(portlist, "-") > 1 { + return "", ports, vizerror.Errorf("port range %#v: too many dashes(-)", portlist) + } + + firstStr, lastStr, isRange := strings.Cut(portlist, "-") + + var first, last uint64 + first, err = strconv.ParseUint(firstStr, 10, 16) + if err != nil { + return "", ports, vizerror.Errorf("port range %#v: invalid first integer", portlist) + } + + if isRange { + last, err = strconv.ParseUint(lastStr, 10, 16) + if err != nil { + return "", ports, vizerror.Errorf("port range %#v: invalid last integer", portlist) + } + } else { + last = first + } + + if first == 0 { + return "", ports, vizerror.Errorf("port range %#v: first port must be >0, or use '*' for wildcard", portlist) + } + + if first > last { + return "", ports, vizerror.Errorf("port range %#v: first port must be >= last port", portlist) + } + + return host, newPortRange(uint16(first), uint16(last)), nil +} + +func newPortRange(first, last uint16) PortRange { + return PortRange{First: first, Last: last} +} diff --git a/tailcfg/proto_port_range_test.go b/tailcfg/proto_port_range_test.go new file mode 100644 index 000000000..5729a4a33 --- /dev/null +++ b/tailcfg/proto_port_range_test.go @@ -0,0 +1,90 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailcfg + +import ( + "errors" + "testing" + + "tailscale.com/types/ipproto" +) + +func TestProtoPortRangeParsing(t *testing.T) { + pr := func(s, e uint16) PortRange { + return PortRange{First: s, Last: e} + } + tests := []struct { + in string + out ProtoPortRange + err error + }{ + {in: "tcp:80", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: pr(80, 80)}}, + {in: "80", out: ProtoPortRange{Ports: pr(80, 80)}}, + {in: "*", out: ProtoPortRange{Ports: PortRangeAny}}, + {in: "*:*", out: ProtoPortRange{Ports: PortRangeAny}}, + {in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}}, + { + in: "tcp:", + err: errors.New(`invalid port list: ""`), + }, + { + in: ":80", + err: errors.New(`empty protocol`), + }, + { + in: "", + err: errors.New(`empty string`), + }, + } + + for _, tc := range tests { + t.Run(tc.in, func(t *testing.T) { + ppr, err := parseProtoPortRange(tc.in) + if gotErr, wantErr := err != nil, tc.err != nil; gotErr != wantErr { + t.Fatalf("got err %v; want %v", err, tc.err) + } else if gotErr { + if err.Error() != tc.err.Error() { + t.Fatalf("got err %q; want %q", err, tc.err) + } + return + } + if *ppr != tc.out { + t.Fatalf("got %v; want %v", ppr, tc.out) + } + }) + } +} + +func TestProtoPortRangeString(t *testing.T) { + tests := []struct { + input ProtoPortRange + want string + }{ + {ProtoPortRange{}, "0"}, + + // Zero protocol. + {ProtoPortRange{Ports: PortRangeAny}, "*"}, + {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, + {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, + + // Non-zero unnamed protocol. + {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, + {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, + + // Non-zero named protocol. + {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, + {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, + {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, + {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, + {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, + {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, + {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, + {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, + } + for _, tc := range tests { + if got := tc.input.String(); got != tc.want { + t.Errorf("String for %v: got %q, want %q", tc.input, got, tc.want) + } + } +}