diff --git a/alembic/command.py b/alembic/command.py index bbff75d1..162b3d0c 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -from typing import Callable from typing import List from typing import Optional from typing import TYPE_CHECKING @@ -15,6 +14,7 @@ if TYPE_CHECKING: from alembic.config import Config from alembic.script.base import Script + from .runtime.environment import ProcessRevisionDirectiveFn def list_templates(config): @@ -124,7 +124,7 @@ def revision( version_path: Optional[str] = None, rev_id: Optional[str] = None, depends_on: Optional[str] = None, - process_revision_directives: Callable = None, + process_revision_directives: Optional[ProcessRevisionDirectiveFn] = None, ) -> Union[Optional["Script"], List[Optional["Script"]]]: """Create a new revision file. @@ -243,9 +243,9 @@ def retrieve_migrations(rev, context): def merge( config: "Config", revisions: str, - message: str = None, - branch_label: str = None, - rev_id: str = None, + message: Optional[str] = None, + branch_label: Optional[str] = None, + rev_id: Optional[str] = None, ) -> Optional["Script"]: """Merge two revisions together. Creates a new migration file. diff --git a/alembic/config.py b/alembic/config.py index dcfb9288..8464407d 100644 --- a/alembic/config.py +++ b/alembic/config.py @@ -99,7 +99,7 @@ def __init__( stdout: TextIO = sys.stdout, cmd_opts: Optional[Namespace] = None, config_args: util.immutabledict = util.immutabledict(), - attributes: dict = None, + attributes: Optional[dict] = None, ) -> None: """Construct a new :class:`.Config`""" self.config_file_name = file_ diff --git a/alembic/context.pyi b/alembic/context.pyi index a2e53994..9871fadd 100644 --- a/alembic/context.pyi +++ b/alembic/context.pyi @@ -19,13 +19,13 @@ if TYPE_CHECKING: from sqlalchemy.sql.schema import MetaData from .config import Config + from .operations import MigrateOperation from .runtime.migration import _ProxyTransaction from .runtime.migration import MigrationContext from .script import ScriptDirectory - ### end imports ### -def begin_transaction() -> Union["_ProxyTransaction", ContextManager]: +def begin_transaction() -> Union[_ProxyTransaction, ContextManager]: """Return a context manager that will enclose an operation within a "transaction", as defined by the environment's offline @@ -75,29 +75,33 @@ def configure( connection: Optional[Connection] = None, url: Optional[str] = None, dialect_name: Optional[str] = None, - dialect_opts: Optional[dict] = None, + dialect_opts: Optional[Dict[str, Any]] = None, transactional_ddl: Optional[bool] = None, transaction_per_migration: bool = False, output_buffer: Optional[TextIO] = None, starting_rev: Optional[str] = None, tag: Optional[str] = None, - template_args: Optional[dict] = None, + template_args: Optional[Dict[str, Any]] = None, render_as_batch: bool = False, target_metadata: Optional[MetaData] = None, - include_name: Optional[Callable] = None, - include_object: Optional[Callable] = None, + include_name: Optional[Callable[..., bool]] = None, + include_object: Optional[Callable[..., bool]] = None, include_schemas: bool = False, - process_revision_directives: Optional[Callable] = None, + process_revision_directives: Optional[ + Callable[ + [MigrationContext, Tuple[str, str], List[MigrateOperation]], None + ] + ] = None, compare_type: bool = False, compare_server_default: bool = False, - render_item: Optional[Callable] = None, + render_item: Optional[Callable[..., bool]] = None, literal_binds: bool = False, upgrade_token: str = "upgrades", downgrade_token: str = "downgrades", alembic_module_prefix: str = "op.", sqlalchemy_module_prefix: str = "sa.", user_module_prefix: Optional[str] = None, - on_version_apply: Optional[Callable] = None, + on_version_apply: Optional[Callable[..., None]] = None, **kw: Any, ) -> None: """Configure a :class:`.MigrationContext` within this diff --git a/alembic/op.pyi b/alembic/op.pyi index 490d7146..4e80a00a 100644 --- a/alembic/op.pyi +++ b/alembic/op.pyi @@ -35,8 +35,8 @@ if TYPE_CHECKING: from .operations.ops import BatchOperations from .operations.ops import MigrateOperation + from .runtime.migration import MigrationContext from .util.sqla_compat import _literal_bindparam - ### end imports ### def add_column( @@ -1082,13 +1082,13 @@ def get_bind() -> Connection: """ -def get_context(): +def get_context() -> MigrationContext: """Return the :class:`.MigrationContext` object that's currently in use. """ -def implementation_for(op_cls: Any) -> Callable: +def implementation_for(op_cls: Any) -> Callable[..., Any]: """Register an implementation for a given :class:`.MigrateOperation`. This is part of the operation extensibility API. @@ -1101,7 +1101,7 @@ def implementation_for(op_cls: Any) -> Callable: def inline_literal( value: Union[str, int], type_: None = None -) -> "_literal_bindparam": +) -> _literal_bindparam: """Produce an 'inline literal' expression, suitable for using in an INSERT, UPDATE, or DELETE statement. @@ -1152,7 +1152,7 @@ def invoke(operation: MigrateOperation) -> Any: def register_operation( name: str, sourcename: Optional[str] = None -) -> Callable: +) -> Callable[..., Any]: """Register a new operation for this class. This method is normally used to add new operations diff --git a/alembic/operations/base.py b/alembic/operations/base.py index 535dff0f..2178998a 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -25,6 +25,7 @@ from ..util.compat import formatannotation_fwdref from ..util.compat import inspect_formatargspec from ..util.compat import inspect_getfullargspec +from ..util.sqla_compat import _literal_bindparam NoneType = type(None) @@ -39,7 +40,6 @@ from .ops import MigrateOperation from ..ddl import DefaultImpl from ..runtime.migration import MigrationContext - from ..util.sqla_compat import _literal_bindparam __all__ = ("Operations", "BatchOperations") @@ -80,8 +80,8 @@ class Operations(util.ModuleClsProxy): def __init__( self, - migration_context: "MigrationContext", - impl: Optional["BatchOperationsImpl"] = None, + migration_context: MigrationContext, + impl: Optional[BatchOperationsImpl] = None, ) -> None: """Construct a new :class:`.Operations` @@ -100,7 +100,7 @@ def __init__( @classmethod def register_operation( cls, name: str, sourcename: Optional[str] = None - ) -> Callable: + ) -> Callable[..., Any]: """Register a new operation for this class. This method is normally used to add new operations @@ -188,7 +188,7 @@ def %(name)s%(args)s: return register @classmethod - def implementation_for(cls, op_cls: Any) -> Callable: + def implementation_for(cls, op_cls: Any) -> Callable[..., Any]: """Register an implementation for a given :class:`.MigrateOperation`. This is part of the operation extensibility API. @@ -208,8 +208,8 @@ def decorate(fn): @classmethod @contextmanager def context( - cls, migration_context: "MigrationContext" - ) -> Iterator["Operations"]: + cls, migration_context: MigrationContext + ) -> Iterator[Operations]: op = Operations(migration_context) op._install_proxy() yield op @@ -382,7 +382,7 @@ def batch_alter_table( yield batch_op impl.flush() - def get_context(self): + def get_context(self) -> MigrationContext: """Return the :class:`.MigrationContext` object that's currently in use. @@ -390,7 +390,7 @@ def get_context(self): return self.migration_context - def invoke(self, operation: "MigrateOperation") -> Any: + def invoke(self, operation: MigrateOperation) -> Any: """Given a :class:`.MigrateOperation`, invoke it in terms of this :class:`.Operations` instance. @@ -400,7 +400,7 @@ def invoke(self, operation: "MigrateOperation") -> Any: ) return fn(self, operation) - def f(self, name: str) -> "conv": + def f(self, name: str) -> conv: """Indicate a string name that has already had a naming convention applied to it. @@ -440,7 +440,7 @@ def f(self, name: str) -> "conv": def inline_literal( self, value: Union[str, int], type_: None = None - ) -> "_literal_bindparam": + ) -> _literal_bindparam: r"""Produce an 'inline literal' expression, suitable for using in an INSERT, UPDATE, or DELETE statement. @@ -484,7 +484,7 @@ def inline_literal( """ return sqla_compat._literal_bindparam(None, value, type_=type_) - def get_bind(self) -> "Connection": + def get_bind(self) -> Connection: """Return the current 'bind'. Under normal circumstances, this is the diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index 3cec5b1c..6dbbcc31 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING from typing import Union +from .migration import _ProxyTransaction from .migration import MigrationContext from .. import util from ..operations import Operations @@ -23,13 +24,17 @@ from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.schema import MetaData - from .migration import _ProxyTransaction from ..config import Config from ..ddl import DefaultImpl + from ..operations.ops import MigrateOperation from ..script.base import ScriptDirectory _RevNumber = Optional[Union[str, Tuple[str, ...]]] +ProcessRevisionDirectiveFn = Callable[ + [MigrationContext, Tuple[str, str], List["MigrateOperation"]], None +] + class EnvironmentContext(util.ModuleClsProxy): @@ -109,7 +114,7 @@ def my_function(rev, context): """ def __init__( - self, config: "Config", script: "ScriptDirectory", **kw: Any + self, config: Config, script: ScriptDirectory, **kw: Any ) -> None: r"""Construct a new :class:`.EnvironmentContext`. @@ -124,7 +129,7 @@ def __init__( self.script = script self.context_opts = kw - def __enter__(self) -> "EnvironmentContext": + def __enter__(self) -> EnvironmentContext: """Establish a context which provides a :class:`.EnvironmentContext` object to env.py scripts. @@ -265,13 +270,13 @@ def get_tag_argument(self) -> Optional[str]: @overload def get_x_argument( # type:ignore[misc] - self, as_dictionary: "Literal[False]" = ... + self, as_dictionary: Literal[False] = ... ) -> List[str]: ... @overload def get_x_argument( # type:ignore[misc] - self, as_dictionary: "Literal[True]" = ... + self, as_dictionary: Literal[True] = ... ) -> Dict[str, str]: ... @@ -326,32 +331,34 @@ def get_x_argument( def configure( self, - connection: Optional["Connection"] = None, + connection: Optional[Connection] = None, url: Optional[str] = None, dialect_name: Optional[str] = None, - dialect_opts: Optional[dict] = None, + dialect_opts: Optional[Dict[str, Any]] = None, transactional_ddl: Optional[bool] = None, transaction_per_migration: bool = False, output_buffer: Optional[TextIO] = None, starting_rev: Optional[str] = None, tag: Optional[str] = None, - template_args: Optional[dict] = None, + template_args: Optional[Dict[str, Any]] = None, render_as_batch: bool = False, - target_metadata: Optional["MetaData"] = None, - include_name: Optional[Callable] = None, - include_object: Optional[Callable] = None, + target_metadata: Optional[MetaData] = None, + include_name: Optional[Callable[..., bool]] = None, + include_object: Optional[Callable[..., bool]] = None, include_schemas: bool = False, - process_revision_directives: Optional[Callable] = None, + process_revision_directives: Optional[ + ProcessRevisionDirectiveFn + ] = None, compare_type: bool = False, compare_server_default: bool = False, - render_item: Optional[Callable] = None, + render_item: Optional[Callable[..., bool]] = None, literal_binds: bool = False, upgrade_token: str = "upgrades", downgrade_token: str = "downgrades", alembic_module_prefix: str = "op.", sqlalchemy_module_prefix: str = "sa.", user_module_prefix: Optional[str] = None, - on_version_apply: Optional[Callable] = None, + on_version_apply: Optional[Callable[..., None]] = None, **kw: Any, ) -> None: """Configure a :class:`.MigrationContext` within this @@ -859,7 +866,7 @@ def run_migrations(self, **kw: Any) -> None: def execute( self, - sql: Union["ClauseElement", str], + sql: Union[ClauseElement, str], execution_options: Optional[dict] = None, ) -> None: """Execute the given SQL using the current change context. @@ -888,7 +895,7 @@ def static_output(self, text: str) -> None: def begin_transaction( self, - ) -> Union["_ProxyTransaction", ContextManager]: + ) -> Union[_ProxyTransaction, ContextManager]: """Return a context manager that will enclose an operation within a "transaction", as defined by the environment's offline @@ -934,7 +941,7 @@ def begin_transaction( return self.get_context().begin_transaction() - def get_context(self) -> "MigrationContext": + def get_context(self) -> MigrationContext: """Return the current :class:`.MigrationContext` object. If :meth:`.EnvironmentContext.configure` has not been @@ -946,7 +953,7 @@ def get_context(self) -> "MigrationContext": raise Exception("No context has been configured yet.") return self._migration_context - def get_bind(self) -> "Connection": + def get_bind(self) -> Connection: """Return the current 'bind'. In "online" mode, this is the @@ -959,5 +966,5 @@ def get_bind(self) -> "Connection": """ return self.get_context().bind # type: ignore[return-value] - def get_impl(self) -> "DefaultImpl": + def get_impl(self) -> DefaultImpl: return self.get_context().impl diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py index b6ceb0cd..ff2687ce 100644 --- a/alembic/util/langhelpers.py +++ b/alembic/util/langhelpers.py @@ -198,7 +198,7 @@ def to_tuple(x: Any, default: tuple) -> tuple: @overload -def to_tuple(x: None, default: _T = None) -> _T: +def to_tuple(x: None, default: Optional[_T] = None) -> _T: ... diff --git a/docs/build/unreleased/1110.rst b/docs/build/unreleased/1110.rst new file mode 100644 index 00000000..fe9cfffe --- /dev/null +++ b/docs/build/unreleased/1110.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, typing + :tickets: 1110 + + Fixed typing issue where :paramref:`.revision.process_revision_directives` + was not fully typed; additionally ensured all ``Callable`` and ``Dict`` + arguments to :meth:`.EnvironmentContext.configure` include parameters in + the typing declaration. + + Additionally updated the codebase for Mypy 0.990 compliance. \ No newline at end of file diff --git a/tools/write_pyi.py b/tools/write_pyi.py index 52fac3c1..e5112fdb 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -125,7 +125,7 @@ def generate_pyi_for_proxy( def _generate_stub_for_attr(cls, name, printer, env): try: annotations = typing.get_type_hints(cls, env) - except NameError as e: + except NameError: annotations = cls.__annotations__ type_ = annotations.get(name, "Any") if isinstance(type_, str) and type_[0] in "'\"": @@ -155,10 +155,7 @@ def _formatannotation(annotation, base_module=None): if getattr(annotation, "__module__", None) == "typing": retval = repr(annotation).replace("typing.", "") elif isinstance(annotation, type): - if annotation.__module__ in ("builtins", base_module): - retval = annotation.__qualname__ - else: - retval = annotation.__module__ + "." + annotation.__qualname__ + retval = annotation.__qualname__ else: retval = annotation @@ -184,6 +181,7 @@ def {name}{argspec}: '''{fn.__doc__}''' """ ) + printer.write_indented_block(func_text)