Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from projectdiscovery/issue-17-key-ratelimiter
adds mulitlimiter + adaptive ratelimiter
- Loading branch information
Showing
6 changed files
with
235 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
package ratelimit_test | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
"time" | ||
|
||
"github.com/projectdiscovery/ratelimit" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestAdaptiveRateLimit(t *testing.T) { | ||
limiter := ratelimit.NewUnlimited(context.Background()) | ||
start := time.Now() | ||
|
||
for i := 0; i < 132; i++ { | ||
limiter.Take() | ||
// got 429 / hit ratelimit after 100 | ||
if i == 100 { | ||
// Retry-After and new limiter (calibrate using different statergies) | ||
// new expected ratelimit 30req every 5 sec | ||
limiter.SleepandReset(time.Duration(5)*time.Second, 30, time.Duration(5)*time.Second) | ||
} | ||
} | ||
require.Equal(t, time.Since(start).Round(time.Second), time.Duration(10)*time.Second) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
package ratelimit | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"time" | ||
) | ||
|
||
// Options of MultiLimiter | ||
type Options struct { | ||
Key string // Unique Identifier | ||
IsUnlimited bool | ||
MaxCount uint | ||
Duration time.Duration | ||
} | ||
|
||
// Validate given MultiLimiter Options | ||
func (o *Options) Validate() error { | ||
if !o.IsUnlimited { | ||
if o.Key == "" { | ||
return fmt.Errorf("empty keys not allowed") | ||
} | ||
if o.MaxCount == 0 { | ||
return fmt.Errorf("maxcount cannot be zero") | ||
} | ||
if o.Duration == 0 { | ||
return fmt.Errorf("time duration not set") | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
// MultiLimiter is wrapper around Limiter than can limit based on a key | ||
type MultiLimiter struct { | ||
limiters map[string]*Limiter | ||
ctx context.Context | ||
} | ||
|
||
// Adds new bucket with key | ||
func (m *MultiLimiter) Add(opts *Options) error { | ||
if err := opts.Validate(); err != nil { | ||
return err | ||
} | ||
_, ok := m.limiters[opts.Key] | ||
if ok { | ||
return fmt.Errorf("key already exists") | ||
} | ||
var rlimiter *Limiter | ||
if opts.IsUnlimited { | ||
rlimiter = NewUnlimited(m.ctx) | ||
} else { | ||
rlimiter = New(m.ctx, opts.MaxCount, opts.Duration) | ||
} | ||
m.limiters[opts.Key] = rlimiter | ||
return nil | ||
} | ||
|
||
// GetLimit returns current ratelimit of given key | ||
func (m *MultiLimiter) GetLimit(key string) (uint, error) { | ||
limiter, ok := m.limiters[key] | ||
if !ok || limiter == nil { | ||
return 0, fmt.Errorf("key doesnot exist") | ||
} | ||
return limiter.GetLimit(), nil | ||
} | ||
|
||
// Take one token from bucket returns error if key not present | ||
func (m *MultiLimiter) Take(key string) error { | ||
limiter, ok := m.limiters[key] | ||
if !ok || limiter == nil { | ||
return fmt.Errorf("key doesnot exist") | ||
} | ||
limiter.Take() | ||
return nil | ||
} | ||
|
||
// SleepandReset stops timer removes all tokens and resets with new limit (used for Adaptive Ratelimiting) | ||
func (m *MultiLimiter) SleepandReset(SleepTime time.Duration, opts *Options) error { | ||
if err := opts.Validate(); err != nil { | ||
return err | ||
} | ||
limiter, ok := m.limiters[opts.Key] | ||
if !ok || limiter == nil { | ||
return fmt.Errorf("key doesnot exist") | ||
} | ||
limiter.SleepandReset(SleepTime, opts.MaxCount, opts.Duration) | ||
return nil | ||
} | ||
|
||
// NewMultiLimiter : Limits | ||
func NewMultiLimiter(ctx context.Context, opts *Options) (*MultiLimiter, error) { | ||
if err := opts.Validate(); err != nil { | ||
return nil, err | ||
} | ||
multilimiter := &MultiLimiter{ | ||
ctx: ctx, | ||
limiters: map[string]*Limiter{}, | ||
} | ||
return multilimiter, multilimiter.Add(opts) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package ratelimit_test | ||
|
||
import ( | ||
"context" | ||
"sync" | ||
"testing" | ||
"time" | ||
|
||
"github.com/projectdiscovery/ratelimit" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestMultiLimiter(t *testing.T) { | ||
limiter, err := ratelimit.NewMultiLimiter(context.Background(), &ratelimit.Options{ | ||
Key: "default", | ||
IsUnlimited: false, | ||
MaxCount: 100, | ||
Duration: time.Duration(3) * time.Second, | ||
}) | ||
require.Nil(t, err) | ||
wg := &sync.WaitGroup{} | ||
|
||
wg.Add(1) | ||
go func() { | ||
defer wg.Done() | ||
defaultStart := time.Now() | ||
for i := 0; i < 201; i++ { | ||
errx := limiter.Take("default") | ||
require.Nil(t, errx, "failed to take") | ||
} | ||
require.Greater(t, time.Since(defaultStart), time.Duration(6)*time.Second) | ||
}() | ||
|
||
err = limiter.Add(&ratelimit.Options{ | ||
Key: "one", | ||
IsUnlimited: false, | ||
MaxCount: 100, | ||
Duration: time.Duration(3) * time.Second, | ||
}) | ||
require.Nil(t, err) | ||
|
||
wg.Add(1) | ||
go func() { | ||
defer wg.Done() | ||
oneStart := time.Now() | ||
for i := 0; i < 201; i++ { | ||
errx := limiter.Take("one") | ||
require.Nil(t, errx) | ||
} | ||
require.Greater(t, time.Since(oneStart), time.Duration(6)*time.Second) | ||
}() | ||
wg.Wait() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters