Skip to content

Commit

Permalink
proto: ensure MarshalOptions are plumbed to all Size calls
Browse files Browse the repository at this point in the history
This used to not be necessary, but a subsequent change will change behavior
based on MarshalOptions.Deterministic, so it is important that we do not
accidentally drop MarshalOptions anywhere.

related to golang/protobuf#1609

Change-Id: I6b53352707d93939642a627eb41c930f8ac3157f
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/579995
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Lasse Folger <lassefolger@google.com>
  • Loading branch information
stapelberg committed Apr 18, 2024
1 parent 94bb78c commit d0f77ae
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 21 deletions.
36 changes: 21 additions & 15 deletions internal/impl/codec_field.go
Expand Up @@ -268,16 +268,17 @@ func isInitMessageInfo(p pointer, f *coderFieldInfo) error {
return f.mi.checkInitializedPointer(p.Elem())
}

func sizeMessage(m proto.Message, tagsize int, _ marshalOptions) int {
return protowire.SizeBytes(proto.Size(m)) + tagsize
func sizeMessage(m proto.Message, tagsize int, opts marshalOptions) int {
return protowire.SizeBytes(opts.Options().Size(m)) + tagsize
}

func appendMessage(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
calculatedSize := proto.Size(m)
mopts := opts.Options()
calculatedSize := mopts.Size(m)
b = protowire.AppendVarint(b, wiretag)
b = protowire.AppendVarint(b, uint64(calculatedSize))
before := len(b)
b, err := opts.Options().MarshalAppend(b, m)
b, err := mopts.MarshalAppend(b, m)
if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
}
Expand Down Expand Up @@ -417,8 +418,8 @@ func consumeGroupType(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInf
return f.mi.unmarshalPointer(b, p.Elem(), f.num, opts)
}

func sizeGroup(m proto.Message, tagsize int, _ marshalOptions) int {
return 2*tagsize + proto.Size(m)
func sizeGroup(m proto.Message, tagsize int, opts marshalOptions) int {
return 2*tagsize + opts.Options().Size(m)
}

func appendGroup(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
Expand Down Expand Up @@ -536,26 +537,28 @@ func isInitMessageSliceInfo(p pointer, f *coderFieldInfo) error {
return nil
}

func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, _ marshalOptions) int {
func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, opts marshalOptions) int {
mopts := opts.Options()
s := p.PointerSlice()
n := 0
for _, v := range s {
m := asMessage(v.AsValueOf(goType.Elem()))
n += protowire.SizeBytes(proto.Size(m)) + tagsize
n += protowire.SizeBytes(mopts.Size(m)) + tagsize
}
return n
}

func appendMessageSlice(b []byte, p pointer, wiretag uint64, goType reflect.Type, opts marshalOptions) ([]byte, error) {
mopts := opts.Options()
s := p.PointerSlice()
var err error
for _, v := range s {
m := asMessage(v.AsValueOf(goType.Elem()))
b = protowire.AppendVarint(b, wiretag)
siz := proto.Size(m)
siz := mopts.Size(m)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
b, err = opts.Options().MarshalAppend(b, m)
b, err = mopts.MarshalAppend(b, m)
if err != nil {
return b, err
}
Expand Down Expand Up @@ -602,11 +605,12 @@ func isInitMessageSlice(p pointer, goType reflect.Type) error {
// Slices of messages

func sizeMessageSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int {
mopts := opts.Options()
list := listv.List()
n := 0
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
n += protowire.SizeBytes(proto.Size(m)) + tagsize
n += protowire.SizeBytes(mopts.Size(m)) + tagsize
}
return n
}
Expand All @@ -617,7 +621,7 @@ func appendMessageSliceValue(b []byte, listv protoreflect.Value, wiretag uint64,
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
b = protowire.AppendVarint(b, wiretag)
siz := proto.Size(m)
siz := mopts.Size(m)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
var err error
Expand Down Expand Up @@ -675,11 +679,12 @@ var coderMessageSliceValue = valueCoderFuncs{
}

func sizeGroupSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int {
mopts := opts.Options()
list := listv.List()
n := 0
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
n += 2*tagsize + proto.Size(m)
n += 2*tagsize + mopts.Size(m)
}
return n
}
Expand Down Expand Up @@ -762,12 +767,13 @@ func makeGroupSliceFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type)
}
}

func sizeGroupSlice(p pointer, messageType reflect.Type, tagsize int, _ marshalOptions) int {
func sizeGroupSlice(p pointer, messageType reflect.Type, tagsize int, opts marshalOptions) int {
mopts := opts.Options()
s := p.PointerSlice()
n := 0
for _, v := range s {
m := asMessage(v.AsValueOf(messageType.Elem()))
n += 2*tagsize + proto.Size(m)
n += 2*tagsize + mopts.Size(m)
}
return n
}
Expand Down
28 changes: 22 additions & 6 deletions proto/encode.go
Expand Up @@ -75,6 +75,27 @@ type MarshalOptions struct {
UseCachedSize bool
}

// flags turns the specified MarshalOptions (user-facing) into
// protoiface.MarshalInputFlags (used internally by the marshaler).
//
// See impl.marshalOptions.Options for the inverse operation.
func (o MarshalOptions) flags() protoiface.MarshalInputFlags {
var flags protoiface.MarshalInputFlags

// Note: o.AllowPartial is always forced to true by MarshalOptions.marshal,
// which is why it is not a part of MarshalInputFlags.

if o.Deterministic {
flags |= protoiface.MarshalDeterministic
}

if o.UseCachedSize {
flags |= protoiface.MarshalUseCachedSize
}

return flags
}

// Marshal returns the wire-format encoding of m.
//
// This is the most common entry point for encoding a Protobuf message.
Expand Down Expand Up @@ -157,12 +178,7 @@ func (o MarshalOptions) marshal(b []byte, m protoreflect.Message) (out protoifac
in := protoiface.MarshalInput{
Message: m,
Buf: b,
}
if o.Deterministic {
in.Flags |= protoiface.MarshalDeterministic
}
if o.UseCachedSize {
in.Flags |= protoiface.MarshalUseCachedSize
Flags: o.flags(),
}
if methods.Size != nil {
sout := methods.Size(protoiface.SizeInput{
Expand Down
2 changes: 2 additions & 0 deletions proto/size.go
Expand Up @@ -34,6 +34,7 @@ func (o MarshalOptions) size(m protoreflect.Message) (size int) {
if methods != nil && methods.Size != nil {
out := methods.Size(protoiface.SizeInput{
Message: m,
Flags: o.flags(),
})
return out.Size
}
Expand All @@ -42,6 +43,7 @@ func (o MarshalOptions) size(m protoreflect.Message) (size int) {
// This case is mainly used for legacy types with a Marshal method.
out, _ := methods.Marshal(protoiface.MarshalInput{
Message: m,
Flags: o.flags(),
})
return len(out.Buf)
}
Expand Down

0 comments on commit d0f77ae

Please sign in to comment.