Skip to content

Commit

Permalink
proto: return an error instead of producing invalid wire format
Browse files Browse the repository at this point in the history
There currently is no risk of producing invalid wire format,
but that will change with subsequent changes regarding lazy decoding.

We have been running this change in production for about a month,
without ever triggering the check (until lazy decoding is involved).

related to golang/protobuf#1609

Change-Id: I3c5c956aee2fa81f99dea03ed2a977a1547081fc
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/579595
Auto-Submit: Michael Stapelberg <stapelberg@google.com>
Reviewed-by: Lasse Folger <lassefolger@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
  • Loading branch information
stapelberg authored and gopherbot committed Apr 18, 2024
1 parent 671c2db commit 94bb78c
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 7 deletions.
15 changes: 15 additions & 0 deletions internal/errors/errors.go
Expand Up @@ -87,3 +87,18 @@ func InvalidUTF8(name string) error {
func RequiredNotSet(name string) error {
return New("required field %v not set", name)
}

type SizeMismatchError struct {
Calculated, Measured int
}

func (e *SizeMismatchError) Error() string {
return fmt.Sprintf("size mismatch (see https://github.com/golang/protobuf/issues/1609): calculated=%d, measured=%d", e.Calculated, e.Measured)
}

func MismatchedSizeCalculation(calculated, measured int) error {
return &SizeMismatchError{
Calculated: calculated,
Measured: measured,
}
}
32 changes: 28 additions & 4 deletions internal/impl/codec_field.go
Expand Up @@ -233,9 +233,15 @@ func sizeMessageInfo(p pointer, f *coderFieldInfo, opts marshalOptions) int {
}

func appendMessageInfo(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
calculatedSize := f.mi.sizePointer(p.Elem(), opts)
b = protowire.AppendVarint(b, f.wiretag)
b = protowire.AppendVarint(b, uint64(f.mi.sizePointer(p.Elem(), opts)))
return f.mi.marshalAppendPointer(b, p.Elem(), opts)
b = protowire.AppendVarint(b, uint64(calculatedSize))
before := len(b)
b, err := f.mi.marshalAppendPointer(b, p.Elem(), opts)
if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
}
return b, err
}

func consumeMessageInfo(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
Expand Down Expand Up @@ -267,9 +273,15 @@ func sizeMessage(m proto.Message, tagsize int, _ marshalOptions) int {
}

func appendMessage(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
calculatedSize := proto.Size(m)
b = protowire.AppendVarint(b, wiretag)
b = protowire.AppendVarint(b, uint64(proto.Size(m)))
return opts.Options().MarshalAppend(b, m)
b = protowire.AppendVarint(b, uint64(calculatedSize))
before := len(b)
b, err := opts.Options().MarshalAppend(b, m)
if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
}
return b, err
}

func consumeMessage(b []byte, m proto.Message, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
Expand Down Expand Up @@ -482,10 +494,14 @@ func appendMessageSliceInfo(b []byte, p pointer, f *coderFieldInfo, opts marshal
b = protowire.AppendVarint(b, f.wiretag)
siz := f.mi.sizePointer(v, opts)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
b, err = f.mi.marshalAppendPointer(b, v, opts)
if err != nil {
return b, err
}
if measuredSize := len(b) - before; siz != measuredSize {
return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
}
}
return b, nil
}
Expand Down Expand Up @@ -538,10 +554,14 @@ func appendMessageSlice(b []byte, p pointer, wiretag uint64, goType reflect.Type
b = protowire.AppendVarint(b, wiretag)
siz := proto.Size(m)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
b, err = opts.Options().MarshalAppend(b, m)
if err != nil {
return b, err
}
if measuredSize := len(b) - before; siz != measuredSize {
return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
}
}
return b, nil
}
Expand Down Expand Up @@ -599,11 +619,15 @@ func appendMessageSliceValue(b []byte, listv protoreflect.Value, wiretag uint64,
b = protowire.AppendVarint(b, wiretag)
siz := proto.Size(m)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
var err error
b, err = mopts.MarshalAppend(b, m)
if err != nil {
return b, err
}
if measuredSize := len(b) - before; siz != measuredSize {
return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
}
}
return b, nil
}
Expand Down
15 changes: 13 additions & 2 deletions internal/impl/codec_map.go
Expand Up @@ -9,6 +9,7 @@ import (
"sort"

"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/reflect/protoreflect"
)
Expand Down Expand Up @@ -240,11 +241,16 @@ func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coder
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
size += mapi.valFuncs.size(val, mapValTagSize, opts)
b = protowire.AppendVarint(b, uint64(size))
before := len(b)
b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
if err != nil {
return nil, err
}
return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
b, err = mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
if measuredSize := len(b) - before; size != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(size, measuredSize)
}
return b, err
} else {
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
val := pointerOfValue(valrv)
Expand All @@ -259,7 +265,12 @@ func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coder
}
b = protowire.AppendVarint(b, mapi.valWiretag)
b = protowire.AppendVarint(b, uint64(valSize))
return f.mi.marshalAppendPointer(b, val, opts)
before := len(b)
b, err = f.mi.marshalAppendPointer(b, val, opts)
if measuredSize := len(b) - before; valSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(valSize, measuredSize)
}
return b, err
}
}

Expand Down
9 changes: 9 additions & 0 deletions proto/encode.go
Expand Up @@ -5,12 +5,17 @@
package proto

import (
"errors"
"fmt"

"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/order"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"

protoerrors "google.golang.org/protobuf/internal/errors"
)

// MarshalOptions configures the marshaler.
Expand Down Expand Up @@ -175,6 +180,10 @@ func (o MarshalOptions) marshal(b []byte, m protoreflect.Message) (out protoifac
out.Buf, err = o.marshalMessageSlow(b, m)
}
if err != nil {
var mismatch *protoerrors.SizeMismatchError
if errors.As(err, &mismatch) {
return out, fmt.Errorf("marshaling %s: %v", string(m.Descriptor().FullName()), err)
}
return out, err
}
if allowPartial {
Expand Down
7 changes: 6 additions & 1 deletion proto/messageset.go
Expand Up @@ -47,11 +47,16 @@ func (o MarshalOptions) marshalMessageSet(b []byte, m protoreflect.Message) ([]b
func (o MarshalOptions) marshalMessageSetField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value) ([]byte, error) {
b = messageset.AppendFieldStart(b, fd.Number())
b = protowire.AppendTag(b, messageset.FieldMessage, protowire.BytesType)
b = protowire.AppendVarint(b, uint64(o.Size(value.Message().Interface())))
calculatedSize := o.Size(value.Message().Interface())
b = protowire.AppendVarint(b, uint64(calculatedSize))
before := len(b)
b, err := o.marshalMessage(b, value.Message())
if err != nil {
return b, err
}
if measuredSize := len(b) - before; calculatedSize != measuredSize {
return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
}
b = messageset.AppendFieldEnd(b)
return b, nil
}
Expand Down

0 comments on commit 94bb78c

Please sign in to comment.