Skip to content
This repository has been archived by the owner on Nov 2, 2023. It is now read-only.

Commit

Permalink
sdk/middleware/sqgin: use the actual request context instead of gin's
Browse files Browse the repository at this point in the history
Gin's context wrongly implements `context.Context` and doesn't wrap the
underlying request context at all. Therefore, we need to use the actual request
context `c.Request.Context()` so that the agent can properly manage the request
context, but also to correctly propagate values stored in the context.
  • Loading branch information
Julio Guerra committed Sep 30, 2020
2 parents 94ffdf0 + a555f99 commit a242284
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 6 deletions.
60 changes: 57 additions & 3 deletions sdk/middleware/sqecho/echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package sqecho

import (
"context"
"errors"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -100,7 +101,7 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, body, rec.Body.String())
})

t.Run("control flow", func(t *testing.T) {
t.Run("data and control flow", func(t *testing.T) {
middlewareResponseBody := testlib.RandUTF8String(4096)
middlewareResponseStatus := 433
handlerResponseBody := testlib.RandUTF8String(4096)
Expand Down Expand Up @@ -133,6 +134,11 @@ func TestMiddleware(t *testing.T) {
handler func(echo.Context) error
test func(t *testing.T, rec *httptest.ResponseRecorder, err error)
}{
//
// Control flow tests
// When an handlers, including middlewares, block.
//

{
name: "sqreen first/the middleware aborts before the handler",
middlewares: []echo.MiddlewareFunc{
Expand Down Expand Up @@ -341,6 +347,54 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, middlewareResponseBody+handlerResponseBody+middlewareResponseBody, rec.Body.String())
},
},

//
// Context data flow tests
//
{
name: "middleware1, sqreen, middleware2, handler",
middlewares: []echo.MiddlewareFunc{
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Set("m10", "v10")
c.SetRequest(c.Request().WithContext(context.WithValue(c.Request().Context(), "m11", "v11")))
return next(c)
}
},
middleware(tc.agent),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Set("m20", "v20")
c.SetRequest(c.Request().WithContext(context.WithValue(c.Request().Context(), "m21", "v21")))
return next(c)
}
},
},
handler: func(c echo.Context) error {
// From Gin's context
if v, ok := c.Get("m10").(string); !ok || v != "v10" {
panic("couldn't get the context value m10")
}
if v, ok := c.Get("m20").(string); !ok || v != "v20" {
panic("couldn't get the context value m20")
}

// From the request context
reqCtx := c.Request().Context()
if v, ok := reqCtx.Value("m11").(string); !ok || v != "v11" {
panic("couldn't get the context value m11")
}
if v, ok := reqCtx.Value("m21").(string); !ok || v != "v21" {
panic("couldn't get the context value m21")
}

return c.NoContent(http.StatusOK)
},
test: func(t *testing.T, rec *httptest.ResponseRecorder, err error) {
require.NoError(t, err)
require.Equal(t, http.StatusOK, rec.Code)
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
Expand Down Expand Up @@ -378,8 +432,8 @@ func TestMiddleware(t *testing.T) {
agent.ExpectIsIPAllowed(mock.Anything).Return(false).Once()
agent.ExpectIsPathAllowed(mock.Anything).Return(false).Once()
var (
responseStatusCode int
responseContentType string
responseStatusCode int
responseContentType string
responseContentLength int64
)
agent.ExpectSendClosedRequestContext(mock.MatchedBy(func(recorded types.ClosedRequestContextFace) bool {
Expand Down
2 changes: 1 addition & 1 deletion sdk/middleware/sqgin/gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func middlewareHandler(agent protection_context.AgentFace, c *gingonic.Context)
requestReader := &requestReaderImpl{c: c}
responseWriter := &responseWriterImpl{c: c}

ctx, reqCtx, cancelHandlerContext := http_protection.NewRequestContext(c, agent, responseWriter, requestReader)
ctx, reqCtx, cancelHandlerContext := http_protection.NewRequestContext(c.Request.Context(), agent, responseWriter, requestReader)
if ctx == nil {
c.Next()
return
Expand Down
49 changes: 48 additions & 1 deletion sdk/middleware/sqgin/gin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package sqgin

import (
"context"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -91,7 +92,7 @@ func TestMiddleware(t *testing.T) {
})

// Test how the control flows between middleware and handler functions
t.Run("control flow", func(t *testing.T) {
t.Run("data and control flow", func(t *testing.T) {
middlewareResponseBody := testlib.RandUTF8String(4096)
middlewareResponseStatus := 433
handlerResponseBody := testlib.RandUTF8String(4096)
Expand Down Expand Up @@ -124,6 +125,11 @@ func TestMiddleware(t *testing.T) {
handler func(*gin.Context)
test func(t *testing.T, rec *httptest.ResponseRecorder)
}{
//
// Control flow tests
// When an handlers, including middlewares, block.
//

{
name: "sqreen first/next middleware aborts before the handler",
middlewares: []gin.HandlerFunc{
Expand Down Expand Up @@ -292,6 +298,47 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, middlewareResponseBody, rec.Body.String())
},
},

//
// Context data flow tests
//
{
name: "middleware1, sqreen, middleware2, handler",
middlewares: []gin.HandlerFunc{
func(c *gin.Context) {
c.Set("m10", "v10")
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), "m11", "v11"))
},
middleware(tc.agent),
func(c *gin.Context) {
c.Set("m20", "v20")
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), "m21", "v21"))
},
},
handler: func(c *gin.Context) {
// From Gin's context
if v, ok := c.Value("m10").(string); !ok || v != "v10" {
panic("couldn't get the context value m10")
}
if v, ok := c.Value("m20").(string); !ok || v != "v20" {
panic("couldn't get the context value m20")
}

// From the request context
reqCtx := c.Request.Context()
if v, ok := reqCtx.Value("m11").(string); !ok || v != "v11" {
panic("couldn't get the context value m11")
}
if v, ok := reqCtx.Value("m21").(string); !ok || v != "v21" {
panic("couldn't get the context value m21")
}

c.Status(http.StatusOK)
},
test: func(t *testing.T, rec *httptest.ResponseRecorder) {
require.Equal(t, http.StatusOK, rec.Code)
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
Expand Down
31 changes: 30 additions & 1 deletion sdk/middleware/sqhttp/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package sqhttp

import (
"context"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -70,7 +71,7 @@ func TestMiddleware(t *testing.T) {
})

// Test how the control flows between middleware and handler functions
t.Run("control flow", func(t *testing.T) {
t.Run("data and control flow", func(t *testing.T) {
middlewareResponseBody := testlib.RandUTF8String(4096)
middlewareResponseStatus := 433
handlerResponseBody := testlib.RandUTF8String(4096)
Expand Down Expand Up @@ -102,6 +103,11 @@ func TestMiddleware(t *testing.T) {
handlers http.Handler
test func(t *testing.T, rec *httptest.ResponseRecorder)
}{
//
// Control flow tests
// When an handlers, including middlewares, block.
//

{
name: "sqreen first/handler writes the response",
handlers: middleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
Expand Down Expand Up @@ -163,6 +169,29 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, middlewareResponseBody, rec.Body.String())
},
},

//
// Context data flow tests
//
{
name: "middleware, sqreen, handler",
handlers: func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), "m", "v"))
next.ServeHTTP(w, r)
})
}(middleware(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if v, ok := ctx.Value("m").(string); !ok || v != "v" {
panic("couldn't get the context value m")
}

w.WriteHeader(http.StatusOK)
}, tc.agent)),
test: func(t *testing.T, rec *httptest.ResponseRecorder) {
require.Equal(t, http.StatusOK, rec.Code)
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
Expand Down

0 comments on commit a242284

Please sign in to comment.