Skip to content

Commit

Permalink
ForceServerCodec() to set a custom encoding.Codec on the server
Browse files Browse the repository at this point in the history
  • Loading branch information
ash2k committed May 5, 2021
1 parent 0fc0397 commit 177ff4f
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 2 deletions.
4 changes: 2 additions & 2 deletions rpc_util.go
Expand Up @@ -429,9 +429,9 @@ func (o ContentSubtypeCallOption) before(c *callInfo) error {
}
func (o ContentSubtypeCallOption) after(c *callInfo, attempt *csAttempt) {}

// ForceCodec returns a CallOption that will set the given Codec to be
// ForceCodec returns a CallOption that will set codec to be
// used for all request and response messages for a call. The result of calling
// String() will be used as the content-subtype in a case-insensitive manner.
// Name() will be used as the content-subtype in a case-insensitive manner.
//
// See Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
Expand Down
25 changes: 25 additions & 0 deletions server.go
Expand Up @@ -279,6 +279,31 @@ func CustomCodec(codec Codec) ServerOption {
})
}

// ForceServerCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
//
// This will override any lookups by content-subtype for Codecs registered with RegisterCodec.
//
// See Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details. Also see the documentation on RegisterCodec and
// CallContentSubtype for more details on the interaction between encoding.Codec and
// content-subtype.
//
// This function is provided for advanced users; prefer to register codecs using encoding.RegisterCodec.
// The server will automatically use registered codecs based on the incoming requests' headers.
// See also https://github.com/grpc/grpc-go/blob/master/Documentation/encoding.md#using-a-codec.
// Will be supported throughout 1.x.
//
// Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func ForceServerCodec(codec encoding.Codec) ServerOption {
return newFuncServerOption(func(o *serverOptions) {
o.codec = codec
})
}

// RPCCompressor returns a ServerOption that sets a compressor for outbound
// messages. For backward compatibility, all outbound messages will be sent
// using this compressor, regardless of incoming message compression. By
Expand Down
58 changes: 58 additions & 0 deletions test/end2end_test.go
Expand Up @@ -5284,6 +5284,37 @@ func (s) TestGRPCMethod(t *testing.T) {
}
}

func (s) TestForceServerCodec(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}
codec := &protoCodec{}
if err := ss.Start([]grpc.ServerOption{grpc.ForceServerCodec(codec)}); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
}

unmarshalCount := atomic.LoadInt32(&codec.unmarshalCount)
var wantUnmarshalCount int32 = 1
if unmarshalCount != wantUnmarshalCount {
t.Fatalf("protoCodec.unmarshalCount = %d; want %d", unmarshalCount, wantUnmarshalCount)
}
marshalCount := atomic.LoadInt32(&codec.marshalCount)
var wantMarshalCount int32 = 1
if marshalCount != wantMarshalCount {
t.Fatalf("protoCodec.marshalCount = %d; want %d", marshalCount, wantMarshalCount)
}
}

func (s) TestUnaryProxyDoesNotForwardMetadata(t *testing.T) {
const mdkey = "somedata"

Expand Down Expand Up @@ -5653,6 +5684,33 @@ func (c *errCodec) Name() string {
return "Fermat's near-miss."
}

type protoCodec struct {
marshalCount int32
unmarshalCount int32
}

func (p *protoCodec) Marshal(v interface{}) ([]byte, error) {
atomic.AddInt32(&p.marshalCount, 1)
vv, ok := v.(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v)
}
return proto.Marshal(vv)
}

func (p *protoCodec) Unmarshal(data []byte, v interface{}) error {
atomic.AddInt32(&p.unmarshalCount, 1)
vv, ok := v.(proto.Message)
if !ok {
return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v)
}
return proto.Unmarshal(data, vv)
}

func (*protoCodec) Name() string {
return "proto"
}

func (s) TestEncodeDoesntPanic(t *testing.T) {
for _, e := range listTestEnv() {
testEncodeDoesntPanic(t, e)
Expand Down

0 comments on commit 177ff4f

Please sign in to comment.