Skip to content

Commit

Permalink
feat(uuid): Added support for NullUUID
Browse files Browse the repository at this point in the history
  • Loading branch information
sejr committed Jul 12, 2021
1 parent bfb86fa commit f344a5b
Show file tree
Hide file tree
Showing 2 changed files with 358 additions and 0 deletions.
120 changes: 120 additions & 0 deletions null.go
@@ -0,0 +1,120 @@
// Copyright 2021 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package uuid

import (
"bytes"
"database/sql/driver"
"encoding/json"
"fmt"
)

// NullUUID represents a UUID that may be null.
// NullUUID implements the Scanner interface so
// it can be used as a scan destination:
//
// var u uuid.NullUUID
// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&u)
// ...
// if u.Valid {
// // use u.UUID
// } else {
// // NULL value
// }
//
type NullUUID struct {
UUID UUID
Valid bool // Valid is true if UUID is not NULL
}

// Scan implements the Scanner interface.
func (nu *NullUUID) Scan(value interface{}) error {
if value == nil {
nu.UUID, nu.Valid = Nil, false
return nil
}

err := nu.UUID.Scan(value)
if err != nil {
nu.Valid = false
return err
}

nu.Valid = true
return nil
}

// Value implements the driver Valuer interface.
func (nu NullUUID) Value() (driver.Value, error) {
if !nu.Valid {
return nil, nil
}
// Delegate to UUID Value function
return nu.UUID.Value()
}

// MarshalBinary implements encoding.BinaryMarshaler.
func (nu NullUUID) MarshalBinary() ([]byte, error) {
if nu.Valid {
return nu.UUID[:], nil
}

return []byte(nil), nil
}

// UnmarshalBinary implements encoding.BinaryUnmarshaler.
func (nu *NullUUID) UnmarshalBinary(data []byte) error {
if len(data) != 16 {
return fmt.Errorf("invalid UUID (got %d bytes)", len(data))
}
copy(nu.UUID[:], data)
nu.Valid = true
return nil
}

// MarshalText implements encoding.TextMarshaler.
func (nu NullUUID) MarshalText() ([]byte, error) {
if nu.Valid {
return nu.UUID.MarshalText()
}

return []byte{110, 117, 108, 108}, nil
}

// UnmarshalText implements encoding.TextUnmarshaler.
func (nu *NullUUID) UnmarshalText(data []byte) error {
id, err := ParseBytes(data)
if err != nil {
nu.Valid = false
return err
}
nu.UUID = id
nu.Valid = true
return nil
}

// MarshalJSON implements json.Marshaler.
func (nu NullUUID) MarshalJSON() ([]byte, error) {
if nu.Valid {
return json.Marshal(nu.UUID)
}

return json.Marshal(nil)
}

// UnmarshalJSON implements json.Unmarshaler.
func (nu *NullUUID) UnmarshalJSON(data []byte) error {
null := []byte{110, 117, 108, 108}
if bytes.Equal(data, null) {
return nil // valid null UUID
}

var u UUID
// tossing as we know u is valid
_ = json.Unmarshal(data, &u)
nu.Valid = true
nu.UUID = u
return nil
}
238 changes: 238 additions & 0 deletions null_test.go
@@ -0,0 +1,238 @@
// Copyright 2021 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package uuid

import (
"bytes"
"encoding/json"
"fmt"
"testing"
)

func TestNullUUIDScan(t *testing.T) {
var u UUID
var nu NullUUID

uNilErr := u.Scan(nil)
nuNilErr := nu.Scan(nil)
if uNilErr != nil &&
nuNilErr != nil &&
uNilErr.Error() != nuNilErr.Error() {
t.Errorf("expected errors to be equal, got %s, %s", uNilErr, nuNilErr)
}

uInvalidStringErr := u.Scan("test")
nuInvalidStringErr := nu.Scan("test")
if uInvalidStringErr != nil &&
nuInvalidStringErr != nil &&
uInvalidStringErr.Error() != nuInvalidStringErr.Error() {
t.Errorf("expected errors to be equal, got %s, %s", uInvalidStringErr, nuInvalidStringErr)
}

valid := "12345678-abcd-1234-abcd-0123456789ab"
uValidErr := u.Scan(valid)
nuValidErr := nu.Scan(valid)
if uValidErr != nuValidErr {
t.Errorf("expected errors to be equal, got %s, %s", uValidErr, nuValidErr)
}
}

func TestNullUUIDValue(t *testing.T) {
var u UUID
var nu NullUUID

nuValue, nuErr := nu.Value()
if nuErr != nil {
t.Errorf("expected nil err, got err %s", nuErr)
}
if nuValue != nil {
t.Errorf("expected nil value, got non-nil %s", nuValue)
}

u = MustParse("12345678-abcd-1234-abcd-0123456789ab")
nu = NullUUID{
UUID: MustParse("12345678-abcd-1234-abcd-0123456789ab"),
Valid: true,
}

uValue, uErr := u.Value()
nuValue, nuErr = nu.Value()
if uErr != nil {
t.Errorf("expected nil err, got err %s", uErr)
}
if nuErr != nil {
t.Errorf("expected nil err, got err %s", nuErr)
}
if uValue != nuValue {
t.Errorf("expected uuid %s and nulluuid %s to be equal ", uValue, nuValue)
}
}

func TestNullUUIDMarshalText(t *testing.T) {
tests := []struct {
nullUUID NullUUID
}{
{
nullUUID: NullUUID{},
},
{
nullUUID: NullUUID{
UUID: MustParse("12345678-abcd-1234-abcd-0123456789ab"),
Valid: true,
},
},
}
for _, test := range tests {
var uText []byte
var uErr error
nuText, nuErr := test.nullUUID.MarshalText()
if test.nullUUID.Valid {
uText, uErr = test.nullUUID.UUID.MarshalText()
} else {
uText = []byte("null")
}
if nuErr != uErr {
t.Errorf("expected error %e, got %e", nuErr, uErr)
}
if !bytes.Equal(nuText, uText) {
t.Errorf("expected text data %s, got %s", string(nuText), string(uText))
}
}
}

func TestNullUUIDUnmarshalText(t *testing.T) {
tests := []struct {
nullUUID NullUUID
}{
{
nullUUID: NullUUID{},
},
{
nullUUID: NullUUID{
UUID: MustParse("12345678-abcd-1234-abcd-0123456789ab"),
Valid: true,
},
},
}
for _, test := range tests {
var uText []byte
var uErr error
nuText, nuErr := test.nullUUID.MarshalText()
if test.nullUUID.Valid {
uText, uErr = test.nullUUID.UUID.MarshalText()
} else {
uText = []byte("null")
}
if nuErr != uErr {
t.Errorf("expected error %e, got %e", nuErr, uErr)
}
if !bytes.Equal(nuText, uText) {
t.Errorf("expected text data %s, got %s", string(nuText), string(uText))
}
}
}

func TestNullUUIDMarshalBinary(t *testing.T) {
tests := []struct {
nullUUID NullUUID
}{
{
nullUUID: NullUUID{},
},
{
nullUUID: NullUUID{
UUID: MustParse("12345678-abcd-1234-abcd-0123456789ab"),
Valid: true,
},
},
}
for _, test := range tests {
var uBinary []byte
var uErr error
nuBinary, nuErr := test.nullUUID.MarshalBinary()
if test.nullUUID.Valid {
uBinary, uErr = test.nullUUID.UUID.MarshalBinary()
} else {
uBinary = []byte(nil)
}
if nuErr != uErr {
t.Errorf("expected error %e, got %e", nuErr, uErr)
}
if !bytes.Equal(nuBinary, uBinary) {
t.Errorf("expected binary data %s, got %s", string(nuBinary), string(uBinary))
}
}
}

func TestNullUUIDMarshalJSON(t *testing.T) {
jsonNull, _ := json.Marshal(nil)
jsonUUID, _ := json.Marshal(MustParse("12345678-abcd-1234-abcd-0123456789ab"))
tests := []struct {
nullUUID NullUUID
expected []byte
expectedErr error
}{
{
nullUUID: NullUUID{},
expected: jsonNull,
expectedErr: nil,
},
{
nullUUID: NullUUID{
UUID: MustParse(string(jsonUUID)),
Valid: true,
},
expected: []byte(`"12345678-abcd-1234-abcd-0123456789ab"`),
expectedErr: nil,
},
}
for _, test := range tests {
data, err := json.Marshal(&test.nullUUID)
if err != test.expectedErr {
t.Errorf("expected error %e, got %e", test.expectedErr, err)
}
if !bytes.Equal(data, test.expected) {
t.Errorf("expected json data %s, got %s", string(test.expected), string(data))
}
}
}

func TestNullUUIDUnmarshalJSON(t *testing.T) {
jsonNull, _ := json.Marshal(nil)
jsonUUID, _ := json.Marshal(MustParse("12345678-abcd-1234-abcd-0123456789ab"))

var nu NullUUID
err := json.Unmarshal(jsonNull, &nu)
if err != nil || nu.Valid {
t.Errorf("expected nil when unmarshaling null, got %s", err)
}
err = json.Unmarshal(jsonUUID, &nu)
if err != nil || !nu.Valid {
t.Errorf("expected nil when unmarshaling null, got %s", err)
}
}

func TestConformance(t *testing.T) {
input := []byte(`"12345678-abcd-1234-abcd-0123456789ab"`)
var n NullUUID
var u UUID

err := json.Unmarshal(input, &n)
fmt.Printf("Unmarshal NullUUID: %+v %v\n", n, err)
err = json.Unmarshal(input, &u)
fmt.Printf("Unmarshal UUID: %+v %v\n", u, err)

n = NullUUID{}
data, err := json.Marshal(&n)
fmt.Printf("Marshal Empty NullUUID %s %v\n", data, err)

n.Valid = true
n.UUID = u
data, err = json.Marshal(&n)
fmt.Printf("Marshal Filled NullUUID %s %v\n", data, err)

data, err = json.Marshal(&u)
fmt.Printf("Marshal UUID: %s %v\n", data, err)
}

0 comments on commit f344a5b

Please sign in to comment.