diff --git a/desc/protoparse/linker.go b/desc/protoparse/linker.go index 49f91136..a9a3ba31 100644 --- a/desc/protoparse/linker.go +++ b/desc/protoparse/linker.go @@ -15,11 +15,12 @@ import ( ) type linker struct { - files map[string]*parseResult - filenames []string - errs *errorHandler - descriptorPool map[*dpb.FileDescriptorProto]map[string]proto.Message - extensions map[string]map[int32]string + files map[string]*parseResult + filenames []string + errs *errorHandler + descriptorPool map[*dpb.FileDescriptorProto]map[string]proto.Message + packageNamespaces map[*dpb.FileDescriptorProto]map[string]struct{} + extensions map[string]map[int32]string } func newLinker(files *parseResults, errs *errorHandler) *linker { @@ -85,12 +86,14 @@ func (l *linker) linkFiles() (map[string]*desc.FileDescriptor, error) { func (l *linker) createDescriptorPool() error { l.descriptorPool = map[*dpb.FileDescriptorProto]map[string]proto.Message{} + l.packageNamespaces = map[*dpb.FileDescriptorProto]map[string]struct{}{} for _, filename := range l.filenames { r := l.files[filename] fd := r.fd pool := map[string]proto.Message{} l.descriptorPool[fd] = pool prefix := fd.GetPackage() + l.packageNamespaces[fd] = namespacesFromPackage(prefix) if prefix != "" { prefix += "." } @@ -153,6 +156,23 @@ func (l *linker) createDescriptorPool() error { return nil } +func namespacesFromPackage(pkg string) map[string]struct{} { + if pkg == "" { + return nil + } + offs := 0 + pkgs := map[string]struct{}{} + pkgs[pkg] = struct{}{} + for { + pos := strings.IndexByte(pkg[offs:], '.') + if pos == -1 { + return pkgs + } + pkgs[pkg[:offs+pos]] = struct{}{} + offs = offs + pos + 1 + } +} + func addMessageToPool(r *parseResult, pool map[string]proto.Message, errs *errorHandler, prefix string, md *dpb.DescriptorProto) error { fqn := prefix + md.GetName() if err := addToPool(r, pool, errs, fqn, md); err != nil { @@ -313,7 +333,7 @@ func (l *linker) resolveEnumTypes(r *parseResult, fd *dpb.FileDescriptorProto, p func (l *linker) resolveMessageTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, md *dpb.DescriptorProto, scopes []scope) error { fqn := prefix + md.GetName() - scope := messageScope(fqn, isProto3(fd), l.descriptorPool[fd]) + scope := messageScope(fqn, isProto3(fd), l.descriptorPool[fd], l.packageNamespaces[fd]) scopes = append(scopes, scope) prefix = fqn + "." @@ -597,7 +617,6 @@ func fileScope(fd *dpb.FileDescriptorProto, l *linker) scope { // packages are a hierarchy like C++ namespaces) prefixes := internal.CreatePrefixList(fd.GetPackage()) return func(firstName, fullName string) (string, proto.Message, bool) { - prefixMatch := false for _, prefix := range prefixes { var n1, n string if prefix == "" { @@ -605,38 +624,22 @@ func fileScope(fd *dpb.FileDescriptorProto, l *linker) scope { n1, n = fullName, fullName } else { n = prefix + "." + fullName - if prefixMatch { - // it must be in this prefix, so search for full - // name (first name must be a package component) - n1 = n - } else { - n1 = prefix + "." + firstName - } + n1 = prefix + "." + firstName } d, proto3 := l.findSymbol(fd, n1, n, false, map[*dpb.FileDescriptorProto]struct{}{}) if d != nil { return n, d, proto3 } - if prefixMatch { - // we were supposed to find the symbol with this prefix but - // didn't, so we need to go ahead and return sentinel value - return n, sentinelMissingSymbol, false - } - if strings.HasSuffix(prefix, "."+firstName) { - // the first name matches the end of this prefix, which means - // the *next* scope is the match - prefixMatch = true - } } return "", nil, false } } -func messageScope(messageName string, proto3 bool, filePool map[string]proto.Message) scope { +func messageScope(messageName string, proto3 bool, filePool map[string]proto.Message, packages map[string]struct{}) scope { return func(firstName, fullName string) (string, proto.Message, bool) { n1 := messageName + "." + firstName n := messageName + "." + fullName - d := findSymbolInPool(n1, n, filePool) + d := findSymbolInPool(n1, n, filePool, packages) if d != nil { return n, d, proto3 } @@ -644,10 +647,16 @@ func messageScope(messageName string, proto3 bool, filePool map[string]proto.Mes } } -func findSymbolInPool(firstName, fullName string, pool map[string]proto.Message) proto.Message { +func findSymbolInPool(firstName, fullName string, pool map[string]proto.Message, pkgs map[string]struct{}) proto.Message { d, ok := pool[firstName] if !ok { - return nil + _, ok := pkgs[firstName] + if !ok { + return nil + } + // this sentinel means the name is a valid namespace but + // does not refer to a descriptor + d = sentinelMissingSymbol } if firstName == fullName { return d @@ -665,6 +674,12 @@ func findSymbolInPool(firstName, fullName string, pool map[string]proto.Message) } func isAggregateDescriptor(m proto.Message) bool { + if m == sentinelMissingSymbol { + // this indicates the name matched a package, not a + // descriptor, but a package is an aggregate so + // we return true + return true + } switch m.(type) { case *dpb.DescriptorProto, *dpb.EnumDescriptorProto, *dpb.ServiceDescriptorProto: return true @@ -685,7 +700,7 @@ func (l *linker) findSymbol(fd *dpb.FileDescriptorProto, firstName, fullName str return nil, false } checked[fd] = struct{}{} - d := findSymbolInPool(firstName, fullName, l.descriptorPool[fd]) + d := findSymbolInPool(firstName, fullName, l.descriptorPool[fd], l.packageNamespaces[fd]) if d != nil { return d, isProto3(fd) } diff --git a/desc/protoparse/linker_test.go b/desc/protoparse/linker_test.go index 9c50a9bf..6b0afcc9 100644 --- a/desc/protoparse/linker_test.go +++ b/desc/protoparse/linker_test.go @@ -89,6 +89,13 @@ func TestLinkerValidation(t *testing.T) { input map[string]string errMsg string }{ + { + map[string]string{ + "foo.proto": `syntax = "proto3"; package namespace.a; import "foo2.proto"; message Foo{ b.Bar b = 1; }`, + "foo2.proto": `syntax = "proto3"; package namespace.b; message Bar{}`, + }, + "", // should succeed + }, { map[string]string{ "foo.proto": "import \"foo2.proto\"; message fubar{}",