Skip to content

Commit

Permalink
fix(hijack): reset userValues after hijack handler execution (#1199)
Browse files Browse the repository at this point in the history
* fix(hijack): reset userValues after hijack handler execution

* feat: add test

* fix: typo

* fix(test): race condition
  • Loading branch information
Sergio VS committed Jan 18, 2022
1 parent 9123060 commit 2aca3e8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
6 changes: 4 additions & 2 deletions server.go
Expand Up @@ -2375,6 +2375,7 @@ func (s *Server) serveConn(c net.Conn) (err error) {

if hijackHandler != nil {
var hjr io.Reader = c
hctx := ctx
if br != nil {
hjr = br
br = nil
Expand All @@ -2394,7 +2395,7 @@ func (s *Server) serveConn(c net.Conn) (err error) {
if err != nil {
break
}
go hijackConnHandler(hjr, c, s, hijackHandler)
go hijackConnHandler(hctx, hjr, c, s, hijackHandler)
err = errHijacked
break
}
Expand Down Expand Up @@ -2446,7 +2447,7 @@ func (s *Server) setState(nc net.Conn, state ConnState) {
}
}

func hijackConnHandler(r io.Reader, c net.Conn, s *Server, h HijackHandler) {
func hijackConnHandler(ctx *RequestCtx, r io.Reader, c net.Conn, s *Server, h HijackHandler) {
hjc := s.acquireHijackConn(r, c)
h(hjc)

Expand All @@ -2457,6 +2458,7 @@ func hijackConnHandler(r io.Reader, c net.Conn, s *Server, h HijackHandler) {
c.Close()
s.releaseHijackConn(hjc)
}
ctx.ResetUserValues()
}

func (s *Server) acquireHijackConn(r io.Reader, c net.Conn) *hijackConn {
Expand Down
31 changes: 31 additions & 0 deletions server_test.go
Expand Up @@ -2088,6 +2088,37 @@ func TestServeConnKeepRequestAndResponseUntilResetUserValues(t *testing.T) {
}
}

func TestServeConnHijackResetUserValues(t *testing.T) {
t.Parallel()

rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.0\r\nConnection: keep-alive\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("")

ch := make(chan struct{})
go func() {
err := ServeConn(rw, func(ctx *RequestCtx) {
ctx.Hijack(func(c net.Conn) {})
ctx.SetUserValue("myKey", &closerWithRequestCtx{
closeFunc: func(_ *RequestCtx) error {
close(ch)

return nil
}},
)
})
if err != nil {
t.Errorf("unexpected error in ServeConn: %s", err)
}
}()

select {
case <-ch:
case <-time.After(time.Second):
t.Errorf("Timeout: UserValues should be reset")
}
}

func TestServeConnNonHTTP11KeepAlive(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 2aca3e8

Please sign in to comment.