Skip to content

Commit

Permalink
feat: impl asyncio version of RateLimit
Browse files Browse the repository at this point in the history
  • Loading branch information
pyto86pri committed Dec 15, 2020
1 parent 0c1bd4f commit 6a068fc
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 0 deletions.
12 changes: 12 additions & 0 deletions pyratelimit/aio/__init__.py
@@ -0,0 +1,12 @@
from pyratelimit import Per
from pyratelimit.aio.token_bucket import TokenBucket


class RateLimit:
"""A rate limit."""

def __init__(self, per: Per, burst: int) -> None:
self._token_bucket = TokenBucket(per.calls / per.period, burst)

async def wait(self, n: int = 1) -> None:
await self._token_bucket.consume(n)
85 changes: 85 additions & 0 deletions pyratelimit/aio/token_bucket.py
@@ -0,0 +1,85 @@
from asyncio import Condition, Lock, sleep
from asyncio.tasks import Task, create_task
from typing import Callable, Coroutine, Final, List, Tuple


def _create_periodic_task(
f: Callable[[], Coroutine[object, object, None]], period: float
) -> Task[None]:
async def wrapper() -> None:
while True:
await sleep(period)
await f()

return create_task(wrapper())


class TokenBucket:
"""A token bucket."""

def __init__(self, rate: float, bucket_size: int) -> None:
"""Constructor for TokenBucket.
Args:
rate (float): The number of tokens added to the bucket per second.
A token is added to the bucket every 1/rate seconds.
bucket_size (int): The maximum number of tokens the bucket can hold.
Raises:
ValueError: When rate or bucket_size less than or equal to 0.
"""
if rate <= 0:
raise ValueError("rate must be > 0")
if bucket_size <= 0:
raise ValueError("bucket size must be > 0")
self._rate: Final[float] = rate
self._bucket_size: Final[int] = bucket_size
self.n_token = bucket_size
self._cond = Condition(Lock())
_token_filler_worker.register(self)

async def fill(self, n: int = 1) -> None:
"""Fill the bucket with n tokens."""
async with self._cond:
self.n_token = min(self.n_token + n, self._bucket_size)
self._cond.notify()

async def consume(self, n: int = 1) -> None:
"""Consume n tokens from the bucket."""
async with self._cond:
while self.n_token < n:
await self._cond.wait()
else:
self.n_token -= n


class TokenFillerWorker:
"""A worker for filling buckets with tokens periodically"""

def __init__(self) -> None:
self._scheduled: List[Tuple[float, TokenBucket]] = []
self._stopping = False
self._tasks: List[Task[None]] = []

def register(self, tb: TokenBucket) -> None:
"""Register a token bucket to the worker.
Args:
tb (TokenBucket): A token bucket.
Raises:
RuntimeError: When called after stopping the worker.
"""
if self._stopping:
raise RuntimeError("Token filler worker already stopped.")
self._tasks.append(_create_periodic_task(tb.fill, 1 / tb._rate))

def stop(self) -> None:
"""Stop the worker."""
self._stopping = True
for task in self._tasks:
task.cancel()
self._tasks.clear()


_token_filler_worker = TokenFillerWorker()

0 comments on commit 6a068fc

Please sign in to comment.