Skip to content

Commit

Permalink
server: fix ChainUnaryInterceptor and ChainStreamInterceptor to allow…
Browse files Browse the repository at this point in the history
… retrying handlers (#5666)
  • Loading branch information
yiminc committed Nov 22, 2022
1 parent e0a9f11 commit adfb915
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 30 deletions.
50 changes: 20 additions & 30 deletions server.go
Expand Up @@ -1150,21 +1150,16 @@ func chainUnaryServerInterceptors(s *Server) {

func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
// the struct ensures the variables are allocated together, rather than separately, since we
// know they should be garbage collected together. This saves 1 allocation and decreases
// time/call by about 10% on the microbenchmark.
var state struct {
i int
next UnaryHandler
}
state.next = func(ctx context.Context, req interface{}) (interface{}, error) {
if state.i == len(interceptors)-1 {
return interceptors[state.i](ctx, req, info, handler)
}
state.i++
return interceptors[state.i-1](ctx, req, info, state.next)
}
return state.next(ctx, req)
return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
}
}

func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler {
if curr == len(interceptors)-1 {
return finalHandler
}
return func(ctx context.Context, req interface{}) (interface{}, error) {
return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
}
}

Expand Down Expand Up @@ -1470,21 +1465,16 @@ func chainStreamServerInterceptors(s *Server) {

func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor {
return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
// the struct ensures the variables are allocated together, rather than separately, since we
// know they should be garbage collected together. This saves 1 allocation and decreases
// time/call by about 10% on the microbenchmark.
var state struct {
i int
next StreamHandler
}
state.next = func(srv interface{}, ss ServerStream) error {
if state.i == len(interceptors)-1 {
return interceptors[state.i](srv, ss, info, handler)
}
state.i++
return interceptors[state.i-1](srv, ss, info, state.next)
}
return state.next(srv, ss)
return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler))
}
}

func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler {
if curr == len(interceptors)-1 {
return finalHandler
}
return func(srv interface{}, stream ServerStream) error {
return interceptors[curr+1](srv, stream, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler))
}
}

Expand Down
29 changes: 29 additions & 0 deletions server_test.go
Expand Up @@ -27,6 +27,7 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/status"
)
Expand Down Expand Up @@ -130,6 +131,34 @@ func (s) TestGetServiceInfo(t *testing.T) {
}
}

func (s) TestRetryChainedInterceptor(t *testing.T) {
var records []int
i1 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
records = append(records, 1)
// call handler twice to simulate a retry here.
handler(ctx, req)
return handler(ctx, req)
}
i2 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
records = append(records, 2)
return handler(ctx, req)
}
i3 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
records = append(records, 3)
return handler(ctx, req)
}

ii := chainUnaryInterceptors([]UnaryServerInterceptor{i1, i2, i3})

handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
}
ii(context.Background(), nil, nil, handler)
if !cmp.Equal(records, []int{1, 2, 3, 2, 3}) {
t.Fatalf("retry failed on chained interceptors: %v", records)
}
}

func (s) TestStreamContext(t *testing.T) {
expectedStream := &transport.Stream{}
ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)
Expand Down

0 comments on commit adfb915

Please sign in to comment.