diff --git a/server.go b/server.go index a7e93b174e..988c2890d0 100644 --- a/server.go +++ b/server.go @@ -1816,6 +1816,17 @@ func (s *Server) Serve(ln net.Listener) error { // // Shutdown does not close keepalive connections so its recommended to set ReadTimeout and IdleTimeout to something else than 0. func (s *Server) Shutdown() error { + return s.ShutdownWithContext(context.Background()) +} + +// ShutdownWithContext gracefully shuts down the server without interrupting any active connections. +// ShutdownWithContext works by first closing all open listeners and then waiting for all connections to return to idle or context timeout and then shut down. +// +// When ShutdownWithContext is called, Serve, ListenAndServe, and ListenAndServeTLS immediately return nil. +// Make sure the program doesn't exit and waits instead for Shutdown to return. +// +// ShutdownWithContext does not close keepalive connections so its recommended to set ReadTimeout and IdleTimeout to something else than 0. +func (s *Server) ShutdownWithContext(ctx context.Context) (err error) { s.mu.Lock() defer s.mu.Unlock() @@ -1827,7 +1838,7 @@ func (s *Server) Shutdown() error { } for _, ln := range s.ln { - if err := ln.Close(); err != nil { + if err = ln.Close(); err != nil { return err } } @@ -1838,7 +1849,10 @@ func (s *Server) Shutdown() error { // Closing the listener will make Serve() call Stop on the worker pool. // Setting .stop to 1 will make serveConn() break out of its loop. - // Now we just have to wait until all workers are done. + // Now we just have to wait until all workers are done or timeout. + ticker := time.NewTicker(time.Millisecond * 100) + defer ticker.Stop() +END: for { s.closeIdleConns() @@ -1848,12 +1862,18 @@ func (s *Server) Shutdown() error { // This is not an optimal solution but using a sync.WaitGroup // here causes data races as it's hard to prevent Add() to be called // while Wait() is waiting. - time.Sleep(time.Millisecond * 100) + select { + case <-ctx.Done(): + err = ctx.Err() + break END + case <-ticker.C: + continue + } } s.done = nil s.ln = nil - return nil + return err } func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) { diff --git a/server_test.go b/server_test.go index 1a2adcc582..237559a2ea 100644 --- a/server_test.go +++ b/server_test.go @@ -19,6 +19,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "testing" "time" @@ -3597,6 +3598,57 @@ func TestShutdownCloseIdleConns(t *testing.T) { } } +func TestShutdownWithContext(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + time.Sleep(5 * time.Second) + ctx.Success("aaa/bbb", []byte("real response")) + }, + } + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexepcted error: %v", err) + } + }() + time.Sleep(1 * time.Second) + go func() { + conn, err := ln.Dial() + if err != nil { + t.Errorf("unexepcted error: %v", err) + } + + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %v", err) + } + br := bufio.NewReader(conn) + verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") + }() + + time.Sleep(1 * time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + shutdownErr := make(chan error) + go func() { + shutdownErr <- s.ShutdownWithContext(ctx) + }() + + timer := time.NewTimer(time.Second) + select { + case <-timer.C: + t.Fatal("idle connections not closed on shutdown") + case err := <-shutdownErr: + if err == nil || err != context.DeadlineExceeded { + t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded) + } + } + if atomic.LoadInt32(&s.open) != 1 { + t.Fatalf("unexpected open connection num: %#v. Expecting %#v", atomic.LoadInt32(&s.open), 1) + } +} + func TestMultipleServe(t *testing.T) { t.Parallel()