Skip to content

Commit

Permalink
Improve interceptors chaining
Browse files Browse the repository at this point in the history
  • Loading branch information
amenzhinsky committed Jun 8, 2021
1 parent 0956b12 commit a1c389e
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions server.go
Expand Up @@ -1115,22 +1115,22 @@ func chainUnaryServerInterceptors(s *Server) {
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
chainedInt = func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
for i := len(interceptors) - 1; i >= 0; i-- {
chainedInt = chainUnaryInterceptors(chainedInt, interceptors[i])
}
}

s.opts.unaryInt = chainedInt
}

// getChainUnaryHandler recursively generate the chained UnaryHandler
func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler {
if curr == len(interceptors)-1 {
return finalHandler
func chainUnaryInterceptors(curr, next UnaryServerInterceptor) UnaryServerInterceptor {
if curr == nil {
return next
}

return func(ctx context.Context, req interface{}) (interface{}, error) {
return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
return next(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) {
return curr(ctx, req, info, handler)
})
}
}

Expand Down Expand Up @@ -1396,22 +1396,22 @@ func chainStreamServerInterceptors(s *Server) {
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
chainedInt = func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler))
for i := len(interceptors) - 1; i >= 0; i-- {
chainedInt = chainStreamInterceptors(chainedInt, interceptors[i])
}
}

s.opts.streamInt = chainedInt
}

// getChainStreamHandler recursively generate the chained StreamHandler
func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler {
if curr == len(interceptors)-1 {
return finalHandler
func chainStreamInterceptors(curr, next StreamServerInterceptor) StreamServerInterceptor {
if curr == nil {
return next
}

return func(srv interface{}, ss ServerStream) error {
return interceptors[curr+1](srv, ss, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler))
return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
return next(srv, ss, info, func(srv interface{}, stream ServerStream) error {
return curr(srv, ss, info, handler)
})
}
}

Expand Down

0 comments on commit a1c389e

Please sign in to comment.