From 60b4f5f52d70851dc4802f68ef96794302bdf854 Mon Sep 17 00:00:00 2001 From: Angelo Fallaria Date: Sat, 17 Feb 2024 08:11:00 +0800 Subject: [PATCH] feat: update HTTP method parsing in patterns for `Handle` and `HandleFunc` (#900) --- mux.go | 16 +++++++--------- mux_test.go | 30 +++++++++++++++++++----------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/mux.go b/mux.go index 56fa4d28..240ae676 100644 --- a/mux.go +++ b/mux.go @@ -109,15 +109,7 @@ func (mx *Mux) Use(middlewares ...func(http.Handler) http.Handler) { func (mx *Mux) Handle(pattern string, handler http.Handler) { parts := strings.SplitN(pattern, " ", 2) if len(parts) == 2 { - methodStr := strings.ToUpper(parts[0]) - path := parts[1] - - method, ok := methodMap[methodStr] - if !ok { - panic("chi: invalid HTTP method specified in pattern: " + methodStr) - } - - mx.handle(method, path, handler) + mx.Method(parts[0], parts[1], handler) return } @@ -127,6 +119,12 @@ func (mx *Mux) Handle(pattern string, handler http.Handler) { // HandleFunc adds the route `pattern` that matches any http method to // execute the `handlerFn` http.HandlerFunc. func (mx *Mux) HandleFunc(pattern string, handlerFn http.HandlerFunc) { + parts := strings.SplitN(pattern, " ", 2) + if len(parts) == 2 { + mx.Method(parts[0], parts[1], handlerFn) + return + } + mx.handle(mALL, pattern, handlerFn) } diff --git a/mux_test.go b/mux_test.go index 9190cb5d..82f089e7 100644 --- a/mux_test.go +++ b/mux_test.go @@ -691,24 +691,32 @@ func TestMuxHandlePatternValidation(t *testing.T) { t.Run(tc.name, func(t *testing.T) { defer func() { if r := recover(); r != nil && !tc.shouldPanic { - t.Errorf("Unexpected panic for pattern %s", tc.pattern) + t.Errorf("Unexpected panic for pattern %s:\n%v", tc.pattern, r) } }() - r := NewRouter() - r.Handle(tc.pattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r1 := NewRouter() + r1.Handle(tc.pattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(tc.expectedBody)) })) + // Test that HandleFunc also handles method patterns + r2 := NewRouter() + r2.HandleFunc(tc.pattern, func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(tc.expectedBody)) + }) + if !tc.shouldPanic { - // Use testRequest for valid patterns - ts := httptest.NewServer(r) - defer ts.Close() - - resp, body := testRequest(t, ts, tc.method, tc.path, nil) - if body != tc.expectedBody || resp.StatusCode != tc.expectedStatus { - t.Errorf("Expected status %d and body %s; got status %d and body %s for pattern %s", - tc.expectedStatus, tc.expectedBody, resp.StatusCode, body, tc.pattern) + for _, r := range []Router{r1, r2} { + // Use testRequest for valid patterns + ts := httptest.NewServer(r) + defer ts.Close() + + resp, body := testRequest(t, ts, tc.method, tc.path, nil) + if body != tc.expectedBody || resp.StatusCode != tc.expectedStatus { + t.Errorf("Expected status %d and body %s; got status %d and body %s for pattern %s", + tc.expectedStatus, tc.expectedBody, resp.StatusCode, body, tc.pattern) + } } } })