Skip to content

Commit

Permalink
refactor: move chan to main for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
johejo committed Apr 4, 2020
1 parent a4e531a commit 44ed2ff
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 31 deletions.
8 changes: 2 additions & 6 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ import (
"errors"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"

"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
Expand All @@ -19,6 +16,7 @@ import (
type Server struct {
Handler http.Handler
Opts *Options
stop chan struct{} // channel for waiting shutdown
}

// ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
Expand Down Expand Up @@ -134,9 +132,7 @@ func (s *Server) serve(listener net.Listener) {
// See https://golang.org/pkg/net/http/#Server.Shutdown
idleConnsClosed := make(chan struct{})
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
<-sigint
<-s.stop // wait notification for stopping server

// We received an interrupt signal, shut down.
if err := srv.Shutdown(context.Background()); err != nil {
Expand Down
37 changes: 12 additions & 25 deletions http_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
package main

import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"sync"
"syscall"
"testing"
"time"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -163,25 +159,16 @@ func TestRedirectNotWhenHTTPS(t *testing.T) {
}

func TestGracefulShutdown(t *testing.T) {
signals := []syscall.Signal{syscall.SIGINT, syscall.SIGTERM}

for i, signal := range signals {
name := fmt.Sprintf("%s", signal)
t.Run(name, func(t *testing.T) {
opts := NewOptions()
opts.HTTPAddress = fmt.Sprintf(":%d", 4180+i)
srv := Server{Handler: http.DefaultServeMux, Opts: opts}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.ServeHTTP()
}()
time.Sleep(500 * time.Millisecond)
if err := syscall.Kill(os.Getpid(), signal); err != nil {
t.Fatal(err)
}
wg.Wait()
})
}
opts := NewOptions()
stop := make(chan struct{}, 1)
srv := Server{Handler: http.DefaultServeMux, Opts: opts, stop: stop}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.ServeHTTP()
}()

stop <- struct{}{} // emulate catching signals
wg.Wait()
}
10 changes: 10 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"math/rand"
"net/http"
"os"
"os/signal"
"runtime"
"strings"
"syscall"
"time"

"github.com/BurntSushi/toml"
Expand Down Expand Up @@ -204,6 +206,14 @@ func main() {
s := &Server{
Handler: handler,
Opts: opts,
stop: make(chan struct{}, 1),
}
// Observe signals in background goroutine.
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
<-sigint
s.stop <- struct{}{} // notify having caught signal
}()
s.ListenAndServe()
}

0 comments on commit 44ed2ff

Please sign in to comment.