diff --git a/types/opt/bool.go b/types/opt/bool.go index 4ef991c0f..86e32aea2 100644 --- a/types/opt/bool.go +++ b/types/opt/bool.go @@ -10,9 +10,13 @@ import ( "strconv" ) -// Bool represents an optional boolean to be JSON-encoded. -// The string can be empty (for unknown or unspecified), or -// "true" or "false". +// Bool represents an optional boolean to be JSON-encoded. The string +// is either "true", "false", or the enmpty string to mean unset. +// +// As a special case, the underlying string may also be the string +// "unset" as as a synonym for the empty string. This lets the +// explicit unset value be exchanged over an encoding/json "omitempty" +// field without it being dropped. type Bool string func (b *Bool) Set(v bool) { @@ -22,11 +26,14 @@ func (b *Bool) Set(v bool) { func (b *Bool) Clear() { *b = "" } func (b Bool) Get() (v bool, ok bool) { - if b == "" { - return + switch b { + case "true": + return true, true + case "false": + return false, true + default: + return false, false } - v, err := strconv.ParseBool(string(b)) - return v, err == nil } // Scan implements database/sql.Scanner. @@ -74,7 +81,7 @@ func (b Bool) MarshalJSON() ([]byte, error) { return trueBytes, nil case "false": return falseBytes, nil - case "": + case "", "unset": return nullBytes, nil } return nil, fmt.Errorf("invalid opt.Bool value %q", string(b)) @@ -94,7 +101,7 @@ func (b *Bool) UnmarshalJSON(j []byte) error { return nil } if string(j) == "null" { - *b = "" + *b = "unset" return nil } return fmt.Errorf("invalid opt.Bool value %q", j) diff --git a/types/opt/bool_test.go b/types/opt/bool_test.go index 9b3424ee5..d6ab30daf 100644 --- a/types/opt/bool_test.go +++ b/types/opt/bool_test.go @@ -12,9 +12,10 @@ import ( func TestBool(t *testing.T) { tests := []struct { - name string - in any - want string // JSON + name string + in any + want string // JSON + wantBack any }{ { name: "null_for_unset", @@ -27,6 +28,15 @@ func TestBool(t *testing.T) { False: "false", }, want: `{"True":true,"False":false,"Unset":null}`, + wantBack: struct { + True Bool + False Bool + Unset Bool + }{ + True: "true", + False: "false", + Unset: "unset", + }, }, { name: "omitempty_unset", @@ -40,6 +50,24 @@ func TestBool(t *testing.T) { }, want: `{"True":true,"False":false}`, }, + { + name: "unset_marshals_as_null", + in: struct { + True Bool + False Bool + Foo Bool + }{ + True: "true", + False: "false", + Foo: "unset", + }, + want: `{"True":true,"False":false,"Foo":null}`, + wantBack: struct { + True Bool + False Bool + Foo Bool + }{"true", "false", "unset"}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -51,6 +79,10 @@ func TestBool(t *testing.T) { t.Errorf("wrong JSON:\n got: %s\nwant: %s\n", j, tt.want) } + wantBack := tt.in + if tt.wantBack != nil { + wantBack = tt.wantBack + } // And back again: newVal := reflect.New(reflect.TypeOf(tt.in)) out := newVal.Interface() @@ -58,8 +90,8 @@ func TestBool(t *testing.T) { t.Fatalf("Unmarshal %#q: %v", j, err) } got := newVal.Elem().Interface() - if !reflect.DeepEqual(tt.in, got) { - t.Errorf("value mismatch\n got: %+v\nwant: %+v\n", got, tt.in) + if !reflect.DeepEqual(got, wantBack) { + t.Errorf("value mismatch\n got: %+v\nwant: %+v\n", got, wantBack) } }) } @@ -79,11 +111,12 @@ func TestBoolEqualBool(t *testing.T) { {"true", false, false}, {"false", true, false}, {"false", false, true}, + {"1", true, false}, // "1" is not true; only "true" is + {"True", true, false}, // "True" is not true; only "true" is } for _, tt := range tests { if got := tt.b.EqualBool(tt.v); got != tt.want { t.Errorf("(%q).EqualBool(%v) = %v; want %v", string(tt.b), tt.v, got, tt.want) } } - }