Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds mulitlimiter + adaptive ratelimiter #18

Merged
merged 2 commits into from Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Expand Up @@ -8,6 +8,13 @@

A Golang rate limit implementation which allows burst of request during the defined duration.


### Differences with 'golang.org/x/time/rate#Limiter'

The original library i.e `golang.org/x/time/rate` implements classic **token bucket** algorithm allowing a burst of tokens and a refill that happens at a specified ratio by one unit at a time whereas this implementation is a variant that allows a burst of tokens just like "the token bucket" algorithm, but the refill happens entirely at the defined ratio.

This allows scanners to respect maximum defined rate limits, pause until the allowed interval hits, and then process again at maximum speed. The original library slowed down requests according to the refill ratio.

## Example

An Example showing usage of ratelimit as a library is specified below:
Expand Down
26 changes: 26 additions & 0 deletions adaptive_ratelimit_test.go
@@ -0,0 +1,26 @@
package ratelimit_test

import (
"context"
"testing"
"time"

"github.com/projectdiscovery/ratelimit"
"github.com/stretchr/testify/require"
)

func TestAdaptiveRateLimit(t *testing.T) {
limiter := ratelimit.NewUnlimited(context.Background())
start := time.Now()

for i := 0; i < 132; i++ {
limiter.Take()
// got 429 / hit ratelimit after 100
if i == 100 {
// Retry-After and new limiter (calibrate using different statergies)
// new expected ratelimit 30req every 5 sec
limiter.SleepandReset(time.Duration(5)*time.Second, 30, time.Duration(5)*time.Second)
}
}
require.Equal(t, time.Since(start).Round(time.Second), time.Duration(10)*time.Second)
}
1 change: 0 additions & 1 deletion example/main.go
Expand Up @@ -9,7 +9,6 @@ import (
)

func main() {

// create a rate limiter by passing context, max tasks/tokens , time interval
limiter := ratelimit.New(context.Background(), 5, time.Duration(10*time.Second))

Expand Down
100 changes: 100 additions & 0 deletions keyratelimit.go
@@ -0,0 +1,100 @@
package ratelimit

import (
"context"
"fmt"
"time"
)

// Options of MultiLimiter
type Options struct {
Key string // Unique Identifier
IsUnlimited bool
MaxCount uint
Duration time.Duration
}

// Validate given MultiLimiter Options
func (o *Options) Validate() error {
if !o.IsUnlimited {
if o.Key == "" {
return fmt.Errorf("empty keys not allowed")
}
if o.MaxCount == 0 {
return fmt.Errorf("maxcount cannot be zero")
}
if o.Duration == 0 {
return fmt.Errorf("time duration not set")
}
}
return nil
}

// MultiLimiter is wrapper around Limiter than can limit based on a key
type MultiLimiter struct {
limiters map[string]*Limiter
ctx context.Context
}

// Adds new bucket with key
func (m *MultiLimiter) Add(opts *Options) error {
if err := opts.Validate(); err != nil {
return err
}
_, ok := m.limiters[opts.Key]
if ok {
return fmt.Errorf("key already exists")
}
var rlimiter *Limiter
if opts.IsUnlimited {
rlimiter = NewUnlimited(m.ctx)
} else {
rlimiter = New(m.ctx, opts.MaxCount, opts.Duration)
}
m.limiters[opts.Key] = rlimiter
return nil
}

// GetLimit returns current ratelimit of given key
func (m *MultiLimiter) GetLimit(key string) (uint, error) {
limiter, ok := m.limiters[key]
if !ok || limiter == nil {
return 0, fmt.Errorf("key doesnot exist")
}
return limiter.GetLimit(), nil
}

// Take one token from bucket returns error if key not present
func (m *MultiLimiter) Take(key string) error {
limiter, ok := m.limiters[key]
if !ok || limiter == nil {
return fmt.Errorf("key doesnot exist")
}
limiter.Take()
return nil
}

// SleepandReset stops timer removes all tokens and resets with new limit (used for Adaptive Ratelimiting)
func (m *MultiLimiter) SleepandReset(SleepTime time.Duration, opts *Options) error {
if err := opts.Validate(); err != nil {
return err
}
limiter, ok := m.limiters[opts.Key]
if !ok || limiter == nil {
return fmt.Errorf("key doesnot exist")
}
limiter.SleepandReset(SleepTime, opts.MaxCount, opts.Duration)
return nil
}

// NewMultiLimiter : Limits
func NewMultiLimiter(ctx context.Context, opts *Options) (*MultiLimiter, error) {
if err := opts.Validate(); err != nil {
return nil, err
}
multilimiter := &MultiLimiter{
ctx: ctx,
limiters: map[string]*Limiter{},
}
return multilimiter, multilimiter.Add(opts)
}
53 changes: 53 additions & 0 deletions keyratelimit_test.go
@@ -0,0 +1,53 @@
package ratelimit_test

import (
"context"
"sync"
"testing"
"time"

"github.com/projectdiscovery/ratelimit"
"github.com/stretchr/testify/require"
)

func TestMultiLimiter(t *testing.T) {
limiter, err := ratelimit.NewMultiLimiter(context.Background(), &ratelimit.Options{
Key: "default",
IsUnlimited: false,
MaxCount: 100,
Duration: time.Duration(3) * time.Second,
})
require.Nil(t, err)
wg := &sync.WaitGroup{}

wg.Add(1)
go func() {
defer wg.Done()
defaultStart := time.Now()
for i := 0; i < 201; i++ {
errx := limiter.Take("default")
require.Nil(t, errx, "failed to take")
}
require.Greater(t, time.Since(defaultStart), time.Duration(6)*time.Second)
}()

err = limiter.Add(&ratelimit.Options{
Key: "one",
IsUnlimited: false,
MaxCount: 100,
Duration: time.Duration(3) * time.Second,
})
require.Nil(t, err)

wg.Add(1)
go func() {
defer wg.Done()
oneStart := time.Now()
for i := 0; i < 201; i++ {
errx := limiter.Take("one")
require.Nil(t, errx)
}
require.Greater(t, time.Since(oneStart), time.Duration(6)*time.Second)
}()
wg.Wait()
}
63 changes: 49 additions & 14 deletions ratelimit.go
Expand Up @@ -13,16 +13,21 @@ type Limiter struct {
ticker *time.Ticker
tokens chan struct{}
ctx context.Context
// internal
cancelFunc context.CancelFunc
}

func (limiter *Limiter) run() {
func (limiter *Limiter) run(ctx context.Context) {
for {
if limiter.count == 0 {
<-limiter.ticker.C
limiter.count = limiter.maxCount
}

select {
case <-ctx.Done():
// Internal Context
limiter.ticker.Stop()
return
case <-limiter.ctx.Done():
limiter.ticker.Stop()
return
Expand All @@ -39,30 +44,60 @@ func (rateLimiter *Limiter) Take() {
<-rateLimiter.tokens
}

// GetLimit returns current rate limit per given duration
func (ratelimiter *Limiter) GetLimit() uint {
return ratelimiter.maxCount
}

// SleepandReset stops timer removes all tokens and resets with new limit (used for Adaptive Ratelimiting)
func (ratelimiter *Limiter) SleepandReset(sleepTime time.Duration, newLimit uint, duration time.Duration) {
// stop existing Limiter using internalContext
ratelimiter.cancelFunc()
// drain any token
close(ratelimiter.tokens)
<-ratelimiter.tokens
// sleep
time.Sleep(sleepTime)
//reset and start
ratelimiter.maxCount = newLimit
ratelimiter.count = newLimit
ratelimiter.ticker = time.NewTicker(duration)
ratelimiter.tokens = make(chan struct{})
ctx, cancel := context.WithCancel(context.TODO())
ratelimiter.cancelFunc = cancel
go ratelimiter.run(ctx)
}

// New creates a new limiter instance with the tokens amount and the interval
func New(ctx context.Context, max uint, duration time.Duration) *Limiter {
internalctx, cancel := context.WithCancel(context.TODO())

limiter := &Limiter{
maxCount: uint(max),
count: uint(max),
ticker: time.NewTicker(duration),
tokens: make(chan struct{}),
ctx: ctx,
maxCount: uint(max),
count: uint(max),
ticker: time.NewTicker(duration),
tokens: make(chan struct{}),
ctx: ctx,
cancelFunc: cancel,
}
go limiter.run()
go limiter.run(internalctx)

return limiter
}

// NewUnlimited create a bucket with approximated unlimited tokens
func NewUnlimited(ctx context.Context) *Limiter {
internalctx, cancel := context.WithCancel(context.TODO())

limiter := &Limiter{
maxCount: math.MaxUint,
count: math.MaxUint,
ticker: time.NewTicker(time.Millisecond),
tokens: make(chan struct{}),
ctx: ctx,
maxCount: math.MaxUint,
count: math.MaxUint,
ticker: time.NewTicker(time.Millisecond),
tokens: make(chan struct{}),
ctx: ctx,
cancelFunc: cancel,
}
go limiter.run()
go limiter.run(internalctx)

return limiter
}