diff --git a/middleware.go b/middleware.go index cf2b26dc..cb51c565 100644 --- a/middleware.go +++ b/middleware.go @@ -58,22 +58,17 @@ func CORSMethodMiddleware(r *Router) MiddlewareFunc { func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) { var allMethods []string - err := r.Walk(func(route *Route, _ *Router, _ []*Route) error { - for _, m := range route.matchers { - if _, ok := m.(*routeRegexp); ok { - if m.Match(req, &RouteMatch{}) { - methods, err := route.GetMethods() - if err != nil { - return err - } - - allMethods = append(allMethods, methods...) - } - break + for _, route := range r.routes { + var match RouteMatch + if route.Match(req, &match) || match.MatchErr == ErrMethodMismatch { + methods, err := route.GetMethods() + if err != nil { + return nil, err } + + allMethods = append(allMethods, methods...) } - return nil - }) + } - return allMethods, err + return allMethods, nil } diff --git a/middleware_test.go b/middleware_test.go index 27647afe..e9f0ef55 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -478,6 +478,26 @@ func TestCORSMethodMiddleware(t *testing.T) { } } +func TestCORSMethodMiddlewareSubrouter(t *testing.T) { + router := NewRouter().StrictSlash(true) + + subrouter := router.PathPrefix("/test").Subrouter() + subrouter.HandleFunc("/hello", stringHandler("a")).Methods(http.MethodGet, http.MethodOptions, http.MethodPost) + subrouter.HandleFunc("/hello/{name}", stringHandler("b")).Methods(http.MethodGet, http.MethodOptions) + + subrouter.Use(CORSMethodMiddleware(subrouter)) + + rw := NewRecorder() + req := newRequest("GET", "/test/hello/asdf") + router.ServeHTTP(rw, req) + + actualMethods := rw.Header().Get("Access-Control-Allow-Methods") + expectedMethods := "GET,OPTIONS" + if actualMethods != expectedMethods { + t.Fatalf("expected methods %q but got: %q", expectedMethods, actualMethods) + } +} + func TestMiddlewareOnMultiSubrouter(t *testing.T) { first := "first" second := "second"