Skip to content

Commit

Permalink
Merge pull request #3 from confluentinc/middleware
Browse files Browse the repository at this point in the history
Add Middleware endpoint to ValidationHandler for use in gorilla/mux and others
  • Loading branch information
codyaray committed Jun 5, 2020
2 parents 00911cd + e46af0f commit abe831a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 7 deletions.
62 changes: 58 additions & 4 deletions openapi3filter/validation_error_test.go
Expand Up @@ -547,7 +547,7 @@ func (e *mockErrorEncoder) Encode(ctx context.Context, err error, w http.Respons
e.W = w
}

func runTest(t *testing.T, handler http.Handler, encoder ErrorEncoder, req *http.Request) *http.Response {
func runTest_ServeHTTP(t *testing.T, handler http.Handler, encoder ErrorEncoder, req *http.Request) *http.Response {
h := &ValidationHandler{
Handler: handler,
ErrorEncoder: encoder,
Expand All @@ -560,6 +560,18 @@ func runTest(t *testing.T, handler http.Handler, encoder ErrorEncoder, req *http
return w.Result()
}

func runTest_Middleware(t *testing.T, handler http.Handler, encoder ErrorEncoder, req *http.Request) *http.Response {
h := &ValidationHandler{
ErrorEncoder: encoder,
SwaggerFile: "fixtures/petstore.json",
}
err := h.Load()
require.NoError(t, err)
w := httptest.NewRecorder()
h.Middleware(handler).ServeHTTP(w, req)
return w.Result()
}

func TestValidationHandler_ServeHTTP(t *testing.T) {
t.Run("errors on invalid requests", func(t *testing.T) {
httpCtx := context.WithValue(context.Background(), "pig", "tails")
Expand All @@ -569,7 +581,49 @@ func TestValidationHandler_ServeHTTP(t *testing.T) {

handler := &testHandler{}
encoder := &mockErrorEncoder{}
runTest(t, handler, encoder.Encode, r)
runTest_ServeHTTP(t, handler, encoder.Encode, r)

require.False(t, handler.Called)
require.True(t, encoder.Called)
require.Equal(t, httpCtx, encoder.Ctx)
require.NotNil(t, encoder.Err)
})

t.Run("passes valid requests through", func(t *testing.T) {
r := newPetstoreRequest(t, http.MethodGet, "/pet/findByStatus?status=sold", nil)

handler := &testHandler{}
encoder := &mockErrorEncoder{}
runTest_ServeHTTP(t, handler, encoder.Encode, r)

require.True(t, handler.Called)
require.False(t, encoder.Called)
})

t.Run("uses error encoder", func(t *testing.T) {
r := newPetstoreRequest(t, http.MethodPost, "/pet", bytes.NewBufferString(`{"name":"Bahama","photoUrls":"http://cat"}`))

handler := &testHandler{}
encoder := &ValidationErrorEncoder{Encoder: (ErrorEncoder)(DefaultErrorEncoder)}
resp := runTest_ServeHTTP(t, handler, encoder.Encode, r)

body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusUnprocessableEntity, resp.StatusCode)
require.Equal(t, "[422][][] Field must be set to array or not be present [source pointer=/photoUrls]", string(body))
})
}

func TestValidationHandler_Middleware(t *testing.T) {
t.Run("errors on invalid requests", func(t *testing.T) {
httpCtx := context.WithValue(context.Background(), "pig", "tails")
r, err := http.NewRequest(http.MethodGet, "http://unknown-host.com/v2/pet", nil)
require.NoError(t, err)
r = r.WithContext(httpCtx)

handler := &testHandler{}
encoder := &mockErrorEncoder{}
runTest_Middleware(t, handler, encoder.Encode, r)

require.False(t, handler.Called)
require.True(t, encoder.Called)
Expand All @@ -582,7 +636,7 @@ func TestValidationHandler_ServeHTTP(t *testing.T) {

handler := &testHandler{}
encoder := &mockErrorEncoder{}
runTest(t, handler, encoder.Encode, r)
runTest_Middleware(t, handler, encoder.Encode, r)

require.True(t, handler.Called)
require.False(t, encoder.Called)
Expand All @@ -593,7 +647,7 @@ func TestValidationHandler_ServeHTTP(t *testing.T) {

handler := &testHandler{}
encoder := &ValidationErrorEncoder{Encoder: (ErrorEncoder)(DefaultErrorEncoder)}
resp := runTest(t, handler, encoder.Encode, r)
resp := runTest_Middleware(t, handler, encoder.Encode, r)

body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
Expand Down
24 changes: 21 additions & 3 deletions openapi3filter/validation_handler.go
Expand Up @@ -61,15 +61,33 @@ func (h *ValidationHandler) LoadSwagger() error {
}

func (h *ValidationHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
err := h.validateRequest(r)
if err != nil {
h.ErrorEncoder(r.Context(), err, w)
if handled := h.before(w, r); handled {
return
}
// TODO: validateResponse
h.Handler.ServeHTTP(w, r)
}

// Middleware implements gorilla/mux MiddlewareFunc
func (h *ValidationHandler) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handled := h.before(w, r); handled {
return
}
// TODO: validateResponse
next.ServeHTTP(w, r)
})
}

func (h *ValidationHandler) before(w http.ResponseWriter, r *http.Request) (handled bool) {
err := h.validateRequest(r)
if err != nil {
h.ErrorEncoder(r.Context(), err, w)
return true
}
return false
}

func (h *ValidationHandler) validateRequest(r *http.Request) error {
// Find route
route, pathParams, err := h.router.FindRoute(r.Method, r.URL)
Expand Down

0 comments on commit abe831a

Please sign in to comment.