From 7c1ba558cf5c1cd8be9b3f2723b8c91a71f5f524 Mon Sep 17 00:00:00 2001 From: Josh Humphries Date: Wed, 16 Feb 2022 13:07:32 -0500 Subject: [PATCH] improve server reflection 1. Support alternate source of descriptors, like for RPC servers that get their descriptors dynamically and are dynamic proxies 2. Use the new protobuf API v2 stuff to get the descriptors, which is much more sane than the old APIs --- reflection/serverreflection.go | 346 +++++++++------------------- reflection/serverreflection_test.go | 63 +++-- 2 files changed, 147 insertions(+), 262 deletions(-) diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go index 9b387dddee58..704a04a215cb 100644 --- a/reflection/serverreflection.go +++ b/reflection/serverreflection.go @@ -37,24 +37,19 @@ To register server reflection on a gRPC server: package reflection // import "google.golang.org/grpc/reflection" import ( - "bytes" - "compress/gzip" - "fmt" + "errors" "io" - "io/ioutil" "sort" "sync" - "github.com/golang/protobuf/proto" - dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" "google.golang.org/grpc" "google.golang.org/grpc/codes" rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" - "google.golang.org/protobuf/types/dynamicpb" ) // GRPCServer is the interface provided by a gRPC server. It is implemented by @@ -65,291 +60,165 @@ type GRPCServer interface { GetServiceInfo() map[string]grpc.ServiceInfo } +// ExtensionResolver is the interface used to query details about extensions. +// This interface is satisfied by protoregistry.GlobalTypes. +type ExtensionResolver interface { + protoregistry.ExtensionTypeResolver + RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) +} + var _ GRPCServer = (*grpc.Server)(nil) type serverReflectionServer struct { rpb.UnimplementedServerReflectionServer - s GRPCServer + s GRPCServer + descResolver protodesc.Resolver + extResolver ExtensionResolver - initSymbols sync.Once - serviceNames []string - symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files + initServiceNames sync.Once + serviceNames []string } -// Register registers the server reflection service on the given gRPC server. -func Register(s GRPCServer) { - rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ - s: s, - }) +// ServerOptions represents the options used to construct a reflection server. +// +// Either Server or ServiceNames must be populated, but not both. These control +// what services are advertised by the server in the ListServices capability of +// the reflection service. If neither is provided, the returned server can still +// serve descriptors, but it will advertise no service names. +// +// The given DescriptorResolver will be used to resolve symbols and files by +// name. If not present, protoregistry.GlobalFiles will be used. The given +// ExtensionResolver will be used to resolve extensions. If not present, +// protoregistry.GlobalTypes will be used. +type ServerOptions struct { + // An RPC server, whose exposed services are made available via service + // reflection. + Server GRPCServer + // The list of service names. This should only be populated if Server is + // nil. + ServiceNames []string + // Optional resolver used to load descriptors. + DescriptorResolver protodesc.Resolver + // Optional resolver used to query for known extensions. + ExtensionResolver ExtensionResolver } -func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) { - s.initSymbols.Do(func() { - serviceInfo := s.s.GetServiceInfo() - - s.symbols = map[string]*dpb.FileDescriptorProto{} - s.serviceNames = make([]string, 0, len(serviceInfo)) - processed := map[string]struct{}{} - for svc, info := range serviceInfo { - s.serviceNames = append(s.serviceNames, svc) - fdenc, ok := parseMetadata(info.Metadata) - if !ok { - continue - } - fd, err := decodeFileDesc(fdenc) - if err != nil { - continue - } - s.processFile(fd, processed) - } - sort.Strings(s.serviceNames) - }) - - return s.serviceNames, s.symbols -} - -func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) { - filename := fd.GetName() - if _, ok := processed[filename]; ok { - return - } - processed[filename] = struct{}{} - - prefix := fd.GetPackage() - - for _, msg := range fd.MessageType { - s.processMessage(fd, prefix, msg) - } - for _, en := range fd.EnumType { - s.processEnum(fd, prefix, en) +// NewServer returns a reflection server implementation using the given options. +// It returns an error if the given options are invalid. +func NewServer(opts ServerOptions) (rpb.ServerReflectionServer, error) { + if opts.Server != nil && len(opts.ServiceNames) > 0 { + return nil, errors.New("options must specify either Server or ServiceNames, not both") } - for _, ext := range fd.Extension { - s.processField(fd, prefix, ext) + if opts.DescriptorResolver == nil { + opts.DescriptorResolver = protoregistry.GlobalFiles } - for _, svc := range fd.Service { - svcName := fqn(prefix, svc.GetName()) - s.symbols[svcName] = fd - for _, meth := range svc.Method { - name := fqn(svcName, meth.GetName()) - s.symbols[name] = fd - } + if opts.ExtensionResolver == nil { + opts.ExtensionResolver = protoregistry.GlobalTypes } - - for _, dep := range fd.Dependency { - fdenc := proto.FileDescriptor(dep) - fdDep, err := decodeFileDesc(fdenc) - if err != nil { - continue - } - s.processFile(fdDep, processed) - } -} - -func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) { - msgName := fqn(prefix, msg.GetName()) - s.symbols[msgName] = fd - - for _, nested := range msg.NestedType { - s.processMessage(fd, msgName, nested) - } - for _, en := range msg.EnumType { - s.processEnum(fd, msgName, en) - } - for _, ext := range msg.Extension { - s.processField(fd, msgName, ext) - } - for _, fld := range msg.Field { - s.processField(fd, msgName, fld) - } - for _, oneof := range msg.OneofDecl { - oneofName := fqn(msgName, oneof.GetName()) - s.symbols[oneofName] = fd - } -} - -func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) { - enName := fqn(prefix, en.GetName()) - s.symbols[enName] = fd - - for _, val := range en.Value { - valName := fqn(enName, val.GetName()) - s.symbols[valName] = fd - } -} - -func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) { - fldName := fqn(prefix, fld.GetName()) - s.symbols[fldName] = fd -} - -func fqn(prefix, name string) string { - if prefix == "" { - return name - } - return prefix + "." + name + return &serverReflectionServer{ + s: opts.Server, + descResolver: opts.DescriptorResolver, + extResolver: opts.ExtensionResolver, + serviceNames: opts.ServiceNames, + }, nil } -// decodeFileDesc does decompression and unmarshalling on the given -// file descriptor byte slice. -func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { - raw, err := decompress(enc) - if err != nil { - return nil, fmt.Errorf("failed to decompress enc: %v", err) - } - - fd := new(dpb.FileDescriptorProto) - if err := proto.Unmarshal(raw, fd); err != nil { - return nil, fmt.Errorf("bad descriptor: %v", err) - } - return fd, nil -} - -// decompress does gzip decompression. -func decompress(b []byte) ([]byte, error) { - r, err := gzip.NewReader(bytes.NewReader(b)) - if err != nil { - return nil, fmt.Errorf("bad gzipped descriptor: %v", err) - } - out, err := ioutil.ReadAll(r) +// Register registers the server reflection service on the given gRPC server. +func Register(s GRPCServer) { + svr, err := NewServer(ServerOptions{Server: s}) if err != nil { - return nil, fmt.Errorf("bad gzipped descriptor: %v", err) + panic(err) // should not be possible } - return out, nil + rpb.RegisterServerReflectionServer(s, svr) } -func fileDescContainingExtension(typeName string, ext int32) (*dpb.FileDescriptorProto, error) { - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(typeName)) - if err != nil { - return nil, err - } - m := dynamicpb.NewMessage(desc.(protoreflect.MessageDescriptor)) - - var extDesc *proto.ExtensionDesc - for id, desc := range proto.RegisteredExtensions(m) { - if id == ext { - extDesc = desc - break +func (s *serverReflectionServer) init() { + s.initServiceNames.Do(func() { + if s.s == nil { + // no need to init; service names were specified at construction + return } - } - - if extDesc == nil { - return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext) - } - - return decodeFileDesc(proto.FileDescriptor(extDesc.Filename)) + serviceInfo := s.s.GetServiceInfo() + s.serviceNames = make([]string, 0, len(serviceInfo)) + for svc := range serviceInfo { + s.serviceNames = append(s.serviceNames, svc) + } + sort.Strings(s.serviceNames) + }) } // fileDescWithDependencies returns a slice of serialized fileDescriptors in // wire format ([]byte). The fileDescriptors will include fd and all the // transitive dependencies of fd with names not in sentFileDescriptors. -func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) { - r := [][]byte{} - queue := []*dpb.FileDescriptorProto{fd} +func (s *serverReflectionServer) fileDescWithDependencies(fd protoreflect.FileDescriptor, sentFileDescriptors map[string]struct{}) ([][]byte, error) { + var r [][]byte + queue := []protoreflect.FileDescriptor{fd} for len(queue) > 0 { currentfd := queue[0] queue = queue[1:] - if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent { - sentFileDescriptors[currentfd.GetName()] = true - currentfdEncoded, err := proto.Marshal(currentfd) + if _, sent := sentFileDescriptors[currentfd.Path()]; len(r) == 0 || !sent { + sentFileDescriptors[currentfd.Path()] = struct{}{} + fdProto := protodesc.ToFileDescriptorProto(currentfd) + currentfdEncoded, err := proto.Marshal(fdProto) if err != nil { return nil, err } r = append(r, currentfdEncoded) } - for _, dep := range currentfd.Dependency { - fdenc := proto.FileDescriptor(dep) - fdDep, err := decodeFileDesc(fdenc) - if err != nil { - continue - } - queue = append(queue, fdDep) + for i := 0; i < currentfd.Imports().Len(); i++ { + queue = append(queue, currentfd.Imports().Get(i)) } } return r, nil } -// fileDescEncodingByFilename finds the file descriptor for given filename, -// finds all of its previously unsent transitive dependencies, does marshalling -// on them, and returns the marshalled result. -func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { - enc := proto.FileDescriptor(name) - if enc == nil { - return nil, fmt.Errorf("unknown file: %v", name) - } - fd, err := decodeFileDesc(enc) - if err != nil { - return nil, err - } - return fileDescWithDependencies(fd, sentFileDescriptors) -} - -// parseMetadata finds the file descriptor bytes specified meta. -// For SupportPackageIsVersion4, m is the name of the proto file, we -// call proto.FileDescriptor to get the byte slice. -// For SupportPackageIsVersion3, m is a byte slice itself. -func parseMetadata(meta interface{}) ([]byte, bool) { - // Check if meta is the file name. - if fileNameForMeta, ok := meta.(string); ok { - return proto.FileDescriptor(fileNameForMeta), true - } - - // Check if meta is the byte slice. - if enc, ok := meta.([]byte); ok { - return enc, true - } - - return nil, false -} - // fileDescEncodingContainingSymbol finds the file descriptor containing the // given symbol, finds all of its previously unsent transitive dependencies, // does marshalling on them, and returns the marshalled result. The given symbol // can be a type, a service or a method. -func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { - _, symbols := s.getSymbols() - fd := symbols[name] - if fd == nil { - // Check if it's a type name that was not present in the - // transitive dependencies of the registered services. - desc, err := protoregistry.GlobalTypes.FindMessageByName(protoreflect.FullName(name)) - if err != nil { - return nil, err - } - fd = protodesc.ToFileDescriptorProto(desc.Descriptor().ParentFile()) +func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]struct{}) ([][]byte, error) { + d, err := s.descResolver.FindDescriptorByName(protoreflect.FullName(name)) + if err != nil { + return nil, err } - return fileDescWithDependencies(fd, sentFileDescriptors) + return s.fileDescWithDependencies(d.ParentFile(), sentFileDescriptors) } // fileDescEncodingContainingExtension finds the file descriptor containing // given extension, finds all of its previously unsent transitive dependencies, // does marshalling on them, and returns the marshalled result. -func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) { - fd, err := fileDescContainingExtension(typeName, extNum) +func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]struct{}) ([][]byte, error) { + xt, err := s.extResolver.FindExtensionByNumber(protoreflect.FullName(typeName), protoreflect.FieldNumber(extNum)) if err != nil { return nil, err } - return fileDescWithDependencies(fd, sentFileDescriptors) + return s.fileDescWithDependencies(xt.TypeDescriptor().ParentFile(), sentFileDescriptors) } // allExtensionNumbersForTypeName returns all extension numbers for the given type. func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) { - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(name)) - if err != nil { - return nil, err - } - m := dynamicpb.NewMessage(desc.(protoreflect.MessageDescriptor)) - - exts := proto.RegisteredExtensions(m) - extNums := make([]int32, 0, len(exts)) - for id := range exts { - extNums = append(extNums, id) + var numbers []int32 + s.extResolver.RangeExtensionsByMessage(protoreflect.FullName(name), func(xt protoreflect.ExtensionType) bool { + numbers = append(numbers, int32(xt.TypeDescriptor().Number())) + return true + }) + sort.Slice(numbers, func(i, j int) bool { + return numbers[i] < numbers[j] + }) + if len(numbers) == 0 { + // maybe return an error if given type name is not known + if _, err := s.descResolver.FindDescriptorByName(protoreflect.FullName(name)); err != nil { + return nil, err + } } - return extNums, nil + return numbers, nil } // ServerReflectionInfo is the reflection service handler. func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error { - sentFileDescriptors := make(map[string]bool) + s.init() + + sentFileDescriptors := make(map[string]struct{}) for { in, err := stream.Recv() if err == io.EOF { @@ -365,7 +234,11 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } switch req := in.MessageRequest.(type) { case *rpb.ServerReflectionRequest_FileByFilename: - b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors) + var b [][]byte + fd, err := s.descResolver.FindFileByPath(req.FileByFilename) + if err == nil { + b, err = s.fileDescWithDependencies(fd, sentFileDescriptors) + } if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ @@ -426,9 +299,8 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } } case *rpb.ServerReflectionRequest_ListServices: - svcNames, _ := s.getSymbols() - serviceResponses := make([]*rpb.ServiceResponse, len(svcNames)) - for i, n := range svcNames { + serviceResponses := make([]*rpb.ServiceResponse, len(s.serviceNames)) + for i, n := range s.serviceNames { serviceResponses[i] = &rpb.ServiceResponse{ Name: n, } diff --git a/reflection/serverreflection_test.go b/reflection/serverreflection_test.go index 730f51bf012d..77d9f57f708f 100644 --- a/reflection/serverreflection_test.go +++ b/reflection/serverreflection_test.go @@ -27,14 +27,13 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" - dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/grpctest" rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" pb "google.golang.org/grpc/reflection/grpc_testing" pbv3 "google.golang.org/grpc/reflection/grpc_testing_not_regenerate" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" @@ -43,14 +42,14 @@ import ( ) var ( - s = &serverReflectionServer{} + s *serverReflectionServer // fileDescriptor of each test proto file. - fdTest *dpb.FileDescriptorProto - fdTestv3 *dpb.FileDescriptorProto - fdProto2 *dpb.FileDescriptorProto - fdProto2Ext *dpb.FileDescriptorProto - fdProto2Ext2 *dpb.FileDescriptorProto - fdDynamic *dpb.FileDescriptorProto + fdTest *descriptorpb.FileDescriptorProto + fdTestv3 *descriptorpb.FileDescriptorProto + fdProto2 *descriptorpb.FileDescriptorProto + fdProto2Ext *descriptorpb.FileDescriptorProto + fdProto2Ext2 *descriptorpb.FileDescriptorProto + fdDynamic *descriptorpb.FileDescriptorProto // reflection descriptors. fdDynamicFile protoreflect.FileDescriptor // fileDescriptor marshalled. @@ -62,6 +61,14 @@ var ( fdDynamicByte []byte ) +func init() { + svr, err := NewServer(ServerOptions{}) + if err != nil { + panic(err) + } + s = svr.(*serverReflectionServer) +} + const defaultTestTimeout = 10 * time.Second type x struct { @@ -72,23 +79,20 @@ func Test(t *testing.T) { grpctest.RunSubTests(t, x{}) } -func loadFileDesc(filename string) (*dpb.FileDescriptorProto, []byte) { - enc := proto.FileDescriptor(filename) - if enc == nil { - panic(fmt.Sprintf("failed to find fd for file: %v", filename)) - } - fd, err := decodeFileDesc(enc) +func loadFileDesc(filename string) (*descriptorpb.FileDescriptorProto, []byte) { + fd, err := protoregistry.GlobalFiles.FindFileByPath(filename) if err != nil { - panic(fmt.Sprintf("failed to decode enc: %v", err)) + panic(err) } - b, err := proto.Marshal(fd) + fdProto := protodesc.ToFileDescriptorProto(fd) + b, err := proto.Marshal(fdProto) if err != nil { panic(fmt.Sprintf("failed to marshal fd: %v", err)) } - return fd, b + return fdProto, b } -func loadFileDescDynamic(b []byte) (*dpb.FileDescriptorProto, protoreflect.FileDescriptor, []byte) { +func loadFileDescDynamic(b []byte) (*descriptorpb.FileDescriptorProto, protoreflect.FileDescriptor, []byte) { m := new(descriptorpb.FileDescriptorProto) if err := proto.Unmarshal(b, m); err != nil { panic(fmt.Sprintf("failed to unmarshal dynamic proto raw descriptor")) @@ -127,7 +131,7 @@ func (x) TestFileDescContainingExtension(t *testing.T) { for _, test := range []struct { st string extNum int32 - want *dpb.FileDescriptorProto + want *descriptorpb.FileDescriptorProto }{ {"grpc.testing.ToBeExtended", 13, fdProto2Ext}, {"grpc.testing.ToBeExtended", 17, fdProto2Ext}, @@ -135,9 +139,18 @@ func (x) TestFileDescContainingExtension(t *testing.T) { {"grpc.testing.ToBeExtended", 23, fdProto2Ext2}, {"grpc.testing.ToBeExtended", 29, fdProto2Ext2}, } { - fd, err := fileDescContainingExtension(test.st, test.extNum) - if err != nil || !proto.Equal(fd, test.want) { - t.Errorf("fileDescContainingExtension(%q) = %q, %v, want %q, ", test.st, fd, err, test.want) + fd, err := s.fileDescEncodingContainingExtension(test.st, test.extNum, map[string]struct{}{}) + if err != nil { + t.Errorf("fileDescContainingExtension(%q) return error: %v", test.st, err) + continue + } + var actualFd descriptorpb.FileDescriptorProto + if err := proto.Unmarshal(fd[0], &actualFd); err != nil { + t.Errorf("fileDescContainingExtension(%q) return invalid bytes: %v", test.st, err) + continue + } + if !proto.Equal(&actualFd, test.want) { + t.Errorf("fileDescContainingExtension(%q) returned %q, but wanted %q", test.st, &actualFd, test.want) } } } @@ -348,7 +361,7 @@ func testFileContainingSymbol(t *testing.T, stream rpb.ServerReflection_ServerRe {"grpc.testingv3.SearchResponseV3.Result.Value.val", fdTestv3Byte}, {"grpc.testingv3.SearchResponseV3.Result.Value.str", fdTestv3Byte}, {"grpc.testingv3.SearchResponseV3.State", fdTestv3Byte}, - {"grpc.testingv3.SearchResponseV3.State.FRESH", fdTestv3Byte}, + {"grpc.testingv3.SearchResponseV3.FRESH", fdTestv3Byte}, // Test dynamic symbols {"grpc.testing.DynamicService", fdDynamicByte}, {"grpc.testing.DynamicReq", fdDynamicByte}, @@ -578,7 +591,7 @@ func testListServices(t *testing.T, stream rpb.ServerReflection_ServerReflection } } -func registerDynamicProto(srv *grpc.Server, fdp *dpb.FileDescriptorProto, fd protoreflect.FileDescriptor) { +func registerDynamicProto(srv *grpc.Server, fdp *descriptorpb.FileDescriptorProto, fd protoreflect.FileDescriptor) { type emptyInterface interface{} for i := 0; i < fd.Services().Len(); i++ {