/
limiter_test.go
130 lines (106 loc) · 3.8 KB
/
limiter_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package limiter
import (
"fmt"
"testing"
"time"
)
func TestConstructor(t *testing.T) {
lmt := New(nil).SetMax(1)
if lmt.GetMax() != 1 {
t.Errorf("Max field is incorrect. Value: %v", lmt.GetMax())
}
if lmt.GetMessage() != "You have reached maximum request limit." {
t.Errorf("Message field is incorrect. Value: %v", lmt.GetMessage())
}
if lmt.GetStatusCode() != 429 {
t.Errorf("StatusCode field is incorrect. Value: %v", lmt.GetStatusCode())
}
}
func TestConstructorExpiringBuckets(t *testing.T) {
lmt := New(&ExpirableOptions{DefaultExpirationTTL: time.Second, ExpireJobInterval: 0}).SetMax(1)
if lmt.GetMax() != 1 {
t.Errorf("Max field is incorrect. Value: %v", lmt.GetMax())
}
if lmt.GetMessage() != "You have reached maximum request limit." {
t.Errorf("Message field is incorrect. Value: %v", lmt.GetMessage())
}
if lmt.GetStatusCode() != 429 {
t.Errorf("StatusCode field is incorrect. Value: %v", lmt.GetStatusCode())
}
}
func TestLimitReached(t *testing.T) {
lmt := New(nil).SetMax(1).SetBurst(1)
key := "127.0.0.1|/"
if lmt.LimitReached(key) == true {
t.Error("First time count should not reached the limit.")
}
if lmt.LimitReached(key) == false {
t.Error("Second time count should return true because it exceeds 1 request per second.")
}
<-time.After(1 * time.Second)
if lmt.LimitReached(key) == true {
t.Error("Third time count should not reached the limit because the 1 second window has passed.")
}
}
func TestFloatingLimitReached(t *testing.T) {
lmt := New(nil).SetMax(0.1).SetBurst(1)
key := "127.0.0.1|/"
if lmt.LimitReached(key) == true {
t.Error("First time count should not reached the limit.")
}
if lmt.LimitReached(key) == false {
t.Error("Second time count should return true because it exceeds 1 request per 10 seconds.")
}
<-time.After(7 * time.Second)
if lmt.LimitReached(key) == false {
t.Error("Third time count should return true because it exceeds 1 request per 10 seconds.")
}
<-time.After(3 * time.Second)
if lmt.LimitReached(key) == true {
t.Error("Fourth time count should not reached the limit because the 10 second window has passed.")
}
}
func TestLimitReachedWithCustomTokenBucketTTL(t *testing.T) {
lmt := New(&ExpirableOptions{DefaultExpirationTTL: time.Second, ExpireJobInterval: 0}).SetMax(1).SetBurst(1)
key := "127.0.0.1|/"
if lmt.LimitReached(key) == true {
t.Error("First time count should not reached the limit.")
}
if lmt.LimitReached(key) == false {
t.Error("Second time count should return true because it exceeds 1 request per second.")
}
<-time.After(1 * time.Second)
if lmt.LimitReached(key) == true {
t.Error("Third time count should not reached the limit because the 1 second window has passed.")
}
}
func TestMuchHigherMaxRequests(t *testing.T) {
numRequests := 1000
delay := (1 * time.Second) / time.Duration(numRequests)
lmt := New(nil).SetMax(float64(numRequests)).SetBurst(1)
key := "127.0.0.1|/"
for i := 0; i < numRequests; i++ {
time.Sleep(delay)
if lmt.LimitReached(key) == true {
t.Errorf("N(%v) limit should not be reached.", i)
}
}
if lmt.LimitReached(key) == false {
t.Errorf("N(%v) limit should be reached because it exceeds %v request per second.", numRequests+2, numRequests)
}
}
func TestMuchHigherMaxRequestsWithCustomTokenBucketTTL(t *testing.T) {
numRequests := 1000
delay := (1 * time.Second) / time.Duration(numRequests)
lmt := New(&ExpirableOptions{DefaultExpirationTTL: time.Minute, ExpireJobInterval: time.Minute}).SetMax(float64(numRequests)).SetBurst(1)
key := "127.0.0.1|/"
for i := 0; i < numRequests; i++ {
time.Sleep(delay)
if lmt.LimitReached(key) == true {
fmt.Printf("N(%v) limit should not be reached.\n", i)
}
}
if lmt.LimitReached(key) == false {
t.Errorf("N(%v) limit should be reached because it exceeds %v request per second.", numRequests+1, numRequests)
}
}