diff --git a/server.go b/server.go index db431ea485fb..c663f1bb293d 100644 --- a/server.go +++ b/server.go @@ -1115,22 +1115,24 @@ func chainUnaryServerInterceptors(s *Server) { } else if len(interceptors) == 1 { chainedInt = interceptors[0] } else { - for i := len(interceptors) - 1; i >= 0; i-- { - chainedInt = chainUnaryInterceptors(chainedInt, interceptors[i]) - } + chainedInt = chainUnaryInterceptors(interceptors) } s.opts.unaryInt = chainedInt } -func chainUnaryInterceptors(curr, next UnaryServerInterceptor) UnaryServerInterceptor { - if curr == nil { - return next - } +func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) { - return next(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) { - return curr(ctx, req, info, handler) - }) + var i int + var next UnaryHandler + next = func(ctx context.Context, req interface{}) (interface{}, error) { + if i == len(interceptors)-1 { + return interceptors[i](ctx, req, info, handler) + } + i++ + return interceptors[i-1](ctx, req, info, next) + } + return next(ctx, req) } } @@ -1396,22 +1398,24 @@ func chainStreamServerInterceptors(s *Server) { } else if len(interceptors) == 1 { chainedInt = interceptors[0] } else { - for i := len(interceptors) - 1; i >= 0; i-- { - chainedInt = chainStreamInterceptors(chainedInt, interceptors[i]) - } + chainedInt = chainStreamInterceptors(interceptors) } s.opts.streamInt = chainedInt } -func chainStreamInterceptors(curr, next StreamServerInterceptor) StreamServerInterceptor { - if curr == nil { - return next - } +func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor { return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { - return next(srv, ss, info, func(srv interface{}, stream ServerStream) error { - return curr(srv, ss, info, handler) - }) + var i int + var next StreamHandler + next = func(srv interface{}, ss ServerStream) error { + if i == len(interceptors)-1 { + return interceptors[i](srv, ss, info, handler) + } + i++ + return interceptors[i-1](srv, ss, info, handler) + } + return next(srv, ss) } }