diff --git a/config.go b/config.go
index 8a9ff9d..e4250ea 100644
--- a/config.go
+++ b/config.go
@@ -32,7 +32,7 @@ func WithNotFoundHandler(handler HandlerFunc) Option {
// handler just writes the status code http.StatusMethodNotAllowed.
func WithMethodNotAllowedHandler(handler HandlerFunc) Option {
return option(func(c *config) {
- c.methodNotAllowedHandler = c.group.wrap(handler)
+ c.methodNotAllowedHandler = handler
})
}
diff --git a/example/basic/main.go b/example/basic/main.go
index 5af0e03..38dca0f 100644
--- a/example/basic/main.go
+++ b/example/basic/main.go
@@ -1,6 +1,7 @@
package main
import (
+ "fmt"
"html/template"
"log"
"net/http"
@@ -14,9 +15,12 @@ func main() {
bunrouter.Use(reqlog.NewMiddleware(
reqlog.FromEnv("BUNDEBUG"),
)),
+ bunrouter.WithNotFoundHandler(notFoundHandler),
+ bunrouter.WithMethodNotAllowedHandler(methodNotAllowedHandler),
)
router.GET("/", indexHandler)
+ router.POST("/405", indexHandler) // to test methodNotAllowedHandler
router.WithGroup("/api", func(g *bunrouter.Group) {
g.GET("/users/:id", debugHandler)
@@ -39,6 +43,27 @@ func debugHandler(w http.ResponseWriter, req bunrouter.Request) error {
})
}
+func notFoundHandler(w http.ResponseWriter, req bunrouter.Request) error {
+ w.WriteHeader(http.StatusNotFound)
+ fmt.Fprintf(
+ w,
+ "BunRouter can't find a route that matches %s",
+ req.URL.Path,
+ )
+ return nil
+}
+
+func methodNotAllowedHandler(w http.ResponseWriter, req bunrouter.Request) error {
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ fmt.Fprintf(
+ w,
+ "BunRouter does have a route that matches %s, "+
+ "but it does not handle method %s",
+ req.URL.Path, req.Method,
+ )
+ return nil
+}
+
var indexTmpl = `
Welcome
@@ -46,6 +71,8 @@ var indexTmpl = `
/api/users/123
/api/users/current
/api/users/foo/bar
+ /404
+ /405
`
diff --git a/router_test.go b/router_test.go
index 4de34d8..f039db8 100644
--- a/router_test.go
+++ b/router_test.go
@@ -91,10 +91,10 @@ func testMethods(t *testing.T) {
}
func TestNotFound(t *testing.T) {
- calledNotFound := false
+ var calledNotFound int
- notFoundHandler := func(w http.ResponseWriter, r Request) error {
- calledNotFound = true
+ notFoundHandler := func(w http.ResponseWriter, req Request) error {
+ calledNotFound++
return nil
}
@@ -102,21 +102,46 @@ func TestNotFound(t *testing.T) {
router.GET("/user/abc", simpleHandler)
w := httptest.NewRecorder()
- r, _ := http.NewRequest("GET", "/abc/", nil)
- router.ServeHTTP(w, r)
+ req, _ := http.NewRequest("GET", "/abc/", nil)
- if w.Code != http.StatusNotFound {
- t.Errorf("Expected error 404 from built-in not found handler but saw %d", w.Code)
- }
+ router.ServeHTTP(w, req)
+ require.Equal(t, http.StatusNotFound, w.Code)
+ require.Equal(t, 0, calledNotFound)
// Now try with a custome handler.
router = New(WithNotFoundHandler(notFoundHandler))
router.GET("/user/abc", simpleHandler)
- router.ServeHTTP(w, r)
- if !calledNotFound {
- t.Error("Custom not found handler was not called")
+ router.ServeHTTP(w, req)
+ require.Equal(t, http.StatusNotFound, w.Code)
+ require.Equal(t, 1, calledNotFound)
+}
+
+func TestMethodNotAllowed(t *testing.T) {
+ var calledMethodNotAllowed int
+
+ methodNotAllowedHandler := func(w http.ResponseWriter, req Request) error {
+ calledMethodNotAllowed++
+ return nil
}
+
+ router := New()
+ router.POST("/abc", simpleHandler)
+
+ w := httptest.NewRecorder()
+ req, _ := http.NewRequest("GET", "/abc", nil)
+
+ router.ServeHTTP(w, req)
+ require.Equal(t, http.StatusMethodNotAllowed, w.Code)
+ require.Equal(t, 0, calledMethodNotAllowed)
+
+ // Now try with a custome handler.
+ router = New(WithMethodNotAllowedHandler(methodNotAllowedHandler))
+ router.POST("/abc", simpleHandler)
+
+ router.ServeHTTP(w, req)
+ require.Equal(t, http.StatusMethodNotAllowed, w.Code)
+ require.Equal(t, 1, calledMethodNotAllowed)
}
func TestRedirect(t *testing.T) {
@@ -379,7 +404,7 @@ func TestRedirectEscapedPath(t *testing.T) {
require.Equal(t, "/Test%20P@th/", location)
}
-func TestMiddleware(t *testing.T) {
+func TestMiddlewares(t *testing.T) {
var execLog []string
record := func(s string) {
@@ -524,7 +549,7 @@ func TestCORSMiddleware(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
})
- t.Run("CORS to non-existant route", func(t *testing.T) {
+ t.Run("CORS to a non-existant route", func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodOptions, "/api", nil)
router.ServeHTTP(w, req)
@@ -710,7 +735,7 @@ func TestRoutesWithCommonPrefix(t *testing.T) {
}
}
-func TestNotAllowedMiddleware(t *testing.T) {
+func TestMethodNotAllowedWithMiddlewares(t *testing.T) {
var stack []string
middleware := func(next HandlerFunc) HandlerFunc {