Skip to content

Commit

Permalink
Merge "Add overload for ColumnCollection.get(col, default)" into rel_2_0
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzeek authored and Gerrit Code Review committed May 5, 2024
2 parents cbc2b9c + e2d4385 commit 2f335f2
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 31 deletions.
2 changes: 1 addition & 1 deletion lib/sqlalchemy/orm/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def declarative_scan(
supercls_mapper = class_mapper(decl_scan.inherits, False)

colname = column.name if column.name is not None else key
column = self.column = supercls_mapper.local_table.c.get( # type: ignore # noqa: E501
column = self.column = supercls_mapper.local_table.c.get( # type: ignore[assignment] # noqa: E501
colname, column
)

Expand Down
30 changes: 17 additions & 13 deletions lib/sqlalchemy/sql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
from .elements import ClauseList
from .elements import ColumnClause # noqa
from .elements import ColumnElement
from .elements import KeyedColumnElement
from .elements import NamedColumn
from .elements import SQLCoreOperations
from .elements import TextClause
Expand Down Expand Up @@ -1354,7 +1353,7 @@ class _SentinelColumnCharacterization(NamedTuple):
_COLKEY = TypeVar("_COLKEY", Union[None, str], str)

_COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True)
_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]")
_COL = TypeVar("_COL", bound="ColumnElement[Any]")


class _ColumnMetrics(Generic[_COL_co]):
Expand Down Expand Up @@ -1642,9 +1641,15 @@ def compare(self, other: ColumnCollection[Any, Any]) -> bool:
def __eq__(self, other: Any) -> bool:
return self.compare(other)

@overload
def get(self, key: str, default: None = None) -> Optional[_COL_co]: ...

@overload
def get(self, key: str, default: _COL) -> Union[_COL_co, _COL]: ...

def get(
self, key: str, default: Optional[_COL_co] = None
) -> Optional[_COL_co]:
self, key: str, default: Optional[_COL] = None
) -> Optional[Union[_COL_co, _COL]]:
"""Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object
based on a string key name from this
:class:`_expression.ColumnCollection`."""
Expand Down Expand Up @@ -1925,16 +1930,15 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
"""

def add(
self, column: ColumnElement[Any], key: Optional[str] = None
def add( # type: ignore[override]
self, column: _NAMEDCOL, key: Optional[str] = None
) -> None:
named_column = cast(_NAMEDCOL, column)
if key is not None and named_column.key != key:
if key is not None and column.key != key:
raise exc.ArgumentError(
"DedupeColumnCollection requires columns be under "
"the same key as their .key"
)
key = named_column.key
key = column.key

if key is None:
raise exc.ArgumentError(
Expand All @@ -1944,17 +1948,17 @@ def add(
if key in self._index:
existing = self._index[key][1]

if existing is named_column:
if existing is column:
return

self.replace(named_column)
self.replace(column)

# pop out memoized proxy_set as this
# operation may very well be occurring
# in a _make_proxy operation
util.memoized_property.reset(named_column, "proxy_set")
util.memoized_property.reset(column, "proxy_set")
else:
self._append_new_column(key, named_column)
self._append_new_column(key, column)

def _append_new_column(self, key: str, named_column: _NAMEDCOL) -> None:
l = len(self._collection)
Expand Down
10 changes: 10 additions & 0 deletions test/sql/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sqlalchemy.sql import column
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql import roles
from sqlalchemy.sql import table
from sqlalchemy.sql import util as sql_util
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
Expand Down Expand Up @@ -174,3 +175,12 @@ def test_unwrap_order_by(self, expr, expected):

for a, b in zip_longest(unwrapped, expected):
assert a is not None and a.compare(b)

def test_column_collection_get(self):
col_id = column("id", Integer)
col_alt = column("alt", Integer)
table1 = table("mytable", col_id)

is_(table1.columns.get("id"), col_id)
is_(table1.columns.get("alt"), None)
is_(table1.columns.get("alt", col_alt), col_alt)
37 changes: 37 additions & 0 deletions test/typing/plain_files/sql/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any

from sqlalchemy import column
from sqlalchemy import ColumnElement
from sqlalchemy import Integer
from sqlalchemy import literal
from sqlalchemy import table


def test_col_accessors() -> None:
t = table("t", column("a"), column("b"), column("c"))

t.c.a
t.c["a"]

t.c[2]
t.c[0, 1]
t.c[0, 1, "b", "c"]
t.c[(0, 1, "b", "c")]

t.c[:-1]
t.c[0:2]


def test_col_get() -> None:
col_id = column("id", Integer)
col_alt = column("alt", Integer)
tbl = table("mytable", col_id)

# EXPECTED_TYPE: Union[ColumnClause[Any], None]
reveal_type(tbl.c.get("id"))
# EXPECTED_TYPE: Union[ColumnClause[Any], None]
reveal_type(tbl.c.get("id", None))
# EXPECTED_TYPE: Union[ColumnClause[Any], ColumnClause[int]]
reveal_type(tbl.c.get("alt", col_alt))
col: ColumnElement[Any] = tbl.c.get("foo", literal("bar"))
print(col)
17 changes: 0 additions & 17 deletions test/typing/plain_files/sql/selectables.py

This file was deleted.

0 comments on commit 2f335f2

Please sign in to comment.