diff --git a/README.md b/README.md index 2dadcb0..7336fe8 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,13 @@ A Golang rate limit implementation which allows burst of request during the defined duration. + +### Differences with 'golang.org/x/time/rate#Limiter' + +The original library i.e `golang.org/x/time/rate` implements classic **token bucket** algorithm allowing a burst of tokens and a refill that happens at a specified ratio by one unit at a time whereas this implementation is a variant that allows a burst of tokens just like "the token bucket" algorithm, but the refill happens entirely at the defined ratio. + +This allows scanners to respect maximum defined rate limits, pause until the allowed interval hits, and then process again at maximum speed. The original library slowed down requests according to the refill ratio. + ## Example An Example showing usage of ratelimit as a library is specified below: diff --git a/adaptive_ratelimit_test.go b/adaptive_ratelimit_test.go new file mode 100644 index 0000000..0c8bb6d --- /dev/null +++ b/adaptive_ratelimit_test.go @@ -0,0 +1,26 @@ +package ratelimit_test + +import ( + "context" + "testing" + "time" + + "github.com/projectdiscovery/ratelimit" + "github.com/stretchr/testify/require" +) + +func TestAdaptiveRateLimit(t *testing.T) { + limiter := ratelimit.NewUnlimited(context.Background()) + start := time.Now() + + for i := 0; i < 132; i++ { + limiter.Take() + // got 429 / hit ratelimit after 100 + if i == 100 { + // Retry-After and new limiter (calibrate using different statergies) + // new expected ratelimit 30req every 5 sec + limiter.SleepandReset(time.Duration(5)*time.Second, 30, time.Duration(5)*time.Second) + } + } + require.Equal(t, time.Since(start).Round(time.Second), time.Duration(10)*time.Second) +} diff --git a/example/main.go b/example/main.go index 94f739c..92a3520 100644 --- a/example/main.go +++ b/example/main.go @@ -9,7 +9,6 @@ import ( ) func main() { - // create a rate limiter by passing context, max tasks/tokens , time interval limiter := ratelimit.New(context.Background(), 5, time.Duration(10*time.Second)) diff --git a/keyratelimit.go b/keyratelimit.go new file mode 100644 index 0000000..b4af9bf --- /dev/null +++ b/keyratelimit.go @@ -0,0 +1,100 @@ +package ratelimit + +import ( + "context" + "fmt" + "time" +) + +// Options of MultiLimiter +type Options struct { + Key string // Unique Identifier + IsUnlimited bool + MaxCount uint + Duration time.Duration +} + +// Validate given MultiLimiter Options +func (o *Options) Validate() error { + if !o.IsUnlimited { + if o.Key == "" { + return fmt.Errorf("empty keys not allowed") + } + if o.MaxCount == 0 { + return fmt.Errorf("maxcount cannot be zero") + } + if o.Duration == 0 { + return fmt.Errorf("time duration not set") + } + } + return nil +} + +// MultiLimiter is wrapper around Limiter than can limit based on a key +type MultiLimiter struct { + limiters map[string]*Limiter + ctx context.Context +} + +// Adds new bucket with key +func (m *MultiLimiter) Add(opts *Options) error { + if err := opts.Validate(); err != nil { + return err + } + _, ok := m.limiters[opts.Key] + if ok { + return fmt.Errorf("key already exists") + } + var rlimiter *Limiter + if opts.IsUnlimited { + rlimiter = NewUnlimited(m.ctx) + } else { + rlimiter = New(m.ctx, opts.MaxCount, opts.Duration) + } + m.limiters[opts.Key] = rlimiter + return nil +} + +// GetLimit returns current ratelimit of given key +func (m *MultiLimiter) GetLimit(key string) (uint, error) { + limiter, ok := m.limiters[key] + if !ok || limiter == nil { + return 0, fmt.Errorf("key doesnot exist") + } + return limiter.GetLimit(), nil +} + +// Take one token from bucket returns error if key not present +func (m *MultiLimiter) Take(key string) error { + limiter, ok := m.limiters[key] + if !ok || limiter == nil { + return fmt.Errorf("key doesnot exist") + } + limiter.Take() + return nil +} + +// SleepandReset stops timer removes all tokens and resets with new limit (used for Adaptive Ratelimiting) +func (m *MultiLimiter) SleepandReset(SleepTime time.Duration, opts *Options) error { + if err := opts.Validate(); err != nil { + return err + } + limiter, ok := m.limiters[opts.Key] + if !ok || limiter == nil { + return fmt.Errorf("key doesnot exist") + } + limiter.SleepandReset(SleepTime, opts.MaxCount, opts.Duration) + return nil +} + +// NewMultiLimiter : Limits +func NewMultiLimiter(ctx context.Context, opts *Options) (*MultiLimiter, error) { + if err := opts.Validate(); err != nil { + return nil, err + } + multilimiter := &MultiLimiter{ + ctx: ctx, + limiters: map[string]*Limiter{}, + } + return multilimiter, multilimiter.Add(opts) +} diff --git a/keyratelimit_test.go b/keyratelimit_test.go new file mode 100644 index 0000000..47d4a44 --- /dev/null +++ b/keyratelimit_test.go @@ -0,0 +1,53 @@ +package ratelimit_test + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/projectdiscovery/ratelimit" + "github.com/stretchr/testify/require" +) + +func TestMultiLimiter(t *testing.T) { + limiter, err := ratelimit.NewMultiLimiter(context.Background(), &ratelimit.Options{ + Key: "default", + IsUnlimited: false, + MaxCount: 100, + Duration: time.Duration(3) * time.Second, + }) + require.Nil(t, err) + wg := &sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + defaultStart := time.Now() + for i := 0; i < 201; i++ { + errx := limiter.Take("default") + require.Nil(t, errx, "failed to take") + } + require.Greater(t, time.Since(defaultStart), time.Duration(6)*time.Second) + }() + + err = limiter.Add(&ratelimit.Options{ + Key: "one", + IsUnlimited: false, + MaxCount: 100, + Duration: time.Duration(3) * time.Second, + }) + require.Nil(t, err) + + wg.Add(1) + go func() { + defer wg.Done() + oneStart := time.Now() + for i := 0; i < 201; i++ { + errx := limiter.Take("one") + require.Nil(t, errx) + } + require.Greater(t, time.Since(oneStart), time.Duration(6)*time.Second) + }() + wg.Wait() +} diff --git a/ratelimit.go b/ratelimit.go index 5a67686..f292fd2 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -13,16 +13,21 @@ type Limiter struct { ticker *time.Ticker tokens chan struct{} ctx context.Context + // internal + cancelFunc context.CancelFunc } -func (limiter *Limiter) run() { +func (limiter *Limiter) run(ctx context.Context) { for { if limiter.count == 0 { <-limiter.ticker.C limiter.count = limiter.maxCount } - select { + case <-ctx.Done(): + // Internal Context + limiter.ticker.Stop() + return case <-limiter.ctx.Done(): limiter.ticker.Stop() return @@ -39,30 +44,60 @@ func (rateLimiter *Limiter) Take() { <-rateLimiter.tokens } +// GetLimit returns current rate limit per given duration +func (ratelimiter *Limiter) GetLimit() uint { + return ratelimiter.maxCount +} + +// SleepandReset stops timer removes all tokens and resets with new limit (used for Adaptive Ratelimiting) +func (ratelimiter *Limiter) SleepandReset(sleepTime time.Duration, newLimit uint, duration time.Duration) { + // stop existing Limiter using internalContext + ratelimiter.cancelFunc() + // drain any token + close(ratelimiter.tokens) + <-ratelimiter.tokens + // sleep + time.Sleep(sleepTime) + //reset and start + ratelimiter.maxCount = newLimit + ratelimiter.count = newLimit + ratelimiter.ticker = time.NewTicker(duration) + ratelimiter.tokens = make(chan struct{}) + ctx, cancel := context.WithCancel(context.TODO()) + ratelimiter.cancelFunc = cancel + go ratelimiter.run(ctx) +} + // New creates a new limiter instance with the tokens amount and the interval func New(ctx context.Context, max uint, duration time.Duration) *Limiter { + internalctx, cancel := context.WithCancel(context.TODO()) + limiter := &Limiter{ - maxCount: uint(max), - count: uint(max), - ticker: time.NewTicker(duration), - tokens: make(chan struct{}), - ctx: ctx, + maxCount: uint(max), + count: uint(max), + ticker: time.NewTicker(duration), + tokens: make(chan struct{}), + ctx: ctx, + cancelFunc: cancel, } - go limiter.run() + go limiter.run(internalctx) return limiter } // NewUnlimited create a bucket with approximated unlimited tokens func NewUnlimited(ctx context.Context) *Limiter { + internalctx, cancel := context.WithCancel(context.TODO()) + limiter := &Limiter{ - maxCount: math.MaxUint, - count: math.MaxUint, - ticker: time.NewTicker(time.Millisecond), - tokens: make(chan struct{}), - ctx: ctx, + maxCount: math.MaxUint, + count: math.MaxUint, + ticker: time.NewTicker(time.Millisecond), + tokens: make(chan struct{}), + ctx: ctx, + cancelFunc: cancel, } - go limiter.run() + go limiter.run(internalctx) return limiter }