From 6cc1efa697caed9b0ed8a0f3b093fbe7da7125a6 Mon Sep 17 00:00:00 2001 From: Joshua Humphries Date: Wed, 16 Jun 2021 17:21:23 -0400 Subject: [PATCH] protoparse: report warnings when a file has unused imports (#403) --- desc/protoparse/errors.go | 19 +++++ desc/protoparse/linker.go | 67 +++++++++++++-- desc/protoparse/options.go | 79 +++++++++++++----- desc/protoparse/parser.go | 9 +- desc/protoparse/reporting_test.go | 134 +++++++++++++++++++++++++----- 5 files changed, 258 insertions(+), 50 deletions(-) diff --git a/desc/protoparse/errors.go b/desc/protoparse/errors.go index 77ead4d0..00ff6209 100644 --- a/desc/protoparse/errors.go +++ b/desc/protoparse/errors.go @@ -154,3 +154,22 @@ func (e errorWithFilename) Error() string { func (e errorWithFilename) Unwrap() error { return e.underlying } + +// ErrorUnusedImport may be passed to a warning reporter when an unused +// import is detected. The error the reporter receives will be wrapped +// with source position that indicates the file and line where the import +// statement appeared. +type ErrorUnusedImport interface { + error + UnusedImport() string +} + +type errUnusedImport string + +func (e errUnusedImport) Error() string { + return fmt.Sprintf("import %q not used", string(e)) +} + +func (e errUnusedImport) UnusedImport() string { + return string(e) +} diff --git a/desc/protoparse/linker.go b/desc/protoparse/linker.go index ab729333..0301541c 100644 --- a/desc/protoparse/linker.go +++ b/desc/protoparse/linker.go @@ -21,6 +21,7 @@ type linker struct { descriptorPool map[*dpb.FileDescriptorProto]map[string]proto.Message packageNamespaces map[*dpb.FileDescriptorProto]map[string]struct{} extensions map[string]map[int32]string + usedImports map[*dpb.FileDescriptorProto]map[string]struct{} } func newLinker(files *parseResults, errs *errorHandler) *linker { @@ -65,7 +66,7 @@ func (l *linker) linkFiles() (map[string]*desc.FileDescriptor, error) { // options that remain. for _, r := range l.files { fd := linked[r.fd.GetName()] - if err := interpretFileOptions(r, richFileDescriptorish{FileDescriptor: fd}); err != nil { + if err := interpretFileOptions(l, r, richFileDescriptorish{FileDescriptor: fd}); err != nil { return nil, err } // we should now have any message_set_wire_format options parsed @@ -288,6 +289,7 @@ func descriptorType(m proto.Message) string { func (l *linker) resolveReferences() error { l.extensions = map[string]map[int32]string{} + l.usedImports = map[*dpb.FileDescriptorProto]map[string]struct{}{} for _, filename := range l.filenames { r := l.files[filename] fd := r.fd @@ -577,7 +579,7 @@ opts: func (l *linker) resolve(fd *dpb.FileDescriptorProto, name string, onlyTypes bool, scopes []scope) (fqn string, element proto.Message, proto3 bool) { if strings.HasPrefix(name, ".") { // already fully-qualified - d, proto3 := l.findSymbol(fd, name[1:], false, map[*dpb.FileDescriptorProto]struct{}{}) + d, proto3 := l.findSymbol(fd, name[1:]) if d != nil { return name[1:], d, proto3 } @@ -629,7 +631,7 @@ func fileScope(fd *dpb.FileDescriptorProto, l *linker) scope { // packages are a hierarchy like C++ namespaces) prefixes := internal.CreatePrefixList(fd.GetPackage()) querySymbol := func(n string) (d proto.Message, isProto3 bool) { - return l.findSymbol(fd, n, false, map[*dpb.FileDescriptorProto]struct{}{}) + return l.findSymbol(fd, n) } return func(firstName, fullName string) (string, proto.Message, bool) { for _, prefix := range prefixes { @@ -699,6 +701,15 @@ func (l *linker) findSymbolInFile(name string, fd *dpb.FileDescriptorProto) prot return nil } +func (l *linker) markUsed(entryPoint, used *dpb.FileDescriptorProto) { + importsForFile := l.usedImports[entryPoint] + if importsForFile == nil { + importsForFile = map[string]struct{}{} + l.usedImports[entryPoint] = importsForFile + } + importsForFile[used.GetName()] = struct{}{} +} + func isAggregateDescriptor(m proto.Message) bool { if m == sentinelMissingSymbol { // this indicates the name matched a package, not a @@ -720,7 +731,11 @@ func isAggregateDescriptor(m proto.Message) bool { // definitively does not exist". var sentinelMissingSymbol = (*dpb.DescriptorProto)(nil) -func (l *linker) findSymbol(fd *dpb.FileDescriptorProto, name string, public bool, checked map[*dpb.FileDescriptorProto]struct{}) (element proto.Message, proto3 bool) { +func (l *linker) findSymbol(fd *dpb.FileDescriptorProto, name string) (element proto.Message, proto3 bool) { + return l.findSymbolRecursive(fd, fd, name, false, map[*dpb.FileDescriptorProto]struct{}{}) +} + +func (l *linker) findSymbolRecursive(entryPoint, fd *dpb.FileDescriptorProto, name string, public bool, checked map[*dpb.FileDescriptorProto]struct{}) (element proto.Message, proto3 bool) { if _, ok := checked[fd]; ok { // already checked this one return nil, false @@ -741,7 +756,8 @@ func (l *linker) findSymbol(fd *dpb.FileDescriptorProto, name string, public boo // we'll catch this error later continue } - if d, proto3 := l.findSymbol(depres.fd, name, true, checked); d != nil { + if d, proto3 := l.findSymbolRecursive(entryPoint, depres.fd, name, true, checked); d != nil { + l.markUsed(entryPoint, depres.fd) return d, proto3 } } @@ -752,7 +768,8 @@ func (l *linker) findSymbol(fd *dpb.FileDescriptorProto, name string, public boo // we'll catch this error later continue } - if d, proto3 := l.findSymbol(depres.fd, name, true, checked); d != nil { + if d, proto3 := l.findSymbolRecursive(entryPoint, depres.fd, name, true, checked); d != nil { + l.markUsed(entryPoint, depres.fd) return d, proto3 } } @@ -855,3 +872,41 @@ func (l *linker) linkFile(name string, rootImportLoc *SourcePos, seen []string, linked[name] = lfd return lfd, nil } + +func (l *linker) checkForUnusedImports(filename string) { + r := l.files[filename] + usedImports := l.usedImports[r.fd] + node := r.nodes[r.fd] + fileNode, _ := node.(*ast.FileNode) + for i, dep := range r.fd.Dependency { + if _, ok := usedImports[dep]; !ok { + isPublic := false + // it's fine if it's a public import + for _, j := range r.fd.PublicDependency { + if i == int(j) { + isPublic = true + break + } + } + if isPublic { + break + } + var pos *SourcePos + if fileNode != nil { + for _, decl := range fileNode.Decls { + imp, ok := decl.(*ast.ImportNode) + if !ok { + continue + } + if imp.Name.AsString() == dep { + pos = imp.Start() + } + } + } + if pos == nil { + pos = ast.UnknownPos(r.fd.GetName()) + } + r.errs.warn(pos, errUnusedImport(dep)) + } + } +} diff --git a/desc/protoparse/options.go b/desc/protoparse/options.go index 23b06d87..1d7ac000 100644 --- a/desc/protoparse/options.go +++ b/desc/protoparse/options.go @@ -679,11 +679,11 @@ func (er extRangeDescriptorish) GetExtensionRangeOptions() *dpb.ExtensionRangeOp return er.er.GetOptions() } -func interpretFileOptions(r *parseResult, fd fileDescriptorish) error { +func interpretFileOptions(l *linker, r *parseResult, fd fileDescriptorish) error { opts := fd.GetFileOptions() if opts != nil { if len(opts.UninterpretedOption) > 0 { - if remain, err := interpretOptions(r, fd, opts, opts.UninterpretedOption); err != nil { + if remain, err := interpretOptions(l, r, fd, opts, opts.UninterpretedOption); err != nil { return err } else { opts.UninterpretedOption = remain @@ -691,24 +691,24 @@ func interpretFileOptions(r *parseResult, fd fileDescriptorish) error { } } for _, md := range fd.GetMessageTypes() { - if err := interpretMessageOptions(r, md); err != nil { + if err := interpretMessageOptions(l, r, md); err != nil { return err } } for _, fld := range fd.GetExtensions() { - if err := interpretFieldOptions(r, fld); err != nil { + if err := interpretFieldOptions(l, r, fld); err != nil { return err } } for _, ed := range fd.GetEnumTypes() { - if err := interpretEnumOptions(r, ed); err != nil { + if err := interpretEnumOptions(l, r, ed); err != nil { return err } } for _, sd := range fd.GetServices() { opts := sd.GetServiceOptions() if len(opts.GetUninterpretedOption()) > 0 { - if remain, err := interpretOptions(r, sd, opts, opts.UninterpretedOption); err != nil { + if remain, err := interpretOptions(l, r, sd, opts, opts.UninterpretedOption); err != nil { return err } else { opts.UninterpretedOption = remain @@ -717,7 +717,7 @@ func interpretFileOptions(r *parseResult, fd fileDescriptorish) error { for _, mtd := range sd.GetMethods() { opts := mtd.GetMethodOptions() if len(opts.GetUninterpretedOption()) > 0 { - if remain, err := interpretOptions(r, mtd, opts, opts.UninterpretedOption); err != nil { + if remain, err := interpretOptions(l, r, mtd, opts, opts.UninterpretedOption); err != nil { return err } else { opts.UninterpretedOption = remain @@ -728,11 +728,11 @@ func interpretFileOptions(r *parseResult, fd fileDescriptorish) error { return nil } -func interpretMessageOptions(r *parseResult, md msgDescriptorish) error { +func interpretMessageOptions(l *linker, r *parseResult, md msgDescriptorish) error { opts := md.GetMessageOptions() if opts != nil { if len(opts.UninterpretedOption) > 0 { - if remain, err := interpretOptions(r, md, opts, opts.UninterpretedOption); err != nil { + if remain, err := interpretOptions(l, r, md, opts, opts.UninterpretedOption); err != nil { return err } else { opts.UninterpretedOption = remain @@ -740,14 +740,14 @@ func interpretMessageOptions(r *parseResult, md msgDescriptorish) error { } } for _, fld := range md.GetFields() { - if err := interpretFieldOptions(r, fld); err != nil { + if err := interpretFieldOptions(l, r, fld); err != nil { return err } } for _, ood := range md.GetOneOfs() { opts := ood.GetOneOfOptions() if len(opts.GetUninterpretedOption()) > 0 { - if remain, err := interpretOptions(r, ood, opts, opts.UninterpretedOption); err != nil { + if remain, err := interpretOptions(l, r, ood, opts, opts.UninterpretedOption); err != nil { return err } else { opts.UninterpretedOption = remain @@ -755,14 +755,14 @@ func interpretMessageOptions(r *parseResult, md msgDescriptorish) error { } } for _, fld := range md.GetNestedExtensions() { - if err := interpretFieldOptions(r, fld); err != nil { + if err := interpretFieldOptions(l, r, fld); err != nil { return err } } for _, er := range md.GetExtensionRanges() { opts := er.GetExtensionRangeOptions() if len(opts.GetUninterpretedOption()) > 0 { - if remain, err := interpretOptions(r, er, opts, opts.UninterpretedOption); err != nil { + if remain, err := interpretOptions(l, r, er, opts, opts.UninterpretedOption); err != nil { return err } else { opts.UninterpretedOption = remain @@ -770,19 +770,19 @@ func interpretMessageOptions(r *parseResult, md msgDescriptorish) error { } } for _, nmd := range md.GetNestedMessageTypes() { - if err := interpretMessageOptions(r, nmd); err != nil { + if err := interpretMessageOptions(l, r, nmd); err != nil { return err } } for _, ed := range md.GetNestedEnumTypes() { - if err := interpretEnumOptions(r, ed); err != nil { + if err := interpretEnumOptions(l, r, ed); err != nil { return err } } return nil } -func interpretFieldOptions(r *parseResult, fld fldDescriptorish) error { +func interpretFieldOptions(l *linker, r *parseResult, fld fldDescriptorish) error { opts := fld.GetFieldOptions() if len(opts.GetUninterpretedOption()) > 0 { uo := opts.UninterpretedOption @@ -824,7 +824,7 @@ func interpretFieldOptions(r *parseResult, fld fldDescriptorish) error { if len(uo) == 0 { // no real options, only pseudo-options above? clear out options fld.AsFieldDescriptorProto().Options = nil - } else if remain, err := interpretOptions(r, fld, opts, uo); err != nil { + } else if remain, err := interpretOptions(l, r, fld, opts, uo); err != nil { return err } else { opts.UninterpretedOption = remain @@ -898,11 +898,11 @@ func encodeDefaultBytes(b []byte) string { return buf.String() } -func interpretEnumOptions(r *parseResult, ed enumDescriptorish) error { +func interpretEnumOptions(l *linker, r *parseResult, ed enumDescriptorish) error { opts := ed.GetEnumOptions() if opts != nil { if len(opts.UninterpretedOption) > 0 { - if remain, err := interpretOptions(r, ed, opts, opts.UninterpretedOption); err != nil { + if remain, err := interpretOptions(l, r, ed, opts, opts.UninterpretedOption); err != nil { return err } else { opts.UninterpretedOption = remain @@ -912,7 +912,7 @@ func interpretEnumOptions(r *parseResult, ed enumDescriptorish) error { for _, evd := range ed.GetValues() { opts := evd.GetEnumValueOptions() if len(opts.GetUninterpretedOption()) > 0 { - if remain, err := interpretOptions(r, evd, opts, opts.UninterpretedOption); err != nil { + if remain, err := interpretOptions(l, r, evd, opts, opts.UninterpretedOption); err != nil { return err } else { opts.UninterpretedOption = remain @@ -922,8 +922,8 @@ func interpretEnumOptions(r *parseResult, ed enumDescriptorish) error { return nil } -func interpretOptions(res *parseResult, element descriptorish, opts proto.Message, uninterpreted []*dpb.UninterpretedOption) ([]*dpb.UninterpretedOption, error) { - optsd, err := desc.LoadMessageDescriptorForMessage(opts) +func interpretOptions(l *linker, res *parseResult, element descriptorish, opts proto.Message, uninterpreted []*dpb.UninterpretedOption) ([]*dpb.UninterpretedOption, error) { + optsd, err := loadMessageDescriptorForOptions(l, element.GetFile(), opts) if err != nil { if res.lenient { return uninterpreted, nil @@ -993,7 +993,7 @@ func interpretOptions(res *parseResult, element descriptorish, opts proto.Messag } } - // nw try to convert into the passed in message and fail if not successful + // now try to convert into the passed in message and fail if not successful if err := dm.ConvertToDeterministic(opts); err != nil { node := res.nodes[element.AsProto()] return nil, res.errs.handleError(ErrorWithSourcePos{Pos: node.Start(), Underlying: err}) @@ -1002,6 +1002,39 @@ func interpretOptions(res *parseResult, element descriptorish, opts proto.Messag return nil, nil } +func loadMessageDescriptorForOptions(l *linker, fd fileDescriptorish, opts proto.Message) (*desc.MessageDescriptor, error) { + // see if the file imports a custom version of descriptor.proto + fqn := proto.MessageName(opts) + d := findMessageDescriptorForOptions(l, fd, fqn) + if d != nil { + return d, nil + } + // fall back to built-in options descriptors + return desc.LoadMessageDescriptorForMessage(opts) +} + +func findMessageDescriptorForOptions(l *linker, fd fileDescriptorish, messageName string) *desc.MessageDescriptor { + d := fd.FindSymbol(messageName) + if d != nil { + md, _ := d.(*desc.MessageDescriptor) + return md + } + + // TODO: should this support public imports and be recursive? + for _, dep := range fd.GetDependencies() { + d := dep.FindSymbol(messageName) + if d != nil { + if l != nil { + l.markUsed(fd.AsProto().(*dpb.FileDescriptorProto), d.GetFile().AsFileDescriptorProto()) + } + md, _ := d.(*desc.MessageDescriptor) + return md + } + } + + return nil +} + func interpretField(res *parseResult, mc *messageContext, element descriptorish, dm *dynamic.Message, opt *dpb.UninterpretedOption, nameIndex int, pathPrefix []int32) (path []int32, err error) { var fld *desc.FieldDescriptor nm := opt.GetName()[nameIndex] diff --git a/desc/protoparse/parser.go b/desc/protoparse/parser.go index bbfec8d1..bdc000c2 100644 --- a/desc/protoparse/parser.go +++ b/desc/protoparse/parser.go @@ -228,10 +228,15 @@ func (p Parser) ParseFiles(filenames ...string) ([]*desc.FileDescriptor, error) // TODO: if this re-writes one of the names in filenames, lookups below will break protos = fixupFilenames(protos) } - linkedProtos, err := newLinker(results, errs).linkFiles() + l := newLinker(results, errs) + linkedProtos, err := l.linkFiles() if err != nil { return nil, err } + // Now we're done linking, so we can check to see if any imports were unused + for _, file := range filenames { + l.checkForUnusedImports(file) + } if p.IncludeSourceCodeInfo { for name, fd := range linkedProtos { pr := protos[name] @@ -316,7 +321,7 @@ func (p Parser) ParseFilesButDoNotLink(filenames ...string) ([]*dpb.FileDescript pr.errs.errReporter = func(err ErrorWithPos) error { return err } - _ = interpretFileOptions(pr, poorFileDescriptorish{FileDescriptorProto: fd}) + _ = interpretFileOptions(nil, pr, poorFileDescriptorish{FileDescriptorProto: fd}) } if p.IncludeSourceCodeInfo { fd.SourceCodeInfo = pr.generateSourceCodeInfo() diff --git a/desc/protoparse/reporting_test.go b/desc/protoparse/reporting_test.go index f8a870c6..87575782 100644 --- a/desc/protoparse/reporting_test.go +++ b/desc/protoparse/reporting_test.go @@ -235,38 +235,134 @@ func TestWarningReporting(t *testing.T) { } testCases := []struct { - source string + name string + sources map[string]string expectedNotices []string }{ { - source: `syntax = "proto2"; message Foo {}`, + name: "syntax proto2", + sources: map[string]string{ + "test.proto": `syntax = "proto2"; message Foo {}`, + }, }, { - source: `syntax = "proto3"; message Foo {}`, + name: "syntax proto3", + sources: map[string]string{ + "test.proto": `syntax = "proto3"; message Foo {}`, + }, }, { - source: `message Foo {}`, + name: "no syntax", + sources: map[string]string{ + "test.proto": `message Foo {}`, + }, expectedNotices: []string{ "test.proto:1:1: no syntax specified; defaulting to proto2 syntax", }, }, + { + name: "used import", + sources: map[string]string{ + "test.proto": `syntax = "proto3"; import "foo.proto"; message Foo { Bar bar = 1; }`, + "foo.proto": `syntax = "proto3"; message Bar { string name = 1; }`, + }, + }, + { + name: "used public import", + sources: map[string]string{ + "test.proto": `syntax = "proto3"; import "foo.proto"; message Foo { Bar bar = 1; }`, + // we're only asking to compile test.proto, so we won't report unused import for baz.proto + "foo.proto": `syntax = "proto3"; import public "bar.proto"; import "baz.proto";`, + "bar.proto": `syntax = "proto3"; message Bar { string name = 1; }`, + "baz.proto": `syntax = "proto3"; message Baz { }`, + }, + }, + { + name: "used nested public import", + sources: map[string]string{ + "test.proto": `syntax = "proto3"; import "foo.proto"; message Foo { Bar bar = 1; }`, + "foo.proto": `syntax = "proto3"; import public "baz.proto";`, + "baz.proto": `syntax = "proto3"; import public "bar.proto";`, + "bar.proto": `syntax = "proto3"; message Bar { string name = 1; }`, + }, + }, + { + name: "unused import", + sources: map[string]string{ + "test.proto": `syntax = "proto3"; import "foo.proto"; message Foo { string name = 1; }`, + "foo.proto": `syntax = "proto3"; message Bar { string name = 1; }`, + }, + expectedNotices: []string{ + `test.proto:1:20: import "foo.proto" not used`, + }, + }, + { + name: "multiple unused imports", + sources: map[string]string{ + "test.proto": `syntax = "proto3"; import "foo.proto"; import "bar.proto"; import "baz.proto"; message Test { Bar bar = 1; }`, + "foo.proto": `syntax = "proto3"; message Foo {};`, + "bar.proto": `syntax = "proto3"; message Bar {};`, + "baz.proto": `syntax = "proto3"; message Baz {};`, + }, + expectedNotices: []string{ + `test.proto:1:20: import "foo.proto" not used`, + `test.proto:1:60: import "baz.proto" not used`, + }, + }, + { + name: "unused public import is not reported", + sources: map[string]string{ + "test.proto": `syntax = "proto3"; import public "foo.proto"; message Foo { }`, + "foo.proto": `syntax = "proto3"; message Bar { string name = 1; }`, + }, + }, + { + name: "unused descriptor.proto import", + sources: map[string]string{ + "test.proto": `syntax = "proto3"; import "google/protobuf/descriptor.proto"; message Foo { }`, + }, + expectedNotices: []string{ + `test.proto:1:20: import "google/protobuf/descriptor.proto" not used`, + }, + }, + { + name: "explicitly used descriptor.proto import", + sources: map[string]string{ + "test.proto": `syntax = "proto3"; import "google/protobuf/descriptor.proto"; extend google.protobuf.MessageOptions { string foobar = 33333; }`, + }, + }, + { + // having options implicitly uses decriptor.proto + name: "implicitly used descriptor.proto import", + sources: map[string]string{ + "test.proto": `syntax = "proto3"; import "google/protobuf/descriptor.proto"; message Foo { option deprecated = true; }`, + }, + }, + { + // makes sure we can use a given descriptor.proto to override non-custom options + name: "implicitly used descriptor.proto import with new option", + sources: map[string]string{ + "test.proto": `syntax = "proto3"; import "google/protobuf/descriptor.proto"; message Foo { option foobar = 123; }`, + "google/protobuf/descriptor.proto": `syntax = "proto2"; package google.protobuf; message MessageOptions { optional fixed32 foobar = 99; }`, + }, + }, } for _, testCase := range testCases { - accessor := FileContentsFromMap(map[string]string{ - "test.proto": testCase.source, - }) - p := Parser{ - Accessor: accessor, - WarningReporter: rep, - } - msgs = nil - _, err := p.ParseFiles("test.proto") - testutil.Ok(t, err) + t.Run(testCase.name, func(t *testing.T) { + accessor := FileContentsFromMap(testCase.sources) + p := Parser{ + Accessor: accessor, + WarningReporter: rep, + } + msgs = nil + _, err := p.ParseFiles("test.proto") + testutil.Ok(t, err) - actualNotices := make([]string, len(msgs)) - for j, msg := range msgs { - actualNotices[j] = fmt.Sprintf("%s: %s", msg.pos, msg.text) - } - testutil.Eq(t, testCase.expectedNotices, actualNotices) + actualNotices := make([]string, len(msgs)) + for j, msg := range msgs { + actualNotices[j] = fmt.Sprintf("%s: %s", msg.pos, msg.text) + } + testutil.Eq(t, testCase.expectedNotices, actualNotices) + }) } }