Skip to content
This repository has been archived by the owner on Nov 2, 2023. It is now read-only.

sdk/middleware/sqgin: fix request ctx #157

Merged
merged 3 commits into from
Sep 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 57 additions & 3 deletions sdk/middleware/sqecho/echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package sqecho

import (
"context"
"errors"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -100,7 +101,7 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, body, rec.Body.String())
})

t.Run("control flow", func(t *testing.T) {
t.Run("data and control flow", func(t *testing.T) {
middlewareResponseBody := testlib.RandUTF8String(4096)
middlewareResponseStatus := 433
handlerResponseBody := testlib.RandUTF8String(4096)
Expand Down Expand Up @@ -133,6 +134,11 @@ func TestMiddleware(t *testing.T) {
handler func(echo.Context) error
test func(t *testing.T, rec *httptest.ResponseRecorder, err error)
}{
//
// Control flow tests
// When an handlers, including middlewares, block.
//

{
name: "sqreen first/the middleware aborts before the handler",
middlewares: []echo.MiddlewareFunc{
Expand Down Expand Up @@ -341,6 +347,54 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, middlewareResponseBody+handlerResponseBody+middlewareResponseBody, rec.Body.String())
},
},

//
// Context data flow tests
//
{
name: "middleware1, sqreen, middleware2, handler",
middlewares: []echo.MiddlewareFunc{
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Set("m10", "v10")
c.SetRequest(c.Request().WithContext(context.WithValue(c.Request().Context(), "m11", "v11")))
return next(c)
}
},
middleware(tc.agent),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Set("m20", "v20")
c.SetRequest(c.Request().WithContext(context.WithValue(c.Request().Context(), "m21", "v21")))
return next(c)
}
},
},
handler: func(c echo.Context) error {
// From Gin's context
if v, ok := c.Get("m10").(string); !ok || v != "v10" {
panic("couldn't get the context value m10")
}
if v, ok := c.Get("m20").(string); !ok || v != "v20" {
panic("couldn't get the context value m20")
}

// From the request context
reqCtx := c.Request().Context()
if v, ok := reqCtx.Value("m11").(string); !ok || v != "v11" {
panic("couldn't get the context value m11")
}
if v, ok := reqCtx.Value("m21").(string); !ok || v != "v21" {
panic("couldn't get the context value m21")
}

return c.NoContent(http.StatusOK)
},
test: func(t *testing.T, rec *httptest.ResponseRecorder, err error) {
require.NoError(t, err)
require.Equal(t, http.StatusOK, rec.Code)
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
Expand Down Expand Up @@ -378,8 +432,8 @@ func TestMiddleware(t *testing.T) {
agent.ExpectIsIPAllowed(mock.Anything).Return(false).Once()
agent.ExpectIsPathAllowed(mock.Anything).Return(false).Once()
var (
responseStatusCode int
responseContentType string
responseStatusCode int
responseContentType string
responseContentLength int64
)
agent.ExpectSendClosedRequestContext(mock.MatchedBy(func(recorded types.ClosedRequestContextFace) bool {
Expand Down
2 changes: 1 addition & 1 deletion sdk/middleware/sqgin/gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func middlewareHandler(agent protection_context.AgentFace, c *gingonic.Context)
requestReader := &requestReaderImpl{c: c}
responseWriter := &responseWriterImpl{c: c}

ctx, reqCtx, cancelHandlerContext := http_protection.NewRequestContext(c, agent, responseWriter, requestReader)
ctx, reqCtx, cancelHandlerContext := http_protection.NewRequestContext(c.Request.Context(), agent, responseWriter, requestReader)
if ctx == nil {
c.Next()
return
Expand Down
49 changes: 48 additions & 1 deletion sdk/middleware/sqgin/gin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package sqgin

import (
"context"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -91,7 +92,7 @@ func TestMiddleware(t *testing.T) {
})

// Test how the control flows between middleware and handler functions
t.Run("control flow", func(t *testing.T) {
t.Run("data and control flow", func(t *testing.T) {
middlewareResponseBody := testlib.RandUTF8String(4096)
middlewareResponseStatus := 433
handlerResponseBody := testlib.RandUTF8String(4096)
Expand Down Expand Up @@ -124,6 +125,11 @@ func TestMiddleware(t *testing.T) {
handler func(*gin.Context)
test func(t *testing.T, rec *httptest.ResponseRecorder)
}{
//
// Control flow tests
// When an handlers, including middlewares, block.
//

{
name: "sqreen first/next middleware aborts before the handler",
middlewares: []gin.HandlerFunc{
Expand Down Expand Up @@ -292,6 +298,47 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, middlewareResponseBody, rec.Body.String())
},
},

//
// Context data flow tests
//
{
name: "middleware1, sqreen, middleware2, handler",
middlewares: []gin.HandlerFunc{
func(c *gin.Context) {
c.Set("m10", "v10")
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), "m11", "v11"))
},
middleware(tc.agent),
func(c *gin.Context) {
c.Set("m20", "v20")
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), "m21", "v21"))
},
},
handler: func(c *gin.Context) {
// From Gin's context
if v, ok := c.Value("m10").(string); !ok || v != "v10" {
panic("couldn't get the context value m10")
}
if v, ok := c.Value("m20").(string); !ok || v != "v20" {
panic("couldn't get the context value m20")
}

// From the request context
reqCtx := c.Request.Context()
if v, ok := reqCtx.Value("m11").(string); !ok || v != "v11" {
panic("couldn't get the context value m11")
}
if v, ok := reqCtx.Value("m21").(string); !ok || v != "v21" {
panic("couldn't get the context value m21")
}

c.Status(http.StatusOK)
},
test: func(t *testing.T, rec *httptest.ResponseRecorder) {
require.Equal(t, http.StatusOK, rec.Code)
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
Expand Down
31 changes: 30 additions & 1 deletion sdk/middleware/sqhttp/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package sqhttp

import (
"context"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -70,7 +71,7 @@ func TestMiddleware(t *testing.T) {
})

// Test how the control flows between middleware and handler functions
t.Run("control flow", func(t *testing.T) {
t.Run("data and control flow", func(t *testing.T) {
middlewareResponseBody := testlib.RandUTF8String(4096)
middlewareResponseStatus := 433
handlerResponseBody := testlib.RandUTF8String(4096)
Expand Down Expand Up @@ -102,6 +103,11 @@ func TestMiddleware(t *testing.T) {
handlers http.Handler
test func(t *testing.T, rec *httptest.ResponseRecorder)
}{
//
// Control flow tests
// When an handlers, including middlewares, block.
//

{
name: "sqreen first/handler writes the response",
handlers: middleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
Expand Down Expand Up @@ -163,6 +169,29 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, middlewareResponseBody, rec.Body.String())
},
},

//
// Context data flow tests
//
{
name: "middleware, sqreen, handler",
handlers: func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), "m", "v"))
next.ServeHTTP(w, r)
})
}(middleware(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if v, ok := ctx.Value("m").(string); !ok || v != "v" {
panic("couldn't get the context value m")
}

w.WriteHeader(http.StatusOK)
}, tc.agent)),
test: func(t *testing.T, rec *httptest.ResponseRecorder) {
require.Equal(t, http.StatusOK, rec.Code)
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
Expand Down