From c2e8e2880a0239f4c9c73d8341c42644887e4f43 Mon Sep 17 00:00:00 2001 From: Tarun Koyalwar Date: Tue, 13 Dec 2022 22:04:08 +0530 Subject: [PATCH 1/2] multilimiter genesis --- example/main.go | 13 +++++++ keyratelimit.go | 89 ++++++++++++++++++++++++++++++++++++++++++++ keyratelimit_test.go | 30 +++++++++++++++ 3 files changed, 132 insertions(+) create mode 100644 keyratelimit.go create mode 100644 keyratelimit_test.go diff --git a/example/main.go b/example/main.go index 94f739c..29aeab2 100644 --- a/example/main.go +++ b/example/main.go @@ -10,6 +10,19 @@ import ( func main() { + fmt.Printf("[+] Complete Tasks Using MulitLimiter with unique key\n") + + multiLimiter := ratelimit.NewMultiLimiter(context.Background()) + multiLimiter.Add("default", 10) + save1 := time.Now() + + for i := 0; i < 11; i++ { + multiLimiter.Take("default") + fmt.Printf("MulitKey Task %v completed after %v\n", i, time.Since(save1)) + } + + fmt.Printf("\n[+] Complete Tasks Using Limiter\n") + // create a rate limiter by passing context, max tasks/tokens , time interval limiter := ratelimit.New(context.Background(), 5, time.Duration(10*time.Second)) diff --git a/keyratelimit.go b/keyratelimit.go new file mode 100644 index 0000000..08974ee --- /dev/null +++ b/keyratelimit.go @@ -0,0 +1,89 @@ +package ratelimit + +import ( + "context" + "fmt" + "sync" + "time" +) + +/* +Note: +This is somewhat modified version of TokenBucket +Here we consider buffer channel as a bucket +*/ + +// MultiLimiter allows burst of request during defined duration for each key +type MultiLimiter struct { + ticker *time.Ticker + tokens sync.Map // map of buffered channels map[string](chan struct{}) + ctx context.Context +} + +func (m *MultiLimiter) run() { + for { + select { + case <-m.ctx.Done(): + m.ticker.Stop() + return + + case <-m.ticker.C: + // Iterate and fill buffers to their capacity on every tick + m.tokens.Range(func(key, value any) bool { + tokenChan := value.(chan struct{}) + if len(tokenChan) == cap(tokenChan) { + // no need to fill buffer/bucket + return true + } else { + for i := 0; i < cap(tokenChan)-len(tokenChan); i++ { + // fill bucket/buffer with tokens + tokenChan <- struct{}{} + } + } + // if it returns false range is stopped + return true + }) + } + } +} + +// Adds new bucket with key and given tokenrate returns error if it already exists1 +func (m *MultiLimiter) Add(key string, tokensPerMinute uint) error { + _, ok := m.tokens.Load(key) + if ok { + return fmt.Errorf("key already exists") + } + // create a buffered channel of size `tokenPerMinute` + tokenChan := make(chan struct{}, tokensPerMinute) + for i := 0; i < int(tokensPerMinute); i++ { + // fill bucket/buffer with tokens + tokenChan <- struct{}{} + } + m.tokens.Store(key, tokenChan) + return nil +} + +// Take one token from bucket / buffer returns error if key not present +func (m *MultiLimiter) Take(key string) error { + tokenValue, ok := m.tokens.Load(key) + if !ok { + return fmt.Errorf("key doesnot exist") + } + tokenChan := tokenValue.(chan struct{}) + <-tokenChan + + return nil +} + +// NewMultiLimiter : Limits +func NewMultiLimiter(ctx context.Context) *MultiLimiter { + multilimiter := &MultiLimiter{ + ticker: time.NewTicker(time.Minute), // different implementation than ratelimit + ctx: ctx, + tokens: sync.Map{}, + } + + go multilimiter.run() + + return multilimiter +} diff --git a/keyratelimit_test.go b/keyratelimit_test.go new file mode 100644 index 0000000..0c503b0 --- /dev/null +++ b/keyratelimit_test.go @@ -0,0 +1,30 @@ +package ratelimit_test + +import ( + "context" + "testing" + "time" + + "github.com/projectdiscovery/ratelimit" + "github.com/stretchr/testify/require" +) + +func TestMultiLimiter(t *testing.T) { + limiter := ratelimit.NewMultiLimiter(context.Background()) + + // 20 tokens every 1 minute + err := limiter.Add("default", 20) + require.Nil(t, err, "failed to add new key") + + before := time.Now() + // take 21 tokens + for i := 0; i < 21; i++ { + err2 := limiter.Take("default") + require.Nil(t, err2, "failed to take") + } + actual := time.Since(before) + expected := time.Duration(time.Minute) + + require.Greater(t, actual, expected) + +} From cb63b7589961c7e1528ed802f99d3d1871a838f0 Mon Sep 17 00:00:00 2001 From: Tarun Koyalwar Date: Thu, 15 Dec 2022 02:07:28 +0530 Subject: [PATCH 2/2] multilimiter,adaptive ratelimit --- README.md | 7 +++ adaptive_ratelimit_test.go | 26 ++++++++ example/main.go | 14 ----- keyratelimit.go | 125 ++++++++++++++++++++----------------- keyratelimit_test.go | 51 ++++++++++----- ratelimit.go | 63 ++++++++++++++----- 6 files changed, 187 insertions(+), 99 deletions(-) create mode 100644 adaptive_ratelimit_test.go diff --git a/README.md b/README.md index 2dadcb0..7336fe8 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,13 @@ A Golang rate limit implementation which allows burst of request during the defined duration. + +### Differences with 'golang.org/x/time/rate#Limiter' + +The original library i.e `golang.org/x/time/rate` implements classic **token bucket** algorithm allowing a burst of tokens and a refill that happens at a specified ratio by one unit at a time whereas this implementation is a variant that allows a burst of tokens just like "the token bucket" algorithm, but the refill happens entirely at the defined ratio. + +This allows scanners to respect maximum defined rate limits, pause until the allowed interval hits, and then process again at maximum speed. The original library slowed down requests according to the refill ratio. + ## Example An Example showing usage of ratelimit as a library is specified below: diff --git a/adaptive_ratelimit_test.go b/adaptive_ratelimit_test.go new file mode 100644 index 0000000..0c8bb6d --- /dev/null +++ b/adaptive_ratelimit_test.go @@ -0,0 +1,26 @@ +package ratelimit_test + +import ( + "context" + "testing" + "time" + + "github.com/projectdiscovery/ratelimit" + "github.com/stretchr/testify/require" +) + +func TestAdaptiveRateLimit(t *testing.T) { + limiter := ratelimit.NewUnlimited(context.Background()) + start := time.Now() + + for i := 0; i < 132; i++ { + limiter.Take() + // got 429 / hit ratelimit after 100 + if i == 100 { + // Retry-After and new limiter (calibrate using different statergies) + // new expected ratelimit 30req every 5 sec + limiter.SleepandReset(time.Duration(5)*time.Second, 30, time.Duration(5)*time.Second) + } + } + require.Equal(t, time.Since(start).Round(time.Second), time.Duration(10)*time.Second) +} diff --git a/example/main.go b/example/main.go index 29aeab2..92a3520 100644 --- a/example/main.go +++ b/example/main.go @@ -9,20 +9,6 @@ import ( ) func main() { - - fmt.Printf("[+] Complete Tasks Using MulitLimiter with unique key\n") - - multiLimiter := ratelimit.NewMultiLimiter(context.Background()) - multiLimiter.Add("default", 10) - save1 := time.Now() - - for i := 0; i < 11; i++ { - multiLimiter.Take("default") - fmt.Printf("MulitKey Task %v completed after %v\n", i, time.Since(save1)) - } - - fmt.Printf("\n[+] Complete Tasks Using Limiter\n") - // create a rate limiter by passing context, max tasks/tokens , time interval limiter := ratelimit.New(context.Background(), 5, time.Duration(10*time.Second)) diff --git a/keyratelimit.go b/keyratelimit.go index 08974ee..b4af9bf 100644 --- a/keyratelimit.go +++ b/keyratelimit.go @@ -3,87 +3,98 @@ package ratelimit import ( "context" "fmt" - "sync" "time" ) -/* -Note: -This is somewhat modified version of TokenBucket -Here we consider buffer channel as a bucket -*/ - -// MultiLimiter allows burst of request during defined duration for each key -type MultiLimiter struct { - ticker *time.Ticker - tokens sync.Map // map of buffered channels map[string](chan struct{}) - ctx context.Context +// Options of MultiLimiter +type Options struct { + Key string // Unique Identifier + IsUnlimited bool + MaxCount uint + Duration time.Duration } -func (m *MultiLimiter) run() { - for { - select { - case <-m.ctx.Done(): - m.ticker.Stop() - return - - case <-m.ticker.C: - // Iterate and fill buffers to their capacity on every tick - m.tokens.Range(func(key, value any) bool { - tokenChan := value.(chan struct{}) - if len(tokenChan) == cap(tokenChan) { - // no need to fill buffer/bucket - return true - } else { - for i := 0; i < cap(tokenChan)-len(tokenChan); i++ { - // fill bucket/buffer with tokens - tokenChan <- struct{}{} - } - } - // if it returns false range is stopped - return true - }) +// Validate given MultiLimiter Options +func (o *Options) Validate() error { + if !o.IsUnlimited { + if o.Key == "" { + return fmt.Errorf("empty keys not allowed") + } + if o.MaxCount == 0 { + return fmt.Errorf("maxcount cannot be zero") + } + if o.Duration == 0 { + return fmt.Errorf("time duration not set") } } + return nil } -// Adds new bucket with key and given tokenrate returns error if it already exists1 -func (m *MultiLimiter) Add(key string, tokensPerMinute uint) error { - _, ok := m.tokens.Load(key) +// MultiLimiter is wrapper around Limiter than can limit based on a key +type MultiLimiter struct { + limiters map[string]*Limiter + ctx context.Context +} + +// Adds new bucket with key +func (m *MultiLimiter) Add(opts *Options) error { + if err := opts.Validate(); err != nil { + return err + } + _, ok := m.limiters[opts.Key] if ok { return fmt.Errorf("key already exists") } - // create a buffered channel of size `tokenPerMinute` - tokenChan := make(chan struct{}, tokensPerMinute) - for i := 0; i < int(tokensPerMinute); i++ { - // fill bucket/buffer with tokens - tokenChan <- struct{}{} + var rlimiter *Limiter + if opts.IsUnlimited { + rlimiter = NewUnlimited(m.ctx) + } else { + rlimiter = New(m.ctx, opts.MaxCount, opts.Duration) } - m.tokens.Store(key, tokenChan) + m.limiters[opts.Key] = rlimiter return nil } -// Take one token from bucket / buffer returns error if key not present +// GetLimit returns current ratelimit of given key +func (m *MultiLimiter) GetLimit(key string) (uint, error) { + limiter, ok := m.limiters[key] + if !ok || limiter == nil { + return 0, fmt.Errorf("key doesnot exist") + } + return limiter.GetLimit(), nil +} + +// Take one token from bucket returns error if key not present func (m *MultiLimiter) Take(key string) error { - tokenValue, ok := m.tokens.Load(key) - if !ok { + limiter, ok := m.limiters[key] + if !ok || limiter == nil { return fmt.Errorf("key doesnot exist") } - tokenChan := tokenValue.(chan struct{}) - <-tokenChan + limiter.Take() + return nil +} +// SleepandReset stops timer removes all tokens and resets with new limit (used for Adaptive Ratelimiting) +func (m *MultiLimiter) SleepandReset(SleepTime time.Duration, opts *Options) error { + if err := opts.Validate(); err != nil { + return err + } + limiter, ok := m.limiters[opts.Key] + if !ok || limiter == nil { + return fmt.Errorf("key doesnot exist") + } + limiter.SleepandReset(SleepTime, opts.MaxCount, opts.Duration) return nil } // NewMultiLimiter : Limits -func NewMultiLimiter(ctx context.Context) *MultiLimiter { +func NewMultiLimiter(ctx context.Context, opts *Options) (*MultiLimiter, error) { + if err := opts.Validate(); err != nil { + return nil, err + } multilimiter := &MultiLimiter{ - ticker: time.NewTicker(time.Minute), // different implementation than ratelimit - ctx: ctx, - tokens: sync.Map{}, + ctx: ctx, + limiters: map[string]*Limiter{}, } - - go multilimiter.run() - - return multilimiter + return multilimiter, multilimiter.Add(opts) } diff --git a/keyratelimit_test.go b/keyratelimit_test.go index 0c503b0..47d4a44 100644 --- a/keyratelimit_test.go +++ b/keyratelimit_test.go @@ -2,6 +2,7 @@ package ratelimit_test import ( "context" + "sync" "testing" "time" @@ -10,21 +11,43 @@ import ( ) func TestMultiLimiter(t *testing.T) { - limiter := ratelimit.NewMultiLimiter(context.Background()) + limiter, err := ratelimit.NewMultiLimiter(context.Background(), &ratelimit.Options{ + Key: "default", + IsUnlimited: false, + MaxCount: 100, + Duration: time.Duration(3) * time.Second, + }) + require.Nil(t, err) + wg := &sync.WaitGroup{} - // 20 tokens every 1 minute - err := limiter.Add("default", 20) - require.Nil(t, err, "failed to add new key") + wg.Add(1) + go func() { + defer wg.Done() + defaultStart := time.Now() + for i := 0; i < 201; i++ { + errx := limiter.Take("default") + require.Nil(t, errx, "failed to take") + } + require.Greater(t, time.Since(defaultStart), time.Duration(6)*time.Second) + }() - before := time.Now() - // take 21 tokens - for i := 0; i < 21; i++ { - err2 := limiter.Take("default") - require.Nil(t, err2, "failed to take") - } - actual := time.Since(before) - expected := time.Duration(time.Minute) - - require.Greater(t, actual, expected) + err = limiter.Add(&ratelimit.Options{ + Key: "one", + IsUnlimited: false, + MaxCount: 100, + Duration: time.Duration(3) * time.Second, + }) + require.Nil(t, err) + wg.Add(1) + go func() { + defer wg.Done() + oneStart := time.Now() + for i := 0; i < 201; i++ { + errx := limiter.Take("one") + require.Nil(t, errx) + } + require.Greater(t, time.Since(oneStart), time.Duration(6)*time.Second) + }() + wg.Wait() } diff --git a/ratelimit.go b/ratelimit.go index 5a67686..f292fd2 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -13,16 +13,21 @@ type Limiter struct { ticker *time.Ticker tokens chan struct{} ctx context.Context + // internal + cancelFunc context.CancelFunc } -func (limiter *Limiter) run() { +func (limiter *Limiter) run(ctx context.Context) { for { if limiter.count == 0 { <-limiter.ticker.C limiter.count = limiter.maxCount } - select { + case <-ctx.Done(): + // Internal Context + limiter.ticker.Stop() + return case <-limiter.ctx.Done(): limiter.ticker.Stop() return @@ -39,30 +44,60 @@ func (rateLimiter *Limiter) Take() { <-rateLimiter.tokens } +// GetLimit returns current rate limit per given duration +func (ratelimiter *Limiter) GetLimit() uint { + return ratelimiter.maxCount +} + +// SleepandReset stops timer removes all tokens and resets with new limit (used for Adaptive Ratelimiting) +func (ratelimiter *Limiter) SleepandReset(sleepTime time.Duration, newLimit uint, duration time.Duration) { + // stop existing Limiter using internalContext + ratelimiter.cancelFunc() + // drain any token + close(ratelimiter.tokens) + <-ratelimiter.tokens + // sleep + time.Sleep(sleepTime) + //reset and start + ratelimiter.maxCount = newLimit + ratelimiter.count = newLimit + ratelimiter.ticker = time.NewTicker(duration) + ratelimiter.tokens = make(chan struct{}) + ctx, cancel := context.WithCancel(context.TODO()) + ratelimiter.cancelFunc = cancel + go ratelimiter.run(ctx) +} + // New creates a new limiter instance with the tokens amount and the interval func New(ctx context.Context, max uint, duration time.Duration) *Limiter { + internalctx, cancel := context.WithCancel(context.TODO()) + limiter := &Limiter{ - maxCount: uint(max), - count: uint(max), - ticker: time.NewTicker(duration), - tokens: make(chan struct{}), - ctx: ctx, + maxCount: uint(max), + count: uint(max), + ticker: time.NewTicker(duration), + tokens: make(chan struct{}), + ctx: ctx, + cancelFunc: cancel, } - go limiter.run() + go limiter.run(internalctx) return limiter } // NewUnlimited create a bucket with approximated unlimited tokens func NewUnlimited(ctx context.Context) *Limiter { + internalctx, cancel := context.WithCancel(context.TODO()) + limiter := &Limiter{ - maxCount: math.MaxUint, - count: math.MaxUint, - ticker: time.NewTicker(time.Millisecond), - tokens: make(chan struct{}), - ctx: ctx, + maxCount: math.MaxUint, + count: math.MaxUint, + ticker: time.NewTicker(time.Millisecond), + tokens: make(chan struct{}), + ctx: ctx, + cancelFunc: cancel, } - go limiter.run() + go limiter.run(internalctx) return limiter }