Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add slice range checks to UnmarshalBinary(). #232

Merged
merged 2 commits into from Jun 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 11 additions & 1 deletion decimal.go
Expand Up @@ -1152,12 +1152,22 @@ func (d Decimal) MarshalJSON() ([]byte, error) {
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. As a string representation
// is already used when encoding to text, this method stores that string as []byte
func (d *Decimal) UnmarshalBinary(data []byte) error {
// Verify we have at least 5 bytes, 4 for the exponent and at least 1 more
// for the GOB encoded value.
if len(data) < 5 {
return fmt.Errorf("error decoding binary %v: expected at least 5 bytes, got %d", data, len(data))
}

// Extract the exponent
d.exp = int32(binary.BigEndian.Uint32(data[:4]))

// Extract the value
d.value = new(big.Int)
return d.value.GobDecode(data[4:])
if err := d.value.GobDecode(data[4:]); err != nil {
return fmt.Errorf("error decoding binary %v: %s", data, err)
}

return nil
}

// MarshalBinary implements the encoding.BinaryMarshaler interface.
Expand Down
18 changes: 18 additions & 0 deletions decimal_test.go
Expand Up @@ -2798,6 +2798,24 @@ func TestBinary(t *testing.T) {
}
}

func TestBinary_DataTooShort(t *testing.T) {
var d Decimal

err := d.UnmarshalBinary(nil) // nil slice has length 0
if err == nil {
t.Fatalf("expected error, got %v", d)
}
}

func TestBinary_InvalidValue(t *testing.T) {
var d Decimal

err := d.UnmarshalBinary([]byte{0, 0, 0, 0, 'x'}) // valid exponent, invalid value
if err == nil {
t.Fatalf("expected error, got %v", d)
}
}

func slicesEqual(a, b []byte) bool {
for i, val := range a {
if b[i] != val {
Expand Down