From cba12a570e8caa1fafabc2d41afc97bd7a83d758 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 6 Aug 2022 23:25:43 +0300 Subject: [PATCH] Allow arbitrary HTTP method types to be added as routes --- echo.go | 5 +++- router.go | 19 ++++++++++-- router_test.go | 80 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 3 deletions(-) diff --git a/echo.go b/echo.go index 5b10d586e..5738578df 100644 --- a/echo.go +++ b/echo.go @@ -492,8 +492,11 @@ func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *R return e.Add(RouteNotFound, path, h, m...) } -// Any registers a new route for all HTTP methods and path with matching handler +// Any registers a new route for all HTTP methods (supported by Echo) and path with matching handler // in the router with optional route-level middleware. +// +// Note: this method only adds specific set of supported HTTP methods as handler and is not true +// "catch-any-arbitrary-method" way of matching requests. func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { routes := make([]*Route, len(methods)) for i, m := range methods { diff --git a/router.go b/router.go index 90102a294..23c5bd3ba 100644 --- a/router.go +++ b/router.go @@ -51,6 +51,7 @@ type ( put *routeMethod trace *routeMethod report *routeMethod + anyOther map[string]*routeMethod allowHeader string } ) @@ -75,7 +76,8 @@ func (m *routeMethods) isHandler() bool { m.propfind != nil || m.put != nil || m.trace != nil || - m.report != nil + m.report != nil || + len(m.anyOther) != 0 // RouteNotFound/404 is not considered as a handler } @@ -121,6 +123,10 @@ func (m *routeMethods) updateAllowHeader() { if m.report != nil { buf.WriteString(", REPORT") } + for method := range m.anyOther { // for simplicity, we use map and therefore order is not deterministic here + buf.WriteString(", ") + buf.WriteString(method) + } m.allowHeader = buf.String() } @@ -408,6 +414,15 @@ func (n *node) addMethod(method string, h *routeMethod) { case RouteNotFound: n.notFoundHandler = h return // RouteNotFound/404 is not considered as a handler so no further logic needs to be executed + default: + if n.methods.anyOther == nil { + n.methods.anyOther = make(map[string]*routeMethod) + } + if h.handler == nil { + delete(n.methods.anyOther, method) + } else { + n.methods.anyOther[method] = h + } } n.methods.updateAllowHeader() @@ -439,7 +454,7 @@ func (n *node) findMethod(method string) *routeMethod { case REPORT: return n.methods.report default: // RouteNotFound/404 is not considered as a handler - return nil + return n.methods.anyOther[method] } } diff --git a/router_test.go b/router_test.go index 1b0c409b6..a95421011 100644 --- a/router_test.go +++ b/router_test.go @@ -716,6 +716,67 @@ func TestRouterParam(t *testing.T) { } } +func TestRouter_addAndMatchAllSupportedMethods(t *testing.T) { + var testCases = []struct { + name string + givenNoAddRoute bool + whenMethod string + expectPath string + expectError string + }{ + {name: "ok, CONNECT", whenMethod: http.MethodConnect}, + {name: "ok, DELETE", whenMethod: http.MethodDelete}, + {name: "ok, GET", whenMethod: http.MethodGet}, + {name: "ok, HEAD", whenMethod: http.MethodHead}, + {name: "ok, OPTIONS", whenMethod: http.MethodOptions}, + {name: "ok, PATCH", whenMethod: http.MethodPatch}, + {name: "ok, POST", whenMethod: http.MethodPost}, + {name: "ok, PROPFIND", whenMethod: PROPFIND}, + {name: "ok, PUT", whenMethod: http.MethodPut}, + {name: "ok, TRACE", whenMethod: http.MethodTrace}, + {name: "ok, REPORT", whenMethod: REPORT}, + {name: "ok, NON_TRADITIONAL_METHOD", whenMethod: "NON_TRADITIONAL_METHOD"}, + { + name: "ok, NOT_EXISTING_METHOD", + whenMethod: "NOT_EXISTING_METHOD", + givenNoAddRoute: true, + expectPath: "/*", + expectError: "code=405, message=Method Not Allowed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + e.GET("/*", handlerFunc) + + if !tc.givenNoAddRoute { + e.Add(tc.whenMethod, "/my/*", handlerFunc) + } + + req := httptest.NewRequest(tc.whenMethod, "/my/some-url", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) + + e.router.Find(tc.whenMethod, "/my/some-url", c) + err := c.handler(c) + + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + expectPath := "/my/*" + if tc.expectPath != "" { + expectPath = tc.expectPath + } + assert.Equal(t, expectPath, c.Path()) + }) + } +} + func TestMethodNotAllowedAndNotFound(t *testing.T) { e := New() r := e.router @@ -2634,6 +2695,25 @@ func TestRouterHandleMethodOptions(t *testing.T) { } } +func TestRouterAllowHeaderForAnyOtherMethodType(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/users", handlerFunc) + r.Add("COPY", "/users", handlerFunc) + r.Add("LOCK", "/users", handlerFunc) + + req := httptest.NewRequest("TEST", "/users", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) + + r.Find("TEST", "/users", c) + err := c.handler(c) + + assert.EqualError(t, err, "code=405, message=Method Not Allowed") + assert.ElementsMatch(t, []string{"COPY", "GET", "LOCK", "OPTIONS"}, strings.Split(c.Response().Header().Get(HeaderAllow), ", ")) +} + func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) { e := New() r := e.router