Skip to content

Commit

Permalink
add typing parameters
Browse files Browse the repository at this point in the history
Fixed typing issue where :paramref:`.revision.process_revision_directives`
was not fully typed; additionally ensured all ``Callable`` and ``Dict``
arguments to :meth:`.EnvironmentContext.configure` include parameters in
the typing declaration.

Change-Id: I3ac389992f357359439be5659af33525fc290f96
Fixes: #1110
  • Loading branch information
zzzeek committed Nov 15, 2022
1 parent c76fd4b commit cfe87f5
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 57 deletions.
10 changes: 5 additions & 5 deletions alembic/command.py
@@ -1,7 +1,6 @@
from __future__ import annotations

import os
from typing import Callable
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
Expand All @@ -15,6 +14,7 @@
if TYPE_CHECKING:
from alembic.config import Config
from alembic.script.base import Script
from .runtime.environment import ProcessRevisionDirectiveFn


def list_templates(config):
Expand Down Expand Up @@ -124,7 +124,7 @@ def revision(
version_path: Optional[str] = None,
rev_id: Optional[str] = None,
depends_on: Optional[str] = None,
process_revision_directives: Callable = None,
process_revision_directives: Optional[ProcessRevisionDirectiveFn] = None,
) -> Union[Optional["Script"], List[Optional["Script"]]]:
"""Create a new revision file.
Expand Down Expand Up @@ -243,9 +243,9 @@ def retrieve_migrations(rev, context):
def merge(
config: "Config",
revisions: str,
message: str = None,
branch_label: str = None,
rev_id: str = None,
message: Optional[str] = None,
branch_label: Optional[str] = None,
rev_id: Optional[str] = None,
) -> Optional["Script"]:
"""Merge two revisions together. Creates a new migration file.
Expand Down
2 changes: 1 addition & 1 deletion alembic/config.py
Expand Up @@ -99,7 +99,7 @@ def __init__(
stdout: TextIO = sys.stdout,
cmd_opts: Optional[Namespace] = None,
config_args: util.immutabledict = util.immutabledict(),
attributes: dict = None,
attributes: Optional[dict] = None,
) -> None:
"""Construct a new :class:`.Config`"""
self.config_file_name = file_
Expand Down
22 changes: 13 additions & 9 deletions alembic/context.pyi
Expand Up @@ -19,13 +19,13 @@ if TYPE_CHECKING:
from sqlalchemy.sql.schema import MetaData

from .config import Config
from .operations import MigrateOperation
from .runtime.migration import _ProxyTransaction
from .runtime.migration import MigrationContext
from .script import ScriptDirectory

### end imports ###

def begin_transaction() -> Union["_ProxyTransaction", ContextManager]:
def begin_transaction() -> Union[_ProxyTransaction, ContextManager]:
"""Return a context manager that will
enclose an operation within a "transaction",
as defined by the environment's offline
Expand Down Expand Up @@ -75,29 +75,33 @@ def configure(
connection: Optional[Connection] = None,
url: Optional[str] = None,
dialect_name: Optional[str] = None,
dialect_opts: Optional[dict] = None,
dialect_opts: Optional[Dict[str, Any]] = None,
transactional_ddl: Optional[bool] = None,
transaction_per_migration: bool = False,
output_buffer: Optional[TextIO] = None,
starting_rev: Optional[str] = None,
tag: Optional[str] = None,
template_args: Optional[dict] = None,
template_args: Optional[Dict[str, Any]] = None,
render_as_batch: bool = False,
target_metadata: Optional[MetaData] = None,
include_name: Optional[Callable] = None,
include_object: Optional[Callable] = None,
include_name: Optional[Callable[..., bool]] = None,
include_object: Optional[Callable[..., bool]] = None,
include_schemas: bool = False,
process_revision_directives: Optional[Callable] = None,
process_revision_directives: Optional[
Callable[
[MigrationContext, Tuple[str, str], List[MigrateOperation]], None
]
] = None,
compare_type: bool = False,
compare_server_default: bool = False,
render_item: Optional[Callable] = None,
render_item: Optional[Callable[..., bool]] = None,
literal_binds: bool = False,
upgrade_token: str = "upgrades",
downgrade_token: str = "downgrades",
alembic_module_prefix: str = "op.",
sqlalchemy_module_prefix: str = "sa.",
user_module_prefix: Optional[str] = None,
on_version_apply: Optional[Callable] = None,
on_version_apply: Optional[Callable[..., None]] = None,
**kw: Any,
) -> None:
"""Configure a :class:`.MigrationContext` within this
Expand Down
10 changes: 5 additions & 5 deletions alembic/op.pyi
Expand Up @@ -35,8 +35,8 @@ if TYPE_CHECKING:

from .operations.ops import BatchOperations
from .operations.ops import MigrateOperation
from .runtime.migration import MigrationContext
from .util.sqla_compat import _literal_bindparam

### end imports ###

def add_column(
Expand Down Expand Up @@ -1082,13 +1082,13 @@ def get_bind() -> Connection:
"""

def get_context():
def get_context() -> MigrationContext:
"""Return the :class:`.MigrationContext` object that's
currently in use.
"""

def implementation_for(op_cls: Any) -> Callable:
def implementation_for(op_cls: Any) -> Callable[..., Any]:
"""Register an implementation for a given :class:`.MigrateOperation`.
This is part of the operation extensibility API.
Expand All @@ -1101,7 +1101,7 @@ def implementation_for(op_cls: Any) -> Callable:

def inline_literal(
value: Union[str, int], type_: None = None
) -> "_literal_bindparam":
) -> _literal_bindparam:
"""Produce an 'inline literal' expression, suitable for
using in an INSERT, UPDATE, or DELETE statement.
Expand Down Expand Up @@ -1152,7 +1152,7 @@ def invoke(operation: MigrateOperation) -> Any:

def register_operation(
name: str, sourcename: Optional[str] = None
) -> Callable:
) -> Callable[..., Any]:
"""Register a new operation for this class.
This method is normally used to add new operations
Expand Down
24 changes: 12 additions & 12 deletions alembic/operations/base.py
Expand Up @@ -25,6 +25,7 @@
from ..util.compat import formatannotation_fwdref
from ..util.compat import inspect_formatargspec
from ..util.compat import inspect_getfullargspec
from ..util.sqla_compat import _literal_bindparam


NoneType = type(None)
Expand All @@ -39,7 +40,6 @@
from .ops import MigrateOperation
from ..ddl import DefaultImpl
from ..runtime.migration import MigrationContext
from ..util.sqla_compat import _literal_bindparam

__all__ = ("Operations", "BatchOperations")

Expand Down Expand Up @@ -80,8 +80,8 @@ class Operations(util.ModuleClsProxy):

def __init__(
self,
migration_context: "MigrationContext",
impl: Optional["BatchOperationsImpl"] = None,
migration_context: MigrationContext,
impl: Optional[BatchOperationsImpl] = None,
) -> None:
"""Construct a new :class:`.Operations`
Expand All @@ -100,7 +100,7 @@ def __init__(
@classmethod
def register_operation(
cls, name: str, sourcename: Optional[str] = None
) -> Callable:
) -> Callable[..., Any]:
"""Register a new operation for this class.
This method is normally used to add new operations
Expand Down Expand Up @@ -188,7 +188,7 @@ def %(name)s%(args)s:
return register

@classmethod
def implementation_for(cls, op_cls: Any) -> Callable:
def implementation_for(cls, op_cls: Any) -> Callable[..., Any]:
"""Register an implementation for a given :class:`.MigrateOperation`.
This is part of the operation extensibility API.
Expand All @@ -208,8 +208,8 @@ def decorate(fn):
@classmethod
@contextmanager
def context(
cls, migration_context: "MigrationContext"
) -> Iterator["Operations"]:
cls, migration_context: MigrationContext
) -> Iterator[Operations]:
op = Operations(migration_context)
op._install_proxy()
yield op
Expand Down Expand Up @@ -382,15 +382,15 @@ def batch_alter_table(
yield batch_op
impl.flush()

def get_context(self):
def get_context(self) -> MigrationContext:
"""Return the :class:`.MigrationContext` object that's
currently in use.
"""

return self.migration_context

def invoke(self, operation: "MigrateOperation") -> Any:
def invoke(self, operation: MigrateOperation) -> Any:
"""Given a :class:`.MigrateOperation`, invoke it in terms of
this :class:`.Operations` instance.
Expand All @@ -400,7 +400,7 @@ def invoke(self, operation: "MigrateOperation") -> Any:
)
return fn(self, operation)

def f(self, name: str) -> "conv":
def f(self, name: str) -> conv:
"""Indicate a string name that has already had a naming convention
applied to it.
Expand Down Expand Up @@ -440,7 +440,7 @@ def f(self, name: str) -> "conv":

def inline_literal(
self, value: Union[str, int], type_: None = None
) -> "_literal_bindparam":
) -> _literal_bindparam:
r"""Produce an 'inline literal' expression, suitable for
using in an INSERT, UPDATE, or DELETE statement.
Expand Down Expand Up @@ -484,7 +484,7 @@ def inline_literal(
"""
return sqla_compat._literal_bindparam(None, value, type_=type_)

def get_bind(self) -> "Connection":
def get_bind(self) -> Connection:
"""Return the current 'bind'.
Under normal circumstances, this is the
Expand Down
45 changes: 26 additions & 19 deletions alembic/runtime/environment.py
Expand Up @@ -12,6 +12,7 @@
from typing import TYPE_CHECKING
from typing import Union

from .migration import _ProxyTransaction
from .migration import MigrationContext
from .. import util
from ..operations import Operations
Expand All @@ -23,13 +24,17 @@
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.schema import MetaData

from .migration import _ProxyTransaction
from ..config import Config
from ..ddl import DefaultImpl
from ..operations.ops import MigrateOperation
from ..script.base import ScriptDirectory

_RevNumber = Optional[Union[str, Tuple[str, ...]]]

ProcessRevisionDirectiveFn = Callable[
[MigrationContext, Tuple[str, str], List["MigrateOperation"]], None
]


class EnvironmentContext(util.ModuleClsProxy):

Expand Down Expand Up @@ -109,7 +114,7 @@ def my_function(rev, context):
"""

def __init__(
self, config: "Config", script: "ScriptDirectory", **kw: Any
self, config: Config, script: ScriptDirectory, **kw: Any
) -> None:
r"""Construct a new :class:`.EnvironmentContext`.
Expand All @@ -124,7 +129,7 @@ def __init__(
self.script = script
self.context_opts = kw

def __enter__(self) -> "EnvironmentContext":
def __enter__(self) -> EnvironmentContext:
"""Establish a context which provides a
:class:`.EnvironmentContext` object to
env.py scripts.
Expand Down Expand Up @@ -265,13 +270,13 @@ def get_tag_argument(self) -> Optional[str]:

@overload
def get_x_argument( # type:ignore[misc]
self, as_dictionary: "Literal[False]" = ...
self, as_dictionary: Literal[False] = ...
) -> List[str]:
...

@overload
def get_x_argument( # type:ignore[misc]
self, as_dictionary: "Literal[True]" = ...
self, as_dictionary: Literal[True] = ...
) -> Dict[str, str]:
...

Expand Down Expand Up @@ -326,32 +331,34 @@ def get_x_argument(

def configure(
self,
connection: Optional["Connection"] = None,
connection: Optional[Connection] = None,
url: Optional[str] = None,
dialect_name: Optional[str] = None,
dialect_opts: Optional[dict] = None,
dialect_opts: Optional[Dict[str, Any]] = None,
transactional_ddl: Optional[bool] = None,
transaction_per_migration: bool = False,
output_buffer: Optional[TextIO] = None,
starting_rev: Optional[str] = None,
tag: Optional[str] = None,
template_args: Optional[dict] = None,
template_args: Optional[Dict[str, Any]] = None,
render_as_batch: bool = False,
target_metadata: Optional["MetaData"] = None,
include_name: Optional[Callable] = None,
include_object: Optional[Callable] = None,
target_metadata: Optional[MetaData] = None,
include_name: Optional[Callable[..., bool]] = None,
include_object: Optional[Callable[..., bool]] = None,
include_schemas: bool = False,
process_revision_directives: Optional[Callable] = None,
process_revision_directives: Optional[
ProcessRevisionDirectiveFn
] = None,
compare_type: bool = False,
compare_server_default: bool = False,
render_item: Optional[Callable] = None,
render_item: Optional[Callable[..., bool]] = None,
literal_binds: bool = False,
upgrade_token: str = "upgrades",
downgrade_token: str = "downgrades",
alembic_module_prefix: str = "op.",
sqlalchemy_module_prefix: str = "sa.",
user_module_prefix: Optional[str] = None,
on_version_apply: Optional[Callable] = None,
on_version_apply: Optional[Callable[..., None]] = None,
**kw: Any,
) -> None:
"""Configure a :class:`.MigrationContext` within this
Expand Down Expand Up @@ -859,7 +866,7 @@ def run_migrations(self, **kw: Any) -> None:

def execute(
self,
sql: Union["ClauseElement", str],
sql: Union[ClauseElement, str],
execution_options: Optional[dict] = None,
) -> None:
"""Execute the given SQL using the current change context.
Expand Down Expand Up @@ -888,7 +895,7 @@ def static_output(self, text: str) -> None:

def begin_transaction(
self,
) -> Union["_ProxyTransaction", ContextManager]:
) -> Union[_ProxyTransaction, ContextManager]:
"""Return a context manager that will
enclose an operation within a "transaction",
as defined by the environment's offline
Expand Down Expand Up @@ -934,7 +941,7 @@ def begin_transaction(

return self.get_context().begin_transaction()

def get_context(self) -> "MigrationContext":
def get_context(self) -> MigrationContext:
"""Return the current :class:`.MigrationContext` object.
If :meth:`.EnvironmentContext.configure` has not been
Expand All @@ -946,7 +953,7 @@ def get_context(self) -> "MigrationContext":
raise Exception("No context has been configured yet.")
return self._migration_context

def get_bind(self) -> "Connection":
def get_bind(self) -> Connection:
"""Return the current 'bind'.
In "online" mode, this is the
Expand All @@ -959,5 +966,5 @@ def get_bind(self) -> "Connection":
"""
return self.get_context().bind # type: ignore[return-value]

def get_impl(self) -> "DefaultImpl":
def get_impl(self) -> DefaultImpl:
return self.get_context().impl

0 comments on commit cfe87f5

Please sign in to comment.