Skip to content

Commit

Permalink
server: add ForceServerCodec() to set a custom encoding.Codec on the …
Browse files Browse the repository at this point in the history
…server (#4205)
  • Loading branch information
ash2k committed May 6, 2021
1 parent d426aa5 commit d2d6bda
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 2 deletions.
4 changes: 2 additions & 2 deletions rpc_util.go
Expand Up @@ -429,9 +429,9 @@ func (o ContentSubtypeCallOption) before(c *callInfo) error {
}
func (o ContentSubtypeCallOption) after(c *callInfo, attempt *csAttempt) {}

// ForceCodec returns a CallOption that will set the given Codec to be
// ForceCodec returns a CallOption that will set codec to be
// used for all request and response messages for a call. The result of calling
// String() will be used as the content-subtype in a case-insensitive manner.
// Name() will be used as the content-subtype in a case-insensitive manner.
//
// See Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
Expand Down
29 changes: 29 additions & 0 deletions server.go
Expand Up @@ -279,6 +279,35 @@ func CustomCodec(codec Codec) ServerOption {
})
}

// ForceServerCodec returns a ServerOption that sets a codec for message
// marshaling and unmarshaling.
//
// This will override any lookups by content-subtype for Codecs registered
// with RegisterCodec.
//
// See Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details. Also see the documentation on RegisterCodec and
// CallContentSubtype for more details on the interaction between encoding.Codec
// and content-subtype.
//
// This function is provided for advanced users; prefer to register codecs
// using encoding.RegisterCodec.
// The server will automatically use registered codecs based on the incoming
// requests' headers. See also
// https://github.com/grpc/grpc-go/blob/master/Documentation/encoding.md#using-a-codec.
// Will be supported throughout 1.x.
//
// Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func ForceServerCodec(codec encoding.Codec) ServerOption {
return newFuncServerOption(func(o *serverOptions) {
o.codec = codec
})
}

// RPCCompressor returns a ServerOption that sets a compressor for outbound
// messages. For backward compatibility, all outbound messages will be sent
// using this compressor, regardless of incoming message compression. By
Expand Down
58 changes: 58 additions & 0 deletions test/end2end_test.go
Expand Up @@ -5284,6 +5284,37 @@ func (s) TestGRPCMethod(t *testing.T) {
}
}

func (s) TestForceServerCodec(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}
codec := &countingProtoCodec{}
if err := ss.Start([]grpc.ServerOption{grpc.ForceServerCodec(codec)}); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
}

unmarshalCount := atomic.LoadInt32(&codec.unmarshalCount)
const wantUnmarshalCount = 1
if unmarshalCount != wantUnmarshalCount {
t.Fatalf("protoCodec.unmarshalCount = %d; want %d", unmarshalCount, wantUnmarshalCount)
}
marshalCount := atomic.LoadInt32(&codec.marshalCount)
const wantMarshalCount = 1
if marshalCount != wantMarshalCount {
t.Fatalf("protoCodec.marshalCount = %d; want %d", marshalCount, wantMarshalCount)
}
}

func (s) TestUnaryProxyDoesNotForwardMetadata(t *testing.T) {
const mdkey = "somedata"

Expand Down Expand Up @@ -5653,6 +5684,33 @@ func (c *errCodec) Name() string {
return "Fermat's near-miss."
}

type countingProtoCodec struct {
marshalCount int32
unmarshalCount int32
}

func (p *countingProtoCodec) Marshal(v interface{}) ([]byte, error) {
atomic.AddInt32(&p.marshalCount, 1)
vv, ok := v.(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v)
}
return proto.Marshal(vv)
}

func (p *countingProtoCodec) Unmarshal(data []byte, v interface{}) error {
atomic.AddInt32(&p.unmarshalCount, 1)
vv, ok := v.(proto.Message)
if !ok {
return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v)
}
return proto.Unmarshal(data, vv)
}

func (*countingProtoCodec) Name() string {
return "proto"
}

func (s) TestEncodeDoesntPanic(t *testing.T) {
for _, e := range listTestEnv() {
testEncodeDoesntPanic(t, e)
Expand Down

0 comments on commit d2d6bda

Please sign in to comment.