Skip to content

Commit

Permalink
Refactored initialization of some data stores and event brokers
Browse files Browse the repository at this point in the history
This also fixes missing close-on-exit behavior when they were initialized from the class method.
  • Loading branch information
agronholm committed May 15, 2024
1 parent 4c3effd commit c8ad1b5
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 118 deletions.
5 changes: 5 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ APScheduler, see the :doc:`migration section <migration>`.
``DataStore`` implementation, rather than the scheduler, for consistency with the
``acquire_jobs()`` method
- **BREAKING** The ``started_at`` field was moved from ``Job`` to ``JobResult``
- **BREAKING** Removed the ``from_url()`` class methods of ``SQLAlchemyDataStore``,
``MongoDBDataStore`` and ``RedisEventBroker`` in favor of the ability to pass a
connection url to the initializer
- Added the ability to pause and unpause schedules (PR by @WillDaSilva)
- Added the ``scheduled_start`` field to the ``JobAcquired`` event
- Added the ``scheduled_start`` and ``started_at`` fields to the ``JobReleased`` event
Expand All @@ -32,6 +35,8 @@ APScheduler, see the :doc:`migration section <migration>`.
- Fixed ``SQLAlchemyDataStore`` not respecting custom schema name when creating enums
- Fixed skipped intervals with overlapping schedules in ``AndTrigger``
(#911 <https://github.com/agronholm/apscheduler/issues/911>_; PR by Bennett Meares)
- Fixed implicitly created client instances in data stores and event brokers not being
closed along with the store/broker

**4.0.0a4**

Expand Down
2 changes: 1 addition & 1 deletion examples/web/wsgi_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def hello_world():

engine = create_engine("postgresql+psycopg://postgres:secret@localhost/testdb")
data_store = SQLAlchemyDataStore(engine)
event_broker = RedisEventBroker.from_url("redis://localhost")
event_broker = RedisEventBroker("redis://localhost")
scheduler = Scheduler(data_store, event_broker)
scheduler.add_schedule(tick, IntervalTrigger(seconds=1), id="tick")
scheduler.start_in_background()
2 changes: 1 addition & 1 deletion examples/web/wsgi_noframework.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def application(environ, start_response):

engine = create_engine("mysql+pymysql://root:secret@localhost/testdb")
data_store = SQLAlchemyDataStore(engine)
event_broker = RedisEventBroker.from_url("redis://localhost")
event_broker = RedisEventBroker("redis://localhost")
scheduler = Scheduler(data_store, event_broker)
scheduler.add_schedule(tick, IntervalTrigger(seconds=1), id="tick")
scheduler.start_in_background()
48 changes: 31 additions & 17 deletions src/apscheduler/datastores/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,27 @@ class MongoDBDataStore(BaseExternalDataStore):
Operations are retried (in accordance to ``retry_settings``) when an operation
raises :exc:`pymongo.errors.ConnectionFailure`.
:param client: a PyMongo client
:param client_or_uri: a PyMongo client or a MongoDB connection URI
:param database: name of the database to use
.. note:: The data store will not manage the life cycle of any client instance
passed to it, so you need to close the client afterwards when you're done with
it.
.. note:: Datetimes are stored as integers along with their UTC offsets instead of
BSON datetimes due to the BSON datetimes only being accurate to the millisecond
while Python datetimes are accurate to the microsecond.
"""

client: MongoClient = attrs.field(validator=instance_of(MongoClient))
database: str = attrs.field(default="apscheduler", kw_only=True)
client_or_uri: MongoClient | str = attrs.field(
validator=instance_of((MongoClient, str))
)
database: str = attrs.field(
default="apscheduler", kw_only=True, validator=instance_of(str)
)

_client: MongoClient = attrs.field(init=False)
_close_on_exit: bool = attrs.field(init=False, default=False)
_task_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Task)]
_schedule_attrs: ClassVar[list[str]] = [
field.name for field in attrs.fields(Schedule)
Expand All @@ -164,6 +174,12 @@ def _temporary_failure_exceptions(self) -> tuple[type[Exception], ...]:
return (ConnectionFailure,)

def __attrs_post_init__(self) -> None:
if isinstance(self.client_or_uri, str):
self._client = MongoClient(self.client_or_uri)
self._close_on_exit = True
else:
self._client = self.client_or_uri

type_registry = TypeRegistry(
[
CustomEncoder(timedelta, timedelta.total_seconds),
Expand All @@ -176,19 +192,14 @@ def __attrs_post_init__(self) -> None:
type_registry=type_registry,
uuid_representation=UuidRepresentation.STANDARD,
)
database = self.client.get_database(self.database, codec_options=codec_options)
database = self._client.get_database(self.database, codec_options=codec_options)
self._tasks = database["tasks"]
self._schedules = database["schedules"]
self._jobs = database["jobs"]
self._jobs_results = database["job_results"]

@classmethod
def from_url(cls, uri: str, **options) -> MongoDBDataStore:
client = MongoClient(uri)
return cls(client, **options)

def _initialize(self) -> None:
with self.client.start_session() as session:
with self._client.start_session() as session:
if self.start_from_scratch:
self._tasks.delete_many({}, session=session)
self._schedules.delete_many({}, session=session)
Expand All @@ -205,8 +216,11 @@ def _initialize(self) -> None:
async def start(
self, exit_stack: AsyncExitStack, event_broker: EventBroker, logger: Logger
) -> None:
if self._close_on_exit:
exit_stack.push_async_callback(to_thread.run_sync, self._client.close)

await super().start(exit_stack, event_broker, logger)
server_info = await to_thread.run_sync(self.client.server_info)
server_info = await to_thread.run_sync(self._client.server_info)
if server_info["versionArray"] < [4, 0]:
raise RuntimeError(
f"MongoDB server must be at least v4.0; current version = "
Expand Down Expand Up @@ -340,7 +354,7 @@ async def add_schedule(
async def remove_schedules(self, ids: Iterable[str]) -> None:
filters = {"_id": {"$in": list(ids)}} if ids is not None else {}
async for attempt in self._retry():
with attempt, self.client.start_session() as session:
with attempt, self._client.start_session() as session:
async with await AsyncCursor.create(
lambda: self._schedules.find(
filters, projection=["_id", "task_id"], session=session
Expand All @@ -359,7 +373,7 @@ async def remove_schedules(self, ids: Iterable[str]) -> None:

async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
async for attempt in self._retry():
with attempt, self.client.start_session() as session:
with attempt, self._client.start_session() as session:
schedules: list[Schedule] = []
now = datetime.now(timezone.utc).timestamp()
async with await AsyncCursor.create(
Expand Down Expand Up @@ -443,7 +457,7 @@ async def release_schedules(

if requests:
async for attempt in self._retry():
with attempt, self.client.start_session() as session:
with attempt, self._client.start_session() as session:
await to_thread.run_sync(
lambda: self._schedules.bulk_write(
requests, ordered=False, session=session
Expand Down Expand Up @@ -525,7 +539,7 @@ async def acquire_jobs(
self, scheduler_id: str, limit: int | None = None
) -> list[Job]:
async for attempt in self._retry():
with attempt, self.client.start_session() as session:
with attempt, self._client.start_session() as session:
now = datetime.now(timezone.utc)
async with await AsyncCursor.create(
lambda: self._jobs.find(
Expand Down Expand Up @@ -613,7 +627,7 @@ async def acquire_jobs(

async def release_job(self, scheduler_id: str, job: Job, result: JobResult) -> None:
async for attempt in self._retry():
with attempt, self.client.start_session() as session:
with attempt, self._client.start_session() as session:
# Record the job result
if result.expires_at > result.finished_at:
document = result.marshal(self.serializer)
Expand Down Expand Up @@ -659,7 +673,7 @@ async def get_job_result(self, job_id: UUID) -> JobResult | None:
async def cleanup(self) -> None:
# Clean up expired job results
async for attempt in self._retry():
with attempt, self.client.start_session() as session:
with attempt, self._client.start_session() as session:
# Purge expired job results
now = datetime.now(timezone.utc).timestamp()

Expand Down
68 changes: 36 additions & 32 deletions src/apscheduler/datastores/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import sys
from collections import defaultdict
from collections.abc import AsyncGenerator, Mapping, Sequence
from contextlib import AsyncExitStack, asynccontextmanager
Expand All @@ -15,6 +14,7 @@
import sniffio
import tenacity
from anyio import CancelScope, to_thread
from attr.validators import instance_of
from sqlalchemy import (
BigInteger,
Boolean,
Expand All @@ -32,6 +32,7 @@
Uuid,
and_,
bindparam,
create_engine,
false,
or_,
select,
Expand All @@ -41,6 +42,7 @@
CompileError,
IntegrityError,
InterfaceError,
InvalidRequestError,
ProgrammingError,
)
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
Expand Down Expand Up @@ -70,11 +72,6 @@
from ..abc import EventBroker
from .base import BaseExternalDataStore

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class EmulatedTimestampTZ(TypeDecorator[datetime]):
impl = Unicode(32)
Expand Down Expand Up @@ -123,15 +120,22 @@ class SQLAlchemyDataStore(BaseExternalDataStore):
* MySQL (asyncmy driver)
* aiosqlite
:param engine: an asynchronous SQLAlchemy engine
:param engine_or_url: a SQLAlchemy URL or engine (preferably asynchronous, but can
be synchronous)
:param schema: a database schema name to use, if not the default
.. note:: The data store will not manage the life cycle of any engine instance
passed to it, so you need to close the engine afterwards when you're done with
it.
"""

engine: Engine | AsyncEngine
schema: str | None = attrs.field(default=None)
max_poll_time: float | None = attrs.field(default=1)
max_idle_time: float = attrs.field(default=60)
engine_or_url: str | URL | Engine | AsyncEngine = attrs.field(
validator=instance_of((str, URL, Engine, AsyncEngine))
)
schema: str | None = attrs.field(kw_only=True, default=None)

_engine: Engine | AsyncEngine = attrs.field(init=False)
_close_on_exit: bool = attrs.field(init=False, default=False)
_supports_update_returning: bool = attrs.field(init=False, default=False)
_supports_tzaware_timestamps: bool = attrs.field(init=False, default=False)
_supports_native_interval: bool = attrs.field(init=False, default=False)
Expand All @@ -143,34 +147,30 @@ class SQLAlchemyDataStore(BaseExternalDataStore):
_t_job_results: Table = attrs.field(init=False)

def __attrs_post_init__(self) -> None:
if isinstance(self.engine_or_url, (str, URL)):
try:
self._engine = create_async_engine(self.engine_or_url)
except InvalidRequestError:
self._engine = create_engine(self.engine_or_url)

self._close_on_exit = True
else:
self._engine = self.engine_or_url

# Generate the table definitions
prefix = f"{self.schema}." if self.schema else ""
self._supports_tzaware_timestamps = self.engine.dialect.name in (
self._supports_tzaware_timestamps = self._engine.dialect.name in (
"postgresql",
"oracle",
)
self._supports_native_interval = self.engine.dialect.name == "postgresql"
self._supports_native_interval = self._engine.dialect.name == "postgresql"
self._metadata = self.get_table_definitions()
self._t_metadata = self._metadata.tables[prefix + "metadata"]
self._t_tasks = self._metadata.tables[prefix + "tasks"]
self._t_schedules = self._metadata.tables[prefix + "schedules"]
self._t_jobs = self._metadata.tables[prefix + "jobs"]
self._t_job_results = self._metadata.tables[prefix + "job_results"]

@classmethod
def from_url(cls: type[Self], url: str | URL, **options) -> Self:
"""
Create a new asynchronous SQLAlchemy data store.
:param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine`
(must use an async dialect like ``asyncpg`` or ``asyncmy``)
:param options: keyword arguments to pass to the initializer of this class
:return: the newly created data store
"""
engine = create_async_engine(url, future=True)
return cls(engine, **options)

def _retry(self) -> tenacity.AsyncRetrying:
def after_attempt(retry_state: tenacity.RetryCallState) -> None:
self._logger.warning(
Expand All @@ -196,13 +196,13 @@ async def _begin_transaction(
# A shielded cancel scope is injected to the exit stack to allow finalization
# to occur even when the surrounding cancel scope is cancelled
async with AsyncExitStack() as exit_stack:
if isinstance(self.engine, AsyncEngine):
async_cm = self.engine.begin()
if isinstance(self._engine, AsyncEngine):
async_cm = self._engine.begin()
conn = await async_cm.__aenter__()
exit_stack.enter_context(CancelScope(shield=True))
exit_stack.push_async_exit(async_cm.__aexit__)
else:
cm = self.engine.begin()
cm = self._engine.begin()
conn = await to_thread.run_sync(cm.__enter__)
exit_stack.enter_context(CancelScope(shield=True))
exit_stack.push_async_exit(partial(to_thread.run_sync, cm.__exit__))
Expand All @@ -229,7 +229,7 @@ async def _execute(
@property
def _temporary_failure_exceptions(self) -> tuple[type[Exception], ...]:
# SQlite does not use the network, so it doesn't have "temporary" failures
if self.engine.dialect.name == "sqlite":
if self._engine.dialect.name == "sqlite":
return ()

return InterfaceError, OSError
Expand Down Expand Up @@ -336,13 +336,17 @@ def get_table_definitions(self) -> MetaData:
async def start(
self, exit_stack: AsyncExitStack, event_broker: EventBroker, logger: Logger
) -> None:
await super().start(exit_stack, event_broker, logger)
asynclib = sniffio.current_async_library() or "(unknown)"
if asynclib != "asyncio":
raise RuntimeError(
f"This data store requires asyncio; currently running: {asynclib}"
)

if self._close_on_exit:
exit_stack.push_async_callback(self._engine.dispose)

await super().start(exit_stack, event_broker, logger)

# Verify that the schema is in place
async for attempt in self._retry():
with attempt:
Expand Down

0 comments on commit c8ad1b5

Please sign in to comment.