Skip to content

Commit

Permalink
Adds the Allow header on 405 response (#776)
Browse files Browse the repository at this point in the history
  • Loading branch information
EwenQuim committed Jul 13, 2023
1 parent 7f28096 commit 4b14b83
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 7 deletions.
1 change: 1 addition & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ type Context struct {

// methodNotAllowed hint
methodNotAllowed bool
methodsAllowed []methodTyp // allowed methods in case of a 405
}

// Reset a routing context to its initial state.
Expand Down
20 changes: 13 additions & 7 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,11 @@ func (mx *Mux) NotFoundHandler() http.HandlerFunc {

// MethodNotAllowedHandler returns the default Mux 405 responder whenever
// a method cannot be resolved for a route.
func (mx *Mux) MethodNotAllowedHandler() http.HandlerFunc {
func (mx *Mux) MethodNotAllowedHandler(methodsAllowed ...methodTyp) http.HandlerFunc {
if mx.methodNotAllowedHandler != nil {
return mx.methodNotAllowedHandler
}
return methodNotAllowedHandler
return methodNotAllowedHandler(methodsAllowed...)
}

// handle registers a http.Handler in the routing tree for a particular http method
Expand Down Expand Up @@ -445,7 +445,7 @@ func (mx *Mux) routeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
if rctx.methodNotAllowed {
mx.MethodNotAllowedHandler().ServeHTTP(w, r)
mx.MethodNotAllowedHandler(rctx.methodsAllowed...).ServeHTTP(w, r)
} else {
mx.NotFoundHandler().ServeHTTP(w, r)
}
Expand Down Expand Up @@ -480,8 +480,14 @@ func (mx *Mux) updateRouteHandler() {
}

// methodNotAllowedHandler is a helper function to respond with a 405,
// method not allowed.
func methodNotAllowedHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(405)
w.Write(nil)
// method not allowed. It sets the Allow header with the list of allowed
// methods for the route.
func methodNotAllowedHandler(methodsAllowed ...methodTyp) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
for _, m := range methodsAllowed {
w.Header().Add("Allow", reverseMethodMap[m])
}
w.WriteHeader(405)
w.Write(nil)
}
}
39 changes: 39 additions & 0 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,43 @@ func TestMuxNestedNotFound(t *testing.T) {
}
}

func TestMethodNotAllowed(t *testing.T) {
r := NewRouter()

r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hi, get"))
})

r.Head("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hi, head"))
})

ts := httptest.NewServer(r)
defer ts.Close()

t.Run("Registered Method", func(t *testing.T) {
resp, _ := testRequest(t, ts, "GET", "/hi", nil)
if resp.StatusCode != 200 {
t.Fatal(resp.Status)
}
if resp.Header.Values("Allow") != nil {
t.Fatal("allow should be empty when method is registered")
}
})

t.Run("Unregistered Method", func(t *testing.T) {
resp, _ := testRequest(t, ts, "POST", "/hi", nil)
if resp.StatusCode != 405 {
t.Fatal(resp.Status)
}
allowedMethods := resp.Header.Values("Allow")
if len(allowedMethods) != 2 || allowedMethods[0] != "GET" || allowedMethods[1] != "HEAD" {
t.Fatal("Allow header should contain 2 headers: GET, HEAD. Received: ", allowedMethods)

}
})
}

func TestMuxNestedMethodNotAllowed(t *testing.T) {
r := NewRouter()
r.Get("/root", func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -1771,6 +1808,7 @@ func BenchmarkMux(b *testing.B) {
mx := NewRouter()
mx.Get("/", h1)
mx.Get("/hi", h2)
mx.Post("/hi-post", h2) // used to benchmark 405 responses
mx.Get("/sup/{id}/and/{this}", h3)
mx.Get("/sup/{id}/{bar:foo}/{this}", h3)

Expand All @@ -1787,6 +1825,7 @@ func BenchmarkMux(b *testing.B) {
routes := []string{
"/",
"/hi",
"/hi-post",
"/sup/123/and/this",
"/sup/123/foo/this",
"/sharing/z/aBc", // subrouter-1
Expand Down
26 changes: 26 additions & 0 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ var methodMap = map[string]methodTyp{
http.MethodTrace: mTRACE,
}

var reverseMethodMap = map[methodTyp]string{
mCONNECT: http.MethodConnect,
mDELETE: http.MethodDelete,
mGET: http.MethodGet,
mHEAD: http.MethodHead,
mOPTIONS: http.MethodOptions,
mPATCH: http.MethodPatch,
mPOST: http.MethodPost,
mPUT: http.MethodPut,
mTRACE: http.MethodTrace,
}

// RegisterMethod adds support for custom HTTP method handlers, available
// via Router#Method and Router#MethodFunc
func RegisterMethod(method string) {
Expand Down Expand Up @@ -454,6 +466,13 @@ func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node {
return xn
}

for endpoints := range xn.endpoints {
if endpoints == mALL || endpoints == mSTUB {
continue
}
rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints)
}

// flag that the routing context found a route, but not a corresponding
// supported method
rctx.methodNotAllowed = true
Expand Down Expand Up @@ -493,6 +512,13 @@ func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node {
return xn
}

for endpoints := range xn.endpoints {
if endpoints == mALL || endpoints == mSTUB {
continue
}
rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints)
}

// flag that the routing context found a route, but not a corresponding
// supported method
rctx.methodNotAllowed = true
Expand Down

0 comments on commit 4b14b83

Please sign in to comment.