diff --git a/tstime/rate/value.go b/tstime/rate/value.go index 55206f267..610f06bbd 100644 --- a/tstime/rate/value.go +++ b/tstime/rate/value.go @@ -4,6 +4,7 @@ package rate import ( + "encoding/json" "fmt" "math" "sync" @@ -181,3 +182,41 @@ func (r *Value) rateNow(now mono.Time) float64 { func (r *Value) normalizedIntegral() float64 { return r.halfLife() / math.Ln2 } + +type jsonValue struct { + // TODO: Use v2 "encoding/json" for native time.Duration formatting. + HalfLife string `json:"halfLife,omitempty,omitzero"` + Value float64 `json:"value,omitempty,omitzero"` + Updated mono.Time `json:"updated,omitempty,omitzero"` +} + +func (r *Value) MarshalJSON() ([]byte, error) { + if r == nil { + return []byte("null"), nil + } + r.mu.Lock() + defer r.mu.Unlock() + v := jsonValue{Value: r.value, Updated: r.updated} + if r.HalfLife > 0 { + v.HalfLife = r.HalfLife.String() + } + return json.Marshal(v) +} + +func (r *Value) UnmarshalJSON(b []byte) error { + var v jsonValue + if err := json.Unmarshal(b, &v); err != nil { + return err + } + halfLife, err := time.ParseDuration(v.HalfLife) + if err != nil && v.HalfLife != "" { + return fmt.Errorf("invalid halfLife: %w", err) + } + + r.mu.Lock() + defer r.mu.Unlock() + r.HalfLife = halfLife + r.value = v.Value + r.updated = v.Updated + return nil +} diff --git a/tstime/rate/value_test.go b/tstime/rate/value_test.go index 4a776b9d2..dd9a803b1 100644 --- a/tstime/rate/value_test.go +++ b/tstime/rate/value_test.go @@ -6,12 +6,14 @@ package rate import ( "flag" "math" + "reflect" "testing" "time" qt "github.com/frankban/quicktest" "github.com/google/go-cmp/cmp/cmpopts" "tailscale.com/tstime/mono" + "tailscale.com/util/must" ) const ( @@ -234,3 +236,26 @@ func BenchmarkValue(b *testing.B) { v.Add(1) } } + +func TestValueMarshal(t *testing.T) { + now := mono.Now() + tests := []struct { + val *Value + str string + }{ + {val: &Value{}, str: `{}`}, + {val: &Value{HalfLife: 5 * time.Minute}, str: `{"halfLife":"` + (5 * time.Minute).String() + `"}`}, + {val: &Value{value: 12345, updated: now}, str: `{"value":12345,"updated":` + string(must.Get(now.MarshalJSON())) + `}`}, + } + for _, tt := range tests { + str := string(must.Get(tt.val.MarshalJSON())) + if str != tt.str { + t.Errorf("string mismatch: got %v, want %v", str, tt.str) + } + var val Value + must.Do(val.UnmarshalJSON([]byte(str))) + if !reflect.DeepEqual(&val, tt.val) { + t.Errorf("value mismatch: %+v, want %+v", &val, tt.val) + } + } +}