Skip to content
This repository has been archived by the owner on Feb 24, 2024. It is now read-only.

added assert middleware to assert handler's behavior. (fix #2339) #2345

Merged
merged 2 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 35 additions & 10 deletions middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package buffalo

import (
"net/http"
"reflect"
"runtime"
"strings"
Expand Down Expand Up @@ -123,15 +124,38 @@ func (ms *MiddlewareStack) Replace(mw1 MiddlewareFunc, mw2 MiddlewareFunc) {
ms.stack = stack
}

func (ms *MiddlewareStack) handler(info RouteInfo) Handler {
h := info.Handler
if len(ms.stack) > 0 {
mh := func(_ Handler) Handler {
return h
// assertMiddleware is a hidden middleware that works just befor and after the
// actual handler runs to make it sure everything is OK with the Handler
// specification.
//
// It writes response header with `http.StatusOK` if the request handler exited
// without error but the response status is still zero. Setting response is the
// responsibility of handler but this middleware make it sure the response
// should be compatible with middleware specification.
//
// See also: https://github.com/gobuffalo/buffalo/issues/2339
func assertMiddleware(handler Handler) Handler {
return func(c Context) error {
err := handler(c)
if err != nil {
return err
}

tstack := []MiddlewareFunc{mh}
if res, ok := c.Response().(*Response); ok {
if res.Status == 0 {
res.WriteHeader(http.StatusOK)
c.Logger().Debug("warning: handler exited without setting the response status. 200 OK will be used.")
}
}

return err
}
}

func (ms *MiddlewareStack) handler(info RouteInfo) Handler {
tstack := []MiddlewareFunc{assertMiddleware}

if len(ms.stack) > 0 {
sl := len(ms.stack) - 1
for i := sl; i >= 0; i-- {
mw := ms.stack[i]
Expand All @@ -140,12 +164,13 @@ func (ms *MiddlewareStack) handler(info RouteInfo) Handler {
tstack = append(tstack, mw)
}
}
}

for _, mw := range tstack {
h = mw(h)
}
return h
h := info.Handler
for _, mw := range tstack {
h = mw(h)
}

return h
}

Expand Down
65 changes: 65 additions & 0 deletions middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package buffalo

import (
"fmt"
"net/http"
"testing"

Expand Down Expand Up @@ -242,3 +243,67 @@ func Test_Middleware_Remove(t *testing.T) {
_ = w.HTML("/no_log_autos/1").Get()
r.Len(log, 0)
}

func Test_AssertMiddleware_NilStatus200(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Thanks @sio4 !

r := require.New(t)
var status int

a := New(Options{})
a.Use(func(h Handler) Handler {
return func(c Context) error {
err := h(c)

res, ok := c.Response().(*Response)
r.True(ok)
status = res.Status

return err
}
})

a.GET("/200", func(c Context) error {
c.Response().WriteHeader(http.StatusOK) // explicitly set
return nil
})

a.GET("/404", func(c Context) error {
c.Response().WriteHeader(http.StatusNotFound) //explicitly set
return nil
})

a.GET("/nil", func(c Context) error {
return nil // return nil without setting response status. should be OK
})

a.GET("/500", func(c Context) error {
return fmt.Errorf("error") // return error
})

a.GET("/502", func(c Context) error {
return HTTPError{Status: http.StatusBadGateway} // return HTTPError
})

a.GET("/panic", func(c Context) error {
panic("hoy hoy")
})

tests := []struct {
path string
code int
status int
}{
{"/200", http.StatusOK, http.StatusOK}, // when the handler set response code explicitly (e.g. 200, 404)
{"/404", http.StatusNotFound, http.StatusNotFound},
{"/nil", http.StatusOK, http.StatusOK}, // when the handler returns nil without setting status code
{"/502", http.StatusBadGateway, 0}, // set by defaultErrorHandler, when the handler just returns error
{"/500", http.StatusInternalServerError, 0}, // set by defaultErrorHandler, when the handler returns HTTPError
{"/panic", http.StatusInternalServerError, 0}, // set by PanicHandler
}
w := httptest.New(a)

for _, tc := range tests {
res := w.HTML(tc.path).Get()
r.Equal(tc.status, status)
r.Equal(tc.code, res.Code)
}
}