Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor protoparse, fixes multiple issues #316

Merged
merged 2 commits into from
Apr 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion desc/builder/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ func (flb *FieldBuilder) GetExtendeeTypeName() string {
}
}

func (flb *FieldBuilder) buildProto(path []int32, sourceInfo *dpb.SourceCodeInfo) (*dpb.FieldDescriptorProto, error) {
func (flb *FieldBuilder) buildProto(path []int32, sourceInfo *dpb.SourceCodeInfo, isMessageSet bool) (*dpb.FieldDescriptorProto, error) {
addCommentsTo(sourceInfo, path, &flb.comments)

var lbl *dpb.FieldDescriptorProto_Label
Expand All @@ -508,6 +508,11 @@ func (flb *FieldBuilder) buildProto(path []int32, sourceInfo *dpb.SourceCodeInfo
def = proto.String(flb.Default)
}

maxTag := internal.GetMaxTag(isMessageSet)
if flb.number > maxTag {
return nil, fmt.Errorf("tag for field %s cannot be above max %d", GetFullyQualifiedName(flb), maxTag)
}

return &dpb.FieldDescriptorProto{
Name: proto.String(flb.name),
Number: proto.Int32(flb.number),
Expand Down
9 changes: 8 additions & 1 deletion desc/builder/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ func (fb *FileBuilder) buildProto(deps []*desc.FileDescriptor) (*dpb.FileDescrip
extensions := make([]*dpb.FieldDescriptorProto, 0, len(fb.extensions))
for _, exb := range fb.extensions {
path := append(path, internal.File_extensionsTag, int32(len(extensions)))
if exd, err := exb.buildProto(path, &sourceInfo); err != nil {
if exd, err := exb.buildProto(path, &sourceInfo, isExtendeeMessageSet(exb)); err != nil {
return nil, err
} else {
extensions = append(extensions, exd)
Expand Down Expand Up @@ -699,6 +699,13 @@ func (fb *FileBuilder) buildProto(deps []*desc.FileDescriptor) (*dpb.FileDescrip
}, nil
}

func isExtendeeMessageSet(flb *FieldBuilder) bool {
if flb.localExtendee != nil {
return flb.localExtendee.Options.GetMessageSetWireFormat()
}
return flb.foreignExtendee.GetMessageOptions().GetMessageSetWireFormat()
}

// Build constructs a file descriptor based on the contents of this file
// builder. If there are any problems constructing the descriptor, including
// resolving symbols referenced by the builder or failing to meet certain
Expand Down
6 changes: 3 additions & 3 deletions desc/builder/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ func (mb *MessageBuilder) buildProto(path []int32, sourceInfo *dpb.SourceCodeInf
for _, b := range mb.fieldsAndOneOfs {
if flb, ok := b.(*FieldBuilder); ok {
fldpath := append(path, internal.Message_fieldsTag, int32(len(fields)))
fld, err := flb.buildProto(fldpath, sourceInfo)
fld, err := flb.buildProto(fldpath, sourceInfo, mb.Options.GetMessageSetWireFormat())
if err != nil {
return nil, err
}
Expand All @@ -729,7 +729,7 @@ func (mb *MessageBuilder) buildProto(path []int32, sourceInfo *dpb.SourceCodeInf
oneOfs = append(oneOfs, ood)
for _, flb := range oob.choices {
path := append(path, internal.Message_fieldsTag, int32(len(fields)))
fld, err := flb.buildProto(path, sourceInfo)
fld, err := flb.buildProto(path, sourceInfo, mb.Options.GetMessageSetWireFormat())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -775,7 +775,7 @@ func (mb *MessageBuilder) buildProto(path []int32, sourceInfo *dpb.SourceCodeInf
nestedExtensions := make([]*dpb.FieldDescriptorProto, 0, len(mb.nestedExtensions))
for _, exb := range mb.nestedExtensions {
path := append(path, internal.Message_extensionsTag, int32(len(nestedExtensions)))
if exd, err := exb.buildProto(path, sourceInfo); err != nil {
if exd, err := exb.buildProto(path, sourceInfo, isExtendeeMessageSet(exb)); err != nil {
return nil, err
} else {
nestedExtensions = append(nestedExtensions, exd)
Expand Down
22 changes: 20 additions & 2 deletions desc/internal/util.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
package internal

import (
"math"
"unicode"
"unicode/utf8"
)

const (
// MaxTag is the maximum allowed tag number for a field.
MaxTag = 536870911 // 2^29 - 1
// MaxNormalTag is the maximum allowed tag number for a field in a normal message.
MaxNormalTag = 536870911 // 2^29 - 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I might switch this from MaxNormalTag/MaxTag to MaxTag/MaxTagMessageSetWireFormat since the message set max tag is the edge case, but more a matter of preference

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted a name like MaxTag to be the actual maximum value -- the greater of the two -- because I thought that would be more clear.

Maybe MaxNormalTag, MaxMessageSetTag, and MaxTag = MaxMessageSetTag?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea that works! Or keep as is, either way :-)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to adding GetMaxTag(messageSet bool), I also left these other constants, but tweaked them per discussion above.


// MaxMessageSetTag is the maximum allowed tag number of a field in a message that
// uses the message set wire format.
MaxMessageSetTag = math.MaxInt32 - 1

// MaxTag is the maximum allowed tag number. (It is the same as MaxMessageSetTag
// since that is the absolute highest allowed.)
MaxTag = MaxMessageSetTag

// SpecialReservedStart is the first tag in a range that is reserved and not
// allowed for use in message definitions.
Expand Down Expand Up @@ -268,3 +277,12 @@ func CreatePrefixList(pkg string) []string {

return prefixes
}

// GetMaxTag returns the max tag number allowed, based on whether a message uses
// message set wire format or not.
func GetMaxTag(isMessageSet bool) int32 {
if isMessageSet {
return MaxMessageSetTag
}
return MaxNormalTag
}
69 changes: 65 additions & 4 deletions desc/protoparse/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,33 +373,65 @@ func (n *compoundStringNode) value() interface{} {
return n.val
}

type intLiteral interface {
asInt32(min, max int32) (int32, bool)
value() interface{}
}

type intLiteralNode struct {
basicNode
val uint64
}

var _ intLiteral = (*intLiteralNode)(nil)

func (n *intLiteralNode) value() interface{} {
return n.val
}

func (n *intLiteralNode) asInt32(min, max int32) (int32, bool) {
if (min >= 0 && n.val < uint64(min)) || n.val > uint64(max) {
return 0, false
}
return int32(n.val), true
}

type compoundUintNode struct {
basicCompositeNode
val uint64
}

var _ intLiteral = (*compoundUintNode)(nil)

func (n *compoundUintNode) value() interface{} {
return n.val
}

func (n *compoundUintNode) asInt32(min, max int32) (int32, bool) {
if (min >= 0 && n.val < uint64(min)) || n.val > uint64(max) {
return 0, false
}
return int32(n.val), true
}

type compoundIntNode struct {
basicCompositeNode
val int64
}

var _ intLiteral = (*compoundIntNode)(nil)

func (n *compoundIntNode) value() interface{} {
return n.val
}

func (n *compoundIntNode) asInt32(min, max int32) (int32, bool) {
if n.val < int64(min) || n.val > int64(max) {
return 0, false
}
return int32(n.val), true
}

type floatLiteralNode struct {
basicNode
val float64
Expand Down Expand Up @@ -728,16 +760,45 @@ type extensionRangeNode struct {

type rangeNode struct {
basicCompositeNode
stNode, enNode node
st, en int32
startNode, endNode node
endMax bool
}

func (n *rangeNode) rangeStart() node {
return n.stNode
return n.startNode
}

func (n *rangeNode) rangeEnd() node {
return n.enNode
if n.endNode == nil {
return n.startNode
}
return n.endNode
}

func (n *rangeNode) startValue() interface{} {
return n.startNode.(intLiteral).value()
}

func (n *rangeNode) startValueAsInt32(min, max int32) (int32, bool) {
return n.startNode.(intLiteral).asInt32(min, max)
}

func (n *rangeNode) endValue() interface{} {
l, ok := n.endNode.(intLiteral)
if !ok {
return nil
}
return l.value()
}

func (n *rangeNode) endValueAsInt32(min, max int32) (int32, bool) {
if n.endMax {
return max, true
}
if n.endNode == nil {
return n.startValueAsInt32(min, max)
}
return n.endNode.(intLiteral).asInt32(min, max)
}

type reservedNode struct {
Expand Down