From 7446950424da5c8a04d26f294da95f6c98a4174a Mon Sep 17 00:00:00 2001 From: Ayushman <31277910+Spartan09@users.noreply.github.com> Date: Fri, 19 Jan 2024 04:32:13 +0530 Subject: [PATCH] Extend Handle method to parse HTTP method in pattern (#897) --- mux.go | 14 ++++++++++ mux_test.go | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/mux.go b/mux.go index 735ab232..56fa4d28 100644 --- a/mux.go +++ b/mux.go @@ -107,6 +107,20 @@ func (mx *Mux) Use(middlewares ...func(http.Handler) http.Handler) { // Handle adds the route `pattern` that matches any http method to // execute the `handler` http.Handler. func (mx *Mux) Handle(pattern string, handler http.Handler) { + parts := strings.SplitN(pattern, " ", 2) + if len(parts) == 2 { + methodStr := strings.ToUpper(parts[0]) + path := parts[1] + + method, ok := methodMap[methodStr] + if !ok { + panic("chi: invalid HTTP method specified in pattern: " + methodStr) + } + + mx.handle(method, path, handler) + return + } + mx.handle(mALL, pattern, handler) } diff --git a/mux_test.go b/mux_test.go index 350b137d..9190cb5d 100644 --- a/mux_test.go +++ b/mux_test.go @@ -640,6 +640,81 @@ func TestMuxWith(t *testing.T) { } } +func TestMuxHandlePatternValidation(t *testing.T) { + testCases := []struct { + name string + pattern string + shouldPanic bool + method string // Method to be used for the test request + path string // Path to be used for the test request + expectedBody string // Expected response body + expectedStatus int // Expected HTTP status code + }{ + // Valid patterns + { + name: "Valid pattern without HTTP GET", + pattern: "/user/{id}", + shouldPanic: false, + method: "GET", + path: "/user/123", + expectedBody: "without-prefix GET", + expectedStatus: http.StatusOK, + }, + { + name: "Valid pattern with HTTP POST", + pattern: "POST /products/{id}", + shouldPanic: false, + method: "POST", + path: "/products/456", + expectedBody: "with-prefix POST", + expectedStatus: http.StatusOK, + }, + // Invalid patterns + { + name: "Invalid pattern with no method", + pattern: "INVALID/user/{id}", + shouldPanic: true, + }, + { + name: "Invalid pattern with supported method", + pattern: "GET/user/{id}", + shouldPanic: true, + }, + { + name: "Invalid pattern with unsupported method", + pattern: "UNSUPPORTED /unsupported-method", + shouldPanic: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil && !tc.shouldPanic { + t.Errorf("Unexpected panic for pattern %s", tc.pattern) + } + }() + + r := NewRouter() + r.Handle(tc.pattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(tc.expectedBody)) + })) + + if !tc.shouldPanic { + // Use testRequest for valid patterns + ts := httptest.NewServer(r) + defer ts.Close() + + resp, body := testRequest(t, ts, tc.method, tc.path, nil) + if body != tc.expectedBody || resp.StatusCode != tc.expectedStatus { + t.Errorf("Expected status %d and body %s; got status %d and body %s for pattern %s", + tc.expectedStatus, tc.expectedBody, resp.StatusCode, body, tc.pattern) + } + } + }) + } +} + func TestRouterFromMuxWith(t *testing.T) { t.Parallel()