diff --git a/echo.go b/echo.go index 5ae8a1424..0933b4b93 100644 --- a/echo.go +++ b/echo.go @@ -3,41 +3,40 @@ Package echo implements high performance, minimalist Go web framework. Example: - package main + package main - import ( - "net/http" + import ( + "net/http" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - ) + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + ) - // Handler - func hello(c echo.Context) error { - return c.String(http.StatusOK, "Hello, World!") - } + // Handler + func hello(c echo.Context) error { + return c.String(http.StatusOK, "Hello, World!") + } - func main() { - // Echo instance - e := echo.New() + func main() { + // Echo instance + e := echo.New() - // Middleware - e.Use(middleware.Logger()) - e.Use(middleware.Recover()) + // Middleware + e.Use(middleware.Logger()) + e.Use(middleware.Recover()) - // Routes - e.GET("/", hello) + // Routes + e.GET("/", hello) - // Start server - e.Logger.Fatal(e.Start(":1323")) - } + // Start server + e.Logger.Fatal(e.Start(":1323")) + } Learn more at https://echo.labstack.com */ package echo import ( - "bytes" stdContext "context" "crypto/tls" "errors" @@ -62,20 +61,28 @@ import ( type ( // Echo is the top-level framework instance. + // + // Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these + // fields from handlers/middlewares and changing field values at the same time leads to data-races. + // Adding new routes after the server has been started is also not safe! Echo struct { filesystem common // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get // listener address info (on which interface/port was listener binded) without having data races. - startupMutex sync.RWMutex + startupMutex sync.RWMutex + colorer *color.Color + + // premiddleware are middlewares that are run before routing is done. In case a pre-middleware returns + // an error the router is not executed and the request will end up in the global error handler. + premiddleware []MiddlewareFunc + middleware []MiddlewareFunc + maxParam *int + router *Router + routers map[string]*Router + pool sync.Pool + StdLogger *stdLog.Logger - colorer *color.Color - premiddleware []MiddlewareFunc - middleware []MiddlewareFunc - maxParam *int - router *Router - routers map[string]*Router - pool sync.Pool Server *http.Server TLSServer *http.Server Listener net.Listener @@ -93,6 +100,9 @@ type ( Logger Logger IPExtractor IPExtractor ListenerNetwork string + + // OnAddRouteHandler is called when Echo adds new route to specific host router. + OnAddRouteHandler func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) } // Route contains a handler and information for matching against requests. @@ -527,21 +537,20 @@ func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { return e.file(path, file, e.GET, m...) } -func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - name := handlerName(handler) +func (e *Echo) add(host, method, path string, handler HandlerFunc, middlewares ...MiddlewareFunc) *Route { router := e.findRouter(host) - // FIXME: when handler+middleware are both nil ... make it behave like handler removal - router.Add(method, path, func(c Context) error { - h := applyMiddleware(handler, middleware...) + //FIXME: when handler+middleware are both nil ... make it behave like handler removal + name := handlerName(handler) + route := router.add(method, path, name, func(c Context) error { + h := applyMiddleware(handler, middlewares...) return h(c) }) - r := &Route{ - Method: method, - Path: path, - Name: name, + + if e.OnAddRouteHandler != nil { + e.OnAddRouteHandler(host, *route, handler, middlewares) } - e.router.routes[method+path] = r - return r + + return route } // Add registers a new route for an HTTP method and path with matching handler @@ -578,35 +587,13 @@ func (e *Echo) URL(h HandlerFunc, params ...interface{}) string { // Reverse generates an URL from route name and provided parameters. func (e *Echo) Reverse(name string, params ...interface{}) string { - uri := new(bytes.Buffer) - ln := len(params) - n := 0 - for _, r := range e.router.routes { - if r.Name == name { - for i, l := 0, len(r.Path); i < l; i++ { - if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln { - for ; i < l && r.Path[i] != '/'; i++ { - } - uri.WriteString(fmt.Sprintf("%v", params[n])) - n++ - } - if i < l { - uri.WriteByte(r.Path[i]) - } - } - break - } - } - return uri.String() + return e.router.Reverse(name, params...) } -// Routes returns the registered routes. +// Routes returns the registered routes for default router. +// In case when Echo serves multiple hosts/domains use `e.Routers()["domain2.site"].Routes()` to get specific host routes. func (e *Echo) Routes() []*Route { - routes := make([]*Route, 0, len(e.router.routes)) - for _, v := range e.router.routes { - routes = append(routes, v) - } - return routes + return e.router.Routes() } // AcquireContext returns an empty `Context` instance from the pool. @@ -913,8 +900,8 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { // GetPath returns RawPath, if it's empty returns Path from URL // Difference between RawPath and Path is: -// * Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/. -// * RawPath is an optional field which only gets set if the default encoding is different from Path. +// - Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/. +// - RawPath is an optional field which only gets set if the default encoding is different from Path. func GetPath(r *http.Request) string { path := r.URL.RawPath if path == "" { diff --git a/echo_test.go b/echo_test.go index 6bece4fd3..ba7a1e7de 100644 --- a/echo_test.go +++ b/echo_test.go @@ -531,9 +531,9 @@ func TestEchoRoutes(t *testing.T) { } } -func TestEchoRoutesHandleHostsProperly(t *testing.T) { +func TestEchoRoutesHandleAdditionalHosts(t *testing.T) { e := New() - h := e.Host("route.com") + domain2Router := e.Host("domain2.router.com") routes := []*Route{ {http.MethodGet, "/users/:user/events", ""}, {http.MethodGet, "/users/:user/events/public", ""}, @@ -541,24 +541,61 @@ func TestEchoRoutesHandleHostsProperly(t *testing.T) { {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, } for _, r := range routes { - h.Add(r.Method, r.Path, func(c Context) error { + domain2Router.Add(r.Method, r.Path, func(c Context) error { return c.String(http.StatusOK, "OK") }) } + e.Add(http.MethodGet, "/api", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { - found := false - for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { - found = true - break - } + domain2Routes := e.Routers()["domain2.router.com"].Routes() + + assert.Len(t, domain2Routes, len(routes)) + for _, r := range domain2Routes { + found := false + for _, rr := range routes { + if r.Method == rr.Method && r.Path == rr.Path { + found = true + break } - if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) + } + if !found { + t.Errorf("Route %s %s not found", r.Method, r.Path) + } + } +} + +func TestEchoRoutesHandleDefaultHost(t *testing.T) { + e := New() + routes := []*Route{ + {http.MethodGet, "/users/:user/events", ""}, + {http.MethodGet, "/users/:user/events/public", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, + } + for _, r := range routes { + e.Add(r.Method, r.Path, func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + } + e.Host("subdomain.mysite.site").Add(http.MethodGet, "/api", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + defaultRouterRoutes := e.Routes() + assert.Len(t, defaultRouterRoutes, len(routes)) + for _, r := range defaultRouterRoutes { + found := false + for _, rr := range routes { + if r.Method == rr.Method && r.Path == rr.Path { + found = true + break } } + if !found { + t.Errorf("Route %s %s not found", r.Method, r.Path) + } } } @@ -1424,6 +1461,44 @@ func TestEchoListenerNetworkInvalid(t *testing.T) { assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323")) } +func TestEcho_OnAddRouteHandler(t *testing.T) { + type rr struct { + host string + route Route + handler HandlerFunc + middleware []MiddlewareFunc + } + dummyHandler := func(Context) error { return nil } + e := New() + + added := make([]rr, 0) + e.OnAddRouteHandler = func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) { + added = append(added, rr{ + host: host, + route: route, + handler: handler, + middleware: middleware, + }) + } + + e.GET("/static", NotFoundHandler) + e.Host("domain.site").GET("/static/*", dummyHandler, func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + return next(c) + } + }) + + assert.Len(t, added, 2) + + assert.Equal(t, "", added[0].host) + assert.Equal(t, Route{Method: http.MethodGet, Path: "/static", Name: "github.com/labstack/echo/v4.glob..func1"}, added[0].route) + assert.Len(t, added[0].middleware, 0) + + assert.Equal(t, "domain.site", added[1].host) + assert.Equal(t, Route{Method: http.MethodGet, Path: "/static/*", Name: "github.com/labstack/echo/v4.TestEcho_OnAddRouteHandler.func1"}, added[1].route) + assert.Len(t, added[1].middleware, 1) +} + func TestEchoReverse(t *testing.T) { e := New() dummyHandler := func(Context) error { return nil } @@ -1451,14 +1526,27 @@ func TestEchoReverseHandleHostProperly(t *testing.T) { dummyHandler := func(Context) error { return nil } e := New() + + // routes added to the default router are different form different hosts + e.GET("/static", dummyHandler).Name = "default-host /static" + e.GET("/static/*", dummyHandler).Name = "xxx" + + // different host h := e.Host("the_host") - h.GET("/static", dummyHandler).Name = "/static" - h.GET("/static/*", dummyHandler).Name = "/static/*" + h.GET("/static", dummyHandler).Name = "host2 /static" + h.GET("/static/v2/*", dummyHandler).Name = "xxx" + + assert.Equal(t, "/static", e.Reverse("default-host /static")) + // when actual route does not have params and we provide some to Reverse we should get that route url back + assert.Equal(t, "/static", e.Reverse("default-host /static", "missing param")) + + host2Router := e.Routers()["the_host"] + assert.Equal(t, "/static", host2Router.Reverse("host2 /static")) + assert.Equal(t, "/static", host2Router.Reverse("host2 /static", "missing param")) + + assert.Equal(t, "/static/v2/*", host2Router.Reverse("xxx")) + assert.Equal(t, "/static/v2/foo.txt", host2Router.Reverse("xxx", "foo.txt")) - assert.Equal(t, "/static", e.Reverse("/static")) - assert.Equal(t, "/static", e.Reverse("/static", "missing param")) - assert.Equal(t, "/static/*", e.Reverse("/static/*")) - assert.Equal(t, "/static/foo.txt", e.Reverse("/static/*", "foo.txt")) } func TestEcho_ListenerAddr(t *testing.T) { diff --git a/router.go b/router.go index 23c5bd3ba..86a986a29 100644 --- a/router.go +++ b/router.go @@ -2,6 +2,7 @@ package echo import ( "bytes" + "fmt" "net/http" ) @@ -141,6 +142,51 @@ func NewRouter(e *Echo) *Router { } } +// Routes returns the registered routes. +func (r *Router) Routes() []*Route { + routes := make([]*Route, 0, len(r.routes)) + for _, v := range r.routes { + routes = append(routes, v) + } + return routes +} + +// Reverse generates an URL from route name and provided parameters. +func (r *Router) Reverse(name string, params ...interface{}) string { + uri := new(bytes.Buffer) + ln := len(params) + n := 0 + for _, route := range r.routes { + if route.Name == name { + for i, l := 0, len(route.Path); i < l; i++ { + if (route.Path[i] == ':' || route.Path[i] == '*') && n < ln { + for ; i < l && route.Path[i] != '/'; i++ { + } + uri.WriteString(fmt.Sprintf("%v", params[n])) + n++ + } + if i < l { + uri.WriteByte(route.Path[i]) + } + } + break + } + } + return uri.String() +} + +func (r *Router) add(method, path, name string, h HandlerFunc) *Route { + r.Add(method, path, h) + + route := &Route{ + Method: method, + Path: path, + Name: name, + } + r.routes[method+path] = route + return route +} + // Add registers a new route for method and path with matching handler. func (r *Router) Add(method, path string, h HandlerFunc) { // Validate path diff --git a/router_test.go b/router_test.go index a95421011..825170a3f 100644 --- a/router_test.go +++ b/router_test.go @@ -914,19 +914,22 @@ func TestRouterParamWithSlash(t *testing.T) { // Searching route for "/a/c/f" should match "/a/*/f" // When route `4) /a/*/f` is not added then request for "/a/c/f" should match "/:e/c/f" // -// +----------+ -// +-----+ "/" root +--------------------+--------------------------+ -// | +----------+ | | -// | | | -// +-------v-------+ +---v---------+ +-------v---+ -// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) | -// +-+----------+--+ | +-----------+-+ +-----------+ -// | | | | +// +----------+ +// +-----+ "/" root +--------------------+--------------------------+ +// | +----------+ | | +// | | | +// +-------v-------+ +---v---------+ +-------v---+ +// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) | +// +-+----------+--+ | +-----------+-+ +-----------+ +// | | | | +// // +---------------v+ +-- ---v------+ +------v----+ +-----v-----------+ // | "c/d" (static) | | ":" (param) | | "*" (any) | | "/c/f" (static) | // +---------+------+ +--------+----+ +----------++ +-----------------+ -// | | | -// | | | +// +// | | | +// | | | +// // +---------v----+ +------v--------+ +------v--------+ // | "f" (static) | | "/c" (static) | | "/f" (static) | // +--------------+ +---------------+ +---------------+ @@ -998,22 +1001,22 @@ func TestRouteMultiLevelBacktracking(t *testing.T) { // // Request for "/a/c/f" should match "/:e/c/f" // -// +-0,7--------+ -// | "/" (root) |----------------------------------+ -// +------------+ | -// | | | -// | | | -// +-1,6-----------+ | | +-8-----------+ +------v----+ -// | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) | -// +---------------+ +-------------+ +-----------+ -// | | | -// +-2--------v-----+ +v-3,5--------+ +-9------v--------+ -// | "c/d" (static) | | ":" (param) | | "/c/f" (static) | -// +----------------+ +-------------+ +-----------------+ -// | -// +-4--v----------+ -// | "/c" (static) | -// +---------------+ +// +-0,7--------+ +// | "/" (root) |----------------------------------+ +// +------------+ | +// | | | +// | | | +// +-1,6-----------+ | | +-8-----------+ +------v----+ +// | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) | +// +---------------+ +-------------+ +-----------+ +// | | | +// +-2--------v-----+ +v-3,5--------+ +-9------v--------+ +// | "c/d" (static) | | ":" (param) | | "/c/f" (static) | +// +----------------+ +-------------+ +-----------------+ +// | +// +-4--v----------+ +// | "/c" (static) | +// +---------------+ func TestRouteMultiLevelBacktracking2(t *testing.T) { e := New() r := e.router @@ -2695,6 +2698,87 @@ func TestRouterHandleMethodOptions(t *testing.T) { } } +func TestRouter_Routes(t *testing.T) { + type rr struct { + method string + path string + name string + } + var testCases = []struct { + name string + givenRoutes []rr + expect []rr + }{ + { + name: "ok, multiple", + givenRoutes: []rr{ + {method: http.MethodGet, path: "/static", name: "/static"}, + {method: http.MethodGet, path: "/static/*", name: "/static/*"}, + }, + expect: []rr{ + {method: http.MethodGet, path: "/static", name: "/static"}, + {method: http.MethodGet, path: "/static/*", name: "/static/*"}, + }, + }, + { + name: "ok, no routes", + givenRoutes: []rr{}, + expect: []rr{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dummyHandler := func(Context) error { return nil } + + e := New() + route := e.router + + for _, tmp := range tc.givenRoutes { + route.add(tmp.method, tmp.path, tmp.name, dummyHandler) + } + + // Add does not add route. because of backwards compatibility we can not change this method signature + route.Add("LOCK", "/users", handlerFunc) + + result := route.Routes() + assert.Len(t, result, len(tc.expect)) + for _, r := range result { + for _, tmp := range tc.expect { + if tmp.name == r.Name { + assert.Equal(t, tmp.method, r.Method) + assert.Equal(t, tmp.path, r.Path) + } + } + } + }) + } +} + +func TestRouter_Reverse(t *testing.T) { + e := New() + r := e.router + dummyHandler := func(Context) error { return nil } + + r.add(http.MethodGet, "/static", "/static", dummyHandler) + r.add(http.MethodGet, "/static/*", "/static/*", dummyHandler) + r.add(http.MethodGet, "/params/:foo", "/params/:foo", dummyHandler) + r.add(http.MethodGet, "/params/:foo/bar/:qux", "/params/:foo/bar/:qux", dummyHandler) + r.add(http.MethodGet, "/params/:foo/bar/:qux/*", "/params/:foo/bar/:qux/*", dummyHandler) + + assert.Equal(t, "/static", r.Reverse("/static")) + assert.Equal(t, "/static", r.Reverse("/static", "missing param")) + assert.Equal(t, "/static/*", r.Reverse("/static/*")) + assert.Equal(t, "/static/foo.txt", r.Reverse("/static/*", "foo.txt")) + + assert.Equal(t, "/params/:foo", r.Reverse("/params/:foo")) + assert.Equal(t, "/params/one", r.Reverse("/params/:foo", "one")) + assert.Equal(t, "/params/:foo/bar/:qux", r.Reverse("/params/:foo/bar/:qux")) + assert.Equal(t, "/params/one/bar/:qux", r.Reverse("/params/:foo/bar/:qux", "one")) + assert.Equal(t, "/params/one/bar/two", r.Reverse("/params/:foo/bar/:qux", "one", "two")) + assert.Equal(t, "/params/one/bar/two/three", r.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) +} + func TestRouterAllowHeaderForAnyOtherMethodType(t *testing.T) { e := New() r := e.router