Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(context): add ContextWithFallback feature flag (#3166) #3172

Merged
merged 1 commit into from Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions context.go
Expand Up @@ -1158,23 +1158,23 @@ func (c *Context) SetAccepted(formats ...string) {

// Deadline returns that there is no deadline (ok==false) when c.Request has no Context.
func (c *Context) Deadline() (deadline time.Time, ok bool) {
if c.Request == nil || c.Request.Context() == nil {
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
return
}
return c.Request.Context().Deadline()
}

// Done returns nil (chan which will wait forever) when c.Request has no Context.
func (c *Context) Done() <-chan struct{} {
if c.Request == nil || c.Request.Context() == nil {
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
return nil
}
return c.Request.Context().Done()
}

// Err returns nil when c.Request has no Context.
func (c *Context) Err() error {
if c.Request == nil || c.Request.Context() == nil {
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
return nil
}
return c.Request.Context().Err()
Expand All @@ -1195,7 +1195,7 @@ func (c *Context) Value(key any) any {
return val
}
}
if c.Request == nil || c.Request.Context() == nil {
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
return nil
}
return c.Request.Context().Value(key)
Expand Down
115 changes: 103 additions & 12 deletions context_test.go
Expand Up @@ -2097,12 +2097,18 @@ func TestRemoteIPFail(t *testing.T) {
}

func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true

deadline, ok := c.Deadline()
assert.Zero(t, deadline)
assert.False(t, ok)

c2 := &Context{}
c2, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true

c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
d := time.Now().Add(time.Second)
ctx, cancel := context.WithDeadline(context.Background(), d)
Expand All @@ -2114,10 +2120,16 @@ func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
}

func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true

assert.Nil(t, c.Done())

c2 := &Context{}
c2, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true

c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(context.Background())
c2.Request = c2.Request.WithContext(ctx)
Expand All @@ -2126,10 +2138,16 @@ func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
}

func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true

assert.Nil(t, c.Err())

c2 := &Context{}
c2, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true

c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(context.Background())
c2.Request = c2.Request.WithContext(ctx)
Expand All @@ -2138,9 +2156,9 @@ func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
assert.EqualError(t, c2.Err(), context.Canceled.Error())
}

type contextKey string

func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
type contextKey string

tests := []struct {
name string
getContextAndKey func() (*Context, any)
Expand All @@ -2150,7 +2168,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
name: "c with struct context key",
getContextAndKey: func() (*Context, any) {
var key struct{}
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
c.Request, _ = http.NewRequest("POST", "/", nil)
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value"))
return c, key
Expand All @@ -2160,7 +2180,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
{
name: "c with string context key",
getContextAndKey: func() (*Context, any) {
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
c.Request, _ = http.NewRequest("POST", "/", nil)
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), contextKey("key"), "value"))
return c, contextKey("key")
Expand All @@ -2170,15 +2192,20 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
{
name: "c with nil http.Request",
getContextAndKey: func() (*Context, any) {
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
c.Request = nil
return c, "key"
},
value: nil,
},
{
name: "c with nil http.Request.Context()",
getContextAndKey: func() (*Context, any) {
c := &Context{}
c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
c.Request, _ = http.NewRequest("POST", "/", nil)
return c, "key"
},
Expand All @@ -2193,6 +2220,70 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
}
}

func TestContextCopyShouldNotCancel(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

ensureRequestIsOver := make(chan struct{})

wg := &sync.WaitGroup{}

r := New()
r.GET("/", func(ginctx *Context) {
wg.Add(1)

ginctx = ginctx.Copy()

// start async goroutine for calling srv
go func() {
defer wg.Done()

<-ensureRequestIsOver // ensure request is done

req, err := http.NewRequestWithContext(ginctx, http.MethodGet, srv.URL, nil)
must(err)

res, err := http.DefaultClient.Do(req)
if err != nil {
t.Error(fmt.Errorf("request error: %w", err))
return
}

if res.StatusCode != http.StatusOK {
t.Error(fmt.Errorf("unexpected status code: %s", res.Status))
}
}()
})

l, err := net.Listen("tcp", ":0")
must(err)
go func() {
s := &http.Server{
Handler: r,
}

must(s.Serve(l))
}()

addr := strings.Split(l.Addr().String(), ":")
res, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/", addr[len(addr)-1]))
if err != nil {
t.Error(fmt.Errorf("request error: %w", err))
return
}

close(ensureRequestIsOver)

if res.StatusCode != http.StatusOK {
t.Error(fmt.Errorf("unexpected status code: %s", res.Status))
return
}

wg.Wait()
}

func TestContextAddParam(t *testing.T) {
c := &Context{}
id := "id"
Expand Down
3 changes: 3 additions & 0 deletions gin.go
Expand Up @@ -147,6 +147,9 @@ type Engine struct {
// UseH2C enable h2c support.
UseH2C bool

// ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil.
ContextWithFallback bool

delims render.Delims
secureJSONPrefix string
HTMLRender render.HTMLRender
Expand Down