From ae25fc6a8e2fa9c3a72fc325117330736fe729f6 Mon Sep 17 00:00:00 2001 From: Samuel Roth <2413031+sejr@users.noreply.github.com> Date: Mon, 12 Jul 2021 18:10:02 -0400 Subject: [PATCH] feat(uuid): Added support for NullUUID (#76) --- null.go | 120 ++++++++++++++++++++++++++ null_test.go | 238 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 358 insertions(+) create mode 100644 null.go create mode 100644 null_test.go diff --git a/null.go b/null.go new file mode 100644 index 0000000..95bb632 --- /dev/null +++ b/null.go @@ -0,0 +1,120 @@ +// Copyright 2021 Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package uuid + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "fmt" +) + +// NullUUID represents a UUID that may be null. +// NullUUID implements the Scanner interface so +// it can be used as a scan destination: +// +// var u uuid.NullUUID +// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&u) +// ... +// if u.Valid { +// // use u.UUID +// } else { +// // NULL value +// } +// +type NullUUID struct { + UUID UUID + Valid bool // Valid is true if UUID is not NULL +} + +// Scan implements the Scanner interface. +func (nu *NullUUID) Scan(value interface{}) error { + if value == nil { + nu.UUID, nu.Valid = Nil, false + return nil + } + + err := nu.UUID.Scan(value) + if err != nil { + nu.Valid = false + return err + } + + nu.Valid = true + return nil +} + +// Value implements the driver Valuer interface. +func (nu NullUUID) Value() (driver.Value, error) { + if !nu.Valid { + return nil, nil + } + // Delegate to UUID Value function + return nu.UUID.Value() +} + +// MarshalBinary implements encoding.BinaryMarshaler. +func (nu NullUUID) MarshalBinary() ([]byte, error) { + if nu.Valid { + return nu.UUID[:], nil + } + + return []byte(nil), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler. +func (nu *NullUUID) UnmarshalBinary(data []byte) error { + if len(data) != 16 { + return fmt.Errorf("invalid UUID (got %d bytes)", len(data)) + } + copy(nu.UUID[:], data) + nu.Valid = true + return nil +} + +// MarshalText implements encoding.TextMarshaler. +func (nu NullUUID) MarshalText() ([]byte, error) { + if nu.Valid { + return nu.UUID.MarshalText() + } + + return []byte{110, 117, 108, 108}, nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (nu *NullUUID) UnmarshalText(data []byte) error { + id, err := ParseBytes(data) + if err != nil { + nu.Valid = false + return err + } + nu.UUID = id + nu.Valid = true + return nil +} + +// MarshalJSON implements json.Marshaler. +func (nu NullUUID) MarshalJSON() ([]byte, error) { + if nu.Valid { + return json.Marshal(nu.UUID) + } + + return json.Marshal(nil) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (nu *NullUUID) UnmarshalJSON(data []byte) error { + null := []byte{110, 117, 108, 108} + if bytes.Equal(data, null) { + return nil // valid null UUID + } + + var u UUID + // tossing as we know u is valid + _ = json.Unmarshal(data, &u) + nu.Valid = true + nu.UUID = u + return nil +} diff --git a/null_test.go b/null_test.go new file mode 100644 index 0000000..b1988a4 --- /dev/null +++ b/null_test.go @@ -0,0 +1,238 @@ +// Copyright 2021 Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package uuid + +import ( + "bytes" + "encoding/json" + "fmt" + "testing" +) + +func TestNullUUIDScan(t *testing.T) { + var u UUID + var nu NullUUID + + uNilErr := u.Scan(nil) + nuNilErr := nu.Scan(nil) + if uNilErr != nil && + nuNilErr != nil && + uNilErr.Error() != nuNilErr.Error() { + t.Errorf("expected errors to be equal, got %s, %s", uNilErr, nuNilErr) + } + + uInvalidStringErr := u.Scan("test") + nuInvalidStringErr := nu.Scan("test") + if uInvalidStringErr != nil && + nuInvalidStringErr != nil && + uInvalidStringErr.Error() != nuInvalidStringErr.Error() { + t.Errorf("expected errors to be equal, got %s, %s", uInvalidStringErr, nuInvalidStringErr) + } + + valid := "12345678-abcd-1234-abcd-0123456789ab" + uValidErr := u.Scan(valid) + nuValidErr := nu.Scan(valid) + if uValidErr != nuValidErr { + t.Errorf("expected errors to be equal, got %s, %s", uValidErr, nuValidErr) + } +} + +func TestNullUUIDValue(t *testing.T) { + var u UUID + var nu NullUUID + + nuValue, nuErr := nu.Value() + if nuErr != nil { + t.Errorf("expected nil err, got err %s", nuErr) + } + if nuValue != nil { + t.Errorf("expected nil value, got non-nil %s", nuValue) + } + + u = MustParse("12345678-abcd-1234-abcd-0123456789ab") + nu = NullUUID{ + UUID: MustParse("12345678-abcd-1234-abcd-0123456789ab"), + Valid: true, + } + + uValue, uErr := u.Value() + nuValue, nuErr = nu.Value() + if uErr != nil { + t.Errorf("expected nil err, got err %s", uErr) + } + if nuErr != nil { + t.Errorf("expected nil err, got err %s", nuErr) + } + if uValue != nuValue { + t.Errorf("expected uuid %s and nulluuid %s to be equal ", uValue, nuValue) + } +} + +func TestNullUUIDMarshalText(t *testing.T) { + tests := []struct { + nullUUID NullUUID + }{ + { + nullUUID: NullUUID{}, + }, + { + nullUUID: NullUUID{ + UUID: MustParse("12345678-abcd-1234-abcd-0123456789ab"), + Valid: true, + }, + }, + } + for _, test := range tests { + var uText []byte + var uErr error + nuText, nuErr := test.nullUUID.MarshalText() + if test.nullUUID.Valid { + uText, uErr = test.nullUUID.UUID.MarshalText() + } else { + uText = []byte("null") + } + if nuErr != uErr { + t.Errorf("expected error %e, got %e", nuErr, uErr) + } + if !bytes.Equal(nuText, uText) { + t.Errorf("expected text data %s, got %s", string(nuText), string(uText)) + } + } +} + +func TestNullUUIDUnmarshalText(t *testing.T) { + tests := []struct { + nullUUID NullUUID + }{ + { + nullUUID: NullUUID{}, + }, + { + nullUUID: NullUUID{ + UUID: MustParse("12345678-abcd-1234-abcd-0123456789ab"), + Valid: true, + }, + }, + } + for _, test := range tests { + var uText []byte + var uErr error + nuText, nuErr := test.nullUUID.MarshalText() + if test.nullUUID.Valid { + uText, uErr = test.nullUUID.UUID.MarshalText() + } else { + uText = []byte("null") + } + if nuErr != uErr { + t.Errorf("expected error %e, got %e", nuErr, uErr) + } + if !bytes.Equal(nuText, uText) { + t.Errorf("expected text data %s, got %s", string(nuText), string(uText)) + } + } +} + +func TestNullUUIDMarshalBinary(t *testing.T) { + tests := []struct { + nullUUID NullUUID + }{ + { + nullUUID: NullUUID{}, + }, + { + nullUUID: NullUUID{ + UUID: MustParse("12345678-abcd-1234-abcd-0123456789ab"), + Valid: true, + }, + }, + } + for _, test := range tests { + var uBinary []byte + var uErr error + nuBinary, nuErr := test.nullUUID.MarshalBinary() + if test.nullUUID.Valid { + uBinary, uErr = test.nullUUID.UUID.MarshalBinary() + } else { + uBinary = []byte(nil) + } + if nuErr != uErr { + t.Errorf("expected error %e, got %e", nuErr, uErr) + } + if !bytes.Equal(nuBinary, uBinary) { + t.Errorf("expected binary data %s, got %s", string(nuBinary), string(uBinary)) + } + } +} + +func TestNullUUIDMarshalJSON(t *testing.T) { + jsonNull, _ := json.Marshal(nil) + jsonUUID, _ := json.Marshal(MustParse("12345678-abcd-1234-abcd-0123456789ab")) + tests := []struct { + nullUUID NullUUID + expected []byte + expectedErr error + }{ + { + nullUUID: NullUUID{}, + expected: jsonNull, + expectedErr: nil, + }, + { + nullUUID: NullUUID{ + UUID: MustParse(string(jsonUUID)), + Valid: true, + }, + expected: []byte(`"12345678-abcd-1234-abcd-0123456789ab"`), + expectedErr: nil, + }, + } + for _, test := range tests { + data, err := json.Marshal(&test.nullUUID) + if err != test.expectedErr { + t.Errorf("expected error %e, got %e", test.expectedErr, err) + } + if !bytes.Equal(data, test.expected) { + t.Errorf("expected json data %s, got %s", string(test.expected), string(data)) + } + } +} + +func TestNullUUIDUnmarshalJSON(t *testing.T) { + jsonNull, _ := json.Marshal(nil) + jsonUUID, _ := json.Marshal(MustParse("12345678-abcd-1234-abcd-0123456789ab")) + + var nu NullUUID + err := json.Unmarshal(jsonNull, &nu) + if err != nil || nu.Valid { + t.Errorf("expected nil when unmarshaling null, got %s", err) + } + err = json.Unmarshal(jsonUUID, &nu) + if err != nil || !nu.Valid { + t.Errorf("expected nil when unmarshaling null, got %s", err) + } +} + +func TestConformance(t *testing.T) { + input := []byte(`"12345678-abcd-1234-abcd-0123456789ab"`) + var n NullUUID + var u UUID + + err := json.Unmarshal(input, &n) + fmt.Printf("Unmarshal NullUUID: %+v %v\n", n, err) + err = json.Unmarshal(input, &u) + fmt.Printf("Unmarshal UUID: %+v %v\n", u, err) + + n = NullUUID{} + data, err := json.Marshal(&n) + fmt.Printf("Marshal Empty NullUUID %s %v\n", data, err) + + n.Valid = true + n.UUID = u + data, err = json.Marshal(&n) + fmt.Printf("Marshal Filled NullUUID %s %v\n", data, err) + + data, err = json.Marshal(&u) + fmt.Printf("Marshal UUID: %s %v\n", data, err) +}