-
Notifications
You must be signed in to change notification settings - Fork 2
/
concurrentlimit_test.go
164 lines (144 loc) · 3.62 KB
/
concurrentlimit_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
package concurrentlimit
import (
"context"
"errors"
"net"
"net/http"
"strconv"
"sync"
"syscall"
"testing"
"time"
)
func TestNoLimit(t *testing.T) {
limiter := NoLimit()
endFuncs := []func(){}
for i := 0; i < 10000; i++ {
end, err := limiter.Start()
if err != nil {
t.Fatal("NoLimit should never return an error")
}
endFuncs = append(endFuncs, end)
}
// calling all the end functions should work
for _, end := range endFuncs {
end()
}
}
func TestLimiterRace(t *testing.T) {
const permitted = 100
limiter := New(permitted)
// start the limiter in separate goroutines so hopefully the race detector can find bugs
var wg sync.WaitGroup
endFuncs := make(chan func(), permitted)
for i := 0; i < permitted; i++ {
wg.Add(1)
go func() {
defer wg.Done()
end, err := limiter.Start()
if err != nil {
t.Error("Limiter must allow the first N calls", err)
}
endFuncs <- end
}()
}
wg.Wait()
// the next calls must fail
for i := 0; i < 5; i++ {
end, err := limiter.Start()
if !(end == nil && err == ErrLimited) {
t.Fatalf("Limiter must block calls after the first N calls: %p %#v", end, err)
}
}
// Call one end function: the next Start call should work
end := <-endFuncs
end()
end, err := limiter.Start()
if !(end != nil && err == nil) {
t.Fatal("The next call must succeed after end is called")
}
endFuncs <- end
close(endFuncs)
// calling all the end functions should work
for end := range endFuncs {
end()
}
}
// Block HTTP requests until unblock is closed
type blockForConcurrent struct {
unblock chan struct{}
}
func (b *blockForConcurrent) ServeHTTP(w http.ResponseWriter, r *http.Request) {
<-b.unblock
}
func TestHTTP(t *testing.T) {
// set up a rate limited HTTP server
const permitted = 3
// pick a random port that should be available
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
port := listener.Addr().(*net.TCPAddr).Port
err = listener.Close()
if err != nil {
t.Fatal(err)
}
httpAddr := "localhost:" + strconv.Itoa(port)
// start the server
handler := &blockForConcurrent{make(chan struct{})}
testServer := &http.Server{
Addr: httpAddr,
Handler: handler,
}
go func() {
// must allow more connections than requests, otherwise it waits for the connection to close
err := ListenAndServe(testServer, permitted, permitted*2)
if err != http.ErrServerClosed {
t.Error("expected HTTP server to be shutdown; err:", err)
}
}()
defer testServer.Shutdown(context.Background())
responses := make(chan int)
for i := 0; i < permitted+1; i++ {
go func() {
const attempts = 3
for i := 0; i < attempts; i++ {
resp, err := http.Get("http://" + httpAddr)
if err != nil {
var syscallErr syscall.Errno
if errors.As(err, &syscallErr) && syscallErr == syscall.ECONNREFUSED {
// race with the server starting up: try again
time.Sleep(10 * time.Millisecond)
continue
}
close(responses)
t.Error(err)
}
resp.Body.Close()
responses <- resp.StatusCode
return
}
t.Error("failed after too many attempts")
}()
}
okCount := 0
rateLimitedCount := 0
for i := 0; i < permitted+1; i++ {
response := <-responses
if i == 0 {
// unblock the handlers on the first response, no matter what it is
close(handler.unblock)
}
if response == http.StatusOK {
okCount++
} else if response == http.StatusTooManyRequests {
rateLimitedCount++
} else {
t.Fatal("unexpected HTTP status code:", response)
}
}
if !(okCount == permitted && rateLimitedCount == 1) {
t.Error("unexpected OK and rate limited response counts:", okCount, rateLimitedCount)
}
}