diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go index 9b387dddee5..704a04a215c 100644 --- a/reflection/serverreflection.go +++ b/reflection/serverreflection.go @@ -37,24 +37,19 @@ To register server reflection on a gRPC server: package reflection // import "google.golang.org/grpc/reflection" import ( - "bytes" - "compress/gzip" - "fmt" + "errors" "io" - "io/ioutil" "sort" "sync" - "github.com/golang/protobuf/proto" - dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" "google.golang.org/grpc" "google.golang.org/grpc/codes" rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" - "google.golang.org/protobuf/types/dynamicpb" ) // GRPCServer is the interface provided by a gRPC server. It is implemented by @@ -65,291 +60,165 @@ type GRPCServer interface { GetServiceInfo() map[string]grpc.ServiceInfo } +// ExtensionResolver is the interface used to query details about extensions. +// This interface is satisfied by protoregistry.GlobalTypes. +type ExtensionResolver interface { + protoregistry.ExtensionTypeResolver + RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) +} + var _ GRPCServer = (*grpc.Server)(nil) type serverReflectionServer struct { rpb.UnimplementedServerReflectionServer - s GRPCServer + s GRPCServer + descResolver protodesc.Resolver + extResolver ExtensionResolver - initSymbols sync.Once - serviceNames []string - symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files + initServiceNames sync.Once + serviceNames []string } -// Register registers the server reflection service on the given gRPC server. -func Register(s GRPCServer) { - rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ - s: s, - }) +// ServerOptions represents the options used to construct a reflection server. +// +// Either Server or ServiceNames must be populated, but not both. These control +// what services are advertised by the server in the ListServices capability of +// the reflection service. If neither is provided, the returned server can still +// serve descriptors, but it will advertise no service names. +// +// The given DescriptorResolver will be used to resolve symbols and files by +// name. If not present, protoregistry.GlobalFiles will be used. The given +// ExtensionResolver will be used to resolve extensions. If not present, +// protoregistry.GlobalTypes will be used. +type ServerOptions struct { + // An RPC server, whose exposed services are made available via service + // reflection. + Server GRPCServer + // The list of service names. This should only be populated if Server is + // nil. + ServiceNames []string + // Optional resolver used to load descriptors. + DescriptorResolver protodesc.Resolver + // Optional resolver used to query for known extensions. + ExtensionResolver ExtensionResolver } -func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) { - s.initSymbols.Do(func() { - serviceInfo := s.s.GetServiceInfo() - - s.symbols = map[string]*dpb.FileDescriptorProto{} - s.serviceNames = make([]string, 0, len(serviceInfo)) - processed := map[string]struct{}{} - for svc, info := range serviceInfo { - s.serviceNames = append(s.serviceNames, svc) - fdenc, ok := parseMetadata(info.Metadata) - if !ok { - continue - } - fd, err := decodeFileDesc(fdenc) - if err != nil { - continue - } - s.processFile(fd, processed) - } - sort.Strings(s.serviceNames) - }) - - return s.serviceNames, s.symbols -} - -func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) { - filename := fd.GetName() - if _, ok := processed[filename]; ok { - return - } - processed[filename] = struct{}{} - - prefix := fd.GetPackage() - - for _, msg := range fd.MessageType { - s.processMessage(fd, prefix, msg) - } - for _, en := range fd.EnumType { - s.processEnum(fd, prefix, en) +// NewServer returns a reflection server implementation using the given options. +// It returns an error if the given options are invalid. +func NewServer(opts ServerOptions) (rpb.ServerReflectionServer, error) { + if opts.Server != nil && len(opts.ServiceNames) > 0 { + return nil, errors.New("options must specify either Server or ServiceNames, not both") } - for _, ext := range fd.Extension { - s.processField(fd, prefix, ext) + if opts.DescriptorResolver == nil { + opts.DescriptorResolver = protoregistry.GlobalFiles } - for _, svc := range fd.Service { - svcName := fqn(prefix, svc.GetName()) - s.symbols[svcName] = fd - for _, meth := range svc.Method { - name := fqn(svcName, meth.GetName()) - s.symbols[name] = fd - } + if opts.ExtensionResolver == nil { + opts.ExtensionResolver = protoregistry.GlobalTypes } - - for _, dep := range fd.Dependency { - fdenc := proto.FileDescriptor(dep) - fdDep, err := decodeFileDesc(fdenc) - if err != nil { - continue - } - s.processFile(fdDep, processed) - } -} - -func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) { - msgName := fqn(prefix, msg.GetName()) - s.symbols[msgName] = fd - - for _, nested := range msg.NestedType { - s.processMessage(fd, msgName, nested) - } - for _, en := range msg.EnumType { - s.processEnum(fd, msgName, en) - } - for _, ext := range msg.Extension { - s.processField(fd, msgName, ext) - } - for _, fld := range msg.Field { - s.processField(fd, msgName, fld) - } - for _, oneof := range msg.OneofDecl { - oneofName := fqn(msgName, oneof.GetName()) - s.symbols[oneofName] = fd - } -} - -func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) { - enName := fqn(prefix, en.GetName()) - s.symbols[enName] = fd - - for _, val := range en.Value { - valName := fqn(enName, val.GetName()) - s.symbols[valName] = fd - } -} - -func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) { - fldName := fqn(prefix, fld.GetName()) - s.symbols[fldName] = fd -} - -func fqn(prefix, name string) string { - if prefix == "" { - return name - } - return prefix + "." + name + return &serverReflectionServer{ + s: opts.Server, + descResolver: opts.DescriptorResolver, + extResolver: opts.ExtensionResolver, + serviceNames: opts.ServiceNames, + }, nil } -// decodeFileDesc does decompression and unmarshalling on the given -// file descriptor byte slice. -func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { - raw, err := decompress(enc) - if err != nil { - return nil, fmt.Errorf("failed to decompress enc: %v", err) - } - - fd := new(dpb.FileDescriptorProto) - if err := proto.Unmarshal(raw, fd); err != nil { - return nil, fmt.Errorf("bad descriptor: %v", err) - } - return fd, nil -} - -// decompress does gzip decompression. -func decompress(b []byte) ([]byte, error) { - r, err := gzip.NewReader(bytes.NewReader(b)) - if err != nil { - return nil, fmt.Errorf("bad gzipped descriptor: %v", err) - } - out, err := ioutil.ReadAll(r) +// Register registers the server reflection service on the given gRPC server. +func Register(s GRPCServer) { + svr, err := NewServer(ServerOptions{Server: s}) if err != nil { - return nil, fmt.Errorf("bad gzipped descriptor: %v", err) + panic(err) // should not be possible } - return out, nil + rpb.RegisterServerReflectionServer(s, svr) } -func fileDescContainingExtension(typeName string, ext int32) (*dpb.FileDescriptorProto, error) { - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(typeName)) - if err != nil { - return nil, err - } - m := dynamicpb.NewMessage(desc.(protoreflect.MessageDescriptor)) - - var extDesc *proto.ExtensionDesc - for id, desc := range proto.RegisteredExtensions(m) { - if id == ext { - extDesc = desc - break +func (s *serverReflectionServer) init() { + s.initServiceNames.Do(func() { + if s.s == nil { + // no need to init; service names were specified at construction + return } - } - - if extDesc == nil { - return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext) - } - - return decodeFileDesc(proto.FileDescriptor(extDesc.Filename)) + serviceInfo := s.s.GetServiceInfo() + s.serviceNames = make([]string, 0, len(serviceInfo)) + for svc := range serviceInfo { + s.serviceNames = append(s.serviceNames, svc) + } + sort.Strings(s.serviceNames) + }) } // fileDescWithDependencies returns a slice of serialized fileDescriptors in // wire format ([]byte). The fileDescriptors will include fd and all the // transitive dependencies of fd with names not in sentFileDescriptors. -func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) { - r := [][]byte{} - queue := []*dpb.FileDescriptorProto{fd} +func (s *serverReflectionServer) fileDescWithDependencies(fd protoreflect.FileDescriptor, sentFileDescriptors map[string]struct{}) ([][]byte, error) { + var r [][]byte + queue := []protoreflect.FileDescriptor{fd} for len(queue) > 0 { currentfd := queue[0] queue = queue[1:] - if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent { - sentFileDescriptors[currentfd.GetName()] = true - currentfdEncoded, err := proto.Marshal(currentfd) + if _, sent := sentFileDescriptors[currentfd.Path()]; len(r) == 0 || !sent { + sentFileDescriptors[currentfd.Path()] = struct{}{} + fdProto := protodesc.ToFileDescriptorProto(currentfd) + currentfdEncoded, err := proto.Marshal(fdProto) if err != nil { return nil, err } r = append(r, currentfdEncoded) } - for _, dep := range currentfd.Dependency { - fdenc := proto.FileDescriptor(dep) - fdDep, err := decodeFileDesc(fdenc) - if err != nil { - continue - } - queue = append(queue, fdDep) + for i := 0; i < currentfd.Imports().Len(); i++ { + queue = append(queue, currentfd.Imports().Get(i)) } } return r, nil } -// fileDescEncodingByFilename finds the file descriptor for given filename, -// finds all of its previously unsent transitive dependencies, does marshalling -// on them, and returns the marshalled result. -func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { - enc := proto.FileDescriptor(name) - if enc == nil { - return nil, fmt.Errorf("unknown file: %v", name) - } - fd, err := decodeFileDesc(enc) - if err != nil { - return nil, err - } - return fileDescWithDependencies(fd, sentFileDescriptors) -} - -// parseMetadata finds the file descriptor bytes specified meta. -// For SupportPackageIsVersion4, m is the name of the proto file, we -// call proto.FileDescriptor to get the byte slice. -// For SupportPackageIsVersion3, m is a byte slice itself. -func parseMetadata(meta interface{}) ([]byte, bool) { - // Check if meta is the file name. - if fileNameForMeta, ok := meta.(string); ok { - return proto.FileDescriptor(fileNameForMeta), true - } - - // Check if meta is the byte slice. - if enc, ok := meta.([]byte); ok { - return enc, true - } - - return nil, false -} - // fileDescEncodingContainingSymbol finds the file descriptor containing the // given symbol, finds all of its previously unsent transitive dependencies, // does marshalling on them, and returns the marshalled result. The given symbol // can be a type, a service or a method. -func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { - _, symbols := s.getSymbols() - fd := symbols[name] - if fd == nil { - // Check if it's a type name that was not present in the - // transitive dependencies of the registered services. - desc, err := protoregistry.GlobalTypes.FindMessageByName(protoreflect.FullName(name)) - if err != nil { - return nil, err - } - fd = protodesc.ToFileDescriptorProto(desc.Descriptor().ParentFile()) +func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]struct{}) ([][]byte, error) { + d, err := s.descResolver.FindDescriptorByName(protoreflect.FullName(name)) + if err != nil { + return nil, err } - return fileDescWithDependencies(fd, sentFileDescriptors) + return s.fileDescWithDependencies(d.ParentFile(), sentFileDescriptors) } // fileDescEncodingContainingExtension finds the file descriptor containing // given extension, finds all of its previously unsent transitive dependencies, // does marshalling on them, and returns the marshalled result. -func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) { - fd, err := fileDescContainingExtension(typeName, extNum) +func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]struct{}) ([][]byte, error) { + xt, err := s.extResolver.FindExtensionByNumber(protoreflect.FullName(typeName), protoreflect.FieldNumber(extNum)) if err != nil { return nil, err } - return fileDescWithDependencies(fd, sentFileDescriptors) + return s.fileDescWithDependencies(xt.TypeDescriptor().ParentFile(), sentFileDescriptors) } // allExtensionNumbersForTypeName returns all extension numbers for the given type. func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) { - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(name)) - if err != nil { - return nil, err - } - m := dynamicpb.NewMessage(desc.(protoreflect.MessageDescriptor)) - - exts := proto.RegisteredExtensions(m) - extNums := make([]int32, 0, len(exts)) - for id := range exts { - extNums = append(extNums, id) + var numbers []int32 + s.extResolver.RangeExtensionsByMessage(protoreflect.FullName(name), func(xt protoreflect.ExtensionType) bool { + numbers = append(numbers, int32(xt.TypeDescriptor().Number())) + return true + }) + sort.Slice(numbers, func(i, j int) bool { + return numbers[i] < numbers[j] + }) + if len(numbers) == 0 { + // maybe return an error if given type name is not known + if _, err := s.descResolver.FindDescriptorByName(protoreflect.FullName(name)); err != nil { + return nil, err + } } - return extNums, nil + return numbers, nil } // ServerReflectionInfo is the reflection service handler. func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error { - sentFileDescriptors := make(map[string]bool) + s.init() + + sentFileDescriptors := make(map[string]struct{}) for { in, err := stream.Recv() if err == io.EOF { @@ -365,7 +234,11 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } switch req := in.MessageRequest.(type) { case *rpb.ServerReflectionRequest_FileByFilename: - b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors) + var b [][]byte + fd, err := s.descResolver.FindFileByPath(req.FileByFilename) + if err == nil { + b, err = s.fileDescWithDependencies(fd, sentFileDescriptors) + } if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ @@ -426,9 +299,8 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } } case *rpb.ServerReflectionRequest_ListServices: - svcNames, _ := s.getSymbols() - serviceResponses := make([]*rpb.ServiceResponse, len(svcNames)) - for i, n := range svcNames { + serviceResponses := make([]*rpb.ServiceResponse, len(s.serviceNames)) + for i, n := range s.serviceNames { serviceResponses[i] = &rpb.ServiceResponse{ Name: n, } diff --git a/reflection/serverreflection_test.go b/reflection/serverreflection_test.go index 730f51bf012..77d9f57f708 100644 --- a/reflection/serverreflection_test.go +++ b/reflection/serverreflection_test.go @@ -27,14 +27,13 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" - dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/grpctest" rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" pb "google.golang.org/grpc/reflection/grpc_testing" pbv3 "google.golang.org/grpc/reflection/grpc_testing_not_regenerate" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" @@ -43,14 +42,14 @@ import ( ) var ( - s = &serverReflectionServer{} + s *serverReflectionServer // fileDescriptor of each test proto file. - fdTest *dpb.FileDescriptorProto - fdTestv3 *dpb.FileDescriptorProto - fdProto2 *dpb.FileDescriptorProto - fdProto2Ext *dpb.FileDescriptorProto - fdProto2Ext2 *dpb.FileDescriptorProto - fdDynamic *dpb.FileDescriptorProto + fdTest *descriptorpb.FileDescriptorProto + fdTestv3 *descriptorpb.FileDescriptorProto + fdProto2 *descriptorpb.FileDescriptorProto + fdProto2Ext *descriptorpb.FileDescriptorProto + fdProto2Ext2 *descriptorpb.FileDescriptorProto + fdDynamic *descriptorpb.FileDescriptorProto // reflection descriptors. fdDynamicFile protoreflect.FileDescriptor // fileDescriptor marshalled. @@ -62,6 +61,14 @@ var ( fdDynamicByte []byte ) +func init() { + svr, err := NewServer(ServerOptions{}) + if err != nil { + panic(err) + } + s = svr.(*serverReflectionServer) +} + const defaultTestTimeout = 10 * time.Second type x struct { @@ -72,23 +79,20 @@ func Test(t *testing.T) { grpctest.RunSubTests(t, x{}) } -func loadFileDesc(filename string) (*dpb.FileDescriptorProto, []byte) { - enc := proto.FileDescriptor(filename) - if enc == nil { - panic(fmt.Sprintf("failed to find fd for file: %v", filename)) - } - fd, err := decodeFileDesc(enc) +func loadFileDesc(filename string) (*descriptorpb.FileDescriptorProto, []byte) { + fd, err := protoregistry.GlobalFiles.FindFileByPath(filename) if err != nil { - panic(fmt.Sprintf("failed to decode enc: %v", err)) + panic(err) } - b, err := proto.Marshal(fd) + fdProto := protodesc.ToFileDescriptorProto(fd) + b, err := proto.Marshal(fdProto) if err != nil { panic(fmt.Sprintf("failed to marshal fd: %v", err)) } - return fd, b + return fdProto, b } -func loadFileDescDynamic(b []byte) (*dpb.FileDescriptorProto, protoreflect.FileDescriptor, []byte) { +func loadFileDescDynamic(b []byte) (*descriptorpb.FileDescriptorProto, protoreflect.FileDescriptor, []byte) { m := new(descriptorpb.FileDescriptorProto) if err := proto.Unmarshal(b, m); err != nil { panic(fmt.Sprintf("failed to unmarshal dynamic proto raw descriptor")) @@ -127,7 +131,7 @@ func (x) TestFileDescContainingExtension(t *testing.T) { for _, test := range []struct { st string extNum int32 - want *dpb.FileDescriptorProto + want *descriptorpb.FileDescriptorProto }{ {"grpc.testing.ToBeExtended", 13, fdProto2Ext}, {"grpc.testing.ToBeExtended", 17, fdProto2Ext}, @@ -135,9 +139,18 @@ func (x) TestFileDescContainingExtension(t *testing.T) { {"grpc.testing.ToBeExtended", 23, fdProto2Ext2}, {"grpc.testing.ToBeExtended", 29, fdProto2Ext2}, } { - fd, err := fileDescContainingExtension(test.st, test.extNum) - if err != nil || !proto.Equal(fd, test.want) { - t.Errorf("fileDescContainingExtension(%q) = %q, %v, want %q, ", test.st, fd, err, test.want) + fd, err := s.fileDescEncodingContainingExtension(test.st, test.extNum, map[string]struct{}{}) + if err != nil { + t.Errorf("fileDescContainingExtension(%q) return error: %v", test.st, err) + continue + } + var actualFd descriptorpb.FileDescriptorProto + if err := proto.Unmarshal(fd[0], &actualFd); err != nil { + t.Errorf("fileDescContainingExtension(%q) return invalid bytes: %v", test.st, err) + continue + } + if !proto.Equal(&actualFd, test.want) { + t.Errorf("fileDescContainingExtension(%q) returned %q, but wanted %q", test.st, &actualFd, test.want) } } } @@ -348,7 +361,7 @@ func testFileContainingSymbol(t *testing.T, stream rpb.ServerReflection_ServerRe {"grpc.testingv3.SearchResponseV3.Result.Value.val", fdTestv3Byte}, {"grpc.testingv3.SearchResponseV3.Result.Value.str", fdTestv3Byte}, {"grpc.testingv3.SearchResponseV3.State", fdTestv3Byte}, - {"grpc.testingv3.SearchResponseV3.State.FRESH", fdTestv3Byte}, + {"grpc.testingv3.SearchResponseV3.FRESH", fdTestv3Byte}, // Test dynamic symbols {"grpc.testing.DynamicService", fdDynamicByte}, {"grpc.testing.DynamicReq", fdDynamicByte}, @@ -578,7 +591,7 @@ func testListServices(t *testing.T, stream rpb.ServerReflection_ServerReflection } } -func registerDynamicProto(srv *grpc.Server, fdp *dpb.FileDescriptorProto, fd protoreflect.FileDescriptor) { +func registerDynamicProto(srv *grpc.Server, fdp *descriptorpb.FileDescriptorProto, fd protoreflect.FileDescriptor) { type emptyInterface interface{} for i := 0; i < fd.Services().Len(); i++ {