diff --git a/rest/engine.go b/rest/engine.go index da21d8c1fbdc..d7dc82212504 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -33,6 +33,7 @@ type engine struct { shedder load.Shedder priorityShedder load.Shedder tlsConfig *tls.Config + chain *alice.Chain } func newEngine(c RestConf) *engine { @@ -85,19 +86,25 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, route Route, verifier func(chain alice.Chain) alice.Chain) error { - chain := alice.New( - handler.TracingHandler(ng.conf.Name, route.Path), - ng.getLogHandler(), - handler.PrometheusHandler(route.Path), - handler.MaxConns(ng.conf.MaxConns), - handler.BreakerHandler(route.Method, route.Path, metrics), - handler.SheddingHandler(ng.getShedder(fr.priority), metrics), - handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)), - handler.RecoverHandler, - handler.MetricHandler(metrics), - handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)), - handler.GunzipHandler, - ) + var chain alice.Chain + if ng.chain == nil { + chain = alice.New( + handler.TracingHandler(ng.conf.Name, route.Path), + ng.getLogHandler(), + handler.PrometheusHandler(route.Path), + handler.MaxConns(ng.conf.MaxConns), + handler.BreakerHandler(route.Method, route.Path, metrics), + handler.SheddingHandler(ng.getShedder(fr.priority), metrics), + handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)), + handler.RecoverHandler, + handler.MetricHandler(metrics), + handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)), + handler.GunzipHandler, + ) + } else { + chain = *ng.chain + } + chain = ng.appendAuthHandler(fr, chain, verifier) for _, middleware := range ng.middlewares { @@ -206,6 +213,10 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) { ng.tlsConfig = cfg } +func (ng *engine) setChainConfig(chain *alice.Chain) { + ng.chain = chain +} + func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) { ng.unauthorizedCallback = callback } diff --git a/rest/engine_test.go b/rest/engine_test.go index 57780783c8f9..1c84bff7112f 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -229,6 +229,44 @@ func TestEngine_checkedMaxBytes(t *testing.T) { } } +func TestEngine_checkedChain(t *testing.T) { + var called int32 + middleware1 := func() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&called, 1) + next.ServeHTTP(w, r) + atomic.AddInt32(&called, 1) + }) + } + } + middleware2 := func() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&called, 1) + next.ServeHTTP(w, r) + atomic.AddInt32(&called, 1) + }) + } + } + + server := MustNewServer(RestConf{}, WithChain(middleware1(), middleware2())) + server.router = chainRouter{} + server.AddRoutes( + []Route{ + { + Method: http.MethodGet, + Path: "/", + Handler: func(_ http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&called, 1) + }, + }, + }, + ) + server.ngin.bindRoutes(chainRouter{}) + assert.Equal(t, int32(5), atomic.LoadInt32(&called)) +} + func TestEngine_notFoundHandler(t *testing.T) { logx.Disable() @@ -343,3 +381,19 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) { func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) { } + +type chainRouter struct{} + +func (c chainRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { +} + +func (c chainRouter) Handle(_, _ string, handler http.Handler) error { + handler.ServeHTTP(nil, nil) + return nil +} + +func (c chainRouter) SetNotFoundHandler(_ http.Handler) { +} + +func (c chainRouter) SetNotAllowedHandler(_ http.Handler) { +} diff --git a/rest/server.go b/rest/server.go index 51a4dbc09d5c..681543167a57 100644 --- a/rest/server.go +++ b/rest/server.go @@ -7,6 +7,7 @@ import ( "path" "time" + "github.com/justinas/alice" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/rest/handler" "github.com/zeromicro/go-zero/rest/httpx" @@ -242,6 +243,17 @@ func WithTLSConfig(cfg *tls.Config) RunOption { } } +// WithChain returns a RunOption that with given chain config. +func WithChain(middlewares ...func(http.Handler) http.Handler) RunOption { + return func(svr *Server) { + chain := alice.New() + for _, middleware := range middlewares { + chain = chain.Append(middleware) + } + svr.ngin.setChainConfig(&chain) + } +} + // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set. func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { return func(svr *Server) {