-
Notifications
You must be signed in to change notification settings - Fork 4.1k
/
quotas_rate_limit.go
337 lines (281 loc) · 10.3 KB
/
quotas_rate_limit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
package quotas
import (
"encoding/hex"
"fmt"
"math"
"strconv"
"sync"
"time"
"github.com/armon/go-metrics"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/sdk/helper/cryptoutil"
"github.com/sethvargo/go-limiter"
"github.com/sethvargo/go-limiter/httplimit"
"github.com/sethvargo/go-limiter/memorystore"
)
const (
// DefaultRateLimitPurgeInterval defines the default purge interval used by a
// RateLimitQuota to remove stale client rate limiters.
DefaultRateLimitPurgeInterval = time.Minute
// DefaultRateLimitStaleAge defines the default stale age of a client limiter.
DefaultRateLimitStaleAge = 3 * time.Minute
// EnvVaultEnableRateLimitAuditLogging is used to enable audit logging of
// requests that get rejected due to rate limit quota violations.
EnvVaultEnableRateLimitAuditLogging = "VAULT_ENABLE_RATE_LIMIT_AUDIT_LOGGING"
)
// Ensure that RateLimitQuota implements the Quota interface
var _ Quota = (*RateLimitQuota)(nil)
// RateLimitQuota represents the quota rule properties that is used to limit the
// number of requests in a given interval for a namespace or mount.
type RateLimitQuota struct {
// ID is the identifier of the quota
ID string `json:"id"`
// Type of quota this represents
Type Type `json:"type"`
// Name of the quota rule
Name string `json:"name"`
// NamespacePath is the path of the namespace to which this quota is
// applicable.
NamespacePath string `json:"namespace_path"`
// MountPath is the path of the mount to which this quota is applicable
MountPath string `json:"mount_path"`
// Rate defines the number of requests allowed per Interval.
Rate float64 `json:"rate"`
// Interval defines the duration to which rate limiting is applied.
Interval time.Duration `json:"interval"`
// BlockInterval defines the duration during which all requests are blocked for
// a given client. This interval is enforced only if non-zero and a client
// reaches the rate limit.
BlockInterval time.Duration `json:"block_interval"`
lock *sync.RWMutex
store limiter.Store
logger log.Logger
metricSink *metricsutil.ClusterMetricSink
purgeInterval time.Duration
staleAge time.Duration
blockedClients sync.Map
purgeBlocked bool
closePurgeBlockedCh chan struct{}
}
// NewRateLimitQuota creates a quota checker for imposing limits on the number
// of requests in a given interval. An interval time duration of zero may be
// provided, which will default to 1s when initialized. An optional block
// duration may be provided, where if set, when a client reaches the rate limit,
// subsequent requests will fail until the block duration has passed.
func NewRateLimitQuota(name, nsPath, mountPath string, rate float64, interval, block time.Duration) *RateLimitQuota {
id, err := uuid.GenerateUUID()
if err != nil {
// Fall back to generating with a hash of the name, later in initialize
id = ""
}
return &RateLimitQuota{
Name: name,
ID: id,
Type: TypeRateLimit,
NamespacePath: nsPath,
MountPath: mountPath,
Rate: rate,
Interval: interval,
BlockInterval: block,
purgeInterval: DefaultRateLimitPurgeInterval,
staleAge: DefaultRateLimitStaleAge,
}
}
func (q *RateLimitQuota) Clone() *RateLimitQuota {
rlq := &RateLimitQuota{
ID: q.ID,
Name: q.Name,
MountPath: q.MountPath,
Type: q.Type,
NamespacePath: q.NamespacePath,
BlockInterval: q.BlockInterval,
Rate: q.Rate,
Interval: q.Interval,
}
return rlq
}
// initialize ensures the namespace and max requests are initialized, sets the ID
// if it's currently empty, sets the purge interval and stale age to default
// values, and finally starts the client purge go routine if it has been started
// already. Note, initialize will reset the internal rateQuotas mapping.
func (rlq *RateLimitQuota) initialize(logger log.Logger, ms *metricsutil.ClusterMetricSink) error {
if rlq.lock == nil {
rlq.lock = new(sync.RWMutex)
}
rlq.lock.Lock()
defer rlq.lock.Unlock()
// Memdb requires a non-empty value for indexing
if rlq.NamespacePath == "" {
rlq.NamespacePath = "root"
}
if rlq.Interval == 0 {
rlq.Interval = time.Second
}
if rlq.Rate <= 0 {
return fmt.Errorf("invalid rate: %v", rlq.Rate)
}
if rlq.BlockInterval < 0 {
return fmt.Errorf("invalid block interval: %v", rlq.BlockInterval)
}
if logger != nil {
rlq.logger = logger
}
if rlq.metricSink == nil {
rlq.metricSink = ms
}
if rlq.ID == "" {
// A lease which was created with a blank ID may have been persisted
// to storage already (this is the case up to release 1.6.2.)
// So, performance standby nodes could call initialize() on their copy
// of the lease; for consistency we need to generate an ID that is
// deterministic. That ensures later invalidation removes the original
// lease from the memdb, instead of creating a duplicate.
rlq.ID = hex.EncodeToString(cryptoutil.Blake2b256Hash(rlq.Name))
}
// Set purgeInterval if coming from a previous version where purgeInterval was
// not defined.
if rlq.purgeInterval == 0 {
rlq.purgeInterval = DefaultRateLimitPurgeInterval
}
// Set staleAge if coming from a previous version where staleAge was not defined.
if rlq.staleAge == 0 {
rlq.staleAge = DefaultRateLimitStaleAge
}
rlStore, err := memorystore.New(&memorystore.Config{
Tokens: uint64(math.Round(rlq.Rate)), // allow 'rlq.Rate' number of requests per 'Interval'
Interval: rlq.Interval, // time interval in which to enforce rate limiting
SweepInterval: rlq.purgeInterval, // how often stale clients are removed
SweepMinTTL: rlq.staleAge, // how long since the last request a client is considered stale
})
if err != nil {
return err
}
rlq.store = rlStore
rlq.blockedClients = sync.Map{}
if rlq.BlockInterval > 0 && !rlq.purgeBlocked {
rlq.purgeBlocked = true
rlq.closePurgeBlockedCh = make(chan struct{})
go rlq.purgeBlockedClients()
}
return nil
}
// purgeBlockedClients performs a blocking process where every purgeInterval
// duration, we look at all blocked clients to potentially remove from the blocked
// clients map.
//
// A blocked client will only be removed if the current time minus the time the
// client was blocked at is greater than or equal to the block duration. The loop
// will continue to run indefinitely until a value is sent on the closePurgeBlockedCh
// in which we stop the ticker and return.
func (rlq *RateLimitQuota) purgeBlockedClients() {
rlq.lock.RLock()
ticker := time.NewTicker(rlq.purgeInterval)
rlq.lock.RUnlock()
for {
select {
case t := <-ticker.C:
rlq.blockedClients.Range(func(key, value interface{}) bool {
blockedAt := value.(time.Time)
if t.Sub(blockedAt) >= rlq.BlockInterval {
rlq.blockedClients.Delete(key)
}
return true
})
case <-rlq.closePurgeBlockedCh:
ticker.Stop()
rlq.lock.Lock()
rlq.purgeBlocked = false
rlq.lock.Unlock()
return
}
}
}
func (rlq *RateLimitQuota) getPurgeBlocked() bool {
rlq.lock.RLock()
defer rlq.lock.RUnlock()
return rlq.purgeBlocked
}
func (rlq *RateLimitQuota) numBlockedClients() int {
rlq.lock.RLock()
defer rlq.lock.RUnlock()
size := 0
rlq.blockedClients.Range(func(_, _ interface{}) bool {
size++
return true
})
return size
}
// quotaID returns the identifier of the quota rule
func (rlq *RateLimitQuota) quotaID() string {
return rlq.ID
}
// QuotaName returns the name of the quota rule
func (rlq *RateLimitQuota) QuotaName() string {
return rlq.Name
}
// allow decides if the request is allowed by the quota. An error will be
// returned if the request ID or address is empty. If the path is exempt, the
// quota will not be evaluated. Otherwise, the client rate limiter is retrieved
// by address and the rate limit quota is checked against that limiter.
func (rlq *RateLimitQuota) allow(req *Request) (Response, error) {
resp := Response{
Headers: make(map[string]string),
}
if req.ClientAddress == "" {
return resp, fmt.Errorf("missing request client address in quota request")
}
var retryAfter string
defer func() {
if !resp.Allowed {
resp.Headers[httplimit.HeaderRetryAfter] = retryAfter
rlq.metricSink.IncrCounterWithLabels([]string{"quota", "rate_limit", "violation"}, 1, []metrics.Label{{"name", rlq.Name}})
}
}()
// Check if the client is currently blocked and if so, deny the request. Note,
// we cannot simply rely on the presence of the client in the map as the timing
// of purging blocked clients may not yield a false negative. In other words,
// a client may no longer be considered blocked whereas the purging interval
// has yet to run.
if v, ok := rlq.blockedClients.Load(req.ClientAddress); ok {
blockedAt := v.(time.Time)
if time.Since(blockedAt) >= rlq.BlockInterval {
// allow the request and remove the blocked client
rlq.blockedClients.Delete(req.ClientAddress)
} else {
// deny the request and return early
resp.Allowed = false
retryAfter = strconv.Itoa(int(time.Until(blockedAt.Add(rlq.BlockInterval)).Seconds()))
return resp, nil
}
}
limit, remaining, reset, allow := rlq.store.Take(req.ClientAddress)
resp.Allowed = allow
resp.Headers[httplimit.HeaderRateLimitLimit] = strconv.FormatUint(limit, 10)
resp.Headers[httplimit.HeaderRateLimitRemaining] = strconv.FormatUint(remaining, 10)
resp.Headers[httplimit.HeaderRateLimitReset] = strconv.Itoa(int(time.Until(time.Unix(0, int64(reset))).Seconds()))
retryAfter = resp.Headers[httplimit.HeaderRateLimitReset]
// If the request is not allowed (i.e. rate limit threshold reached) and blocking
// is enabled, we add the client to the set of blocked clients.
if !resp.Allowed && rlq.purgeBlocked {
blockedAt := time.Now()
retryAfter = strconv.Itoa(int(time.Until(blockedAt.Add(rlq.BlockInterval)).Seconds()))
rlq.blockedClients.Store(req.ClientAddress, blockedAt)
}
return resp, nil
}
// close stops the current running client purge loop.
// It should be called with the write lock held.
func (rlq *RateLimitQuota) close() error {
if rlq.purgeBlocked {
close(rlq.closePurgeBlockedCh)
}
if rlq.store != nil {
return rlq.store.Close()
}
return nil
}
func (rlq *RateLimitQuota) handleRemount(toPath string) {
rlq.MountPath = toPath
}