Skip to content

Commit

Permalink
server: improve chained interceptors performance (#4524)
Browse files Browse the repository at this point in the history
  • Loading branch information
amenzhinsky committed Jun 25, 2021
1 parent e24ede5 commit 9b2fa9f
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 22 deletions.
48 changes: 26 additions & 22 deletions server.go
Expand Up @@ -1115,22 +1115,24 @@ func chainUnaryServerInterceptors(s *Server) {
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
chainedInt = func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
}
chainedInt = chainUnaryInterceptors(interceptors)
}

s.opts.unaryInt = chainedInt
}

// getChainUnaryHandler recursively generate the chained UnaryHandler
func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler {
if curr == len(interceptors)-1 {
return finalHandler
}

return func(ctx context.Context, req interface{}) (interface{}, error) {
return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
var i int
var next UnaryHandler
next = func(ctx context.Context, req interface{}) (interface{}, error) {
if i == len(interceptors)-1 {
return interceptors[i](ctx, req, info, handler)
}
i++
return interceptors[i-1](ctx, req, info, next)
}
return next(ctx, req)
}
}

Expand Down Expand Up @@ -1398,22 +1400,24 @@ func chainStreamServerInterceptors(s *Server) {
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
chainedInt = func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler))
}
chainedInt = chainStreamInterceptors(interceptors)
}

s.opts.streamInt = chainedInt
}

// getChainStreamHandler recursively generate the chained StreamHandler
func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler {
if curr == len(interceptors)-1 {
return finalHandler
}

return func(srv interface{}, ss ServerStream) error {
return interceptors[curr+1](srv, ss, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler))
func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor {
return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
var i int
var next StreamHandler
next = func(srv interface{}, ss ServerStream) error {
if i == len(interceptors)-1 {
return interceptors[i](srv, ss, info, handler)
}
i++
return interceptors[i-1](srv, ss, info, next)
}
return next(srv, ss)
}
}

Expand Down
57 changes: 57 additions & 0 deletions server_test.go
Expand Up @@ -22,6 +22,7 @@ import (
"context"
"net"
"reflect"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -130,3 +131,59 @@ func (s) TestStreamContext(t *testing.T) {
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream)
}
}

func BenchmarkChainUnaryInterceptor(b *testing.B) {
for _, n := range []int{1, 3, 5, 10} {
n := n
b.Run(strconv.Itoa(n), func(b *testing.B) {
interceptors := make([]UnaryServerInterceptor, 0, n)
for i := 0; i < n; i++ {
interceptors = append(interceptors, func(
ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler,
) (interface{}, error) {
return handler(ctx, req)
})
}

s := NewServer(ChainUnaryInterceptor(interceptors...))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := s.opts.unaryInt(context.Background(), nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
},
); err != nil {
b.Fatal(err)
}
}
})
}
}

func BenchmarkChainStreamInterceptor(b *testing.B) {
for _, n := range []int{1, 3, 5, 10} {
n := n
b.Run(strconv.Itoa(n), func(b *testing.B) {
interceptors := make([]StreamServerInterceptor, 0, n)
for i := 0; i < n; i++ {
interceptors = append(interceptors, func(
srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler,
) error {
return handler(srv, ss)
})
}

s := NewServer(ChainStreamInterceptor(interceptors...))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := s.opts.streamInt(nil, nil, nil, func(srv interface{}, stream ServerStream) error {
return nil
}); err != nil {
b.Fatal(err)
}
}
})
}
}

0 comments on commit 9b2fa9f

Please sign in to comment.