Skip to content

Commit

Permalink
Merge pull request #108455 from Argh4k/race-conditions
Browse files Browse the repository at this point in the history
Copy request in timeout handler

Kubernetes-commit: 9bb5823b83c2929b059498b1e59c08261257126b
  • Loading branch information
k8s-publishing-bot committed Mar 24, 2022
2 parents 96bc518 + 253e375 commit d4cb74b
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 1 deletion.
6 changes: 5 additions & 1 deletion pkg/server/filters/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
resultCh := make(chan interface{})
var tw timeoutWriter
tw, w = newTimeoutWriter(w)

// Make a copy of request and work on it in new goroutine
// to avoid race condition when accessing/modifying request (e.g. headers)
rCopy := r.Clone(r.Context())
go func() {
defer func() {
err := recover()
Expand All @@ -107,7 +111,7 @@ func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
resultCh <- err
}()
t.handler.ServeHTTP(w, r)
t.handler.ServeHTTP(w, rCopy)
}()
select {
case err := <-resultCh:
Expand Down
107 changes: 107 additions & 0 deletions pkg/server/filters/timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,113 @@ func TestTimeoutHeaders(t *testing.T) {
res.Body.Close()
}

func TestTimeoutRequestHeaders(t *testing.T) {
origReallyCrash := runtime.ReallyCrash
runtime.ReallyCrash = false
defer func() {
runtime.ReallyCrash = origReallyCrash
}()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Add dummy request info, otherwise we skip postTimeoutFn
ctx = request.WithRequestInfo(ctx, &request.RequestInfo{})

withDeadline := func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
handler.ServeHTTP(w, req.WithContext(ctx))
})
}

ts := httptest.NewServer(
withDeadline(
WithTimeoutForNonLongRunningRequests(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// trigger the timeout
cancel()
// mutate request Headers
// Authorization filter does it for example
for j := 0; j < 10000; j++ {
req.Header.Set("Test", "post")
}
}),
func(r *http.Request, requestInfo *request.RequestInfo) bool {
return false
},
),
),
)
defer ts.Close()

client := &http.Client{}
req, err := http.NewRequest(http.MethodPatch, ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusGatewayTimeout {
t.Errorf("got res.StatusCde %d; expected %d", res.StatusCode, http.StatusServiceUnavailable)
}
res.Body.Close()
}

func TestTimeoutWithLogging(t *testing.T) {
origReallyCrash := runtime.ReallyCrash
runtime.ReallyCrash = false
defer func() {
runtime.ReallyCrash = origReallyCrash
}()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

withDeadline := func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
handler.ServeHTTP(w, req.WithContext(ctx))
})
}

ts := httptest.NewServer(
WithHTTPLogging(
withDeadline(
WithTimeoutForNonLongRunningRequests(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// trigger the timeout
cancel()
// mutate request Headers
// Authorization filter does it for example
for j := 0; j < 10000; j++ {
req.Header.Set("Test", "post")
}
}),
func(r *http.Request, requestInfo *request.RequestInfo) bool {
return false
},
),
),
),
)
defer ts.Close()

client := &http.Client{}
req, err := http.NewRequest(http.MethodPatch, ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusGatewayTimeout {
t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable)
}
res.Body.Close()
}

func TestErrConnKilled(t *testing.T) {
var buf bytes.Buffer
klog.SetOutput(&buf)
Expand Down

0 comments on commit d4cb74b

Please sign in to comment.