From 005dae391f39407664350d4e29ae3621dcdf88eb Mon Sep 17 00:00:00 2001 From: Yimin Chen Date: Thu, 22 Sep 2022 22:44:29 -0700 Subject: [PATCH 1/3] Fix chainUnaryInterceptors to allow retry --- server.go | 1 + server_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/server.go b/server.go index f4dde72b41f..49041bc9192 100644 --- a/server.go +++ b/server.go @@ -1162,6 +1162,7 @@ func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerIn return interceptors[state.i](ctx, req, info, handler) } state.i++ + defer func() { state.i-- }() return interceptors[state.i-1](ctx, req, info, state.next) } return state.next(ctx, req) diff --git a/server_test.go b/server_test.go index 7d4cf7bfc21..9494be30d1d 100644 --- a/server_test.go +++ b/server_test.go @@ -130,6 +130,34 @@ func (s) TestGetServiceInfo(t *testing.T) { } } +func (s) TestRetryChainedInterceptor(t *testing.T) { + var records []int + i1 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) { + records = append(records, 1) + // call handler twice to simulate a retry here. + handler(ctx, req) + return handler(ctx, req) + } + i2 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) { + records = append(records, 2) + return handler(ctx, req) + } + i3 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) { + records = append(records, 3) + return handler(ctx, req) + } + + ii := chainUnaryInterceptors([]UnaryServerInterceptor{i1, i2, i3}) + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + } + ii(context.Background(), nil, nil, handler) + if !reflect.DeepEqual(records, []int{1, 2, 3, 2, 3}) { + t.Fatalf("retry failed on chained interceptors: %v", records) + } +} + func (s) TestStreamContext(t *testing.T) { expectedStream := &transport.Stream{} ctx := NewContextWithServerTransportStream(context.Background(), expectedStream) From b6ff0e73ec7b2c09760851dcee7ea3be74c4be6d Mon Sep 17 00:00:00 2001 From: Yimin Chen Date: Fri, 4 Nov 2022 21:32:36 -0700 Subject: [PATCH 2/3] Use recursive --- server.go | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/server.go b/server.go index 49041bc9192..346f75dba4b 100644 --- a/server.go +++ b/server.go @@ -1150,22 +1150,16 @@ func chainUnaryServerInterceptors(s *Server) { func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) { - // the struct ensures the variables are allocated together, rather than separately, since we - // know they should be garbage collected together. This saves 1 allocation and decreases - // time/call by about 10% on the microbenchmark. - var state struct { - i int - next UnaryHandler - } - state.next = func(ctx context.Context, req interface{}) (interface{}, error) { - if state.i == len(interceptors)-1 { - return interceptors[state.i](ctx, req, info, handler) - } - state.i++ - defer func() { state.i-- }() - return interceptors[state.i-1](ctx, req, info, state.next) - } - return state.next(ctx, req) + return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler)) + } +} + +func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler { + if curr == len(interceptors)-1 { + return finalHandler + } + return func(ctx context.Context, req interface{}) (interface{}, error) { + return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler)) } } From 64a18d23bcefd4c49107b941b2e93b5b6887200a Mon Sep 17 00:00:00 2001 From: Yimin Chen Date: Sat, 12 Nov 2022 23:31:44 -0800 Subject: [PATCH 3/3] Update chainStreamInterceptors --- server.go | 25 ++++++++++--------------- server_test.go | 3 ++- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/server.go b/server.go index 346f75dba4b..2ed550c91e5 100644 --- a/server.go +++ b/server.go @@ -1465,21 +1465,16 @@ func chainStreamServerInterceptors(s *Server) { func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor { return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { - // the struct ensures the variables are allocated together, rather than separately, since we - // know they should be garbage collected together. This saves 1 allocation and decreases - // time/call by about 10% on the microbenchmark. - var state struct { - i int - next StreamHandler - } - state.next = func(srv interface{}, ss ServerStream) error { - if state.i == len(interceptors)-1 { - return interceptors[state.i](srv, ss, info, handler) - } - state.i++ - return interceptors[state.i-1](srv, ss, info, state.next) - } - return state.next(srv, ss) + return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler)) + } +} + +func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler { + if curr == len(interceptors)-1 { + return finalHandler + } + return func(srv interface{}, stream ServerStream) error { + return interceptors[curr+1](srv, stream, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler)) } } diff --git a/server_test.go b/server_test.go index 9494be30d1d..85a8f5bf72e 100644 --- a/server_test.go +++ b/server_test.go @@ -27,6 +27,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/status" ) @@ -153,7 +154,7 @@ func (s) TestRetryChainedInterceptor(t *testing.T) { return nil, nil } ii(context.Background(), nil, nil, handler) - if !reflect.DeepEqual(records, []int{1, 2, 3, 2, 3}) { + if !cmp.Equal(records, []int{1, 2, 3, 2, 3}) { t.Fatalf("retry failed on chained interceptors: %v", records) } }