Skip to content

Commit

Permalink
feat: add ShutdownWithContext
Browse files Browse the repository at this point in the history
  • Loading branch information
li-jin-gou committed Sep 20, 2022
1 parent 2f1e949 commit 1a01bcd
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
51 changes: 51 additions & 0 deletions server.go
Expand Up @@ -1856,6 +1856,57 @@ func (s *Server) Shutdown() error {
return nil
}

// ShutdownWithContext support shutdown with timeout and timeout happens after all listener already close.
func (s *Server) ShutdownWithContext(ctx context.Context) (err error) {
s.mu.Lock()
defer s.mu.Unlock()

atomic.StoreInt32(&s.stop, 1)
defer atomic.StoreInt32(&s.stop, 0)

if s.ln == nil {
return nil
}

for _, ln := range s.ln {
if err = ln.Close(); err != nil {
return err
}
}

if s.done != nil {
close(s.done)
}

// Closing the listener will make Serve() call Stop on the worker pool.
// Setting .stop to 1 will make serveConn() break out of its loop.
// Now we just have to wait until all workers are done or timeout.
ticker := time.NewTicker(time.Millisecond * 100)
defer ticker.Stop()
END:
for {
s.closeIdleConns()

if open := atomic.LoadInt32(&s.open); open == 0 {
break
}
// This is not an optimal solution but using a sync.WaitGroup
// here causes data races as it's hard to prevent Add() to be called
// while Wait() is waiting.
select {
case <-ctx.Done():
err = ctx.Err()
break END
case <-ticker.C:
continue
}
}

s.done = nil
s.ln = nil
return err
}

func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) {
for {
var c net.Conn
Expand Down
44 changes: 44 additions & 0 deletions server_test.go
Expand Up @@ -3597,6 +3597,50 @@ func TestShutdownCloseIdleConns(t *testing.T) {
}
}

func TestShutdownWithContext(t *testing.T) {
t.Parallel()

ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Success("aaa/bbb", []byte("real response"))
},
}
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexepcted error: %v", err)
}
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexepcted error: %v", err)
}

if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()

shutdownErr := make(chan error)
go func() {
shutdownErr <- s.ShutdownWithContext(ctx)
}()

timer := time.NewTimer(time.Second)
select {
case <-timer.C:
t.Fatal("idle connections not closed on shutdown")
case err = <-shutdownErr:
if err == nil || err != context.DeadlineExceeded {
t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded)
}
}
}

func TestMultipleServe(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 1a01bcd

Please sign in to comment.