From f5012baad77962bc68079dcaabeaed4971aab8b8 Mon Sep 17 00:00:00 2001 From: kemq1 Date: Mon, 16 May 2022 21:32:44 +0800 Subject: [PATCH 1/3] add user middleware chain function --- rest/engine.go | 37 +++++++++++++++++++---------- rest/engine_test.go | 58 +++++++++++++++++++++++++++++++++++++++++++++ rest/server.go | 8 +++++++ 3 files changed, 90 insertions(+), 13 deletions(-) diff --git a/rest/engine.go b/rest/engine.go index 0e238d11d176..d110123fee00 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -32,6 +32,7 @@ type engine struct { shedder load.Shedder priorityShedder load.Shedder tlsConfig *tls.Config + chain *alice.Chain } func newEngine(c RestConf) *engine { @@ -84,19 +85,25 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, route Route, verifier func(chain alice.Chain) alice.Chain) error { - chain := alice.New( - handler.TracingHandler(ng.conf.Name, route.Path), - ng.getLogHandler(), - handler.PrometheusHandler(route.Path), - handler.MaxConns(ng.conf.MaxConns), - handler.BreakerHandler(route.Method, route.Path, metrics), - handler.SheddingHandler(ng.getShedder(fr.priority), metrics), - handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)), - handler.RecoverHandler, - handler.MetricHandler(metrics), - handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)), - handler.GunzipHandler, - ) + chain := alice.Chain{} + if ng.chain == nil { + chain = alice.New( + handler.TracingHandler(ng.conf.Name, route.Path), + ng.getLogHandler(), + handler.PrometheusHandler(route.Path), + handler.MaxConns(ng.conf.MaxConns), + handler.BreakerHandler(route.Method, route.Path, metrics), + handler.SheddingHandler(ng.getShedder(fr.priority), metrics), + handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)), + handler.RecoverHandler, + handler.MetricHandler(metrics), + handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)), + handler.GunzipHandler, + ) + } else { + chain = *ng.chain + } + chain = ng.appendAuthHandler(fr, chain, verifier) for _, middleware := range ng.middlewares { @@ -188,6 +195,10 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) { ng.tlsConfig = cfg } +func (ng *engine) setChainConfig(chain *alice.Chain) { + ng.chain = chain +} + func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) { ng.unauthorizedCallback = callback } diff --git a/rest/engine_test.go b/rest/engine_test.go index a8df1dd9386c..b99a68237e5b 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/justinas/alice" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/conf" "github.com/zeromicro/go-zero/core/logx" @@ -229,6 +230,47 @@ func TestEngine_checkedMaxBytes(t *testing.T) { } } +func TestEngine_checkedChain(t *testing.T) { + var called int32 + middleware1 := func() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&called, 1) + next.ServeHTTP(w, r) + atomic.AddInt32(&called, 1) + }) + } + } + middleware2 := func() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&called, 1) + next.ServeHTTP(w, r) + atomic.AddInt32(&called, 1) + }) + } + } + + chain := alice.New(middleware1(), middleware2()) + chains := WithChain(&chain) + + server := MustNewServer(RestConf{}, chains) + server.router = chainRouter{} + server.AddRoutes( + []Route{ + { + Method: http.MethodGet, + Path: "/", + Handler: func(_ http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&called, 1) + }, + }, + }, + ) + server.ngin.bindRoutes(chainRouter{}) + assert.Equal(t, int32(5), atomic.LoadInt32(&called)) +} + func TestEngine_notFoundHandler(t *testing.T) { logx.Disable() @@ -312,3 +354,19 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) { func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) { } + +type chainRouter struct{} + +func (c chainRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { +} + +func (c chainRouter) Handle(_, _ string, handler http.Handler) error { + handler.ServeHTTP(nil, nil) + return nil +} + +func (c chainRouter) SetNotFoundHandler(_ http.Handler) { +} + +func (c chainRouter) SetNotAllowedHandler(_ http.Handler) { +} diff --git a/rest/server.go b/rest/server.go index 669f7084c514..8d0f952bad08 100644 --- a/rest/server.go +++ b/rest/server.go @@ -7,6 +7,7 @@ import ( "path" "time" + "github.com/justinas/alice" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/rest/handler" "github.com/zeromicro/go-zero/rest/httpx" @@ -237,6 +238,13 @@ func WithTLSConfig(cfg *tls.Config) RunOption { } } +// WithChain returns a RunOption that with given chain config. +func WithChain(chain *alice.Chain) RunOption { + return func(svr *Server) { + svr.ngin.setChainConfig(chain) + } +} + // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set. func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { return func(svr *Server) { From 73f816e86d4e71cee55c3adf0054979b584fb44d Mon Sep 17 00:00:00 2001 From: kemq1 Date: Tue, 17 May 2022 10:57:59 +0800 Subject: [PATCH 2/3] fix staticcheck SA4006 --- rest/engine.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest/engine.go b/rest/engine.go index d110123fee00..348cc7328ca9 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -85,7 +85,7 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, route Route, verifier func(chain alice.Chain) alice.Chain) error { - chain := alice.Chain{} + var chain alice.Chain if ng.chain == nil { chain = alice.New( handler.TracingHandler(ng.conf.Name, route.Path), From 875600343ee86a9d6cd80a4fdc89f23b22d7b9d0 Mon Sep 17 00:00:00 2001 From: kemq1 Date: Thu, 2 Jun 2022 17:48:16 +0800 Subject: [PATCH 3/3] chang code Implementation style --- rest/engine_test.go | 6 +----- rest/server.go | 8 ++++++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/rest/engine_test.go b/rest/engine_test.go index b99a68237e5b..53209adb4af6 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/justinas/alice" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/conf" "github.com/zeromicro/go-zero/core/logx" @@ -251,10 +250,7 @@ func TestEngine_checkedChain(t *testing.T) { } } - chain := alice.New(middleware1(), middleware2()) - chains := WithChain(&chain) - - server := MustNewServer(RestConf{}, chains) + server := MustNewServer(RestConf{}, WithChain(middleware1(), middleware2())) server.router = chainRouter{} server.AddRoutes( []Route{ diff --git a/rest/server.go b/rest/server.go index 8d0f952bad08..124b4d20cb84 100644 --- a/rest/server.go +++ b/rest/server.go @@ -239,9 +239,13 @@ func WithTLSConfig(cfg *tls.Config) RunOption { } // WithChain returns a RunOption that with given chain config. -func WithChain(chain *alice.Chain) RunOption { +func WithChain(middlewares ...func(http.Handler) http.Handler) RunOption { return func(svr *Server) { - svr.ngin.setChainConfig(chain) + chain := alice.New() + for _, middleware := range middlewares { + chain = chain.Append(middleware) + } + svr.ngin.setChainConfig(&chain) } }