diff --git a/tailcfg/proto_port_range.go b/tailcfg/proto_port_range.go index e19a3b578..f65c58804 100644 --- a/tailcfg/proto_port_range.go +++ b/tailcfg/proto_port_range.go @@ -13,6 +13,11 @@ import ( "tailscale.com/util/vizerror" ) +var ( + errEmptyProtocol = errors.New("empty protocol") + errEmptyString = errors.New("empty string") +) + // ProtoPortRange is used to encode "proto:port" format. // The following formats are supported: // @@ -30,6 +35,28 @@ type ProtoPortRange struct { Ports PortRange } +// UnmarshalText implements the encoding.TextUnmarshaler interface. See +// ProtoPortRange for the format. +func (ppr *ProtoPortRange) UnmarshalText(text []byte) error { + ppr2, err := parseProtoPortRange(string(text)) + if err != nil { + return err + } + *ppr = *ppr2 + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface. See +// ProtoPortRange for the format. +func (ppr *ProtoPortRange) MarshalText() ([]byte, error) { + if ppr.Proto == 0 && ppr.Ports == (PortRange{}) { + return []byte{}, nil + } + return []byte(ppr.String()), nil +} + +// String implements the stringer interface. See ProtoPortRange for the +// format. func (ppr ProtoPortRange) String() string { if ppr.Proto == 0 { if ppr.Ports == PortRangeAny { @@ -69,7 +96,7 @@ func ParseProtoPortRanges(ips []string) ([]ProtoPortRange, error) { func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) { if ipProtoPort == "" { - return nil, errors.New("empty string") + return nil, errEmptyString } if ipProtoPort == "*" { return &ProtoPortRange{Ports: PortRangeAny}, nil @@ -82,7 +109,7 @@ func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) { return nil, err } if protoStr == "" { - return nil, errors.New("empty protocol") + return nil, errEmptyProtocol } ppr := &ProtoPortRange{ diff --git a/tailcfg/proto_port_range_test.go b/tailcfg/proto_port_range_test.go index 5729a4a33..59ccc9be4 100644 --- a/tailcfg/proto_port_range_test.go +++ b/tailcfg/proto_port_range_test.go @@ -4,12 +4,15 @@ package tailcfg import ( - "errors" + "encoding" "testing" "tailscale.com/types/ipproto" + "tailscale.com/util/vizerror" ) +var _ encoding.TextUnmarshaler = (*ProtoPortRange)(nil) + func TestProtoPortRangeParsing(t *testing.T) { pr := func(s, e uint16) PortRange { return PortRange{First: s, Last: e} @@ -26,30 +29,28 @@ func TestProtoPortRangeParsing(t *testing.T) { {in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}}, { in: "tcp:", - err: errors.New(`invalid port list: ""`), + err: vizerror.Errorf("invalid port list: %#v", ""), }, { in: ":80", - err: errors.New(`empty protocol`), + err: errEmptyProtocol, }, { in: "", - err: errors.New(`empty string`), + err: errEmptyString, }, } for _, tc := range tests { t.Run(tc.in, func(t *testing.T) { - ppr, err := parseProtoPortRange(tc.in) - if gotErr, wantErr := err != nil, tc.err != nil; gotErr != wantErr { - t.Fatalf("got err %v; want %v", err, tc.err) - } else if gotErr { - if err.Error() != tc.err.Error() { - t.Fatalf("got err %q; want %q", err, tc.err) + var ppr ProtoPortRange + err := ppr.UnmarshalText([]byte(tc.in)) + if tc.err != err { + if err == nil || tc.err.Error() != err.Error() { + t.Fatalf("want err=%v, got %v", tc.err, err) } - return } - if *ppr != tc.out { + if ppr != tc.out { t.Fatalf("got %v; want %v", ppr, tc.out) } }) @@ -88,3 +89,43 @@ func TestProtoPortRangeString(t *testing.T) { } } } + +func TestProtoPortRangeRoundTrip(t *testing.T) { + tests := []struct { + input ProtoPortRange + text string + }{ + {ProtoPortRange{Ports: PortRangeAny}, "*"}, + {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, + {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, + {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, + {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, + {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, + {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, + {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, + {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, + {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, + {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, + {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, + {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, + } + + for _, tc := range tests { + out, err := tc.input.MarshalText() + if err != nil { + t.Errorf("MarshalText for %v: %v", tc.input, err) + continue + } + if got := string(out); got != tc.text { + t.Errorf("MarshalText for %#v: got %q, want %q", tc.input, got, tc.text) + } + var ppr ProtoPortRange + if err := ppr.UnmarshalText(out); err != nil { + t.Errorf("UnmarshalText for %q: err=%v", tc.text, err) + continue + } + if ppr != tc.input { + t.Errorf("round trip error for %q: got %v, want %#v", tc.text, ppr, tc.input) + } + } +}