Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding strategy support #89

Merged
merged 2 commits into from May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Expand Up @@ -5,6 +5,7 @@ go 1.21
require (
github.com/projectdiscovery/utils v0.0.89
github.com/stretchr/testify v1.9.0
golang.org/x/time v0.5.0
)

require (
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Expand Up @@ -6,6 +6,8 @@ github.com/projectdiscovery/utils v0.0.89 h1:ruH2bSkpX/rB7EPp2EV/rWyAubQVxCVU38n
github.com/projectdiscovery/utils v0.0.89/go.mod h1:Dwh5cxn7y97jvyYG3GmBvj0negfH9IjH15qXnzFNtOI=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
55 changes: 50 additions & 5 deletions ratelimit.go
Expand Up @@ -5,20 +5,27 @@ import (
"math"
"sync/atomic"
"time"

"golang.org/x/time/rate"
)

// equals to -1
var minusOne = ^uint32(0)

// Limiter allows a burst of request during the defined duration
type Limiter struct {
strategy Strategy
maxCount atomic.Uint32
interval time.Duration
count atomic.Uint32
ticker *time.Ticker
tokens chan struct{}
ctx context.Context
// internal
cancelFunc context.CancelFunc

// wraps uber's leaky bucket limiter sizing it to the desired tokens per duration
leakyBucketLimiter *rate.Limiter
}

func (limiter *Limiter) run(ctx context.Context) {
Expand Down Expand Up @@ -46,12 +53,22 @@ func (limiter *Limiter) run(ctx context.Context) {

// Take one token from the bucket
func (limiter *Limiter) Take() {
<-limiter.tokens
switch limiter.strategy {
case LeakyBucket:
_ = limiter.leakyBucketLimiter.Wait(context.TODO())
default:
<-limiter.tokens
}
}

// CanTake checks if the rate limiter has any token
func (limiter *Limiter) CanTake() bool {
return limiter.count.Load() > 0
switch limiter.strategy {
case LeakyBucket:
return limiter.leakyBucketLimiter.Tokens() > 0
default:
return limiter.count.Load() > 0
}
}

// GetLimit returns current rate limit per given duration
Expand All @@ -62,17 +79,32 @@ func (limiter *Limiter) GetLimit() uint {
// GetLimit returns current rate limit per given duration
func (limiter *Limiter) SetLimit(max uint) {
limiter.maxCount.Store(uint32(max))
switch limiter.strategy {
case LeakyBucket:
limiter.leakyBucketLimiter.SetBurst(int(max))
default:
}
}

// GetLimit returns current rate limit per given duration
func (limiter *Limiter) SetDuration(d time.Duration) {
limiter.ticker.Reset(d)
limiter.interval = d
switch limiter.strategy {
case LeakyBucket:
limiter.leakyBucketLimiter.SetLimit(rate.Every(d))
default:
limiter.ticker.Reset(d)
}
}

// Stop the rate limiter canceling the internal context
func (limiter *Limiter) Stop() {
if limiter.cancelFunc != nil {
limiter.cancelFunc()
switch limiter.strategy {
case LeakyBucket: // NOP
default:
if limiter.cancelFunc != nil {
limiter.cancelFunc()
}
}
}

Expand All @@ -87,6 +119,8 @@ func New(ctx context.Context, max uint, duration time.Duration) *Limiter {
tokens: make(chan struct{}),
ctx: ctx,
cancelFunc: cancel,
strategy: None,
interval: duration,
}
limiter.maxCount.Store(uint32(max))
limiter.count.Store(uint32(max))
Expand All @@ -110,3 +144,14 @@ func NewUnlimited(ctx context.Context) *Limiter {

return limiter
}

// NewUnlimited create a bucket with approximated unlimited tokens
func NewLeakyBucket(ctx context.Context, max uint, duration time.Duration) *Limiter {
limiter := &Limiter{
strategy: LeakyBucket,
leakyBucketLimiter: rate.NewLimiter(rate.Every(duration), int(max)),
}
limiter.maxCount.Store(uint32(max))
limiter.interval = duration
return limiter
}
13 changes: 13 additions & 0 deletions ratelimit_test.go
Expand Up @@ -95,4 +95,17 @@ func TestRateLimit(t *testing.T) {
limiter.Take()
require.False(t, limiter.CanTake())
})

t.Run("LeakyBucket", func(t *testing.T) {
limiter := NewLeakyBucket(context.TODO(), 1, time.Second)

start := time.Now()
limiter.Take() // 0
limiter.Take() // 1s
limiter.Take() // 2s
limiter.Take() // 3s
took := time.Since(start)
expected := 3 * time.Second
require.True(t, took >= expected)
})
}
8 changes: 8 additions & 0 deletions strategy.go
@@ -0,0 +1,8 @@
package ratelimit

type Strategy uint8

const (
None Strategy = iota
LeakyBucket
)