From 0cdec1ab8560a73379dc2a64c3c9d3e17940f66d Mon Sep 17 00:00:00 2001 From: hackerman <3372410+aeneasr@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:18:24 +0100 Subject: [PATCH] feat: add nullable bool defaulting to false (#758) --- sqlxx/types.go | 50 +++++++++++++++++++++++++++++++++++++++++++++ sqlxx/types_test.go | 36 ++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/sqlxx/types.go b/sqlxx/types.go index defadb29..2d167a46 100644 --- a/sqlxx/types.go +++ b/sqlxx/types.go @@ -163,6 +163,56 @@ func (ns *NullBool) UnmarshalJSON(data []byte) error { return errors.WithStack(json.Unmarshal(data, &ns.Bool)) } +// FalsyNullBool represents a bool that may be null. +// It JSON decodes to false if null. +// +// swagger:type bool +// swagger:model falsyNullBool +type FalsyNullBool struct { + Bool bool + Valid bool // Valid is true if Bool is not NULL +} + +// Scan implements the Scanner interface. +func (ns *FalsyNullBool) Scan(value interface{}) error { + var d = sql.NullBool{} + if err := d.Scan(value); err != nil { + return err + } + + ns.Bool = d.Bool + ns.Valid = d.Valid + return nil +} + +// Value implements the driver Valuer interface. +func (ns FalsyNullBool) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return ns.Bool, nil +} + +// MarshalJSON returns m as the JSON encoding of m. +func (ns FalsyNullBool) MarshalJSON() ([]byte, error) { + if !ns.Valid { + return []byte("false"), nil + } + return json.Marshal(ns.Bool) +} + +// UnmarshalJSON sets *m to a copy of data. +func (ns *FalsyNullBool) UnmarshalJSON(data []byte) error { + if ns == nil { + return errors.New("json.RawMessage: UnmarshalJSON on nil pointer") + } + if len(data) == 0 || string(data) == "null" { + return nil + } + ns.Valid = true + return errors.WithStack(json.Unmarshal(data, &ns.Bool)) +} + // swagger:type string // swagger:model nullString type NullString string diff --git a/sqlxx/types_test.go b/sqlxx/types_test.go index 1b686672..b19afd65 100644 --- a/sqlxx/types_test.go +++ b/sqlxx/types_test.go @@ -64,6 +64,42 @@ func TestNullBoolMarshalJSON(t *testing.T) { } } +func TestNullBoolDefaultFalseMarshalJSON(t *testing.T) { + type outer struct { + Bool *FalsyNullBool `json:"null_bool,omitempty"` + } + + for k, tc := range []struct { + in *outer + expected string + }{ + {in: &outer{&FalsyNullBool{Valid: false, Bool: true}}, expected: "{\"null_bool\":false}"}, + {in: &outer{&FalsyNullBool{Valid: false, Bool: false}}, expected: "{\"null_bool\":false}"}, + {in: &outer{&FalsyNullBool{Valid: true, Bool: true}}, expected: "{\"null_bool\":true}"}, + {in: &outer{&FalsyNullBool{Valid: true, Bool: false}}, expected: "{\"null_bool\":false}"}, + {in: &outer{}, expected: "{}"}, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + out, err := json.Marshal(tc.in) + require.NoError(t, err) + assert.EqualValues(t, tc.expected, string(out)) + + var actual outer + require.NoError(t, json.Unmarshal(out, &actual)) + if tc.in.Bool == nil { + assert.Nil(t, actual.Bool) + return + } else if !tc.in.Bool.Valid { + assert.False(t, actual.Bool.Bool) + return + } + + assert.EqualValues(t, tc.in.Bool.Bool, actual.Bool.Bool) + assert.EqualValues(t, tc.in.Bool.Valid, actual.Bool.Valid) + }) + } +} + func TestNullInt64MarshalJSON(t *testing.T) { type outer struct { Int64 *NullInt64 `json:"null_int,omitempty"`