types/opt: support an explicit "unset" value for Bool

Updates #4843

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/5323/head
Brad Fitzpatrick 2 years ago committed by Brad Fitzpatrick
parent 3bb57504af
commit 8e821d7aa8

@ -10,9 +10,13 @@ import (
"strconv" "strconv"
) )
// Bool represents an optional boolean to be JSON-encoded. // Bool represents an optional boolean to be JSON-encoded. The string
// The string can be empty (for unknown or unspecified), or // is either "true", "false", or the enmpty string to mean unset.
// "true" or "false". //
// As a special case, the underlying string may also be the string
// "unset" as as a synonym for the empty string. This lets the
// explicit unset value be exchanged over an encoding/json "omitempty"
// field without it being dropped.
type Bool string type Bool string
func (b *Bool) Set(v bool) { func (b *Bool) Set(v bool) {
@ -22,11 +26,14 @@ func (b *Bool) Set(v bool) {
func (b *Bool) Clear() { *b = "" } func (b *Bool) Clear() { *b = "" }
func (b Bool) Get() (v bool, ok bool) { func (b Bool) Get() (v bool, ok bool) {
if b == "" { switch b {
return case "true":
return true, true
case "false":
return false, true
default:
return false, false
} }
v, err := strconv.ParseBool(string(b))
return v, err == nil
} }
// Scan implements database/sql.Scanner. // Scan implements database/sql.Scanner.
@ -74,7 +81,7 @@ func (b Bool) MarshalJSON() ([]byte, error) {
return trueBytes, nil return trueBytes, nil
case "false": case "false":
return falseBytes, nil return falseBytes, nil
case "": case "", "unset":
return nullBytes, nil return nullBytes, nil
} }
return nil, fmt.Errorf("invalid opt.Bool value %q", string(b)) return nil, fmt.Errorf("invalid opt.Bool value %q", string(b))
@ -94,7 +101,7 @@ func (b *Bool) UnmarshalJSON(j []byte) error {
return nil return nil
} }
if string(j) == "null" { if string(j) == "null" {
*b = "" *b = "unset"
return nil return nil
} }
return fmt.Errorf("invalid opt.Bool value %q", j) return fmt.Errorf("invalid opt.Bool value %q", j)

@ -12,9 +12,10 @@ import (
func TestBool(t *testing.T) { func TestBool(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
in any in any
want string // JSON want string // JSON
wantBack any
}{ }{
{ {
name: "null_for_unset", name: "null_for_unset",
@ -27,6 +28,15 @@ func TestBool(t *testing.T) {
False: "false", False: "false",
}, },
want: `{"True":true,"False":false,"Unset":null}`, want: `{"True":true,"False":false,"Unset":null}`,
wantBack: struct {
True Bool
False Bool
Unset Bool
}{
True: "true",
False: "false",
Unset: "unset",
},
}, },
{ {
name: "omitempty_unset", name: "omitempty_unset",
@ -40,6 +50,24 @@ func TestBool(t *testing.T) {
}, },
want: `{"True":true,"False":false}`, want: `{"True":true,"False":false}`,
}, },
{
name: "unset_marshals_as_null",
in: struct {
True Bool
False Bool
Foo Bool
}{
True: "true",
False: "false",
Foo: "unset",
},
want: `{"True":true,"False":false,"Foo":null}`,
wantBack: struct {
True Bool
False Bool
Foo Bool
}{"true", "false", "unset"},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -51,6 +79,10 @@ func TestBool(t *testing.T) {
t.Errorf("wrong JSON:\n got: %s\nwant: %s\n", j, tt.want) t.Errorf("wrong JSON:\n got: %s\nwant: %s\n", j, tt.want)
} }
wantBack := tt.in
if tt.wantBack != nil {
wantBack = tt.wantBack
}
// And back again: // And back again:
newVal := reflect.New(reflect.TypeOf(tt.in)) newVal := reflect.New(reflect.TypeOf(tt.in))
out := newVal.Interface() out := newVal.Interface()
@ -58,8 +90,8 @@ func TestBool(t *testing.T) {
t.Fatalf("Unmarshal %#q: %v", j, err) t.Fatalf("Unmarshal %#q: %v", j, err)
} }
got := newVal.Elem().Interface() got := newVal.Elem().Interface()
if !reflect.DeepEqual(tt.in, got) { if !reflect.DeepEqual(got, wantBack) {
t.Errorf("value mismatch\n got: %+v\nwant: %+v\n", got, tt.in) t.Errorf("value mismatch\n got: %+v\nwant: %+v\n", got, wantBack)
} }
}) })
} }
@ -79,11 +111,12 @@ func TestBoolEqualBool(t *testing.T) {
{"true", false, false}, {"true", false, false},
{"false", true, false}, {"false", true, false},
{"false", false, true}, {"false", false, true},
{"1", true, false}, // "1" is not true; only "true" is
{"True", true, false}, // "True" is not true; only "true" is
} }
for _, tt := range tests { for _, tt := range tests {
if got := tt.b.EqualBool(tt.v); got != tt.want { if got := tt.b.EqualBool(tt.v); got != tt.want {
t.Errorf("(%q).EqualBool(%v) = %v; want %v", string(tt.b), tt.v, got, tt.want) t.Errorf("(%q).EqualBool(%v) = %v; want %v", string(tt.b), tt.v, got, tt.want)
} }
} }
} }

Loading…
Cancel
Save