diff --git a/decimal.go b/decimal.go index 9a80dc53..0287e3fc 100644 --- a/decimal.go +++ b/decimal.go @@ -1603,6 +1603,13 @@ type NullDecimal struct { Valid bool } +func NewNullDecimal(d Decimal) NullDecimal { + return NullDecimal{ + Decimal: d, + Valid: true, + } +} + // Scan implements the sql.Scanner interface for database deserialization. func (d *NullDecimal) Scan(value interface{}) error { if value == nil { diff --git a/decimal_test.go b/decimal_test.go index 686b8bbc..53d59a43 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -3269,6 +3269,18 @@ func TestTan(t *testing.T) { } } +func TestNewNullDecimal(t *testing.T) { + d := NewFromInt(1) + nd := NewNullDecimal(d) + + if !nd.Valid { + t.Errorf("expected NullDecimal to be valid") + } + if nd.Decimal != d { + t.Errorf("expected NullDecimal to hold the provided Decimal") + } +} + func ExampleNewFromFloat32() { fmt.Println(NewFromFloat32(123.123123123123).String()) fmt.Println(NewFromFloat32(.123123123123123).String())