Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add user middleware chain function #1913

Merged
merged 7 commits into from Jun 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 24 additions & 13 deletions rest/engine.go
Expand Up @@ -33,6 +33,7 @@ type engine struct {
shedder load.Shedder
priorityShedder load.Shedder
tlsConfig *tls.Config
chain *alice.Chain
}

func newEngine(c RestConf) *engine {
Expand Down Expand Up @@ -85,19 +86,25 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met

func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
route Route, verifier func(chain alice.Chain) alice.Chain) error {
chain := alice.New(
handler.TracingHandler(ng.conf.Name, route.Path),
ng.getLogHandler(),
handler.PrometheusHandler(route.Path),
handler.MaxConns(ng.conf.MaxConns),
handler.BreakerHandler(route.Method, route.Path, metrics),
handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
handler.RecoverHandler,
handler.MetricHandler(metrics),
handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
handler.GunzipHandler,
)
var chain alice.Chain
if ng.chain == nil {
chain = alice.New(
handler.TracingHandler(ng.conf.Name, route.Path),
ng.getLogHandler(),
handler.PrometheusHandler(route.Path),
handler.MaxConns(ng.conf.MaxConns),
handler.BreakerHandler(route.Method, route.Path, metrics),
handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
handler.RecoverHandler,
handler.MetricHandler(metrics),
handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
handler.GunzipHandler,
)
} else {
chain = *ng.chain
}

chain = ng.appendAuthHandler(fr, chain, verifier)

for _, middleware := range ng.middlewares {
Expand Down Expand Up @@ -206,6 +213,10 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) {
ng.tlsConfig = cfg
}

func (ng *engine) setChainConfig(chain *alice.Chain) {
ng.chain = chain
}

func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
ng.unauthorizedCallback = callback
}
Expand Down
54 changes: 54 additions & 0 deletions rest/engine_test.go
Expand Up @@ -229,6 +229,44 @@ func TestEngine_checkedMaxBytes(t *testing.T) {
}
}

func TestEngine_checkedChain(t *testing.T) {
var called int32
middleware1 := func() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&called, 1)
next.ServeHTTP(w, r)
atomic.AddInt32(&called, 1)
})
}
}
middleware2 := func() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&called, 1)
next.ServeHTTP(w, r)
atomic.AddInt32(&called, 1)
})
}
}

server := MustNewServer(RestConf{}, WithChain(middleware1(), middleware2()))
server.router = chainRouter{}
server.AddRoutes(
[]Route{
{
Method: http.MethodGet,
Path: "/",
Handler: func(_ http.ResponseWriter, _ *http.Request) {
atomic.AddInt32(&called, 1)
},
},
},
)
server.ngin.bindRoutes(chainRouter{})
assert.Equal(t, int32(5), atomic.LoadInt32(&called))
}

func TestEngine_notFoundHandler(t *testing.T) {
logx.Disable()

Expand Down Expand Up @@ -343,3 +381,19 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {

func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
}

type chainRouter struct{}

func (c chainRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
}

func (c chainRouter) Handle(_, _ string, handler http.Handler) error {
handler.ServeHTTP(nil, nil)
return nil
}

func (c chainRouter) SetNotFoundHandler(_ http.Handler) {
}

func (c chainRouter) SetNotAllowedHandler(_ http.Handler) {
}
12 changes: 12 additions & 0 deletions rest/server.go
Expand Up @@ -7,6 +7,7 @@ import (
"path"
"time"

"github.com/justinas/alice"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/handler"
"github.com/zeromicro/go-zero/rest/httpx"
Expand Down Expand Up @@ -242,6 +243,17 @@ func WithTLSConfig(cfg *tls.Config) RunOption {
}
}

// WithChain returns a RunOption that with given chain config.
func WithChain(middlewares ...func(http.Handler) http.Handler) RunOption {
return func(svr *Server) {
chain := alice.New()
for _, middleware := range middlewares {
chain = chain.Append(middleware)
}
svr.ngin.setChainConfig(&chain)
}
}

// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
return func(svr *Server) {
Expand Down