diff --git a/server.go b/server.go index f4dde72b41f..2ed550c91e5 100644 --- a/server.go +++ b/server.go @@ -1150,21 +1150,16 @@ func chainUnaryServerInterceptors(s *Server) { func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) { - // the struct ensures the variables are allocated together, rather than separately, since we - // know they should be garbage collected together. This saves 1 allocation and decreases - // time/call by about 10% on the microbenchmark. - var state struct { - i int - next UnaryHandler - } - state.next = func(ctx context.Context, req interface{}) (interface{}, error) { - if state.i == len(interceptors)-1 { - return interceptors[state.i](ctx, req, info, handler) - } - state.i++ - return interceptors[state.i-1](ctx, req, info, state.next) - } - return state.next(ctx, req) + return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler)) + } +} + +func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler { + if curr == len(interceptors)-1 { + return finalHandler + } + return func(ctx context.Context, req interface{}) (interface{}, error) { + return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler)) } } @@ -1470,21 +1465,16 @@ func chainStreamServerInterceptors(s *Server) { func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor { return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { - // the struct ensures the variables are allocated together, rather than separately, since we - // know they should be garbage collected together. This saves 1 allocation and decreases - // time/call by about 10% on the microbenchmark. - var state struct { - i int - next StreamHandler - } - state.next = func(srv interface{}, ss ServerStream) error { - if state.i == len(interceptors)-1 { - return interceptors[state.i](srv, ss, info, handler) - } - state.i++ - return interceptors[state.i-1](srv, ss, info, state.next) - } - return state.next(srv, ss) + return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler)) + } +} + +func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler { + if curr == len(interceptors)-1 { + return finalHandler + } + return func(srv interface{}, stream ServerStream) error { + return interceptors[curr+1](srv, stream, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler)) } } diff --git a/server_test.go b/server_test.go index 7d4cf7bfc21..85a8f5bf72e 100644 --- a/server_test.go +++ b/server_test.go @@ -27,6 +27,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/status" ) @@ -130,6 +131,34 @@ func (s) TestGetServiceInfo(t *testing.T) { } } +func (s) TestRetryChainedInterceptor(t *testing.T) { + var records []int + i1 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) { + records = append(records, 1) + // call handler twice to simulate a retry here. + handler(ctx, req) + return handler(ctx, req) + } + i2 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) { + records = append(records, 2) + return handler(ctx, req) + } + i3 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) { + records = append(records, 3) + return handler(ctx, req) + } + + ii := chainUnaryInterceptors([]UnaryServerInterceptor{i1, i2, i3}) + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + } + ii(context.Background(), nil, nil, handler) + if !cmp.Equal(records, []int{1, 2, 3, 2, 3}) { + t.Fatalf("retry failed on chained interceptors: %v", records) + } +} + func (s) TestStreamContext(t *testing.T) { expectedStream := &transport.Stream{} ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)