diff --git a/app.go b/app.go index fcad05bfce..a5763ed26e 100644 --- a/app.go +++ b/app.go @@ -110,6 +110,8 @@ type App struct { latestRoute *Route // TLS handler tlsHandler *TLSHandler + // custom method check + customMethod bool // Mount fields mountFields *mountFields } @@ -380,6 +382,11 @@ type Config struct { // // Optional. Default: DefaultColors ColorScheme Colors `json:"color_scheme"` + + // RequestMethods provides customizibility for HTTP methods. You can add/remove methods as you wish. + // + // Optional. Defaukt: DefaultMethods + RequestMethods []string } // Static defines configuration options when defining static assets. @@ -445,6 +452,19 @@ const ( DefaultCompressedFileSuffix = ".fiber.gz" ) +// HTTP methods enabled by default +var DefaultMethods = []string{ + MethodGet, + MethodHead, + MethodPost, + MethodPut, + MethodDelete, + MethodConnect, + MethodOptions, + MethodTrace, + MethodPatch, +} + // DefaultErrorHandler that process return errors from handlers var DefaultErrorHandler = func(c *Ctx, err error) error { code := StatusInternalServerError @@ -469,9 +489,6 @@ var DefaultErrorHandler = func(c *Ctx, err error) error { func New(config ...Config) *App { // Create a new app app := &App{ - // Create router stack - stack: make([][]*Route, len(intMethod)), - treeStack: make([]map[string][]*Route, len(intMethod)), // Create Ctx pool pool: sync.Pool{ New: func() interface{} { @@ -538,12 +555,21 @@ func New(config ...Config) *App { if app.config.Network == "" { app.config.Network = NetworkTCP4 } + if len(app.config.RequestMethods) == 0 { + app.config.RequestMethods = DefaultMethods + } else { + app.customMethod = true + } app.config.trustedProxiesMap = make(map[string]struct{}, len(app.config.TrustedProxies)) for _, ipAddress := range app.config.TrustedProxies { app.handleTrustedProxy(ipAddress) } + // Create router stack + app.stack = make([][]*Route, len(app.config.RequestMethods)) + app.treeStack = make([]map[string][]*Route, len(app.config.RequestMethods)) + // Override colors app.config.ColorScheme = defaultColors(app.config.ColorScheme) @@ -724,7 +750,7 @@ func (app *App) Static(prefix, root string, config ...Static) Router { // All will register the handler on all HTTP methods func (app *App) All(path string, handlers ...Handler) Router { - for _, method := range intMethod { + for _, method := range app.config.RequestMethods { _ = app.Add(method, path, handlers...) } return app diff --git a/app_test.go b/app_test.go index 2e87114b26..c3b6565757 100644 --- a/app_test.go +++ b/app_test.go @@ -435,13 +435,32 @@ func Test_App_Use_StrictRouting(t *testing.T) { } func Test_App_Add_Method_Test(t *testing.T) { - app := New() defer func() { if err := recover(); err != nil { - utils.AssertEqual(t, "add: invalid http method JOHN\n", fmt.Sprintf("%v", err)) + utils.AssertEqual(t, "add: invalid http method JANE\n", fmt.Sprintf("%v", err)) } }() + + methods := append(DefaultMethods, "JOHN") + app := New(Config{ + RequestMethods: methods, + }) + app.Add("JOHN", "/doe", testEmptyHandler) + + resp, err := app.Test(httptest.NewRequest("JOHN", "/doe", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/doe", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, StatusMethodNotAllowed, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest("UNKNOWN", "/doe", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, StatusBadRequest, resp.StatusCode, "Status code") + + app.Add("JANE", "/doe", testEmptyHandler) } // go test -run Test_App_GETOnly @@ -487,7 +506,7 @@ func Test_App_Chaining(t *testing.T) { return c.SendStatus(202) }) // check handler count for registered HEAD route - utils.AssertEqual(t, 5, len(app.stack[methodInt(MethodHead)][0].Handlers), "app.Test(req)") + utils.AssertEqual(t, 5, len(app.stack[app.methodInt(MethodHead)][0].Handlers), "app.Test(req)") req := httptest.NewRequest(MethodPost, "/john", nil) @@ -1250,16 +1269,17 @@ func Test_App_Stack(t *testing.T) { app.Post("/path3", testEmptyHandler) stack := app.Stack() - utils.AssertEqual(t, 9, len(stack)) - utils.AssertEqual(t, 3, len(stack[methodInt(MethodGet)])) - utils.AssertEqual(t, 3, len(stack[methodInt(MethodHead)])) - utils.AssertEqual(t, 2, len(stack[methodInt(MethodPost)])) - utils.AssertEqual(t, 1, len(stack[methodInt(MethodPut)])) - utils.AssertEqual(t, 1, len(stack[methodInt(MethodPatch)])) - utils.AssertEqual(t, 1, len(stack[methodInt(MethodDelete)])) - utils.AssertEqual(t, 1, len(stack[methodInt(MethodConnect)])) - utils.AssertEqual(t, 1, len(stack[methodInt(MethodOptions)])) - utils.AssertEqual(t, 1, len(stack[methodInt(MethodTrace)])) + methodList := app.config.RequestMethods + utils.AssertEqual(t, len(methodList), len(stack)) + utils.AssertEqual(t, 3, len(stack[app.methodInt(MethodGet)])) + utils.AssertEqual(t, 3, len(stack[app.methodInt(MethodHead)])) + utils.AssertEqual(t, 2, len(stack[app.methodInt(MethodPost)])) + utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodPut)])) + utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodPatch)])) + utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodDelete)])) + utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodConnect)])) + utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodOptions)])) + utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodTrace)])) } // go test -run Test_App_HandlersCount @@ -1513,6 +1533,19 @@ func Test_App_SetTLSHandler(t *testing.T) { utils.AssertEqual(t, "example.golang", c.ClientHelloInfo().ServerName) } +func Test_App_AddCustomRequestMethod(t *testing.T) { + methods := append(DefaultMethods, "TEST") + app := New(Config{ + RequestMethods: methods, + }) + appMethods := app.config.RequestMethods + + // method name is always uppercase - https://datatracker.ietf.org/doc/html/rfc7231#section-4.1 + utils.AssertEqual(t, len(app.stack), len(appMethods)) + utils.AssertEqual(t, len(app.stack), len(appMethods)) + utils.AssertEqual(t, "TEST", appMethods[len(appMethods)-1]) +} + func TestApp_GetRoutes(t *testing.T) { app := New() app.Use(func(c *Ctx) error { @@ -1524,7 +1557,7 @@ func TestApp_GetRoutes(t *testing.T) { app.Delete("/delete", handler).Name("delete") app.Post("/post", handler).Name("post") routes := app.GetRoutes(false) - utils.AssertEqual(t, 11, len(routes)) + utils.AssertEqual(t, 2+len(app.config.RequestMethods), len(routes)) methodMap := map[string]string{"/delete": "delete", "/post": "post"} for _, route := range routes { name, ok := methodMap[route.Path] @@ -1540,5 +1573,4 @@ func TestApp_GetRoutes(t *testing.T) { utils.AssertEqual(t, true, ok) utils.AssertEqual(t, name, route.Name) } - } diff --git a/ctx.go b/ctx.go index dcbe712526..ac2e25855e 100644 --- a/ctx.go +++ b/ctx.go @@ -163,7 +163,7 @@ func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx { c.pathOriginal = app.getString(fctx.URI().PathOriginal()) // Set method c.method = app.getString(fctx.Request.Header.Method()) - c.methodINT = methodInt(c.method) + c.methodINT = app.methodInt(c.method) // Attach *fasthttp.RequestCtx to ctx c.fasthttp = fctx // reset base uri @@ -906,7 +906,7 @@ func (c *Ctx) Location(path string) { func (c *Ctx) Method(override ...string) string { if len(override) > 0 { method := utils.ToUpper(override[0]) - mINT := methodInt(method) + mINT := c.app.methodInt(method) if mINT == -1 { return c.method } diff --git a/group.go b/group.go index 93826807e9..a312fc5518 100644 --- a/group.go +++ b/group.go @@ -133,7 +133,7 @@ func (grp *Group) Static(prefix, root string, config ...Static) Router { // All will register the handler on all HTTP methods func (grp *Group) All(path string, handlers ...Handler) Router { - for _, method := range intMethod { + for _, method := range grp.app.config.RequestMethods { _ = grp.Add(method, path, handlers...) } return grp diff --git a/helpers.go b/helpers.go index c611b917f1..aa9511da82 100644 --- a/helpers.go +++ b/helpers.go @@ -78,8 +78,9 @@ func (app *App) quoteString(raw string) string { } // Scan stack if other methods match the request -func methodExist(ctx *Ctx) (exist bool) { - for i := 0; i < len(intMethod); i++ { +func (app *App) methodExist(ctx *Ctx) (exist bool) { + methods := app.config.RequestMethods + for i := 0; i < len(methods); i++ { // Skip original method if ctx.methodINT == i { continue @@ -109,7 +110,7 @@ func methodExist(ctx *Ctx) (exist bool) { // We matched exist = true // Add method to Allow header - ctx.Append(HeaderAllow, intMethod[i]) + ctx.Append(HeaderAllow, methods[i]) // Break stack loop break } @@ -331,42 +332,41 @@ var getBytesImmutable = func(s string) (b []byte) { } // HTTP methods and their unique INTs -func methodInt(s string) int { - switch s { - case MethodGet: - return 0 - case MethodHead: - return 1 - case MethodPost: - return 2 - case MethodPut: - return 3 - case MethodDelete: - return 4 - case MethodConnect: - return 5 - case MethodOptions: - return 6 - case MethodTrace: - return 7 - case MethodPatch: - return 8 - default: - return -1 +func (app *App) methodInt(s string) int { + // For better performance + if !app.customMethod { + switch s { + case MethodGet: + return 0 + case MethodHead: + return 1 + case MethodPost: + return 2 + case MethodPut: + return 3 + case MethodDelete: + return 4 + case MethodConnect: + return 5 + case MethodOptions: + return 6 + case MethodTrace: + return 7 + case MethodPatch: + return 8 + default: + return -1 + } + } + + // For method customization + for i, v := range app.config.RequestMethods { + if s == v { + return i + } } -} -// HTTP methods slice -var intMethod = []string{ - MethodGet, - MethodHead, - MethodPost, - MethodPut, - MethodDelete, - MethodConnect, - MethodOptions, - MethodTrace, - MethodPatch, + return -1 } // HTTP methods were copied from net/http. diff --git a/router.go b/router.go index d83c186863..84ff606443 100644 --- a/router.go +++ b/router.go @@ -139,7 +139,7 @@ func (app *App) next(c *Ctx) (match bool, err error) { // If no match, scan stack again if other methods match the request // Moved from app.handler because middleware may break the route chain - if !c.matched && methodExist(c) { + if !c.matched && app.methodExist(c) { err = ErrMethodNotAllowed } return @@ -216,7 +216,7 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl // Uppercase HTTP methods method = utils.ToUpper(method) // Check if the HTTP method is valid unless it's USE - if method != methodUse && methodInt(method) == -1 { + if method != methodUse && app.methodInt(method) == -1 { panic(fmt.Sprintf("add: invalid http method %s\n", method)) } // A route requires atleast one ctx handler @@ -277,7 +277,7 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl // Middleware route matches all HTTP methods if isUse { // Add route to all HTTP methods stack - for _, m := range intMethod { + for _, m := range app.config.RequestMethods { // Create a route copy to avoid duplicates during compression r := route app.addRoute(m, &r) @@ -435,7 +435,7 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) { } // Get unique HTTP method identifier - m := methodInt(method) + m := app.methodInt(method) // prevent identically route registration l := len(app.stack[m]) @@ -469,7 +469,7 @@ func (app *App) buildTree() *App { } // loop all the methods and stacks and create the prefix tree - for m := range intMethod { + for m := range app.config.RequestMethods { tsMap := make(map[string][]*Route) for _, route := range app.stack[m] { treePath := "" @@ -483,7 +483,7 @@ func (app *App) buildTree() *App { } // loop the methods and tree stacks and add global stack and sort everything - for m := range intMethod { + for m := range app.config.RequestMethods { tsMap := app.treeStack[m] for treePart := range tsMap { if treePart != "" {