Skip to content

Commit

Permalink
Merge "Some improvements to the cache key generation speed" into main
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzeek authored and Gerrit Code Review committed May 17, 2024
2 parents e69ca16 + 6fbf001 commit 83bd285
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 78 deletions.
47 changes: 37 additions & 10 deletions lib/sqlalchemy/sql/_util_cy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def _is_compiled() -> bool:

# END GENERATED CYTHON IMPORT

if cython.compiled:
from cython.cimports.sqlalchemy.util._collections_cy import _get_id
else:
_get_id = id


@cython.cclass
class prefix_anon_map(Dict[str, str]):
Expand Down Expand Up @@ -67,7 +72,7 @@ def __missing__(self, key: str, /) -> str:
class anon_map(
Dict[
Union[int, str, "Literal[CacheConst.NO_CACHE]"],
Union[Literal[True], str],
Union[int, Literal[True]],
]
):
"""A map that creates new keys for missing key access.
Expand All @@ -90,19 +95,41 @@ def __cinit__(self): # type: ignore[no-untyped-def]
else:
_index: int = 0 # type: ignore[no-redef]

def get_anon(self, obj: object, /) -> Tuple[str, bool]:
@cython.cfunc # type:ignore[misc]
@cython.inline # type:ignore[misc]
def _add_missing(
self: anon_map, key: Union[int, str, "Literal[CacheConst.NO_CACHE]"], /
) -> int:
val: int = self._index
self._index += 1
self_dict: dict = self # type: ignore[type-arg]
self_dict[key] = val
return val

def get_anon(self: anon_map, obj: object, /) -> Tuple[int, bool]:
self_dict: dict = self # type: ignore[type-arg]

idself = id(obj)
idself: int = _get_id(obj)
if idself in self_dict:
return self_dict[idself], True
else:
return self.__missing__(idself), False
return self._add_missing(idself), False

def __missing__(self, key: Union[int, str], /) -> str:
val: str
self_dict: dict = self # type: ignore[type-arg]
if cython.compiled:

self_dict[key] = val = str(self._index)
self._index += 1
return val
def __getitem__(
self: anon_map,
key: Union[int, str, "Literal[CacheConst.NO_CACHE]"],
/,
) -> Union[int, Literal[True]]:
self_dict: dict = self # type: ignore[type-arg]

if key in self_dict:
return self_dict[key] # type:ignore[no-any-return]
else:
return self._add_missing(key) # type:ignore[no-any-return]

def __missing__(
self: anon_map, key: Union[int, str, "Literal[CacheConst.NO_CACHE]"], /
) -> int:
return self._add_missing(key) # type:ignore[no-any-return]
11 changes: 7 additions & 4 deletions lib/sqlalchemy/sql/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

from operator import itemgetter
import typing
from typing import Any
from typing import Callable
Expand Down Expand Up @@ -103,14 +104,16 @@ def _gen_annotations_cache_key(
else value
),
)
for key, value in [
(key, self._annotations[key])
for key in sorted(self._annotations)
]
for key, value in sorted(
self._annotations.items(), key=_get_item0
)
),
)


_get_item0 = itemgetter(0)


class SupportsWrappingAnnotations(SupportsAnnotations):
__slots__ = ()

Expand Down
6 changes: 3 additions & 3 deletions lib/sqlalchemy/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3683,7 +3683,7 @@ def visit_bindparam(
bind_expression_template=wrapped,
**kwargs,
)
return "(%s)" % ret
return f"({ret})"

return wrapped

Expand All @@ -3702,7 +3702,7 @@ def visit_bindparam(
bindparam, within_columns_clause=True, **kwargs
)
if bindparam.expanding:
ret = "(%s)" % ret
ret = f"({ret})"
return ret

name = self._truncate_bindparam(bindparam)
Expand Down Expand Up @@ -3799,7 +3799,7 @@ def visit_bindparam(
)

if bindparam.expanding:
ret = "(%s)" % ret
ret = f"({ret})"

return ret

Expand Down
103 changes: 61 additions & 42 deletions lib/sqlalchemy/sql/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,9 +1736,8 @@ def _anon_label(
seed = seed + "_"

if isinstance(seed, _anonymous_label):
return _anonymous_label.safe_construct(
hash_value, "", enclosing_label=seed
)
# NOTE: the space after the hash is required
return _anonymous_label(f"{seed}%({hash_value} )s")

return _anonymous_label.safe_construct(hash_value, seed or "anon")

Expand Down Expand Up @@ -1941,12 +1940,12 @@ class BindParameter(roles.InElementRole, KeyedColumnElement[_T]):
]

key: str
_anon_map_key: Optional[str] = None
type: TypeEngine[_T]
value: Optional[_T]

_is_crud = False
_is_bind_parameter = True
_key_is_anon = False

# bindparam implements its own _gen_cache_key() method however
# we check subclasses for this flag, else no cache key is generated
Expand Down Expand Up @@ -1977,22 +1976,24 @@ def __init__(
key = quoted_name.construct(key, quote)

if unique:
self.key = _anonymous_label.safe_construct(
id(self),
(
key
if key is not None
and not isinstance(key, _anonymous_label)
else "param"
),
sanitize_key=True,
self.key, self._anon_map_key = (
_anonymous_label.safe_construct_with_key(
id(self),
(
key
if key is not None
and not isinstance(key, _anonymous_label)
else "param"
),
sanitize_key=True,
)
)
self._key_is_anon = True
elif key:
self.key = key
else:
self.key = _anonymous_label.safe_construct(id(self), "param")
self._key_is_anon = True
self.key, self._anon_map_key = (
_anonymous_label.safe_construct_with_key(id(self), "param")
)

# identifying key that won't change across
# clones, used to identify the bind's logical
Expand Down Expand Up @@ -2081,7 +2082,7 @@ def effective_value(self) -> Optional[_T]:
else:
return self.value

def render_literal_execute(self) -> BindParameter[_T]:
def render_literal_execute(self) -> Self:
"""Produce a copy of this bound parameter that will enable the
:paramref:`_sql.BindParameter.literal_execute` flag.
Expand All @@ -2102,7 +2103,7 @@ def render_literal_execute(self) -> BindParameter[_T]:
:ref:`engine_thirdparty_caching`
"""
c = ClauseElement._clone(self)
c: Self = ClauseElement._clone(self)
c.literal_execute = True
return c

Expand All @@ -2115,12 +2116,12 @@ def _negate_in_binary(self, negated_op, original_op):
return self

def _with_binary_element_type(self, type_):
c = ClauseElement._clone(self)
c: Self = ClauseElement._clone(self) # type: ignore[assignment]
c.type = type_
return c

def _clone(self, maintain_key: bool = False, **kw: Any) -> Self:
c = ClauseElement._clone(self, **kw)
c: Self = ClauseElement._clone(self, **kw)
# ensure all the BindParameter objects stay in cloned set.
# in #7823, we changed "clone" so that a clone only keeps a reference
# to the "original" element, since for column correspondence, that's
Expand All @@ -2131,7 +2132,7 @@ def _clone(self, maintain_key: bool = False, **kw: Any) -> Self:
# forward.
c._cloned_set.update(self._cloned_set)
if not maintain_key and self.unique:
c.key = _anonymous_label.safe_construct(
c.key, c._anon_map_key = _anonymous_label.safe_construct_with_key(
id(c), c._orig_key or "param", sanitize_key=True
)
return c
Expand All @@ -2155,15 +2156,21 @@ def _gen_cache_key(self, anon_map, bindparams):
id_,
self.__class__,
self.type._static_cache_key,
self.key % anon_map if self._key_is_anon else self.key,
(
anon_map[self._anon_map_key]
if self._anon_map_key is not None
else self.key
),
self.literal_execute,
)

def _convert_to_unique(self):
if not self.unique:
self.unique = True
self.key = _anonymous_label.safe_construct(
id(self), self._orig_key or "param", sanitize_key=True
self.key, self._anon_map_key = (
_anonymous_label.safe_construct_with_key(
id(self), self._orig_key or "param", sanitize_key=True
)
)

def __getstate__(self):
Expand All @@ -2179,9 +2186,10 @@ def __getstate__(self):

def __setstate__(self, state):
if state.get("unique", False):
state["key"] = _anonymous_label.safe_construct(
anon_and_key = _anonymous_label.safe_construct_with_key(
id(self), state.get("_orig_key", "param"), sanitize_key=True
)
state["key"], state["_anon_map_key"] = anon_and_key
self.__dict__.update(state)

def __repr__(self):
Expand Down Expand Up @@ -4939,10 +4947,12 @@ def _gen_tq_label(
return None
elif t is not None and is_named_from_clause(t):
if has_schema_attr(t) and t.schema:
label = t.schema.replace(".", "_") + "_" + t.name + "_" + name
label = (
t.schema.replace(".", "_") + "_" + t.name + ("_" + name)
)
else:
assert not TYPE_CHECKING or isinstance(t, NamedFromClause)
label = t.name + "_" + name
label = t.name + ("_" + name)

# propagate name quoting rules for labels.
if is_quoted_name(name) and name.quote is not None:
Expand All @@ -4969,7 +4979,7 @@ def _gen_tq_label(
_label = label
counter = 1
while _label in t.c:
_label = label + "_" + str(counter)
_label = label + f"_{counter}"
counter += 1
label = _label

Expand Down Expand Up @@ -5370,6 +5380,7 @@ class conv(_truncated_label):
# _truncated_identifier() sequence in a custom
# compiler
_generated_label = _truncated_label
_anonymous_label_escape = re.compile(r"[%\(\) \$]+")


class _anonymous_label(_truncated_label):
Expand All @@ -5378,29 +5389,37 @@ class _anonymous_label(_truncated_label):

__slots__ = ()

@classmethod
def safe_construct_with_key(
cls, seed: int, body: str, sanitize_key: bool = False
) -> typing_Tuple[_anonymous_label, str]:
# need to escape chars that interfere with format
# strings in any case, issue #8724
body = _anonymous_label_escape.sub("_", body)

if sanitize_key:
# sanitize_key is then an extra step used by BindParameter
body = body.strip("_")

key = f"{seed} {body.replace('%', '%%')}"
label = _anonymous_label(f"%({key})s")
return label, key

@classmethod
def safe_construct(
cls,
seed: int,
body: str,
enclosing_label: Optional[str] = None,
sanitize_key: bool = False,
cls, seed: int, body: str, sanitize_key: bool = False
) -> _anonymous_label:
# need to escape chars that interfere with format
# strings in any case, issue #8724
body = re.sub(r"[%\(\) \$]+", "_", body)
body = _anonymous_label_escape.sub("_", body)

if sanitize_key:
# sanitize_key is then an extra step used by BindParameter
body = body.strip("_")

label = "%%(%d %s)s" % (seed, body.replace("%", "%%"))
if enclosing_label:
label = "%s%s" % (enclosing_label, label)
return _anonymous_label(f"%({seed} {body.replace('%', '%%')})s")

return _anonymous_label(label)

def __add__(self, other):
def __add__(self, other: str) -> _anonymous_label:
if "%" in other and not isinstance(other, _anonymous_label):
other = str(other).replace("%", "%%")
else:
Expand All @@ -5413,7 +5432,7 @@ def __add__(self, other):
)
)

def __radd__(self, other):
def __radd__(self, other: str) -> _anonymous_label:
if "%" in other and not isinstance(other, _anonymous_label):
other = str(other).replace("%", "%%")
else:
Expand All @@ -5426,7 +5445,7 @@ def __radd__(self, other):
)
)

def apply_map(self, map_):
def apply_map(self, map_: Mapping[str, Any]) -> str:
if self.quote is not None:
# preserve quoting only if necessary
return quoted_name(self % map_, self.quote)
Expand Down
8 changes: 8 additions & 0 deletions lib/sqlalchemy/util/_collections_cy.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# util/_collections_cy.pxd
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php

cdef unsigned long long _get_id(item: object)

0 comments on commit 83bd285

Please sign in to comment.