Skip to content

Commit

Permalink
Merge "Updated typing for self_group()" into rel_2_0
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzeek authored and Gerrit Code Review committed May 5, 2024
2 parents 2f335f2 + 24c73f9 commit 79474c3
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 33 deletions.
59 changes: 42 additions & 17 deletions lib/sqlalchemy/sql/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from ..util import HasMemoized_ro_memoized_attribute
from ..util import TypingOnly
from ..util.typing import Literal
from ..util.typing import ParamSpec
from ..util.typing import Self

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -1429,13 +1430,11 @@ def _non_anon_label(self) -> Optional[str]:
_alt_names: Sequence[str] = ()

@overload
def self_group(
self: ColumnElement[_T], against: Optional[OperatorType] = None
) -> ColumnElement[_T]: ...
def self_group(self, against: None = None) -> ColumnElement[_T]: ...

@overload
def self_group(
self: ColumnElement[Any], against: Optional[OperatorType] = None
self, against: Optional[OperatorType] = None
) -> ColumnElement[Any]: ...

def self_group(
Expand Down Expand Up @@ -2581,7 +2580,9 @@ def comparator(self):
# be using this method.
return self.type.comparator_factory(self) # type: ignore

def self_group(self, against=None):
def self_group(
self, against: Optional[OperatorType] = None
) -> Union[Self, Grouping[Any]]:
if against is operators.in_op:
return Grouping(self)
else:
Expand Down Expand Up @@ -2786,7 +2787,9 @@ def append(self, clause):
def _from_objects(self) -> List[FromClause]:
return list(itertools.chain(*[c._from_objects for c in self.clauses]))

def self_group(self, against=None):
def self_group(
self, against: Optional[OperatorType] = None
) -> Union[Self, Grouping[Any]]:
if self.group and operators.is_precedent(self.operator, against):
return Grouping(self)
else:
Expand All @@ -2809,7 +2812,9 @@ class OperatorExpression(ColumnElement[_T]):
def is_comparison(self):
return operators.is_comparison(self.operator)

def self_group(self, against=None):
def self_group(
self, against: Optional[OperatorType] = None
) -> Union[Self, Grouping[_T]]:
if (
self.group
and operators.is_precedent(self.operator, against)
Expand Down Expand Up @@ -3169,7 +3174,9 @@ def or_(
def _select_iterable(self) -> _SelectIterable:
return (self,)

def self_group(self, against=None):
def self_group(
self, against: Optional[OperatorType] = None
) -> Union[Self, Grouping[bool]]:
if not self.clauses:
return self
else:
Expand Down Expand Up @@ -3252,7 +3259,7 @@ def _bind_param(self, operator, obj, type_=None, expanding=False):
]
)

def self_group(self, against=None):
def self_group(self, against: Optional[OperatorType] = None) -> Self:
# Tuple is parenthesized by definition.
return self

Expand Down Expand Up @@ -3485,7 +3492,9 @@ def typed_expression(self):
def wrapped_column_expression(self):
return self.clause

def self_group(self, against=None):
def self_group(
self, against: Optional[OperatorType] = None
) -> TypeCoerce[_T]:
grouped = self.clause.self_group(against=against)
if grouped is not self.clause:
return TypeCoerce(grouped, self.type)
Expand Down Expand Up @@ -3700,7 +3709,9 @@ def _negate(self):
else:
return ClauseElement._negate(self)

def self_group(self, against=None):
def self_group(
self, against: Optional[OperatorType] = None
) -> Union[Self, Grouping[_T]]:
if self.operator and operators.is_precedent(self.operator, against):
return Grouping(self)
else:
Expand Down Expand Up @@ -3787,7 +3798,7 @@ def __init__(self, element, operator, negate):
def wrapped_column_expression(self):
return self.element

def self_group(self, against=None):
def self_group(self, against: Optional[OperatorType] = None) -> Self:
return self

def _negate(self):
Expand Down Expand Up @@ -3987,8 +3998,8 @@ def __init__(self, start, stop, step, _name=None):
)
self.type = type_api.NULLTYPE

def self_group(self, against=None):
assert against is operator.getitem
def self_group(self, against: Optional[OperatorType] = None) -> Self:
assert against is operator.getitem # type: ignore[comparison-overlap]
return self


Expand All @@ -4006,7 +4017,7 @@ class GroupedElement(DQLDMLClauseElement):

element: ClauseElement

def self_group(self, against=None):
def self_group(self, against: Optional[OperatorType] = None) -> Self:
return self

def _ungroup(self):
Expand Down Expand Up @@ -4070,6 +4081,12 @@ def __setstate__(self, state):
self.element = state["element"]
self.type = state["type"]

if TYPE_CHECKING:

def self_group(
self, against: Optional[OperatorType] = None
) -> Self: ...


class _OverrideBinds(Grouping[_T]):
"""used by cache_key->_apply_params_to_element to allow compilation /
Expand Down Expand Up @@ -4570,6 +4587,9 @@ def _make_proxy(
return c.key, c


_PS = ParamSpec("_PS")


class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]):
"""Represents a column label (AS).
Expand Down Expand Up @@ -4667,13 +4687,18 @@ def _order_by_label_element(self):
def element(self) -> ColumnElement[_T]:
return self._element.self_group(against=operators.as_)

def self_group(self, against=None):
def self_group(self, against: Optional[OperatorType] = None) -> Label[_T]:
return self._apply_to_inner(self._element.self_group, against=against)

def _negate(self):
return self._apply_to_inner(self._element._negate)

def _apply_to_inner(self, fn, *arg, **kw):
def _apply_to_inner(
self,
fn: Callable[_PS, ColumnElement[_T]],
*arg: _PS.args,
**kw: _PS.kwargs,
) -> Label[_T]:
sub_element = fn(*arg, **kw)
if sub_element is not self._element:
return Label(self.name, sub_element, type_=self.type)
Expand Down
35 changes: 19 additions & 16 deletions lib/sqlalchemy/sql/selectable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,6 @@ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
def self_group(
self, against: Optional[OperatorType] = None
) -> FromGrouping:
...
return FromGrouping(self)

@util.preload_module("sqlalchemy.sql.util")
Expand Down Expand Up @@ -2889,6 +2888,12 @@ def __getstate__(self) -> Dict[str, FromClause]:
def __setstate__(self, state: Dict[str, FromClause]) -> None:
self.element = state["element"]

if TYPE_CHECKING:

def self_group(
self, against: Optional[OperatorType] = None
) -> Self: ...


class NamedFromGrouping(FromGrouping, NamedFromClause):
"""represent a grouping of a named FROM clause
Expand All @@ -2899,6 +2904,12 @@ class NamedFromGrouping(FromGrouping, NamedFromClause):

inherit_cache = True

if TYPE_CHECKING:

def self_group(
self, against: Optional[OperatorType] = None
) -> Self: ...


class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
"""Represents a minimal "table" construct.
Expand Down Expand Up @@ -3312,6 +3323,12 @@ def _column_types(self) -> List[TypeEngine[Any]]:
def __clause_element__(self) -> ScalarValues:
return self

if TYPE_CHECKING:

def self_group(
self, against: Optional[OperatorType] = None
) -> Self: ...


class SelectBase(
roles.SelectStatementRole,
Expand Down Expand Up @@ -3684,7 +3701,6 @@ def select_statement(self) -> _SB:
return self.element

def self_group(self, against: Optional[OperatorType] = None) -> Self:
...
return self

if TYPE_CHECKING:
Expand Down Expand Up @@ -6324,7 +6340,6 @@ def _needs_parens_for_grouping(self) -> bool:
def self_group(
self, against: Optional[OperatorType] = None
) -> Union[SelectStatementGrouping[Self], Self]:
...
"""Return a 'grouping' construct as per the
:class:`_expression.ClauseElement` specification.
Expand Down Expand Up @@ -6516,19 +6531,7 @@ def where(self, crit: _ColumnExpressionArgument[bool]) -> Self:
self.element = cast("Select[Any]", self.element).where(crit)
return self

@overload
def self_group(
self: ScalarSelect[Any], against: Optional[OperatorType] = None
) -> ScalarSelect[Any]: ...

@overload
def self_group(
self: ColumnElement[Any], against: Optional[OperatorType] = None
) -> ColumnElement[Any]: ...

def self_group(
self, against: Optional[OperatorType] = None
) -> ColumnElement[Any]:
def self_group(self, against: Optional[OperatorType] = None) -> Self:
return self

if TYPE_CHECKING:
Expand Down
5 changes: 5 additions & 0 deletions test/typing/plain_files/sql/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,8 @@ class A(Base):
# op functions
t1 = operators.eq(A.id, 1)
select().where(t1)

# EXPECTED_TYPE: BinaryExpression[Any]
reveal_type(col.op("->>")("field"))
# EXPECTED_TYPE: Union[BinaryExpression[Any], Grouping[Any]]
reveal_type(col.op("->>")("field").self_group())

0 comments on commit 79474c3

Please sign in to comment.