Skip to content

Commit

Permalink
fix: call method not allowed handler once
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed May 7, 2022
1 parent 2533c46 commit 37425b7
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 15 deletions.
2 changes: 1 addition & 1 deletion config.go
Expand Up @@ -32,7 +32,7 @@ func WithNotFoundHandler(handler HandlerFunc) Option {
// handler just writes the status code http.StatusMethodNotAllowed.
func WithMethodNotAllowedHandler(handler HandlerFunc) Option {
return option(func(c *config) {
c.methodNotAllowedHandler = c.group.wrap(handler)
c.methodNotAllowedHandler = handler
})
}

Expand Down
27 changes: 27 additions & 0 deletions example/basic/main.go
@@ -1,6 +1,7 @@
package main

import (
"fmt"
"html/template"
"log"
"net/http"
Expand All @@ -14,9 +15,12 @@ func main() {
bunrouter.Use(reqlog.NewMiddleware(
reqlog.FromEnv("BUNDEBUG"),
)),
bunrouter.WithNotFoundHandler(notFoundHandler),
bunrouter.WithMethodNotAllowedHandler(methodNotAllowedHandler),
)

router.GET("/", indexHandler)
router.POST("/405", indexHandler) // to test methodNotAllowedHandler

router.WithGroup("/api", func(g *bunrouter.Group) {
g.GET("/users/:id", debugHandler)
Expand All @@ -39,13 +43,36 @@ func debugHandler(w http.ResponseWriter, req bunrouter.Request) error {
})
}

func notFoundHandler(w http.ResponseWriter, req bunrouter.Request) error {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(
w,
"<html>BunRouter can't find a route that matches <strong>%s</strong></html>",
req.URL.Path,
)
return nil
}

func methodNotAllowedHandler(w http.ResponseWriter, req bunrouter.Request) error {
w.WriteHeader(http.StatusMethodNotAllowed)
fmt.Fprintf(
w,
"<html>BunRouter does have a route that matches <strong>%s</strong>, "+
"but it does not handle method <strong>%s</strong></html>",
req.URL.Path, req.Method,
)
return nil
}

var indexTmpl = `
<html>
<h1>Welcome</h1>
<ul>
<li><a href="/api/users/123">/api/users/123</a></li>
<li><a href="/api/users/current">/api/users/current</a></li>
<li><a href="/api/users/foo/bar">/api/users/foo/bar</a></li>
<li><a href="/404">/404</a></li>
<li><a href="/405">/405</a></li>
</ul>
</html>
`
Expand Down
53 changes: 39 additions & 14 deletions router_test.go
Expand Up @@ -91,32 +91,57 @@ func testMethods(t *testing.T) {
}

func TestNotFound(t *testing.T) {
calledNotFound := false
var calledNotFound int

notFoundHandler := func(w http.ResponseWriter, r Request) error {
calledNotFound = true
notFoundHandler := func(w http.ResponseWriter, req Request) error {
calledNotFound++
return nil
}

router := New()
router.GET("/user/abc", simpleHandler)

w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/abc/", nil)
router.ServeHTTP(w, r)
req, _ := http.NewRequest("GET", "/abc/", nil)

if w.Code != http.StatusNotFound {
t.Errorf("Expected error 404 from built-in not found handler but saw %d", w.Code)
}
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code)
require.Equal(t, 0, calledNotFound)

// Now try with a custome handler.
router = New(WithNotFoundHandler(notFoundHandler))
router.GET("/user/abc", simpleHandler)

router.ServeHTTP(w, r)
if !calledNotFound {
t.Error("Custom not found handler was not called")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code)
require.Equal(t, 1, calledNotFound)
}

func TestMethodNotAllowed(t *testing.T) {
var calledMethodNotAllowed int

methodNotAllowedHandler := func(w http.ResponseWriter, req Request) error {
calledMethodNotAllowed++
return nil
}

router := New()
router.POST("/abc", simpleHandler)

w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/abc", nil)

router.ServeHTTP(w, req)
require.Equal(t, http.StatusMethodNotAllowed, w.Code)
require.Equal(t, 0, calledMethodNotAllowed)

// Now try with a custome handler.
router = New(WithMethodNotAllowedHandler(methodNotAllowedHandler))
router.POST("/abc", simpleHandler)

router.ServeHTTP(w, req)
require.Equal(t, http.StatusMethodNotAllowed, w.Code)
require.Equal(t, 1, calledMethodNotAllowed)
}

func TestRedirect(t *testing.T) {
Expand Down Expand Up @@ -379,7 +404,7 @@ func TestRedirectEscapedPath(t *testing.T) {
require.Equal(t, "/Test%20P@th/", location)
}

func TestMiddleware(t *testing.T) {
func TestMiddlewares(t *testing.T) {
var execLog []string

record := func(s string) {
Expand Down Expand Up @@ -524,7 +549,7 @@ func TestCORSMiddleware(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
})

t.Run("CORS to non-existant route", func(t *testing.T) {
t.Run("CORS to a non-existant route", func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodOptions, "/api", nil)
router.ServeHTTP(w, req)
Expand Down Expand Up @@ -710,7 +735,7 @@ func TestRoutesWithCommonPrefix(t *testing.T) {
}
}

func TestNotAllowedMiddleware(t *testing.T) {
func TestMethodNotAllowedWithMiddlewares(t *testing.T) {
var stack []string

middleware := func(next HandlerFunc) HandlerFunc {
Expand Down

0 comments on commit 37425b7

Please sign in to comment.