diff --git a/call_test.go b/call_test.go index 48424fef9f7..3280109f4fb 100644 --- a/call_test.go +++ b/call_test.go @@ -127,8 +127,6 @@ type server struct { channelzID *channelz.Identifier } -type ctxKey string - func newTestServer() *server { return &server{ startedErr: make(chan error, 1), @@ -211,298 +209,3 @@ func (s *server) stop() { s.conns = nil s.mu.Unlock() } - -func setUp(t *testing.T, port int, maxStreams uint32) (*server, *ClientConn) { - return setUpWithOptions(t, port, maxStreams) -} - -func setUpWithOptions(t *testing.T, port int, maxStreams uint32, dopts ...DialOption) (*server, *ClientConn) { - server := newTestServer() - go server.start(t, port, maxStreams) - server.wait(t, 2*time.Second) - addr := "localhost:" + server.port - dopts = append(dopts, WithBlock(), WithInsecure(), WithCodec(testCodec{})) - cc, err := Dial(addr, dopts...) - if err != nil { - t.Fatalf("Failed to create ClientConn: %v", err) - } - return server, cc -} - -func (s) TestUnaryClientInterceptor(t *testing.T) { - parentKey := ctxKey("parentKey") - - interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { - if ctx.Value(parentKey) == nil { - t.Fatalf("interceptor should have %v in context", parentKey) - } - return invoker(ctx, method, req, reply, cc, opts...) - } - - server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(interceptor)) - defer func() { - cc.Close() - server.stop() - }() - - var reply string - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0) - if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse { - t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) - } -} - -func (s) TestChainUnaryClientInterceptor(t *testing.T) { - var ( - parentKey = ctxKey("parentKey") - firstIntKey = ctxKey("firstIntKey") - secondIntKey = ctxKey("secondIntKey") - ) - - firstInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { - if ctx.Value(parentKey) == nil { - t.Fatalf("first interceptor should have %v in context", parentKey) - } - if ctx.Value(firstIntKey) != nil { - t.Fatalf("first interceptor should not have %v in context", firstIntKey) - } - if ctx.Value(secondIntKey) != nil { - t.Fatalf("first interceptor should not have %v in context", secondIntKey) - } - firstCtx := context.WithValue(ctx, firstIntKey, 1) - err := invoker(firstCtx, method, req, reply, cc, opts...) - *(reply.(*string)) += "1" - return err - } - - secondInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { - if ctx.Value(parentKey) == nil { - t.Fatalf("second interceptor should have %v in context", parentKey) - } - if ctx.Value(firstIntKey) == nil { - t.Fatalf("second interceptor should have %v in context", firstIntKey) - } - if ctx.Value(secondIntKey) != nil { - t.Fatalf("second interceptor should not have %v in context", secondIntKey) - } - secondCtx := context.WithValue(ctx, secondIntKey, 2) - err := invoker(secondCtx, method, req, reply, cc, opts...) - *(reply.(*string)) += "2" - return err - } - - lastInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { - if ctx.Value(parentKey) == nil { - t.Fatalf("last interceptor should have %v in context", parentKey) - } - if ctx.Value(firstIntKey) == nil { - t.Fatalf("last interceptor should have %v in context", firstIntKey) - } - if ctx.Value(secondIntKey) == nil { - t.Fatalf("last interceptor should have %v in context", secondIntKey) - } - err := invoker(ctx, method, req, reply, cc, opts...) - *(reply.(*string)) += "3" - return err - } - - server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainUnaryInterceptor(firstInt, secondInt, lastInt)) - defer func() { - cc.Close() - server.stop() - }() - - var reply string - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0) - if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse+"321" { - t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) - } -} - -func (s) TestChainOnBaseUnaryClientInterceptor(t *testing.T) { - var ( - parentKey = ctxKey("parentKey") - baseIntKey = ctxKey("baseIntKey") - ) - - baseInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { - if ctx.Value(parentKey) == nil { - t.Fatalf("base interceptor should have %v in context", parentKey) - } - if ctx.Value(baseIntKey) != nil { - t.Fatalf("base interceptor should not have %v in context", baseIntKey) - } - baseCtx := context.WithValue(ctx, baseIntKey, 1) - return invoker(baseCtx, method, req, reply, cc, opts...) - } - - chainInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { - if ctx.Value(parentKey) == nil { - t.Fatalf("chain interceptor should have %v in context", parentKey) - } - if ctx.Value(baseIntKey) == nil { - t.Fatalf("chain interceptor should have %v in context", baseIntKey) - } - return invoker(ctx, method, req, reply, cc, opts...) - } - - server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(baseInt), WithChainUnaryInterceptor(chainInt)) - defer func() { - cc.Close() - server.stop() - }() - - var reply string - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0) - if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse { - t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) - } -} - -func (s) TestChainStreamClientInterceptor(t *testing.T) { - var ( - parentKey = ctxKey("parentKey") - firstIntKey = ctxKey("firstIntKey") - secondIntKey = ctxKey("secondIntKey") - ) - - firstInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) { - if ctx.Value(parentKey) == nil { - t.Fatalf("first interceptor should have %v in context", parentKey) - } - if ctx.Value(firstIntKey) != nil { - t.Fatalf("first interceptor should not have %v in context", firstIntKey) - } - if ctx.Value(secondIntKey) != nil { - t.Fatalf("first interceptor should not have %v in context", secondIntKey) - } - firstCtx := context.WithValue(ctx, firstIntKey, 1) - return streamer(firstCtx, desc, cc, method, opts...) - } - - secondInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) { - if ctx.Value(parentKey) == nil { - t.Fatalf("second interceptor should have %v in context", parentKey) - } - if ctx.Value(firstIntKey) == nil { - t.Fatalf("second interceptor should have %v in context", firstIntKey) - } - if ctx.Value(secondIntKey) != nil { - t.Fatalf("second interceptor should not have %v in context", secondIntKey) - } - secondCtx := context.WithValue(ctx, secondIntKey, 2) - return streamer(secondCtx, desc, cc, method, opts...) - } - - lastInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) { - if ctx.Value(parentKey) == nil { - t.Fatalf("last interceptor should have %v in context", parentKey) - } - if ctx.Value(firstIntKey) == nil { - t.Fatalf("last interceptor should have %v in context", firstIntKey) - } - if ctx.Value(secondIntKey) == nil { - t.Fatalf("last interceptor should have %v in context", secondIntKey) - } - return streamer(ctx, desc, cc, method, opts...) - } - - server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainStreamInterceptor(firstInt, secondInt, lastInt)) - defer func() { - cc.Close() - server.stop() - }() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0) - _, err := cc.NewStream(parentCtx, &StreamDesc{}, "/foo/bar") - if err != nil { - t.Fatalf("grpc.NewStream(_, _, _) = %v, want ", err) - } -} - -func (s) TestInvoke(t *testing.T) { - server, cc := setUp(t, 0, math.MaxUint32) - var reply string - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse { - t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) - } - cc.Close() - server.stop() -} - -func (s) TestInvokeLargeErr(t *testing.T) { - server, cc := setUp(t, 0, math.MaxUint32) - var reply string - req := "hello" - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - err := cc.Invoke(ctx, "/foo/bar", &req, &reply) - if _, ok := status.FromError(err); !ok { - t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") - } - if status.Code(err) != codes.Internal || len(errorDesc(err)) != sizeLargeErr { - t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want an error of code %d and desc size %d", err, codes.Internal, sizeLargeErr) - } - cc.Close() - server.stop() -} - -// TestInvokeErrorSpecialChars checks that error messages don't get mangled. -func (s) TestInvokeErrorSpecialChars(t *testing.T) { - server, cc := setUp(t, 0, math.MaxUint32) - var reply string - req := "weird error" - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - err := cc.Invoke(ctx, "/foo/bar", &req, &reply) - if _, ok := status.FromError(err); !ok { - t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") - } - if got, want := errorDesc(err), weirdError; got != want { - t.Fatalf("grpc.Invoke(_, _, _, _, _) error = %q, want %q", got, want) - } - cc.Close() - server.stop() -} - -// TestInvokeCancel checks that an Invoke with a canceled context is not sent. -func (s) TestInvokeCancel(t *testing.T) { - server, cc := setUp(t, 0, math.MaxUint32) - var reply string - req := "canceled" - for i := 0; i < 100; i++ { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - cc.Invoke(ctx, "/foo/bar", &req, &reply) - } - if canceled != 0 { - t.Fatalf("received %d of 100 canceled requests", canceled) - } - cc.Close() - server.stop() -} - -// TestInvokeCancelClosedNonFail checks that a canceled non-failfast RPC -// on a closed client will terminate. -func (s) TestInvokeCancelClosedNonFailFast(t *testing.T) { - server, cc := setUp(t, 0, math.MaxUint32) - var reply string - cc.Close() - req := "hello" - ctx, cancel := context.WithCancel(context.Background()) - cancel() - if err := cc.Invoke(ctx, "/foo/bar", &req, &reply, WaitForReady(true)); err == nil { - t.Fatalf("canceled invoke on closed connection should fail") - } - server.stop() -} diff --git a/test/interceptor_test.go b/test/interceptor_test.go new file mode 100644 index 00000000000..34a7cad5cc5 --- /dev/null +++ b/test/interceptor_test.go @@ -0,0 +1,279 @@ +/* + * + * Copyright 2022 gRPC authors. + + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package test + +import ( + "context" + "fmt" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" + testpb "google.golang.org/grpc/test/grpc_testing" +) + +type parentCtxkey struct{} +type firstInterceptorCtxkey struct{} +type secondInterceptorCtxkey struct{} +type baseInterceptorCtxKey struct{} + +const ( + parentCtxVal = "parent" + firstInterceptorCtxVal = "firstInterceptor" + secondInterceptorCtxVal = "secondInterceptor" + baseInterceptorCtxVal = "baseInterceptor" +) + +// TestUnaryClientInterceptor_ContextValuePropagation verifies that a unary +// interceptor receives context values specified in the context passed to the +// RPC call. +func (s) TestUnaryClientInterceptor_ContextValuePropagation(t *testing.T) { + errCh := testutils.NewChannel() + unaryInt := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { + errCh.Send(fmt.Errorf("unaryInt got %q in context.Val, want %q", got, parentCtxVal)) + } + errCh.Send(nil) + return invoker(ctx, method, req, reply, cc, opts...) + } + + // Start a stub server and use the above unary interceptor while creating a + // ClientConn to it. + ss := &stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, + } + if err := ss.Start(nil, grpc.WithUnaryInterceptor(unaryInt)); err != nil { + t.Fatalf("Failed to start stub server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil { + t.Fatalf("ss.Client.EmptyCall() failed: %v", err) + } + val, err := errCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err) + } + if val != nil { + t.Fatalf("unary interceptor failed: %v", val) + } +} + +// TestChainUnaryClientInterceptor_ContextValuePropagation verifies that a chain +// of unary interceptors receive context values specified in the original call +// as well as the ones specified by prior interceptors in the chain. +func (s) TestChainUnaryClientInterceptor_ContextValuePropagation(t *testing.T) { + errCh := testutils.NewChannel() + firstInt := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { + errCh.SendContext(ctx, fmt.Errorf("first interceptor got %q in context.Val, want %q", got, parentCtxVal)) + } + if ctx.Value(firstInterceptorCtxkey{}) != nil { + errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", firstInterceptorCtxkey{})) + } + if ctx.Value(secondInterceptorCtxkey{}) != nil { + errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", secondInterceptorCtxkey{})) + } + firstCtx := context.WithValue(ctx, firstInterceptorCtxkey{}, firstInterceptorCtxVal) + return invoker(firstCtx, method, req, reply, cc, opts...) + } + + secondInt := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { + errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, parentCtxVal)) + } + if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal { + errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal)) + } + if ctx.Value(secondInterceptorCtxkey{}) != nil { + errCh.SendContext(ctx, fmt.Errorf("second interceptor should not have %T in context", secondInterceptorCtxkey{})) + } + secondCtx := context.WithValue(ctx, secondInterceptorCtxkey{}, secondInterceptorCtxVal) + return invoker(secondCtx, method, req, reply, cc, opts...) + } + + lastInt := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { + errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, parentCtxVal)) + } + if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal { + errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal)) + } + if got, ok := ctx.Value(secondInterceptorCtxkey{}).(string); !ok || got != secondInterceptorCtxVal { + errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, secondInterceptorCtxVal)) + } + errCh.SendContext(ctx, nil) + return invoker(ctx, method, req, reply, cc, opts...) + } + + // Start a stub server and use the above chain of interceptors while creating + // a ClientConn to it. + ss := &stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, + } + if err := ss.Start(nil, grpc.WithChainUnaryInterceptor(firstInt, secondInt, lastInt)); err != nil { + t.Fatalf("Failed to start stub server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil { + t.Fatalf("ss.Client.EmptyCall() failed: %v", err) + } + val, err := errCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err) + } + if val != nil { + t.Fatalf("unary interceptor failed: %v", val) + } +} + +// TestChainOnBaseUnaryClientInterceptor_ContextValuePropagation verifies that +// unary interceptors specified as a base interceptor or as a chain interceptor +// receive context values specified in the original call as well as the ones +// specified by interceptors in the chain. +func (s) TestChainOnBaseUnaryClientInterceptor_ContextValuePropagation(t *testing.T) { + errCh := testutils.NewChannel() + baseInt := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { + errCh.SendContext(ctx, fmt.Errorf("base interceptor got %q in context.Val, want %q", got, parentCtxVal)) + } + if ctx.Value(baseInterceptorCtxKey{}) != nil { + errCh.SendContext(ctx, fmt.Errorf("baseinterceptor should not have %T in context", baseInterceptorCtxKey{})) + } + baseCtx := context.WithValue(ctx, baseInterceptorCtxKey{}, baseInterceptorCtxVal) + return invoker(baseCtx, method, req, reply, cc, opts...) + } + + chainInt := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { + errCh.SendContext(ctx, fmt.Errorf("chain interceptor got %q in context.Val, want %q", got, parentCtxVal)) + } + if got, ok := ctx.Value(baseInterceptorCtxKey{}).(string); !ok || got != baseInterceptorCtxVal { + errCh.SendContext(ctx, fmt.Errorf("chain interceptor got %q in context.Val, want %q", got, baseInterceptorCtxVal)) + } + errCh.SendContext(ctx, nil) + return invoker(ctx, method, req, reply, cc, opts...) + } + + // Start a stub server and use the above chain of interceptors while creating + // a ClientConn to it. + ss := &stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, + } + if err := ss.Start(nil, grpc.WithUnaryInterceptor(baseInt), grpc.WithChainUnaryInterceptor(chainInt)); err != nil { + t.Fatalf("Failed to start stub server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil { + t.Fatalf("ss.Client.EmptyCall() failed: %v", err) + } + val, err := errCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err) + } + if val != nil { + t.Fatalf("unary interceptor failed: %v", val) + } +} + +// TestChainStreamClientInterceptor_ContextValuePropagation verifies that a +// chain of stream interceptors receive context values specified in the original +// call as well as the ones specified by the prior interceptors in the chain. +func (s) TestChainStreamClientInterceptor_ContextValuePropagation(t *testing.T) { + errCh := testutils.NewChannel() + firstInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { + errCh.SendContext(ctx, fmt.Errorf("first interceptor got %q in context.Val, want %q", got, parentCtxVal)) + } + if ctx.Value(firstInterceptorCtxkey{}) != nil { + errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", firstInterceptorCtxkey{})) + } + if ctx.Value(secondInterceptorCtxkey{}) != nil { + errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", secondInterceptorCtxkey{})) + } + firstCtx := context.WithValue(ctx, firstInterceptorCtxkey{}, firstInterceptorCtxVal) + return streamer(firstCtx, desc, cc, method, opts...) + } + + secondInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { + errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, parentCtxVal)) + } + if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal { + errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal)) + } + if ctx.Value(secondInterceptorCtxkey{}) != nil { + errCh.SendContext(ctx, fmt.Errorf("second interceptor should not have %T in context", secondInterceptorCtxkey{})) + } + secondCtx := context.WithValue(ctx, secondInterceptorCtxkey{}, secondInterceptorCtxVal) + return streamer(secondCtx, desc, cc, method, opts...) + } + + lastInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { + errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, parentCtxVal)) + } + if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal { + errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal)) + } + if got, ok := ctx.Value(secondInterceptorCtxkey{}).(string); !ok || got != secondInterceptorCtxVal { + errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, secondInterceptorCtxVal)) + } + errCh.SendContext(ctx, nil) + return streamer(ctx, desc, cc, method, opts...) + } + + // Start a stub server and use the above chain of interceptors while creating + // a ClientConn to it. + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + if _, err := stream.Recv(); err != nil { + return err + } + return stream.Send(&testpb.StreamingOutputCallResponse{}) + }, + } + if err := ss.Start(nil, grpc.WithChainStreamInterceptor(firstInt, secondInt, lastInt)); err != nil { + t.Fatalf("Failed to start stub server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := ss.Client.FullDuplexCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal)); err != nil { + t.Fatalf("ss.Client.FullDuplexCall() failed: %v", err) + } + val, err := errCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for stream interceptor to be invoked: %v", err) + } + if val != nil { + t.Fatalf("stream interceptor failed: %v", val) + } +} diff --git a/test/invoke_test.go b/test/invoke_test.go new file mode 100644 index 00000000000..49ad9044ee3 --- /dev/null +++ b/test/invoke_test.go @@ -0,0 +1,152 @@ +/* + * + * Copyright 2022 gRPC authors. + + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package test + +import ( + "context" + "strings" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/status" + testpb "google.golang.org/grpc/test/grpc_testing" +) + +// TestInvoke verifies a straightforward invocation of ClientConn.Invoke(). +func (s) TestInvoke(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Failed to start stub server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := ss.CC.Invoke(ctx, "/grpc.testing.TestService/EmptyCall", &testpb.Empty{}, &testpb.Empty{}); err != nil { + t.Fatalf("grpc.Invoke(\"/grpc.testing.TestService/EmptyCall\") failed: %v", err) + } +} + +// TestInvokeLargeErr verifies an invocation of ClientConn.Invoke() where the +// server returns a really large error message. +func (s) TestInvokeLargeErr(t *testing.T) { + largeErrorStr := strings.Repeat("A", 1024*1024) + ss := &stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, status.Error(codes.Internal, largeErrorStr) + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Failed to start stub server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + err := ss.CC.Invoke(ctx, "/grpc.testing.TestService/EmptyCall", &testpb.Empty{}, &testpb.Empty{}) + if err == nil { + t.Fatal("grpc.Invoke(\"/grpc.testing.TestService/EmptyCall\") succeeded when expected to fail") + } + st, ok := status.FromError(err) + if !ok { + t.Fatal("grpc.Invoke(\"/grpc.testing.TestService/EmptyCall\") received non-status error") + } + if status.Code(err) != codes.Internal || st.Message() != largeErrorStr { + t.Fatalf("grpc.Invoke(\"/grpc.testing.TestService/EmptyCall\") failed with error: %v, want an error of code %d and desc size %d", err, codes.Internal, len(largeErrorStr)) + } +} + +// TestInvokeErrorSpecialChars tests an invocation of ClientConn.Invoke() and +// verifies that error messages don't get mangled. +func (s) TestInvokeErrorSpecialChars(t *testing.T) { + const weirdError = "format verbs: %v%s" + ss := &stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, status.Error(codes.Internal, weirdError) + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Failed to start stub server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + err := ss.CC.Invoke(ctx, "/grpc.testing.TestService/EmptyCall", &testpb.Empty{}, &testpb.Empty{}) + if err == nil { + t.Fatal("grpc.Invoke(\"/grpc.testing.TestService/EmptyCall\") succeeded when expected to fail") + } + st, ok := status.FromError(err) + if !ok { + t.Fatal("grpc.Invoke(\"/grpc.testing.TestService/EmptyCall\") received non-status error") + } + if status.Code(err) != codes.Internal || st.Message() != weirdError { + t.Fatalf("grpc.Invoke(\"/grpc.testing.TestService/EmptyCall\") failed with error: %v, want %v", err, weirdError) + } +} + +// TestInvokeCancel tests an invocation of ClientConn.Invoke() with a cancelled +// context and verifies that the request is not actually sent to the server. +func (s) TestInvokeCancel(t *testing.T) { + cancelled := 0 + ss := &stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + cancelled++ + return &testpb.Empty{}, nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Failed to start stub server: %v", err) + } + defer ss.Stop() + + for i := 0; i < 100; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + ss.CC.Invoke(ctx, "/grpc.testing.TestService/EmptyCall", &testpb.Empty{}, &testpb.Empty{}) + } + if cancelled != 0 { + t.Fatalf("server received %d of 100 cancelled requests", cancelled) + } +} + +// TestInvokeCancelClosedNonFail tests an invocation of ClientConn.Invoke() with +// a cancelled non-failfast RPC on a closed ClientConn and verifies that the +// call terminates with an error. +func (s) TestInvokeCancelClosedNonFailFast(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Failed to start stub server: %v", err) + } + defer ss.Stop() + + ss.CC.Close() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := ss.CC.Invoke(ctx, "/grpc.testing.TestService/EmptyCall", &testpb.Empty{}, &testpb.Empty{}, grpc.WaitForReady(true)); err == nil { + t.Fatal("ClientConn.Invoke() on closed connection succeeded when expected to fail") + } +}