Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Register custom methods #2107

Merged
34 changes: 30 additions & 4 deletions app.go
Expand Up @@ -110,6 +110,8 @@ type App struct {
latestRoute *Route
// TLS handler
tlsHandler *TLSHandler
// custom method check
customMethod bool
// Mount fields
mountFields *mountFields
}
Expand Down Expand Up @@ -380,6 +382,11 @@ type Config struct {
//
// Optional. Default: DefaultColors
ColorScheme Colors `json:"color_scheme"`

// RequestMethods provides customizibility for HTTP methods. You can add/remove methods as you wish.
//
// Optional. Defaukt: DefaultMethods
RequestMethods []string
}

// Static defines configuration options when defining static assets.
Expand Down Expand Up @@ -445,6 +452,19 @@ const (
DefaultCompressedFileSuffix = ".fiber.gz"
)

// HTTP methods enabled by default
var DefaultMethods = []string{
MethodGet,
MethodHead,
MethodPost,
MethodPut,
MethodDelete,
MethodConnect,
MethodOptions,
MethodTrace,
MethodPatch,
}

// DefaultErrorHandler that process return errors from handlers
var DefaultErrorHandler = func(c *Ctx, err error) error {
code := StatusInternalServerError
Expand All @@ -469,9 +489,6 @@ var DefaultErrorHandler = func(c *Ctx, err error) error {
func New(config ...Config) *App {
// Create a new app
app := &App{
// Create router stack
stack: make([][]*Route, len(intMethod)),
treeStack: make([]map[string][]*Route, len(intMethod)),
// Create Ctx pool
pool: sync.Pool{
New: func() interface{} {
Expand Down Expand Up @@ -538,12 +555,21 @@ func New(config ...Config) *App {
if app.config.Network == "" {
app.config.Network = NetworkTCP4
}
if len(app.config.RequestMethods) == 0 {
app.config.RequestMethods = DefaultMethods
} else {
app.customMethod = true
}

app.config.trustedProxiesMap = make(map[string]struct{}, len(app.config.TrustedProxies))
for _, ipAddress := range app.config.TrustedProxies {
app.handleTrustedProxy(ipAddress)
}

// Create router stack
app.stack = make([][]*Route, len(app.config.RequestMethods))
app.treeStack = make([]map[string][]*Route, len(app.config.RequestMethods))

// Override colors
app.config.ColorScheme = defaultColors(app.config.ColorScheme)

Expand Down Expand Up @@ -724,7 +750,7 @@ func (app *App) Static(prefix, root string, config ...Static) Router {

// All will register the handler on all HTTP methods
func (app *App) All(path string, handlers ...Handler) Router {
for _, method := range intMethod {
for _, method := range app.config.RequestMethods {
_ = app.Add(method, path, handlers...)
}
return app
Expand Down
62 changes: 47 additions & 15 deletions app_test.go
Expand Up @@ -435,13 +435,32 @@ func Test_App_Use_StrictRouting(t *testing.T) {
}

func Test_App_Add_Method_Test(t *testing.T) {
app := New()
defer func() {
if err := recover(); err != nil {
utils.AssertEqual(t, "add: invalid http method JOHN\n", fmt.Sprintf("%v", err))
utils.AssertEqual(t, "add: invalid http method JANE\n", fmt.Sprintf("%v", err))
}
}()

methods := append(DefaultMethods, "JOHN")
app := New(Config{
RequestMethods: methods,
})

app.Add("JOHN", "/doe", testEmptyHandler)

resp, err := app.Test(httptest.NewRequest("JOHN", "/doe", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code")

resp, err = app.Test(httptest.NewRequest(MethodGet, "/doe", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, StatusMethodNotAllowed, resp.StatusCode, "Status code")

resp, err = app.Test(httptest.NewRequest("UNKNOWN", "/doe", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, StatusBadRequest, resp.StatusCode, "Status code")

app.Add("JANE", "/doe", testEmptyHandler)
}

// go test -run Test_App_GETOnly
Expand Down Expand Up @@ -487,7 +506,7 @@ func Test_App_Chaining(t *testing.T) {
return c.SendStatus(202)
})
// check handler count for registered HEAD route
utils.AssertEqual(t, 5, len(app.stack[methodInt(MethodHead)][0].Handlers), "app.Test(req)")
utils.AssertEqual(t, 5, len(app.stack[app.methodInt(MethodHead)][0].Handlers), "app.Test(req)")

req := httptest.NewRequest(MethodPost, "/john", nil)

Expand Down Expand Up @@ -1250,16 +1269,17 @@ func Test_App_Stack(t *testing.T) {
app.Post("/path3", testEmptyHandler)

stack := app.Stack()
utils.AssertEqual(t, 9, len(stack))
utils.AssertEqual(t, 3, len(stack[methodInt(MethodGet)]))
utils.AssertEqual(t, 3, len(stack[methodInt(MethodHead)]))
utils.AssertEqual(t, 2, len(stack[methodInt(MethodPost)]))
utils.AssertEqual(t, 1, len(stack[methodInt(MethodPut)]))
utils.AssertEqual(t, 1, len(stack[methodInt(MethodPatch)]))
utils.AssertEqual(t, 1, len(stack[methodInt(MethodDelete)]))
utils.AssertEqual(t, 1, len(stack[methodInt(MethodConnect)]))
utils.AssertEqual(t, 1, len(stack[methodInt(MethodOptions)]))
utils.AssertEqual(t, 1, len(stack[methodInt(MethodTrace)]))
methodList := app.config.RequestMethods
utils.AssertEqual(t, len(methodList), len(stack))
utils.AssertEqual(t, 3, len(stack[app.methodInt(MethodGet)]))
utils.AssertEqual(t, 3, len(stack[app.methodInt(MethodHead)]))
utils.AssertEqual(t, 2, len(stack[app.methodInt(MethodPost)]))
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodPut)]))
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodPatch)]))
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodDelete)]))
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodConnect)]))
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodOptions)]))
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodTrace)]))
}

// go test -run Test_App_HandlersCount
Expand Down Expand Up @@ -1513,6 +1533,19 @@ func Test_App_SetTLSHandler(t *testing.T) {
utils.AssertEqual(t, "example.golang", c.ClientHelloInfo().ServerName)
}

func Test_App_AddCustomRequestMethod(t *testing.T) {
methods := append(DefaultMethods, "TEST")
app := New(Config{
RequestMethods: methods,
})
appMethods := app.config.RequestMethods

// method name is always uppercase - https://datatracker.ietf.org/doc/html/rfc7231#section-4.1
utils.AssertEqual(t, len(app.stack), len(appMethods))
utils.AssertEqual(t, len(app.stack), len(appMethods))
utils.AssertEqual(t, "TEST", appMethods[len(appMethods)-1])
}

func TestApp_GetRoutes(t *testing.T) {
app := New()
app.Use(func(c *Ctx) error {
Expand All @@ -1524,7 +1557,7 @@ func TestApp_GetRoutes(t *testing.T) {
app.Delete("/delete", handler).Name("delete")
app.Post("/post", handler).Name("post")
routes := app.GetRoutes(false)
utils.AssertEqual(t, 11, len(routes))
utils.AssertEqual(t, 2+len(app.config.RequestMethods), len(routes))
methodMap := map[string]string{"/delete": "delete", "/post": "post"}
for _, route := range routes {
name, ok := methodMap[route.Path]
Expand All @@ -1540,5 +1573,4 @@ func TestApp_GetRoutes(t *testing.T) {
utils.AssertEqual(t, true, ok)
utils.AssertEqual(t, name, route.Name)
}

}
4 changes: 2 additions & 2 deletions ctx.go
Expand Up @@ -163,7 +163,7 @@ func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx {
c.pathOriginal = app.getString(fctx.URI().PathOriginal())
// Set method
c.method = app.getString(fctx.Request.Header.Method())
c.methodINT = methodInt(c.method)
c.methodINT = app.methodInt(c.method)
// Attach *fasthttp.RequestCtx to ctx
c.fasthttp = fctx
// reset base uri
Expand Down Expand Up @@ -906,7 +906,7 @@ func (c *Ctx) Location(path string) {
func (c *Ctx) Method(override ...string) string {
if len(override) > 0 {
method := utils.ToUpper(override[0])
mINT := methodInt(method)
mINT := c.app.methodInt(method)
if mINT == -1 {
return c.method
}
Expand Down
2 changes: 1 addition & 1 deletion group.go
Expand Up @@ -133,7 +133,7 @@ func (grp *Group) Static(prefix, root string, config ...Static) Router {

// All will register the handler on all HTTP methods
func (grp *Group) All(path string, handlers ...Handler) Router {
for _, method := range intMethod {
for _, method := range grp.app.config.RequestMethods {
_ = grp.Add(method, path, handlers...)
}
return grp
Expand Down
74 changes: 37 additions & 37 deletions helpers.go
Expand Up @@ -78,8 +78,9 @@ func (app *App) quoteString(raw string) string {
}

// Scan stack if other methods match the request
func methodExist(ctx *Ctx) (exist bool) {
for i := 0; i < len(intMethod); i++ {
func (app *App) methodExist(ctx *Ctx) (exist bool) {
methods := app.config.RequestMethods
for i := 0; i < len(methods); i++ {
// Skip original method
if ctx.methodINT == i {
continue
Expand Down Expand Up @@ -109,7 +110,7 @@ func methodExist(ctx *Ctx) (exist bool) {
// We matched
exist = true
// Add method to Allow header
ctx.Append(HeaderAllow, intMethod[i])
ctx.Append(HeaderAllow, methods[i])
// Break stack loop
break
}
Expand Down Expand Up @@ -331,42 +332,41 @@ var getBytesImmutable = func(s string) (b []byte) {
}

// HTTP methods and their unique INTs
func methodInt(s string) int {
switch s {
case MethodGet:
return 0
case MethodHead:
return 1
case MethodPost:
return 2
case MethodPut:
return 3
case MethodDelete:
return 4
case MethodConnect:
return 5
case MethodOptions:
return 6
case MethodTrace:
return 7
case MethodPatch:
return 8
default:
return -1
func (app *App) methodInt(s string) int {
// For better performance
if !app.customMethod {
switch s {
case MethodGet:
return 0
case MethodHead:
return 1
case MethodPost:
return 2
case MethodPut:
return 3
case MethodDelete:
return 4
case MethodConnect:
return 5
case MethodOptions:
return 6
case MethodTrace:
return 7
case MethodPatch:
return 8
default:
return -1
}
}

// For method customization
for i, v := range app.config.RequestMethods {
if s == v {
return i
}
}
}

// HTTP methods slice
var intMethod = []string{
MethodGet,
MethodHead,
MethodPost,
MethodPut,
MethodDelete,
MethodConnect,
MethodOptions,
MethodTrace,
MethodPatch,
return -1
}

// HTTP methods were copied from net/http.
Expand Down
12 changes: 6 additions & 6 deletions router.go
Expand Up @@ -139,7 +139,7 @@ func (app *App) next(c *Ctx) (match bool, err error) {

// If no match, scan stack again if other methods match the request
// Moved from app.handler because middleware may break the route chain
if !c.matched && methodExist(c) {
if !c.matched && app.methodExist(c) {
err = ErrMethodNotAllowed
}
return
Expand Down Expand Up @@ -216,7 +216,7 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl
// Uppercase HTTP methods
method = utils.ToUpper(method)
// Check if the HTTP method is valid unless it's USE
if method != methodUse && methodInt(method) == -1 {
if method != methodUse && app.methodInt(method) == -1 {
panic(fmt.Sprintf("add: invalid http method %s\n", method))
}
// A route requires atleast one ctx handler
Expand Down Expand Up @@ -277,7 +277,7 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl
// Middleware route matches all HTTP methods
if isUse {
// Add route to all HTTP methods stack
for _, m := range intMethod {
for _, m := range app.config.RequestMethods {
// Create a route copy to avoid duplicates during compression
r := route
app.addRoute(m, &r)
Expand Down Expand Up @@ -435,7 +435,7 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
}

// Get unique HTTP method identifier
m := methodInt(method)
m := app.methodInt(method)

// prevent identically route registration
l := len(app.stack[m])
Expand Down Expand Up @@ -469,7 +469,7 @@ func (app *App) buildTree() *App {
}

// loop all the methods and stacks and create the prefix tree
for m := range intMethod {
for m := range app.config.RequestMethods {
tsMap := make(map[string][]*Route)
for _, route := range app.stack[m] {
treePath := ""
Expand All @@ -483,7 +483,7 @@ func (app *App) buildTree() *App {
}

// loop the methods and tree stacks and add global stack and sort everything
for m := range intMethod {
for m := range app.config.RequestMethods {
tsMap := app.treeStack[m]
for treePart := range tsMap {
if treePart != "" {
Expand Down