diff --git a/handler.go b/handler.go index 0e866ca..8f4c1de 100644 --- a/handler.go +++ b/handler.go @@ -5,6 +5,7 @@ package flamego import ( + "fmt" "net/http" "reflect" @@ -52,7 +53,7 @@ func (invoke teapotInvoker) Invoke([]interface{}) ([]reflect.Value, error) { // gain up to 3x performance improvement. func validateAndWrapHandler(h Handler, wrapper func(Handler) Handler) Handler { if reflect.TypeOf(h).Kind() != reflect.Func { - panic("handler must be a callable function") + panic(fmt.Sprintf("handler must be a callable function, but got %T", h)) } if inject.IsFastInvoker(h) { diff --git a/router.go b/router.go index e5e155e..2c257d3 100644 --- a/router.go +++ b/router.go @@ -50,10 +50,12 @@ type Router interface { Trace(routePath string, handlers ...Handler) *Route // Any is a shortcut for `r.Route("*", routePath, handlers)`. Any(routePath string, handlers ...Handler) *Route - // Routes is a shortcut of adding same handlers for different HTTP methods. + // Routes is a shortcut of adding route with same list of handlers for different + // HTTP methods. // // Example: - // f.Routes("/", "GET,POST", handlers) + // f.Routes("/", http.MethodGet, http.MethodPost, handlers...) + // f.Routes("/", "GET,POST", handlers...) Routes(routePath, methods string, handlers ...Handler) *Route // NotFound configures a http.HandlerFunc to be called when no matching route is // found. When it is not set, http.NotFound is used. Be sure to set @@ -272,9 +274,24 @@ func (r *router) Routes(routePath, methods string, handlers ...Handler) *Route { panic("empty methods") } - var route *Route + var ms []string for _, m := range strings.Split(methods, ",") { - route = r.Route(strings.TrimSpace(m), routePath, handlers) + ms = append(ms, strings.TrimSpace(m)) + } + + // Collect methods from handlers if they are strings + for i, h := range handlers { + m, ok := h.(string) + if !ok { + handlers = handlers[i:] + break + } + ms = append(ms, m) + } + + var route *Route + for _, m := range ms { + route = r.Route(m, routePath, handlers) } return route } diff --git a/router_test.go b/router_test.go index c49ea9e..dbf95ea 100644 --- a/router_test.go +++ b/router_test.go @@ -119,29 +119,52 @@ func TestRouter_Route(t *testing.T) { func TestRouter_Routes(t *testing.T) { ctx := newMockContext() - contextCreator := func(w http.ResponseWriter, r *http.Request, params route.Params, handlers []Handler, urlPath urlPather) internalContext { + contextCreator := func(_ http.ResponseWriter, _ *http.Request, params route.Params, _ []Handler, _ urlPather) internalContext { ctx.MockContext.ParamFunc.SetDefaultHook(func(s string) string { return params[s] }) return ctx } - r := newRouter(contextCreator) - r.Routes("/routes", "GET,POST", func() {}) + t.Run("use single string", func(t *testing.T) { + r := newRouter(contextCreator) - for _, m := range []string{http.MethodGet, http.MethodPost} { - gotRoute := "" - ctx.run_ = func() { gotRoute = ctx.Param("route") } + r.Routes("/routes", "GET,POST", func() {}) - resp := httptest.NewRecorder() - req, err := http.NewRequest(m, "/routes", nil) - assert.Nil(t, err) + for _, m := range []string{http.MethodGet, http.MethodPost} { + gotRoute := "" + ctx.run_ = func() { gotRoute = ctx.Param("route") } - r.ServeHTTP(resp, req) + resp := httptest.NewRecorder() + req, err := http.NewRequest(m, "/routes", nil) + assert.Nil(t, err) - assert.Equal(t, http.StatusOK, resp.Code) - assert.Equal(t, "/routes", gotRoute) - } + r.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, "/routes", gotRoute) + } + }) + + t.Run("use multiple strings", func(t *testing.T) { + r := newRouter(contextCreator) + + r.Routes("/routes", http.MethodGet, http.MethodPost, func() {}) + + for _, m := range []string{http.MethodGet, http.MethodPost} { + gotRoute := "" + ctx.run_ = func() { gotRoute = ctx.Param("route") } + + resp := httptest.NewRecorder() + req, err := http.NewRequest(m, "/routes", nil) + assert.Nil(t, err) + + r.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, "/routes", gotRoute) + } + }) } func TestRouter_AutoHead(t *testing.T) {