Skip to content

Commit

Permalink
feat: close idle connections when server shutdown (#1155)
Browse files Browse the repository at this point in the history
* feat: close idle connections when server shutdown

* Fix redundant code

* Update test

* Update test
  • Loading branch information
ichxxx committed Nov 13, 2021
1 parent a94a2c3 commit 3b117f8
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 1 deletion.
32 changes: 31 additions & 1 deletion server.go
Expand Up @@ -414,9 +414,12 @@ type Server struct {
writerPool sync.Pool
hijackConnPool sync.Pool

// We need to know our listeners so we can close them in Shutdown().
// We need to know our listeners and idle connections so we can close them in Shutdown().
ln []net.Listener

idleConns map[net.Conn]struct{}
idleConnsMu sync.Mutex

mu sync.Mutex
open int32
stop int32
Expand Down Expand Up @@ -1835,6 +1838,8 @@ func (s *Server) Shutdown() error {
close(s.done)
}

s.closeIdleConns()

// Closing the listener will make Serve() call Stop on the worker pool.
// Setting .stop to 1 will make serveConn() break out of its loop.
// Now we just have to wait until all workers are done.
Expand Down Expand Up @@ -2424,6 +2429,7 @@ func (s *Server) serveConn(c net.Conn) (err error) {
}

func (s *Server) setState(nc net.Conn, state ConnState) {
s.trackConn(nc, state)
if hook := s.ConnState; hook != nil {
hook(nc, state)
}
Expand Down Expand Up @@ -2793,6 +2799,30 @@ func (s *Server) writeErrorResponse(bw *bufio.Writer, ctx *RequestCtx, serverNam
return bw
}

func (s *Server) trackConn(c net.Conn, state ConnState) {
s.idleConnsMu.Lock()
switch state {
case StateIdle:
if s.idleConns == nil {
s.idleConns = make(map[net.Conn]struct{})
}
s.idleConns[c] = struct{}{}

default:
delete(s.idleConns, c)
}
s.idleConnsMu.Unlock()
}

func (s *Server) closeIdleConns() {
s.idleConnsMu.Lock()
for c := range s.idleConns {
_ = c.Close()
}
s.idleConns = nil
s.idleConnsMu.Unlock()
}

// A ConnState represents the state of a client connection to a server.
// It's used by the optional Server.ConnState hook.
type ConnState int
Expand Down
41 changes: 41 additions & 0 deletions server_test.go
Expand Up @@ -3347,6 +3347,47 @@ func TestShutdownErr(t *testing.T) {
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}

func TestShutdownCloseIdleConns(t *testing.T) {
t.Parallel()

ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Success("aaa/bbb", []byte("real response"))
},
}
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexepcted error: %s", err)
}
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexepcted error: %s", err)
}

if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %s", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")

shutdownErr := make(chan error)
go func() {
shutdownErr <- s.Shutdown()
}()

timer := time.NewTimer(time.Second)
select {
case <-timer.C:
t.Fatal("idle connections not closed on shutdown")
case err = <-shutdownErr:
if err != nil {
t.Errorf("unexepcted error: %s", err)
}
}
}

func TestMultipleServe(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 3b117f8

Please sign in to comment.