diff --git a/cmd/k8s-operator/proxy_test.go b/cmd/k8s-operator/proxy_test.go index 741c1fe85..9940e80c9 100644 --- a/cmd/k8s-operator/proxy_test.go +++ b/cmd/k8s-operator/proxy_test.go @@ -45,15 +45,15 @@ func TestImpersonationHeaders(t *testing.T) { emailish: "foo@example.com", capMap: tailcfg.PeerCapMap{ capabilityName: { - []byte(`{"impersonate":{"groups":["group1","group2"]}}`), - []byte(`{"impersonate":{"groups":["group1","group3"]}}`), // One group is duplicated. - []byte(`{"impersonate":{"groups":["group4"]}}`), - []byte(`{"impersonate":{"groups":["group2"]}}`), // duplicate + tailcfg.RawMessage(`{"impersonate":{"groups":["group1","group2"]}}`), + tailcfg.RawMessage(`{"impersonate":{"groups":["group1","group3"]}}`), // One group is duplicated. + tailcfg.RawMessage(`{"impersonate":{"groups":["group4"]}}`), + tailcfg.RawMessage(`{"impersonate":{"groups":["group2"]}}`), // duplicate // These should be ignored, but should parse correctly. - []byte(`{}`), - []byte(`{"impersonate":{}}`), - []byte(`{"impersonate":{"groups":[]}}`), + tailcfg.RawMessage(`{}`), + tailcfg.RawMessage(`{"impersonate":{}}`), + tailcfg.RawMessage(`{"impersonate":{"groups":[]}}`), }, }, wantHeaders: http.Header{ @@ -67,7 +67,7 @@ func TestImpersonationHeaders(t *testing.T) { tags: []string{"tag:foo", "tag:bar"}, capMap: tailcfg.PeerCapMap{ capabilityName: { - []byte(`{"impersonate":{"groups":["group1"]}}`), + tailcfg.RawMessage(`{"impersonate":{"groups":["group1"]}}`), }, }, wantHeaders: http.Header{ @@ -81,7 +81,7 @@ func TestImpersonationHeaders(t *testing.T) { tags: []string{"tag:foo", "tag:bar"}, capMap: tailcfg.PeerCapMap{ capabilityName: { - []byte(`[]`), + tailcfg.RawMessage(`[]`), }, }, wantHeaders: http.Header{}, diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 89fc678f9..50e5d568d 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -216,6 +216,31 @@ func (emptyStructJSONSlice) MarshalJSON() ([]byte, error) { func (emptyStructJSONSlice) UnmarshalJSON([]byte) error { return nil } +// RawMessage is a raw encoded JSON value. It implements Marshaler and +// Unmarshaler and can be used to delay JSON decoding or precompute a JSON +// encoding. +// +// It is like json.RawMessage but is a string instead of a []byte to better +// portray immutable data. +type RawMessage string + +// MarshalJSON returns m as the JSON encoding of m. +func (m RawMessage) MarshalJSON() ([]byte, error) { + if m == "" { + return []byte("null"), nil + } + return []byte(m), nil +} + +// UnmarshalJSON sets *m to a copy of data. +func (m *RawMessage) UnmarshalJSON(data []byte) error { + if m == nil { + return errors.New("RawMessage: UnmarshalJSON on nil pointer") + } + *m = RawMessage(data) + return nil +} + type Node struct { ID NodeID StableID StableNodeID @@ -1256,7 +1281,7 @@ const ( // // The values are opaque to Tailscale, but are passed through from the ACLs to // the application via the WhoIs API. -type PeerCapMap map[PeerCapability][]json.RawMessage +type PeerCapMap map[PeerCapability][]RawMessage // UnmarshalCapJSON unmarshals each JSON value in cm[cap] as T. // If cap does not exist in cm, it returns (nil, nil). @@ -1269,7 +1294,7 @@ func UnmarshalCapJSON[T any](cm PeerCapMap, cap PeerCapability) ([]T, error) { out := make([]T, 0, len(vals)) for _, v := range vals { var t T - if err := json.Unmarshal(v, &t); err != nil { + if err := json.Unmarshal([]byte(v), &t); err != nil { return nil, err } out = append(out, t) diff --git a/tailcfg/tailcfg_test.go b/tailcfg/tailcfg_test.go index 5ed99c9cd..db8a27299 100644 --- a/tailcfg/tailcfg_test.go +++ b/tailcfg/tailcfg_test.go @@ -726,3 +726,82 @@ func TestUnmarshalHealth(t *testing.T) { } } } + +func TestRawMessage(t *testing.T) { + // Create a few types of json.RawMessages and then marshal them back and + // forth to make sure they round-trip. + + type rule struct { + Ports []int `json:",omitempty"` + } + tests := []struct { + name string + val map[string][]rule + wire map[string][]RawMessage + }{ + { + name: "nil", + val: nil, + wire: nil, + }, + { + name: "empty", + val: map[string][]rule{}, + wire: map[string][]RawMessage{}, + }, + { + name: "one", + val: map[string][]rule{ + "foo": {{Ports: []int{1, 2, 3}}}, + }, + wire: map[string][]RawMessage{ + "foo": { + `{"Ports":[1,2,3]}`, + }, + }, + }, + { + name: "many", + val: map[string][]rule{ + "foo": {{Ports: []int{1, 2, 3}}}, + "bar": {{Ports: []int{4, 5, 6}}, {Ports: []int{7, 8, 9}}}, + "baz": nil, + "abc": {}, + "def": {{}}, + }, + wire: map[string][]RawMessage{ + "foo": { + `{"Ports":[1,2,3]}`, + }, + "bar": { + `{"Ports":[4,5,6]}`, + `{"Ports":[7,8,9]}`, + }, + "baz": nil, + "abc": {}, + "def": {"{}"}, + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + j := must.Get(json.Marshal(tc.val)) + var gotWire map[string][]RawMessage + if err := json.Unmarshal(j, &gotWire); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if !reflect.DeepEqual(gotWire, tc.wire) { + t.Errorf("got %#v; want %#v", gotWire, tc.wire) + } + + j = must.Get(json.Marshal(tc.wire)) + var gotVal map[string][]rule + if err := json.Unmarshal(j, &gotVal); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if !reflect.DeepEqual(gotVal, tc.val) { + t.Errorf("got %#v; want %#v", gotVal, tc.val) + } + }) + } +} diff --git a/wgengine/filter/filter_clone.go b/wgengine/filter/filter_clone.go index e34ae7685..97366d83c 100644 --- a/wgengine/filter/filter_clone.go +++ b/wgengine/filter/filter_clone.go @@ -6,7 +6,6 @@ package filter import ( - "encoding/json" "net/netip" "tailscale.com/tailcfg" @@ -49,12 +48,7 @@ func (src *CapMatch) Clone() *CapMatch { } dst := new(CapMatch) *dst = *src - if src.Values != nil { - dst.Values = make([]json.RawMessage, len(src.Values)) - for i := range dst.Values { - dst.Values[i] = append(src.Values[i][:0:0], src.Values[i]...) - } - } + dst.Values = append(src.Values[:0:0], src.Values...) return dst } @@ -62,5 +56,5 @@ func (src *CapMatch) Clone() *CapMatch { var _CapMatchCloneNeedsRegeneration = CapMatch(struct { Dst netip.Prefix Cap tailcfg.PeerCapability - Values []json.RawMessage + Values []tailcfg.RawMessage }{}) diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index 4475c2332..2d6f71d19 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -4,7 +4,6 @@ package filter import ( - "encoding/json" "fmt" "net/netip" "strings" @@ -60,7 +59,7 @@ type CapMatch struct { // Values are the raw JSON values of the capability. // See tailcfg.PeerCapability and tailcfg.PeerCapMap for details. - Values []json.RawMessage + Values []tailcfg.RawMessage } // Match matches packets from any IP address in Srcs to any ip:port in