diff --git a/middleware/url_format.go b/middleware/url_format.go index 919eb0fe..da27ee9b 100644 --- a/middleware/url_format.go +++ b/middleware/url_format.go @@ -8,6 +8,10 @@ import ( "github.com/go-chi/chi/v5" ) +const ( + errRouteContextNil = "RouteContext was nil." +) + var ( // URLFormatCtxKey is the context.Context key to store the URL format data // for a request. @@ -52,7 +56,12 @@ func URLFormat(next http.Handler) http.Handler { path := r.URL.Path rctx := chi.RouteContext(r.Context()) - if rctx != nil && rctx.RoutePath != "" { + if rctx == nil { + http.Error(w, errRouteContextNil, http.StatusInternalServerError) + return + } + + if rctx.RoutePath != "" { path = rctx.RoutePath } diff --git a/middleware/url_format_test.go b/middleware/url_format_test.go index e1ac324f..08873e82 100644 --- a/middleware/url_format_test.go +++ b/middleware/url_format_test.go @@ -1,8 +1,10 @@ package middleware import ( + "context" "net/http" "net/http/httptest" + "strings" "testing" "github.com/go-chi/chi/v5" @@ -67,3 +69,30 @@ func TestURLFormatInSubRouter(t *testing.T) { t.Fatalf(resp) } } + +func TestURLFormatWithoutChiRouteContext(t *testing.T) { + r := chi.NewRouter() + + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + newCtx := context.WithValue(r.Context(), chi.RouteCtxKey, nil) + next.ServeHTTP(w, r.WithContext(newCtx)) + }) + }) + r.Use(URLFormat) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, respBody := testRequest(t, ts, "GET", "/", nil) + if resp.StatusCode != http.StatusInternalServerError { + t.Fatalf("non 500 response: %v", resp.StatusCode) + } + + if strings.TrimSpace(respBody) != errRouteContextNil { + t.Fatalf("Expected error message: %s, but got: %s", errRouteContextNil, respBody) + } +}