From e23851d89cce1909f269b8109ed68196d62c10fe Mon Sep 17 00:00:00 2001 From: Mikhail Mazurskiy Date: Wed, 5 May 2021 21:14:57 +1000 Subject: [PATCH] ForceServerCodec() to set a custom encoding.Codec on the server --- rpc_util.go | 4 +-- server.go | 29 ++++++++++++++++++++++ test/end2end_test.go | 58 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 2 deletions(-) diff --git a/rpc_util.go b/rpc_util.go index c0a1208f2f3..c8ae0e4444c 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -429,9 +429,9 @@ func (o ContentSubtypeCallOption) before(c *callInfo) error { } func (o ContentSubtypeCallOption) after(c *callInfo, attempt *csAttempt) {} -// ForceCodec returns a CallOption that will set the given Codec to be +// ForceCodec returns a CallOption that will set codec to be // used for all request and response messages for a call. The result of calling -// String() will be used as the content-subtype in a case-insensitive manner. +// Name() will be used as the content-subtype in a case-insensitive manner. // // See Content-Type on // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for diff --git a/server.go b/server.go index b2793ab00b5..d2bc884277a 100644 --- a/server.go +++ b/server.go @@ -279,6 +279,35 @@ func CustomCodec(codec Codec) ServerOption { }) } +// ForceServerCodec returns a ServerOption that sets a codec for message +// marshaling and unmarshaling. +// +// This will override any lookups by content-subtype for Codecs registered +// with RegisterCodec. +// +// See Content-Type on +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. Also see the documentation on RegisterCodec and +// CallContentSubtype for more details on the interaction between encoding.Codec +// and content-subtype. +// +// This function is provided for advanced users; prefer to register codecs +// using encoding.RegisterCodec. +// The server will automatically use registered codecs based on the incoming +// requests' headers. See also +// https://github.com/grpc/grpc-go/blob/master/Documentation/encoding.md#using-a-codec. +// Will be supported throughout 1.x. +// +// Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func ForceServerCodec(codec encoding.Codec) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.codec = codec + }) +} + // RPCCompressor returns a ServerOption that sets a compressor for outbound // messages. For backward compatibility, all outbound messages will be sent // using this compressor, regardless of incoming message compression. By diff --git a/test/end2end_test.go b/test/end2end_test.go index 1baf2e347d1..86fc480669c 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5284,6 +5284,37 @@ func (s) TestGRPCMethod(t *testing.T) { } } +func (s) TestForceServerCodec(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + codec := &countingProtoCodec{} + if err := ss.Start([]grpc.ServerOption{grpc.ForceServerCodec(codec)}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err) + } + + unmarshalCount := atomic.LoadInt32(&codec.unmarshalCount) + const wantUnmarshalCount = 1 + if unmarshalCount != wantUnmarshalCount { + t.Fatalf("protoCodec.unmarshalCount = %d; want %d", unmarshalCount, wantUnmarshalCount) + } + marshalCount := atomic.LoadInt32(&codec.marshalCount) + const wantMarshalCount = 1 + if marshalCount != wantMarshalCount { + t.Fatalf("protoCodec.marshalCount = %d; want %d", marshalCount, wantMarshalCount) + } +} + func (s) TestUnaryProxyDoesNotForwardMetadata(t *testing.T) { const mdkey = "somedata" @@ -5653,6 +5684,33 @@ func (c *errCodec) Name() string { return "Fermat's near-miss." } +type countingProtoCodec struct { + marshalCount int32 + unmarshalCount int32 +} + +func (p *countingProtoCodec) Marshal(v interface{}) ([]byte, error) { + atomic.AddInt32(&p.marshalCount, 1) + vv, ok := v.(proto.Message) + if !ok { + return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v) + } + return proto.Marshal(vv) +} + +func (p *countingProtoCodec) Unmarshal(data []byte, v interface{}) error { + atomic.AddInt32(&p.unmarshalCount, 1) + vv, ok := v.(proto.Message) + if !ok { + return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v) + } + return proto.Unmarshal(data, vv) +} + +func (*countingProtoCodec) Name() string { + return "proto" +} + func (s) TestEncodeDoesntPanic(t *testing.T) { for _, e := range listTestEnv() { testEncodeDoesntPanic(t, e)