diff --git a/server.go b/server.go index a7e93b174e..9f66de7900 100644 --- a/server.go +++ b/server.go @@ -1856,6 +1856,57 @@ func (s *Server) Shutdown() error { return nil } +// ShutdownWithContext support shutdown with timeout and timeout happens after all listener already close. +func (s *Server) ShutdownWithContext(ctx context.Context) (err error) { + s.mu.Lock() + defer s.mu.Unlock() + + atomic.StoreInt32(&s.stop, 1) + defer atomic.StoreInt32(&s.stop, 0) + + if s.ln == nil { + return nil + } + + for _, ln := range s.ln { + if err = ln.Close(); err != nil { + return err + } + } + + if s.done != nil { + close(s.done) + } + + // Closing the listener will make Serve() call Stop on the worker pool. + // Setting .stop to 1 will make serveConn() break out of its loop. + // Now we just have to wait until all workers are done or timeout. + ticker := time.NewTicker(time.Millisecond * 100) + defer ticker.Stop() +END: + for { + s.closeIdleConns() + + if open := atomic.LoadInt32(&s.open); open == 0 { + break + } + // This is not an optimal solution but using a sync.WaitGroup + // here causes data races as it's hard to prevent Add() to be called + // while Wait() is waiting. + select { + case <-ctx.Done(): + err = ctx.Err() + break END + case <-ticker.C: + continue + } + } + + s.done = nil + s.ln = nil + return err +} + func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) { for { var c net.Conn diff --git a/server_test.go b/server_test.go index 1a2adcc582..c38532db47 100644 --- a/server_test.go +++ b/server_test.go @@ -3597,6 +3597,50 @@ func TestShutdownCloseIdleConns(t *testing.T) { } } +func TestShutdownWithContext(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.Success("aaa/bbb", []byte("real response")) + }, + } + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexepcted error: %v", err) + } + }() + conn, err := ln.Dial() + if err != nil { + t.Fatalf("unexepcted error: %v", err) + } + + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %v", err) + } + br := bufio.NewReader(conn) + verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + shutdownErr := make(chan error) + go func() { + shutdownErr <- s.ShutdownWithContext(ctx) + }() + + timer := time.NewTimer(time.Second) + select { + case <-timer.C: + t.Fatal("idle connections not closed on shutdown") + case err = <-shutdownErr: + if err == nil || err != context.DeadlineExceeded { + t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded) + } + } +} + func TestMultipleServe(t *testing.T) { t.Parallel()