diff --git a/server.go b/server.go index d90f3fcd3bf..db431ea485f 100644 --- a/server.go +++ b/server.go @@ -1115,22 +1115,22 @@ func chainUnaryServerInterceptors(s *Server) { } else if len(interceptors) == 1 { chainedInt = interceptors[0] } else { - chainedInt = func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) { - return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler)) + for i := len(interceptors) - 1; i >= 0; i-- { + chainedInt = chainUnaryInterceptors(chainedInt, interceptors[i]) } } s.opts.unaryInt = chainedInt } -// getChainUnaryHandler recursively generate the chained UnaryHandler -func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler { - if curr == len(interceptors)-1 { - return finalHandler +func chainUnaryInterceptors(curr, next UnaryServerInterceptor) UnaryServerInterceptor { + if curr == nil { + return next } - - return func(ctx context.Context, req interface{}) (interface{}, error) { - return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler)) + return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) { + return next(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) { + return curr(ctx, req, info, handler) + }) } } @@ -1396,22 +1396,22 @@ func chainStreamServerInterceptors(s *Server) { } else if len(interceptors) == 1 { chainedInt = interceptors[0] } else { - chainedInt = func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { - return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler)) + for i := len(interceptors) - 1; i >= 0; i-- { + chainedInt = chainStreamInterceptors(chainedInt, interceptors[i]) } } s.opts.streamInt = chainedInt } -// getChainStreamHandler recursively generate the chained StreamHandler -func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler { - if curr == len(interceptors)-1 { - return finalHandler +func chainStreamInterceptors(curr, next StreamServerInterceptor) StreamServerInterceptor { + if curr == nil { + return next } - - return func(srv interface{}, ss ServerStream) error { - return interceptors[curr+1](srv, ss, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler)) + return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { + return next(srv, ss, info, func(srv interface{}, stream ServerStream) error { + return curr(srv, ss, info, handler) + }) } }