From 37425b7af9f103a7f26c7f8cc9a04ddac8b805b4 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sat, 7 May 2022 13:05:23 +0300 Subject: [PATCH] fix: call method not allowed handler once --- config.go | 2 +- example/basic/main.go | 27 ++++++++++++++++++++++ router_test.go | 53 +++++++++++++++++++++++++++++++------------ 3 files changed, 67 insertions(+), 15 deletions(-) diff --git a/config.go b/config.go index 8a9ff9d..e4250ea 100644 --- a/config.go +++ b/config.go @@ -32,7 +32,7 @@ func WithNotFoundHandler(handler HandlerFunc) Option { // handler just writes the status code http.StatusMethodNotAllowed. func WithMethodNotAllowedHandler(handler HandlerFunc) Option { return option(func(c *config) { - c.methodNotAllowedHandler = c.group.wrap(handler) + c.methodNotAllowedHandler = handler }) } diff --git a/example/basic/main.go b/example/basic/main.go index 5af0e03..38dca0f 100644 --- a/example/basic/main.go +++ b/example/basic/main.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "html/template" "log" "net/http" @@ -14,9 +15,12 @@ func main() { bunrouter.Use(reqlog.NewMiddleware( reqlog.FromEnv("BUNDEBUG"), )), + bunrouter.WithNotFoundHandler(notFoundHandler), + bunrouter.WithMethodNotAllowedHandler(methodNotAllowedHandler), ) router.GET("/", indexHandler) + router.POST("/405", indexHandler) // to test methodNotAllowedHandler router.WithGroup("/api", func(g *bunrouter.Group) { g.GET("/users/:id", debugHandler) @@ -39,6 +43,27 @@ func debugHandler(w http.ResponseWriter, req bunrouter.Request) error { }) } +func notFoundHandler(w http.ResponseWriter, req bunrouter.Request) error { + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf( + w, + "BunRouter can't find a route that matches %s", + req.URL.Path, + ) + return nil +} + +func methodNotAllowedHandler(w http.ResponseWriter, req bunrouter.Request) error { + w.WriteHeader(http.StatusMethodNotAllowed) + fmt.Fprintf( + w, + "BunRouter does have a route that matches %s, "+ + "but it does not handle method %s", + req.URL.Path, req.Method, + ) + return nil +} + var indexTmpl = `

Welcome

@@ -46,6 +71,8 @@ var indexTmpl = `
  • /api/users/123
  • /api/users/current
  • /api/users/foo/bar
  • +
  • /404
  • +
  • /405
  • ` diff --git a/router_test.go b/router_test.go index 4de34d8..f039db8 100644 --- a/router_test.go +++ b/router_test.go @@ -91,10 +91,10 @@ func testMethods(t *testing.T) { } func TestNotFound(t *testing.T) { - calledNotFound := false + var calledNotFound int - notFoundHandler := func(w http.ResponseWriter, r Request) error { - calledNotFound = true + notFoundHandler := func(w http.ResponseWriter, req Request) error { + calledNotFound++ return nil } @@ -102,21 +102,46 @@ func TestNotFound(t *testing.T) { router.GET("/user/abc", simpleHandler) w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/abc/", nil) - router.ServeHTTP(w, r) + req, _ := http.NewRequest("GET", "/abc/", nil) - if w.Code != http.StatusNotFound { - t.Errorf("Expected error 404 from built-in not found handler but saw %d", w.Code) - } + router.ServeHTTP(w, req) + require.Equal(t, http.StatusNotFound, w.Code) + require.Equal(t, 0, calledNotFound) // Now try with a custome handler. router = New(WithNotFoundHandler(notFoundHandler)) router.GET("/user/abc", simpleHandler) - router.ServeHTTP(w, r) - if !calledNotFound { - t.Error("Custom not found handler was not called") + router.ServeHTTP(w, req) + require.Equal(t, http.StatusNotFound, w.Code) + require.Equal(t, 1, calledNotFound) +} + +func TestMethodNotAllowed(t *testing.T) { + var calledMethodNotAllowed int + + methodNotAllowedHandler := func(w http.ResponseWriter, req Request) error { + calledMethodNotAllowed++ + return nil } + + router := New() + router.POST("/abc", simpleHandler) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/abc", nil) + + router.ServeHTTP(w, req) + require.Equal(t, http.StatusMethodNotAllowed, w.Code) + require.Equal(t, 0, calledMethodNotAllowed) + + // Now try with a custome handler. + router = New(WithMethodNotAllowedHandler(methodNotAllowedHandler)) + router.POST("/abc", simpleHandler) + + router.ServeHTTP(w, req) + require.Equal(t, http.StatusMethodNotAllowed, w.Code) + require.Equal(t, 1, calledMethodNotAllowed) } func TestRedirect(t *testing.T) { @@ -379,7 +404,7 @@ func TestRedirectEscapedPath(t *testing.T) { require.Equal(t, "/Test%20P@th/", location) } -func TestMiddleware(t *testing.T) { +func TestMiddlewares(t *testing.T) { var execLog []string record := func(s string) { @@ -524,7 +549,7 @@ func TestCORSMiddleware(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) }) - t.Run("CORS to non-existant route", func(t *testing.T) { + t.Run("CORS to a non-existant route", func(t *testing.T) { w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodOptions, "/api", nil) router.ServeHTTP(w, req) @@ -710,7 +735,7 @@ func TestRoutesWithCommonPrefix(t *testing.T) { } } -func TestNotAllowedMiddleware(t *testing.T) { +func TestMethodNotAllowedWithMiddlewares(t *testing.T) { var stack []string middleware := func(next HandlerFunc) HandlerFunc {