Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate from aioredis to redis.asyncio #1074

Merged
merged 2 commits into from Dec 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
234 changes: 35 additions & 199 deletions aiogram/contrib/fsm_storage/redis.py
@@ -1,18 +1,18 @@
"""
This module has redis storage for finite-state machine based on `aioredis <https://github.com/aio-libs/aioredis>`_ driver
This module has redis storage for finite-state machine based on `redis <https://pypi.org/project/redis/>`_ driver.
"""

import asyncio
import logging
import typing
from abc import ABC, abstractmethod

import aioredis

from ...dispatcher.storage import BaseStorage
from ...utils import json
from ...utils.deprecated import deprecated

if typing.TYPE_CHECKING:
import aioredis

STATE_KEY = 'state'
STATE_DATA_KEY = 'data'
STATE_BUCKET_KEY = 'bucket'
Expand Down Expand Up @@ -67,6 +67,8 @@ async def redis(self) -> "aioredis.RedisConnection":
Get Redis connection
"""
# Use thread-safe asyncio Lock because this method without that is not safe
import aioredis

async with self._connection_lock:
if self._redis is None or self._redis.closed:
self._redis = await aioredis.create_connection((self._host, self._port),
Expand Down Expand Up @@ -207,138 +209,6 @@ async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket)


class AioRedisAdapterBase(ABC):
"""Base aioredis adapter class."""

def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: typing.Optional[int] = None,
password: typing.Optional[str] = None,
ssl: typing.Optional[bool] = None,
pool_size: int = 10,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
prefix: str = "fsm",
state_ttl: typing.Optional[int] = None,
data_ttl: typing.Optional[int] = None,
bucket_ttl: typing.Optional[int] = None,
**kwargs,
):
self._host = host
self._port = port
self._db = db
self._password = password
self._ssl = ssl
self._pool_size = pool_size
self._kwargs = kwargs
self._prefix = (prefix,)

self._state_ttl = state_ttl
self._data_ttl = data_ttl
self._bucket_ttl = bucket_ttl

self._redis: typing.Optional["aioredis.Redis"] = None
self._connection_lock = asyncio.Lock()

@abstractmethod
async def get_redis(self) -> aioredis.Redis:
"""Get Redis connection."""
pass

async def close(self):
"""Grace shutdown."""
pass

async def wait_closed(self):
"""Wait for grace shutdown finishes."""
pass

async def set(self, name, value, ex=None, **kwargs):
"""Set the value at key ``name`` to ``value``."""
if ex == 0:
ex = None
return await self._redis.set(name, value, ex=ex, **kwargs)

async def get(self, name, **kwargs):
"""Return the value at key ``name`` or None."""
return await self._redis.get(name, **kwargs)

async def delete(self, *names):
"""Delete one or more keys specified by ``names``"""
return await self._redis.delete(*names)

async def keys(self, pattern, **kwargs):
"""Returns a list of keys matching ``pattern``."""
return await self._redis.keys(pattern, **kwargs)

async def flushdb(self):
"""Delete all keys in the current database."""
return await self._redis.flushdb()


class AioRedisAdapterV1(AioRedisAdapterBase):
"""Redis adapter for aioredis v1."""

async def get_redis(self) -> aioredis.Redis:
"""Get Redis connection."""
async with self._connection_lock: # to prevent race
if self._redis is None or self._redis.closed:
self._redis = await aioredis.create_redis_pool(
(self._host, self._port),
db=self._db,
password=self._password,
ssl=self._ssl,
minsize=1,
maxsize=self._pool_size,
**self._kwargs,
)
return self._redis

async def close(self):
async with self._connection_lock:
if self._redis and not self._redis.closed:
self._redis.close()

async def wait_closed(self):
async with self._connection_lock:
if self._redis:
return await self._redis.wait_closed()
return True

async def get(self, name, **kwargs):
return await self._redis.get(name, encoding="utf8", **kwargs)

async def set(self, name, value, ex=None, **kwargs):
if ex == 0:
ex = None
return await self._redis.set(name, value, expire=ex, **kwargs)

async def keys(self, pattern, **kwargs):
"""Returns a list of keys matching ``pattern``."""
return await self._redis.keys(pattern, encoding="utf8", **kwargs)


class AioRedisAdapterV2(AioRedisAdapterBase):
"""Redis adapter for aioredis v2."""

async def get_redis(self) -> aioredis.Redis:
"""Get Redis connection."""
async with self._connection_lock: # to prevent race
if self._redis is None:
self._redis = aioredis.Redis(
host=self._host,
port=self._port,
db=self._db,
password=self._password,
ssl=self._ssl,
max_connections=self._pool_size,
decode_responses=True,
**self._kwargs,
)
return self._redis


class RedisStorage2(BaseStorage):
"""
Busted Redis-base storage for FSM.
Expand All @@ -356,7 +226,6 @@ class RedisStorage2(BaseStorage):
.. code-block:: python3

await dp.storage.close()
await dp.storage.wait_closed()

"""

Expand All @@ -375,75 +244,49 @@ def __init__(
bucket_ttl: typing.Optional[int] = None,
**kwargs,
):
self._host = host
self._port = port
self._db = db
self._password = password
self._ssl = ssl
self._pool_size = pool_size
self._kwargs = kwargs
self._prefix = (prefix,)
from redis.asyncio import Redis

self._redis: typing.Optional[Redis] = Redis(
host=host,
port=port,
db=db,
password=password,
ssl=ssl,
max_connections=pool_size,
decode_responses=True,
**kwargs,
)

self._prefix = (prefix,)
self._state_ttl = state_ttl
self._data_ttl = data_ttl
self._bucket_ttl = bucket_ttl

self._redis: typing.Optional[AioRedisAdapterBase] = None
self._connection_lock = asyncio.Lock()

@deprecated("This method will be removed in aiogram v3.0. "
"You should use your own instance of Redis.", stacklevel=3)
async def redis(self) -> aioredis.Redis:
adapter = await self._get_adapter()
return await adapter.get_redis()

async def _get_adapter(self) -> AioRedisAdapterBase:
"""Get adapter based on aioredis version."""
if self._redis is None:
redis_version = int(aioredis.__version__.split(".")[0])
connection_data = dict(
host=self._host,
port=self._port,
db=self._db,
password=self._password,
ssl=self._ssl,
pool_size=self._pool_size,
**self._kwargs,
)
if redis_version == 1:
self._redis = AioRedisAdapterV1(**connection_data)
elif redis_version == 2:
self._redis = AioRedisAdapterV2(**connection_data)
else:
raise RuntimeError(f"Unsupported aioredis version: {redis_version}")
await self._redis.get_redis()
async def redis(self) -> "aioredis.Redis":
return self._redis

def generate_key(self, *parts):
return ':'.join(self._prefix + tuple(map(str, parts)))

async def close(self):
if self._redis:
return await self._redis.close()
await self._redis.close()

async def wait_closed(self):
if self._redis:
await self._redis.wait_closed()
self._redis = None
pass

async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
redis = await self._get_adapter()
return await redis.get(key) or self.resolve_state(default)
return await self._redis.get(key) or self.resolve_state(default)

async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_DATA_KEY)
redis = await self._get_adapter()
raw_result = await redis.get(key)
raw_result = await self._redis.get(key)
if raw_result:
return json.loads(raw_result)
return default or {}
Expand All @@ -452,21 +295,19 @@ async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: ty
state: typing.Optional[typing.AnyStr] = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
redis = await self._get_adapter()
if state is None:
await redis.delete(key)
await self._redis.delete(key)
else:
await redis.set(key, self.resolve_state(state), ex=self._state_ttl)
await self._redis.set(key, self.resolve_state(state), ex=self._state_ttl)

async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_DATA_KEY)
redis = await self._get_adapter()
if data:
await redis.set(key, json.dumps(data), ex=self._data_ttl)
await self._redis.set(key, json.dumps(data), ex=self._data_ttl)
else:
await redis.delete(key)
await self._redis.delete(key)

async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs):
Expand All @@ -483,8 +324,7 @@ async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: t
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
redis = await self._get_adapter()
raw_result = await redis.get(key)
raw_result = await self._redis.get(key)
if raw_result:
return json.loads(raw_result)
return default or {}
Expand All @@ -493,11 +333,10 @@ async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: t
bucket: typing.Dict = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
redis = await self._get_adapter()
if bucket:
await redis.set(key, json.dumps(bucket), ex=self._bucket_ttl)
await self._redis.set(key, json.dumps(bucket), ex=self._bucket_ttl)
else:
await redis.delete(key)
await self._redis.delete(key)

async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
Expand All @@ -515,24 +354,21 @@ async def reset_all(self, full=True):
:param full: clean DB or clean only states
:return:
"""
redis = await self._get_adapter()

if full:
await redis.flushdb()
await self._redis.flushdb()
else:
keys = await redis.keys(self.generate_key('*'))
await redis.delete(*keys)
keys = await self._redis.keys(self.generate_key('*'))
await self._redis.delete(*keys)

async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]:
"""
Get list of all stored chat's and user's

:return: list of tuples where first element is chat id and second is user id
"""
redis = await self._get_adapter()
result = []

keys = await redis.keys(self.generate_key('*', '*', STATE_KEY))
keys = await self._redis.keys(self.generate_key('*', '*', STATE_KEY))
for item in keys:
*_, chat, user, _ = item.split(':')
result.append((chat, user))
Expand Down