Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: fix ChainUnaryInterceptor and ChainStreamInterceptor to allow retrying handlers #5666

Merged
merged 3 commits into from Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 20 additions & 30 deletions server.go
Expand Up @@ -1150,21 +1150,16 @@ func chainUnaryServerInterceptors(s *Server) {

func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
// the struct ensures the variables are allocated together, rather than separately, since we
// know they should be garbage collected together. This saves 1 allocation and decreases
// time/call by about 10% on the microbenchmark.
var state struct {
i int
next UnaryHandler
}
state.next = func(ctx context.Context, req interface{}) (interface{}, error) {
if state.i == len(interceptors)-1 {
return interceptors[state.i](ctx, req, info, handler)
}
state.i++
return interceptors[state.i-1](ctx, req, info, state.next)
}
return state.next(ctx, req)
return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
}
}

func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler {
yiminc marked this conversation as resolved.
Show resolved Hide resolved
if curr == len(interceptors)-1 {
return finalHandler
}
return func(ctx context.Context, req interface{}) (interface{}, error) {
return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
}
}

Expand Down Expand Up @@ -1470,21 +1465,16 @@ func chainStreamServerInterceptors(s *Server) {

func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor {
return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
// the struct ensures the variables are allocated together, rather than separately, since we
// know they should be garbage collected together. This saves 1 allocation and decreases
// time/call by about 10% on the microbenchmark.
var state struct {
i int
next StreamHandler
}
state.next = func(srv interface{}, ss ServerStream) error {
if state.i == len(interceptors)-1 {
return interceptors[state.i](srv, ss, info, handler)
}
state.i++
return interceptors[state.i-1](srv, ss, info, state.next)
}
return state.next(srv, ss)
return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler))
}
}

func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler {
if curr == len(interceptors)-1 {
return finalHandler
}
return func(srv interface{}, stream ServerStream) error {
return interceptors[curr+1](srv, stream, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler))
}
}

Expand Down
29 changes: 29 additions & 0 deletions server_test.go
Expand Up @@ -27,6 +27,7 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/status"
)
Expand Down Expand Up @@ -130,6 +131,34 @@ func (s) TestGetServiceInfo(t *testing.T) {
}
}

func (s) TestRetryChainedInterceptor(t *testing.T) {
var records []int
i1 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
records = append(records, 1)
// call handler twice to simulate a retry here.
handler(ctx, req)
return handler(ctx, req)
}
i2 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
records = append(records, 2)
return handler(ctx, req)
}
i3 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
records = append(records, 3)
return handler(ctx, req)
}

ii := chainUnaryInterceptors([]UnaryServerInterceptor{i1, i2, i3})

handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
}
ii(context.Background(), nil, nil, handler)
if !cmp.Equal(records, []int{1, 2, 3, 2, 3}) {
t.Fatalf("retry failed on chained interceptors: %v", records)
}
}

func (s) TestStreamContext(t *testing.T) {
expectedStream := &transport.Stream{}
ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)
Expand Down