From f74a52bd67bc1432d1d080e3d23180253980d4b5 Mon Sep 17 00:00:00 2001 From: Bharath Ramesh Date: Fri, 4 Dec 2020 07:57:26 +0000 Subject: [PATCH] Added XML support for NullDecimal * Added UnmarshalText and MarshalText support for NullDecimal * Add tests for XML operations on NullDecimal --- decimal.go | 25 ++++++++++++++++++ decimal_test.go | 69 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/decimal.go b/decimal.go index c0404f1d..524f2890 100644 --- a/decimal.go +++ b/decimal.go @@ -1307,6 +1307,31 @@ func (d NullDecimal) MarshalJSON() ([]byte, error) { return d.Decimal.MarshalJSON() } +// UnmarshalText implements the encoding.TextUnmarshaler interface for XML +// deserialization +func (d *NullDecimal) UnmarshalText(text []byte) error { + str := string(text) + + if str == "" { + d.Valid = false + return nil + } + if err := d.Decimal.UnmarshalText(text); err != nil { + return err + } + d.Valid = true + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface for XML +// serialization. +func (d NullDecimal) MarshalText() (text []byte, err error) { + if !d.Valid { + return []byte{}, nil + } + return d.Decimal.MarshalText() +} + // Trig functions // Atan returns the arctangent, in radians, of x. diff --git a/decimal_test.go b/decimal_test.go index d7ec5e42..ad9d82b1 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -766,6 +766,75 @@ func TestBadXML(t *testing.T) { } } +func TestNullDecimalXML(t *testing.T) { + for _, x := range testTable { + s := x.short + var doc struct { + XMLName xml.Name `xml:"account"` + Amount NullDecimal `xml:"amount"` + } + docStr := `` + s + `` + err := xml.Unmarshal([]byte(docStr), &doc) + if err != nil { + t.Errorf("error unmarshaling %s: %v", docStr, err) + } else if doc.Amount.Decimal.String() != s { + t.Errorf("expected %s, got %s (%s, %d)", + s, doc.Amount.Decimal.String(), + doc.Amount.Decimal.value.String(), doc.Amount.Decimal.exp) + } + + out, err := xml.Marshal(&doc) + if err != nil { + t.Errorf("error marshaling %+v: %v", doc, err) + } else if string(out) != docStr { + t.Errorf("expected %s, got %s", docStr, string(out)) + } + } + + var doc struct { + XMLName xml.Name `xml:"account"` + Amount NullDecimal `xml:"amount"` + } + + docStr := `` + err := xml.Unmarshal([]byte(docStr), &doc) + if err != nil { + t.Errorf("error unmarshaling: %s: %v", docStr, err) + } else if doc.Amount.Valid { + t.Errorf("expected null value to have Valid = false, got Valid = true and Decimal = %s (%s, %d)", + doc.Amount.Decimal.String(), + doc.Amount.Decimal.value.String(), doc.Amount.Decimal.exp) + } + + expected := `` + out, err := xml.Marshal(&doc) + if err != nil { + t.Errorf("error marshaling %+v: %v", doc, err) + } else if string(out) != expected { + t.Errorf("expected %s, got %s", expected, string(out)) + } +} + +func TestNullDecimalBadXML(t *testing.T) { + for _, testCase := range []string{ + "o_o", + "7", + ``, + `nope`, + `0.333`, + } { + var doc struct { + XMLName xml.Name `xml:"account"` + Amount NullDecimal `xml:"amount"` + } + err := xml.Unmarshal([]byte(testCase), &doc) + if err == nil { + t.Errorf("expected error, got %+v", doc) + } + } +} + func TestDecimal_rescale(t *testing.T) { type Inp struct { int int64