forked from aiogram/aiogram
/
mongo.py
217 lines (169 loc) · 7.74 KB
/
mongo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""
This module has mongo storage for finite-state machine
based on `motor <https://github.com/mongodb/motor>`_ driver
"""
from typing import Union, Dict, Optional, List, Tuple, AnyStr
try:
import pymongo
import motor
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
except ModuleNotFoundError as e:
import warnings
warnings.warn("Install motor with `pip install motor`")
raise e
from ...dispatcher.storage import BaseStorage
STATE = 'aiogram_state'
DATA = 'aiogram_data'
BUCKET = 'aiogram_bucket'
COLLECTIONS = (STATE, DATA, BUCKET)
class MongoStorage(BaseStorage):
"""
Mongo-based storage for FSM.
Usage:
.. code-block:: python3
storage = MongoStorage(host='localhost', port=27017, db_name='aiogram_fsm')
dp = Dispatcher(bot, storage=storage)
And need to close Mongo client connections when shutdown
.. code-block:: python3
await dp.storage.close()
await dp.storage.wait_closed()
"""
def __init__(self, host='localhost', port=27017, db_name='aiogram_fsm', uri=None,
username=None, password=None, index=True, **kwargs):
self._host = host
self._port = port
self._db_name: str = db_name
self._uri = uri
self._username = username
self._password = password
self._kwargs = kwargs # custom client options like SSL configuration, etc.
self._mongo: Optional[AsyncIOMotorClient] = None
self._db: Optional[AsyncIOMotorDatabase] = None
self._index = index
async def get_client(self) -> AsyncIOMotorClient:
if isinstance(self._mongo, AsyncIOMotorClient):
return self._mongo
if self._uri:
try:
self._mongo = AsyncIOMotorClient(self._uri, **self._kwargs)
except pymongo.errors.ConfigurationError as e:
if "query() got an unexpected keyword argument 'lifetime'" in e.args[0]:
import logging
logger = logging.getLogger("aiogram")
logger.warning("Run `pip install dnspython==1.16.0` in order to fix ConfigurationError. More information: https://github.com/mongodb/mongo-python-driver/pull/423#issuecomment-528998245")
raise e
return self._mongo
uri = 'mongodb://'
# set username + password
if self._username and self._password:
uri += f'{self._username}:{self._password}@'
# set host and port (optional)
uri += f'{self._host}:{self._port}' if self._host else f'localhost:{self._port}'
# define and return client
self._mongo = AsyncIOMotorClient(uri)
return self._mongo
async def get_db(self) -> AsyncIOMotorDatabase:
"""
Get Mongo db
This property is awaitable.
"""
if isinstance(self._db, AsyncIOMotorDatabase):
return self._db
mongo = await self.get_client()
self._db = mongo.get_database(self._db_name)
if self._index:
await self.apply_index(self._db)
return self._db
@staticmethod
async def apply_index(db):
for collection in COLLECTIONS:
await db[collection].create_index(keys=[('chat', 1), ('user', 1)],
name="chat_user_idx", unique=True, background=True)
async def close(self):
if self._mongo:
self._mongo.close()
async def wait_closed(self):
return True
async def set_state(self, *,
chat: Union[str, int, None] = None,
user: Union[str, int, None] = None,
state: Optional[AnyStr] = None):
chat, user = self.check_address(chat=chat, user=user)
db = await self.get_db()
if state is None:
await db[STATE].delete_one(filter={'chat': chat, 'user': user})
else:
await db[STATE].update_one(
filter={'chat': chat, 'user': user},
update={'$set': {'state': self.resolve_state(state)}},
upsert=True,
)
async def get_state(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
default: Optional[str] = None) -> Optional[str]:
chat, user = self.check_address(chat=chat, user=user)
db = await self.get_db()
result = await db[STATE].find_one(filter={'chat': chat, 'user': user})
return result.get('state') if result else self.resolve_state(default)
async def set_data(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
data: Dict = None):
chat, user = self.check_address(chat=chat, user=user)
db = await self.get_db()
if not data:
await db[DATA].delete_one(filter={'chat': chat, 'user': user})
else:
await db[DATA].update_one(filter={'chat': chat, 'user': user},
update={'$set': {'data': data}}, upsert=True)
async def get_data(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
default: Optional[dict] = None) -> Dict:
chat, user = self.check_address(chat=chat, user=user)
db = await self.get_db()
result = await db[DATA].find_one(filter={'chat': chat, 'user': user})
return result.get('data') if result else default or {}
async def update_data(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
data: Dict = None, **kwargs):
if data is None:
data = {}
temp_data = await self.get_data(chat=chat, user=user, default={})
temp_data.update(data, **kwargs)
await self.set_data(chat=chat, user=user, data=temp_data)
def has_bucket(self):
return True
async def get_bucket(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
default: Optional[dict] = None) -> Dict:
chat, user = self.check_address(chat=chat, user=user)
db = await self.get_db()
result = await db[BUCKET].find_one(filter={'chat': chat, 'user': user})
return result.get('bucket') if result else default or {}
async def set_bucket(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
bucket: Dict = None):
chat, user = self.check_address(chat=chat, user=user)
db = await self.get_db()
await db[BUCKET].update_one(filter={'chat': chat, 'user': user},
update={'$set': {'bucket': bucket}}, upsert=True)
async def update_bucket(self, *, chat: Union[str, int, None] = None,
user: Union[str, int, None] = None,
bucket: Dict = None, **kwargs):
if bucket is None:
bucket = {}
temp_bucket = await self.get_bucket(chat=chat, user=user)
temp_bucket.update(bucket, **kwargs)
await self.set_bucket(chat=chat, user=user, bucket=temp_bucket)
async def reset_all(self, full=True):
"""
Reset states in DB
:param full: clean DB or clean only states
:return:
"""
db = await self.get_db()
await db[STATE].drop()
if full:
await db[DATA].drop()
await db[BUCKET].drop()
async def get_states_list(self) -> List[Tuple[int, int]]:
"""
Get list of all stored chat's and user's
:return: list of tuples where first element is chat id and second is user id
"""
db = await self.get_db()
items = await db[STATE].find().to_list(length=None)
return [(int(item['chat']), int(item['user'])) for item in items]