From 93e724cbb0b46d4751b8e6b04974f11e3bbae75d Mon Sep 17 00:00:00 2001 From: Divjot Arora Date: Wed, 24 Mar 2021 20:54:06 -0400 Subject: [PATCH] Error if BSON cstrings contain null bytes --- bson/bsonrw/value_writer.go | 15 ++++++++++++++- bson/extjson_prose_test.go | 8 ++++++++ bson/marshal_test.go | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) diff --git a/bson/bsonrw/value_writer.go b/bson/bsonrw/value_writer.go index 3717198366..7b7d7ad3f2 100644 --- a/bson/bsonrw/value_writer.go +++ b/bson/bsonrw/value_writer.go @@ -12,6 +12,7 @@ import ( "io" "math" "strconv" + "strings" "sync" "go.mongodb.org/mongo-driver/bson/bsontype" @@ -247,7 +248,12 @@ func (vw *valueWriter) invalidTransitionError(destination mode, name string, mod func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error { switch vw.stack[vw.frame].mode { case mElement: - vw.buf = bsoncore.AppendHeader(vw.buf, t, vw.stack[vw.frame].key) + key := vw.stack[vw.frame].key + if !isValidCString(key) { + return errors.New("BSON element key cannot contain null bytes") + } + + vw.buf = bsoncore.AppendHeader(vw.buf, t, key) case mValue: // TODO: Do this with a cache of the first 1000 or so array keys. vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey)) @@ -430,6 +436,9 @@ func (vw *valueWriter) WriteObjectID(oid primitive.ObjectID) error { } func (vw *valueWriter) WriteRegex(pattern string, options string) error { + if !isValidCString(pattern) || !isValidCString(options) { + return errors.New("BSON regex values cannot contain null bytes") + } if err := vw.writeElementHeader(bsontype.Regex, mode(0), "WriteRegex"); err != nil { return err } @@ -602,3 +611,7 @@ func (vw *valueWriter) writeLength() error { vw.buf[start+3] = byte(length >> 24) return nil } + +func isValidCString(cs string) bool { + return !strings.ContainsRune(cs, '\x00') +} diff --git a/bson/extjson_prose_test.go b/bson/extjson_prose_test.go index 211ab5e699..6b884e667e 100644 --- a/bson/extjson_prose_test.go +++ b/bson/extjson_prose_test.go @@ -45,3 +45,11 @@ func TestExtJSON(t *testing.T) { }) } } + +func TestExtJSONNullBytes(t *testing.T) { + t.Run("element keys", func(t *testing.T) { + doc := D{{"a\x00", "foo"}} + res, err := MarshalExtJSON(doc, false, false) + assert.NotNil(t, err, "expected MarshalExtJSON error but got nil with result %v", string(res)) + }) +} diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 7e570676b9..319870522d 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -8,6 +8,7 @@ package bson import ( "bytes" + "errors" "fmt" "reflect" "testing" @@ -267,3 +268,35 @@ func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) { }) }) } + +func TestNullBytes(t *testing.T) { + t.Run("element keys", func(t *testing.T) { + doc := D{{"a\x00", "foobar"}} + res, err := Marshal(doc) + want := errors.New("BSON element key cannot contain null bytes") + assert.Equal(t, want, err, "expected Marshal error %v, got error %v with result %q", want, err, Raw(res)) + }) + + t.Run("regex values", func(t *testing.T) { + wantErr := errors.New("BSON regex values cannot contain null bytes") + + testCases := []struct { + name string + pattern string + options string + }{ + {"null bytes in pattern", "a\x00", "i"}, + {"null bytes in options", "pattern", "i\x00"}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + regex := primitive.Regex{ + Pattern: tc.pattern, + Options: tc.options, + } + res, err := Marshal(D{{"foo", regex}}) + assert.Equal(t, wantErr, err, "expected Marshal error %v, got error %v with result %q", wantErr, err, Raw(res)) + }) + } + }) +}