From 1d05928a6157d149eb5980607dd28811ea3196ba Mon Sep 17 00:00:00 2001 From: Yimin Chen Date: Thu, 22 Sep 2022 22:44:29 -0700 Subject: [PATCH] Fix chainUnaryInterceptors to allow retry --- server.go | 1 + server_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/server.go b/server.go index 6ef3df67d9e5..878f4e450ca9 100644 --- a/server.go +++ b/server.go @@ -1152,6 +1152,7 @@ func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerIn return interceptors[state.i](ctx, req, info, handler) } state.i++ + defer func() { state.i-- }() return interceptors[state.i-1](ctx, req, info, state.next) } return state.next(ctx, req) diff --git a/server_test.go b/server_test.go index 7d4cf7bfc21e..9494be30d1de 100644 --- a/server_test.go +++ b/server_test.go @@ -130,6 +130,34 @@ func (s) TestGetServiceInfo(t *testing.T) { } } +func (s) TestRetryChainedInterceptor(t *testing.T) { + var records []int + i1 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) { + records = append(records, 1) + // call handler twice to simulate a retry here. + handler(ctx, req) + return handler(ctx, req) + } + i2 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) { + records = append(records, 2) + return handler(ctx, req) + } + i3 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) { + records = append(records, 3) + return handler(ctx, req) + } + + ii := chainUnaryInterceptors([]UnaryServerInterceptor{i1, i2, i3}) + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + } + ii(context.Background(), nil, nil, handler) + if !reflect.DeepEqual(records, []int{1, 2, 3, 2, 3}) { + t.Fatalf("retry failed on chained interceptors: %v", records) + } +} + func (s) TestStreamContext(t *testing.T) { expectedStream := &transport.Stream{} ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)