diff --git a/openapi3filter/validation_error_test.go b/openapi3filter/validation_error_test.go index 74e5a55c8..8a889295f 100644 --- a/openapi3filter/validation_error_test.go +++ b/openapi3filter/validation_error_test.go @@ -514,7 +514,7 @@ func (e *mockErrorEncoder) Encode(ctx context.Context, err error, w http.Respons e.W = w } -func runTest(t *testing.T, handler http.Handler, encoder ErrorEncoder, req *http.Request) *http.Response { +func runTest_ServeHTTP(t *testing.T, handler http.Handler, encoder ErrorEncoder, req *http.Request) *http.Response { h := &ValidationHandler{ Handler: handler, ErrorEncoder: encoder, @@ -527,6 +527,18 @@ func runTest(t *testing.T, handler http.Handler, encoder ErrorEncoder, req *http return w.Result() } +func runTest_Middleware(t *testing.T, handler http.Handler, encoder ErrorEncoder, req *http.Request) *http.Response { + h := &ValidationHandler{ + ErrorEncoder: encoder, + SwaggerFile: "fixtures/petstore.json", + } + err := h.Load() + require.NoError(t, err) + w := httptest.NewRecorder() + h.Middleware(handler).ServeHTTP(w, req) + return w.Result() +} + func TestValidationHandler_ServeHTTP(t *testing.T) { t.Run("errors on invalid requests", func(t *testing.T) { httpCtx := context.WithValue(context.Background(), "pig", "tails") @@ -536,7 +548,49 @@ func TestValidationHandler_ServeHTTP(t *testing.T) { handler := &testHandler{} encoder := &mockErrorEncoder{} - runTest(t, handler, encoder.Encode, r) + runTest_ServeHTTP(t, handler, encoder.Encode, r) + + require.False(t, handler.Called) + require.True(t, encoder.Called) + require.Equal(t, httpCtx, encoder.Ctx) + require.NotNil(t, encoder.Err) + }) + + t.Run("passes valid requests through", func(t *testing.T) { + r := newPetstoreRequest(t, http.MethodGet, "/pet/findByStatus?status=sold", nil) + + handler := &testHandler{} + encoder := &mockErrorEncoder{} + runTest_ServeHTTP(t, handler, encoder.Encode, r) + + require.True(t, handler.Called) + require.False(t, encoder.Called) + }) + + t.Run("uses error encoder", func(t *testing.T) { + r := newPetstoreRequest(t, http.MethodPost, "/pet", bytes.NewBufferString(`{"name":"Bahama","photoUrls":"http://cat"}`)) + + handler := &testHandler{} + encoder := &ValidationErrorEncoder{Encoder: (ErrorEncoder)(DefaultErrorEncoder)} + resp := runTest_ServeHTTP(t, handler, encoder.Encode, r) + + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusUnprocessableEntity, resp.StatusCode) + require.Equal(t, "[422][][] Field must be set to array or not be present [source pointer=/photoUrls]", string(body)) + }) +} + +func TestValidationHandler_Middleware(t *testing.T) { + t.Run("errors on invalid requests", func(t *testing.T) { + httpCtx := context.WithValue(context.Background(), "pig", "tails") + r, err := http.NewRequest(http.MethodGet, "http://unknown-host.com/v2/pet", nil) + require.NoError(t, err) + r = r.WithContext(httpCtx) + + handler := &testHandler{} + encoder := &mockErrorEncoder{} + runTest_Middleware(t, handler, encoder.Encode, r) require.False(t, handler.Called) require.True(t, encoder.Called) @@ -549,7 +603,7 @@ func TestValidationHandler_ServeHTTP(t *testing.T) { handler := &testHandler{} encoder := &mockErrorEncoder{} - runTest(t, handler, encoder.Encode, r) + runTest_Middleware(t, handler, encoder.Encode, r) require.True(t, handler.Called) require.False(t, encoder.Called) @@ -560,7 +614,7 @@ func TestValidationHandler_ServeHTTP(t *testing.T) { handler := &testHandler{} encoder := &ValidationErrorEncoder{Encoder: (ErrorEncoder)(DefaultErrorEncoder)} - resp := runTest(t, handler, encoder.Encode, r) + resp := runTest_Middleware(t, handler, encoder.Encode, r) body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) diff --git a/openapi3filter/validation_handler.go b/openapi3filter/validation_handler.go index f72ef1cec..11a60e4b2 100644 --- a/openapi3filter/validation_handler.go +++ b/openapi3filter/validation_handler.go @@ -42,15 +42,33 @@ func (h *ValidationHandler) Load() error { } func (h *ValidationHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - err := h.validateRequest(r) - if err != nil { - h.ErrorEncoder(r.Context(), err, w) + if handled := h.before(w, r); handled { return } // TODO: validateResponse h.Handler.ServeHTTP(w, r) } +// Middleware implements gorilla/mux MiddlewareFunc +func (h *ValidationHandler) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if handled := h.before(w, r); handled { + return + } + // TODO: validateResponse + next.ServeHTTP(w, r) + }) +} + +func (h *ValidationHandler) before(w http.ResponseWriter, r *http.Request) (handled bool) { + err := h.validateRequest(r) + if err != nil { + h.ErrorEncoder(r.Context(), err, w) + return true + } + return false +} + func (h *ValidationHandler) validateRequest(r *http.Request) error { // Find route route, pathParams, err := h.router.FindRoute(r.Method, r.URL)