Skip to content

Commit

Permalink
Add support for registering handlers for 404 routes
Browse files Browse the repository at this point in the history
  • Loading branch information
aldas committed Jul 10, 2022
1 parent ddb66e1 commit f33b7a6
Show file tree
Hide file tree
Showing 6 changed files with 372 additions and 22 deletions.
13 changes: 13 additions & 0 deletions echo.go
Expand Up @@ -183,6 +183,8 @@ const (
PROPFIND = "PROPFIND"
// REPORT Method can be used to get information about a resource, see rfc 3253
REPORT = "REPORT"
// RouteNotFound is special method type for routes handling "route not found" (404) cases
RouteNotFound = "echo_route_not_found"
)

// Headers
Expand Down Expand Up @@ -480,6 +482,16 @@ func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return e.Add(http.MethodTrace, path, h, m...)
}

// RouteNotFound registers a special-case route which is executed when no other route is found (i.e. HTTP 404 cases)
// for current request URL.
// Path supports static and named/any parameters just like other http method is defined. Generally path is ended with
// wildcard/match-any character (`/*`, `/download/*` etc).
//
// Example: `e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })`
func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return e.Add(RouteNotFound, path, h, m...)
}

// Any registers a new route for all HTTP methods and path with matching handler
// in the router with optional route-level middleware.
func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
Expand Down Expand Up @@ -515,6 +527,7 @@ func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route {
func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
name := handlerName(handler)
router := e.findRouter(host)
// FIXME: when handler+middleware are both nil ... make it behave like handler removal
router.Add(method, path, func(c Context) error {
h := applyMiddleware(handler, middleware...)
return h(c)
Expand Down
64 changes: 64 additions & 0 deletions echo_test.go
Expand Up @@ -766,6 +766,70 @@ func TestEchoNotFound(t *testing.T) {
assert.Equal(t, http.StatusNotFound, rec.Code)
}

func TestEcho_RouteNotFound(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectRoute interface{}
expectCode int
}{
{
name: "404, route to static not found handler /a/c/xx",
whenURL: "/a/c/xx",
expectRoute: "GET /a/c/xx",
expectCode: http.StatusNotFound,
},
{
name: "404, route to path param not found handler /a/:file",
whenURL: "/a/echo.exe",
expectRoute: "GET /a/:file",
expectCode: http.StatusNotFound,
},
{
name: "404, route to any not found handler /*",
whenURL: "/b/echo.exe",
expectRoute: "GET /*",
expectCode: http.StatusNotFound,
},
{
name: "200, route /a/c/df to /a/c/df",
whenURL: "/a/c/df",
expectRoute: "GET /a/c/df",
expectCode: http.StatusOK,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()

okHandler := func(c Context) error {
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
}
notFoundHandler := func(c Context) error {
return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
}

e.GET("/", okHandler)
e.GET("/a/c/df", okHandler)
e.GET("/a/b*", okHandler)
e.PUT("/*", okHandler)

e.RouteNotFound("/a/c/xx", notFoundHandler) // static
e.RouteNotFound("/a/:file", notFoundHandler) // param
e.RouteNotFound("/*", notFoundHandler) // any

req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()

e.ServeHTTP(rec, req)

assert.Equal(t, tc.expectCode, rec.Code)
assert.Equal(t, tc.expectRoute, rec.Body.String())
})
}
}

func TestEchoMethodNotAllowed(t *testing.T) {
e := New()

Expand Down
7 changes: 7 additions & 0 deletions group.go
Expand Up @@ -107,6 +107,13 @@ func (g *Group) File(path, file string) {
g.file(path, file, g.GET)
}

// RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group.
//
// Example: `g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })`
func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.Add(RouteNotFound, path, h, m...)
}

// Add implements `Echo#Add()` for sub-routes within the Group.
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
// Combine into a new slice to avoid accidentally passing the same slice for
Expand Down
65 changes: 65 additions & 0 deletions group_test.go
Expand Up @@ -119,3 +119,68 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
assert.Equal(t, "/*", m)

}

func TestGroup_RouteNotFound(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectRoute interface{}
expectCode int
}{
{
name: "404, route to static not found handler /group/a/c/xx",
whenURL: "/group/a/c/xx",
expectRoute: "GET /group/a/c/xx",
expectCode: http.StatusNotFound,
},
{
name: "404, route to path param not found handler /group/a/:file",
whenURL: "/group/a/echo.exe",
expectRoute: "GET /group/a/:file",
expectCode: http.StatusNotFound,
},
{
name: "404, route to any not found handler /group/*",
whenURL: "/group/b/echo.exe",
expectRoute: "GET /group/*",
expectCode: http.StatusNotFound,
},
{
name: "200, route /group/a/c/df to /group/a/c/df",
whenURL: "/group/a/c/df",
expectRoute: "GET /group/a/c/df",
expectCode: http.StatusOK,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
g := e.Group("/group")

okHandler := func(c Context) error {
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
}
notFoundHandler := func(c Context) error {
return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
}

g.GET("/", okHandler)
g.GET("/a/c/df", okHandler)
g.GET("/a/b*", okHandler)
g.PUT("/*", okHandler)

g.RouteNotFound("/a/c/xx", notFoundHandler) // static
g.RouteNotFound("/a/:file", notFoundHandler) // param
g.RouteNotFound("/*", notFoundHandler) // any

req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()

e.ServeHTTP(rec, req)

assert.Equal(t, tc.expectCode, rec.Code)
assert.Equal(t, tc.expectRoute, rec.Body.String())
})
}
}
60 changes: 38 additions & 22 deletions router.go
Expand Up @@ -28,6 +28,9 @@ type (
isLeaf bool
// isHandler indicates that node has at least one handler registered to it
isHandler bool

// notFoundHandler is handler registered with RouteNotFound method and is executed for 404 cases
notFoundHandler HandlerFunc
}
kind uint8
children []*node
Expand Down Expand Up @@ -68,6 +71,7 @@ func (m *methodHandler) isHandler() bool {
m.put != nil ||
m.trace != nil ||
m.report != nil
// RouteNotFound/404 is not considered as a handler
}

func (m *methodHandler) updateAllowHeader() {
Expand Down Expand Up @@ -369,14 +373,12 @@ func (n *node) addHandler(method string, h HandlerFunc) {
n.methodHandler.trace = h
case REPORT:
n.methodHandler.report = h
case RouteNotFound:
n.notFoundHandler = h
}

n.methodHandler.updateAllowHeader()
if h != nil {
n.isHandler = true
} else {
n.isHandler = n.methodHandler.isHandler()
}
n.isHandler = n.methodHandler.isHandler()
}

func (n *node) findHandler(method string) HandlerFunc {
Expand All @@ -403,7 +405,7 @@ func (n *node) findHandler(method string) HandlerFunc {
return n.methodHandler.trace
case REPORT:
return n.methodHandler.report
default:
default: // RouteNotFound/404 is not considered as a handler
return nil
}
}
Expand Down Expand Up @@ -506,7 +508,7 @@ func (r *Router) Find(method, path string, c Context) {
// No matching prefix, let's backtrack to the first possible alternative node of the decision path
nk, ok := backtrackToNextNodeKind(staticKind)
if !ok {
return // No other possibilities on the decision path
return // No other possibilities on the decision path, handler will be whatever context is reset to.
} else if nk == paramKind {
goto Param
// NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently
Expand All @@ -522,15 +524,21 @@ func (r *Router) Find(method, path string, c Context) {
search = search[lcpLen:]
searchIndex = searchIndex + lcpLen

// Finish routing if no remaining search and we are on a node with handler and matching method type
if search == "" && currentNode.isHandler {
// check if current node has handler registered for http method we are looking for. we store currentNode as
// best matching in case we do no find no more routes matching this path+method
if previousBestMatchNode == nil {
previousBestMatchNode = currentNode
}
if h := currentNode.findHandler(method); h != nil {
matchedHandler = h
// Finish routing if is no request path remaining to search
if search == "" {
// in case of node that is handler we have exact method type match or something for 405 to use
if currentNode.isHandler {
// check if current node has handler registered for http method we are looking for. we store currentNode as
// best matching in case we do no find no more routes matching this path+method
if previousBestMatchNode == nil {
previousBestMatchNode = currentNode
}
if h := currentNode.findHandler(method); h != nil {
matchedHandler = h
break
}
} else if currentNode.notFoundHandler != nil {
matchedHandler = currentNode.notFoundHandler
break
}
}
Expand All @@ -550,7 +558,8 @@ func (r *Router) Find(method, path string, c Context) {
i := 0
l := len(search)
if currentNode.isLeaf {
// when param node does not have any children then param node should act similarly to any node - consider all remaining search as match
// when param node does not have any children (path param is last piece of route path) then param node should
// act similarly to any node - consider all remaining search as match
i = l
} else {
for ; i < l && search[i] != '/'; i++ {
Expand All @@ -575,13 +584,16 @@ func (r *Router) Find(method, path string, c Context) {
searchIndex += +len(search)
search = ""

// check if current node has handler registered for http method we are looking for. we store currentNode as
// best matching in case we do no find no more routes matching this path+method
if h := currentNode.findHandler(method); h != nil {
matchedHandler = h
break
}
// we store currentNode as best matching in case we do not find more routes matching this path+method. Needed for 405
if previousBestMatchNode == nil {
previousBestMatchNode = currentNode
}
if h := currentNode.findHandler(method); h != nil {
matchedHandler = h
if currentNode.notFoundHandler != nil {
matchedHandler = currentNode.notFoundHandler
break
}
}
Expand All @@ -604,6 +616,8 @@ func (r *Router) Find(method, path string, c Context) {
return // nothing matched at all
}

// matchedHandler could be method+path handler that we matched or notFoundHandler from node with matching path
// user provided not found (404) handler has priority over generic method not found (405) handler or global 404 handler
if matchedHandler != nil {
ctx.handler = matchedHandler
} else {
Expand All @@ -612,7 +626,9 @@ func (r *Router) Find(method, path string, c Context) {
currentNode = previousBestMatchNode

ctx.handler = NotFoundHandler
if currentNode.isHandler {
if currentNode.notFoundHandler != nil {
ctx.handler = currentNode.notFoundHandler
} else if currentNode.isHandler {
ctx.Set(ContextKeyHeaderAllow, currentNode.methodHandler.allowHeader)
ctx.handler = MethodNotAllowedHandler
if method == http.MethodOptions {
Expand Down

0 comments on commit f33b7a6

Please sign in to comment.