Skip to content

Commit

Permalink
adding strategy support
Browse files Browse the repository at this point in the history
  • Loading branch information
Mzack9999 committed Apr 16, 2024
1 parent f65d2d5 commit a5ff5fc
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 5 deletions.
1 change: 1 addition & 0 deletions go.mod
Expand Up @@ -5,6 +5,7 @@ go 1.21
require (
github.com/projectdiscovery/utils v0.0.89
github.com/stretchr/testify v1.9.0
golang.org/x/time v0.5.0
)

require (
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Expand Up @@ -6,6 +6,8 @@ github.com/projectdiscovery/utils v0.0.89 h1:ruH2bSkpX/rB7EPp2EV/rWyAubQVxCVU38n
github.com/projectdiscovery/utils v0.0.89/go.mod h1:Dwh5cxn7y97jvyYG3GmBvj0negfH9IjH15qXnzFNtOI=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
55 changes: 50 additions & 5 deletions ratelimit.go
Expand Up @@ -5,20 +5,27 @@ import (
"math"
"sync/atomic"
"time"

"golang.org/x/time/rate"
)

// equals to -1
var minusOne = ^uint32(0)

// Limiter allows a burst of request during the defined duration
type Limiter struct {
strategy Strategy
maxCount atomic.Uint32
interval time.Duration
count atomic.Uint32
ticker *time.Ticker
tokens chan struct{}
ctx context.Context
// internal
cancelFunc context.CancelFunc

// wraps uber's leaky bucket limiter sizing it to the desired tokens per duration
leakyBucketLimiter *rate.Limiter
}

func (limiter *Limiter) run(ctx context.Context) {
Expand Down Expand Up @@ -46,12 +53,22 @@ func (limiter *Limiter) run(ctx context.Context) {

// Take one token from the bucket
func (limiter *Limiter) Take() {
<-limiter.tokens
switch limiter.strategy {
case LeakyBucket:
limiter.leakyBucketLimiter.Wait(context.TODO())

Check failure on line 58 in ratelimit.go

View workflow job for this annotation

GitHub Actions / Lint Test

Error return value of `limiter.leakyBucketLimiter.Wait` is not checked (errcheck)
default:
<-limiter.tokens
}
}

// CanTake checks if the rate limiter has any token
func (limiter *Limiter) CanTake() bool {
return limiter.count.Load() > 0
switch limiter.strategy {
case LeakyBucket:
return limiter.leakyBucketLimiter.Tokens() > 0
default:
return limiter.count.Load() > 0
}
}

// GetLimit returns current rate limit per given duration
Expand All @@ -62,17 +79,32 @@ func (limiter *Limiter) GetLimit() uint {
// GetLimit returns current rate limit per given duration
func (limiter *Limiter) SetLimit(max uint) {
limiter.maxCount.Store(uint32(max))
switch limiter.strategy {
case LeakyBucket:
limiter.leakyBucketLimiter.SetBurst(int(max))
default:
}
}

// GetLimit returns current rate limit per given duration
func (limiter *Limiter) SetDuration(d time.Duration) {
limiter.ticker.Reset(d)
limiter.interval = d
switch limiter.strategy {
case LeakyBucket:
limiter.leakyBucketLimiter.SetLimit(rate.Every(d))
default:
limiter.ticker.Reset(d)
}
}

// Stop the rate limiter canceling the internal context
func (limiter *Limiter) Stop() {
if limiter.cancelFunc != nil {
limiter.cancelFunc()
switch limiter.strategy {
case LeakyBucket: // NOP
default:
if limiter.cancelFunc != nil {
limiter.cancelFunc()
}
}
}

Expand All @@ -87,6 +119,8 @@ func New(ctx context.Context, max uint, duration time.Duration) *Limiter {
tokens: make(chan struct{}),
ctx: ctx,
cancelFunc: cancel,
strategy: None,
interval: duration,
}
limiter.maxCount.Store(uint32(max))
limiter.count.Store(uint32(max))
Expand All @@ -110,3 +144,14 @@ func NewUnlimited(ctx context.Context) *Limiter {

return limiter
}

// NewUnlimited create a bucket with approximated unlimited tokens
func NewLeakyBucket(ctx context.Context, max uint, duration time.Duration) *Limiter {
limiter := &Limiter{
strategy: LeakyBucket,
leakyBucketLimiter: rate.NewLimiter(rate.Every(duration), int(max)),
}
limiter.maxCount.Store(uint32(max))
limiter.interval = duration
return limiter
}
13 changes: 13 additions & 0 deletions ratelimit_test.go
Expand Up @@ -95,4 +95,17 @@ func TestRateLimit(t *testing.T) {
limiter.Take()
require.False(t, limiter.CanTake())
})

t.Run("LeakyBucket", func(t *testing.T) {
limiter := NewLeakyBucket(context.TODO(), 1, time.Second)

start := time.Now()
limiter.Take() // 0
limiter.Take() // 1s
limiter.Take() // 2s
limiter.Take() // 3s
took := time.Since(start)
expected := 3 * time.Second
require.True(t, took >= expected)
})
}
8 changes: 8 additions & 0 deletions strategy.go
@@ -0,0 +1,8 @@
package ratelimit

type Strategy uint8

const (
None Strategy = iota
LeakyBucket
)

0 comments on commit a5ff5fc

Please sign in to comment.