summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/__init__.py3
-rw-r--r--lib/sqlalchemy/sql/annotation.py9
-rw-r--r--lib/sqlalchemy/sql/base.py7
-rw-r--r--lib/sqlalchemy/sql/compiler.py54
-rw-r--r--lib/sqlalchemy/sql/ddl.py3
-rw-r--r--lib/sqlalchemy/sql/dml.py9
-rw-r--r--lib/sqlalchemy/sql/elements.py138
-rw-r--r--lib/sqlalchemy/sql/functions.py39
-rw-r--r--lib/sqlalchemy/sql/schema.py2
-rw-r--r--lib/sqlalchemy/sql/selectable.py32
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py24
-rw-r--r--lib/sqlalchemy/sql/traversals.py129
-rw-r--r--lib/sqlalchemy/sql/type_api.py7
-rw-r--r--lib/sqlalchemy/sql/visitors.py19
14 files changed, 338 insertions, 137 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index 78de80734..a25c1b083 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -100,6 +100,7 @@ def __go(lcls):
from .elements import AnnotatedColumnElement
from .elements import ClauseList # noqa
from .selectable import AnnotatedFromClause # noqa
+ from .traversals import _preconfigure_traversals
from . import base
from . import coercions
@@ -122,6 +123,8 @@ def __go(lcls):
_prepare_annotations(FromClause, AnnotatedFromClause)
_prepare_annotations(ClauseList, Annotated)
+ _preconfigure_traversals(ClauseElement)
+
_sa_util.preloaded.import_prefix("sqlalchemy.sql")
from . import naming # noqa
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index 08ed121d3..8a0d6ec28 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -338,6 +338,15 @@ def _new_annotation_type(cls, base_cls):
anno_cls._traverse_internals = list(cls._traverse_internals) + [
("_annotations", InternalTraversal.dp_annotations_key)
]
+ elif cls.__dict__.get("inherit_cache", False):
+ anno_cls._traverse_internals = list(cls._traverse_internals) + [
+ ("_annotations", InternalTraversal.dp_annotations_key)
+ ]
+
+ # some classes include this even if they have traverse_internals
+ # e.g. BindParameter, add it if present.
+ if cls.__dict__.get("inherit_cache", False):
+ anno_cls.inherit_cache = True
anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 5dd3b519a..5f2ce8f14 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -624,19 +624,14 @@ class Executable(Generative):
_bind = None
_with_options = ()
_with_context_options = ()
- _cache_enable = True
_executable_traverse_internals = [
("_with_options", ExtendedInternalTraversal.dp_has_cache_key_list),
("_with_context_options", ExtendedInternalTraversal.dp_plain_obj),
- ("_cache_enable", ExtendedInternalTraversal.dp_plain_obj),
+ ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs),
]
@_generative
- def _disable_caching(self):
- self._cache_enable = HasCacheKey()
-
- @_generative
def options(self, *options):
"""Apply options to this statement.
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 2519438d1..61178291a 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -373,6 +373,8 @@ class Compiled(object):
_cached_metadata = None
+ _result_columns = None
+
schema_translate_map = None
execution_options = util.immutabledict()
@@ -433,7 +435,6 @@ class Compiled(object):
self,
dialect,
statement,
- bind=None,
schema_translate_map=None,
render_schema_translate=False,
compile_kwargs=util.immutabledict(),
@@ -463,7 +464,6 @@ class Compiled(object):
"""
self.dialect = dialect
- self.bind = bind
self.preparer = self.dialect.identifier_preparer
if schema_translate_map:
self.schema_translate_map = schema_translate_map
@@ -527,24 +527,6 @@ class Compiled(object):
"""Return the bind params for this compiled object."""
return self.construct_params()
- def execute(self, *multiparams, **params):
- """Execute this compiled object."""
-
- e = self.bind
- if e is None:
- raise exc.UnboundExecutionError(
- "This Compiled object is not bound to any Engine "
- "or Connection.",
- code="2afi",
- )
- return e._execute_compiled(self, multiparams, params)
-
- def scalar(self, *multiparams, **params):
- """Execute this compiled object and return the result's
- scalar value."""
-
- return self.execute(*multiparams, **params).scalar()
-
class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
"""Produces DDL specification for TypeEngine objects."""
@@ -687,6 +669,13 @@ class SQLCompiler(Compiled):
insert_prefetch = update_prefetch = ()
+ _cache_key_bind_match = None
+ """a mapping that will relate the BindParameter object we compile
+ to those that are part of the extracted collection of parameters
+ in the cache key, if we were given a cache key.
+
+ """
+
def __init__(
self,
dialect,
@@ -717,6 +706,9 @@ class SQLCompiler(Compiled):
self.cache_key = cache_key
+ if cache_key:
+ self._cache_key_bind_match = {b: b for b in cache_key[1]}
+
# compile INSERT/UPDATE defaults/sequences inlined (no pre-
# execute)
self.inline = inline or getattr(statement, "_inline", False)
@@ -875,8 +867,9 @@ class SQLCompiler(Compiled):
replace_context=err,
)
+ ckbm = self._cache_key_bind_match
resolved_extracted = {
- b.key: extracted
+ ckbm[b]: extracted
for b, extracted in zip(orig_extracted, extracted_parameters)
}
else:
@@ -907,7 +900,7 @@ class SQLCompiler(Compiled):
else:
if resolved_extracted:
value_param = resolved_extracted.get(
- bindparam.key, bindparam
+ bindparam, bindparam
)
else:
value_param = bindparam
@@ -936,9 +929,7 @@ class SQLCompiler(Compiled):
)
if resolved_extracted:
- value_param = resolved_extracted.get(
- bindparam.key, bindparam
- )
+ value_param = resolved_extracted.get(bindparam, bindparam)
else:
value_param = bindparam
@@ -2021,6 +2012,19 @@ class SQLCompiler(Compiled):
)
self.binds[bindparam.key] = self.binds[name] = bindparam
+
+ # if we are given a cache key that we're going to match against,
+ # relate the bindparam here to one that is most likely present
+ # in the "extracted params" portion of the cache key. this is used
+ # to set up a positional mapping that is used to determine the
+ # correct parameters for a subsequent use of this compiled with
+ # a different set of parameter values. here, we accommodate for
+ # parameters that may have been cloned both before and after the cache
+ # key was been generated.
+ ckbm = self._cache_key_bind_match
+ if ckbm:
+ ckbm.update({bp: bindparam for bp in bindparam._cloned_set})
+
if bindparam.isoutparam:
self.has_out_parameters = True
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
index 569030651..d3730b124 100644
--- a/lib/sqlalchemy/sql/ddl.py
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -28,6 +28,9 @@ class _DDLCompiles(ClauseElement):
return dialect.ddl_compiler(dialect, self, **kw)
+ def _compile_w_cache(self, *arg, **kw):
+ raise NotImplementedError()
+
class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
"""Base class for DDL expression constructs.
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index a82641d77..50b2a935a 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -641,7 +641,7 @@ class ValuesBase(UpdateBase):
if self._preserve_parameter_order:
arg = [
(
- k,
+ coercions.expect(roles.DMLColumnRole, k),
coercions.expect(
roles.ExpressionElementRole,
v,
@@ -654,7 +654,7 @@ class ValuesBase(UpdateBase):
self._ordered_values = arg
else:
arg = {
- k: coercions.expect(
+ coercions.expect(roles.DMLColumnRole, k): coercions.expect(
roles.ExpressionElementRole,
v,
type_=NullType(),
@@ -772,6 +772,7 @@ class Insert(ValuesBase):
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
)
@ValuesBase._constructor_20_deprecations(
@@ -997,6 +998,7 @@ class Update(DMLWhereBase, ValuesBase):
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
)
@ValuesBase._constructor_20_deprecations(
@@ -1187,7 +1189,7 @@ class Update(DMLWhereBase, ValuesBase):
)
arg = [
(
- k,
+ coercions.expect(roles.DMLColumnRole, k),
coercions.expect(
roles.ExpressionElementRole,
v,
@@ -1238,6 +1240,7 @@ class Delete(DMLWhereBase, UpdateBase):
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
)
@ValuesBase._constructor_20_deprecations(
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 8e1b623a7..60c816ee6 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -381,6 +381,7 @@ class ClauseElement(
try:
traverse_internals = self._traverse_internals
except AttributeError:
+ # user-defined classes may not have a _traverse_internals
return
for attrname, obj, meth in _copy_internals.run_generated_dispatch(
@@ -410,6 +411,7 @@ class ClauseElement(
try:
traverse_internals = self._traverse_internals
except AttributeError:
+ # user-defined classes may not have a _traverse_internals
return []
return itertools.chain.from_iterable(
@@ -516,10 +518,62 @@ class ClauseElement(
dialect = bind.dialect
elif self.bind:
dialect = self.bind.dialect
- bind = self.bind
else:
dialect = default.StrCompileDialect()
- return self._compiler(dialect, bind=bind, **kw)
+
+ return self._compiler(dialect, **kw)
+
+ def _compile_w_cache(
+ self,
+ dialect,
+ compiled_cache=None,
+ column_keys=None,
+ inline=False,
+ schema_translate_map=None,
+ **kw
+ ):
+ if compiled_cache is not None:
+ elem_cache_key = self._generate_cache_key()
+ else:
+ elem_cache_key = None
+
+ cache_hit = False
+
+ if elem_cache_key:
+ cache_key, extracted_params = elem_cache_key
+ key = (
+ dialect,
+ cache_key,
+ tuple(column_keys),
+ bool(schema_translate_map),
+ inline,
+ )
+ compiled_sql = compiled_cache.get(key)
+
+ if compiled_sql is None:
+ compiled_sql = self._compiler(
+ dialect,
+ cache_key=elem_cache_key,
+ column_keys=column_keys,
+ inline=inline,
+ schema_translate_map=schema_translate_map,
+ **kw
+ )
+ compiled_cache[key] = compiled_sql
+ else:
+ cache_hit = True
+ else:
+ extracted_params = None
+ compiled_sql = self._compiler(
+ dialect,
+ cache_key=elem_cache_key,
+ column_keys=column_keys,
+ inline=inline,
+ schema_translate_map=schema_translate_map,
+ **kw
+ )
+
+ return compiled_sql, extracted_params, cache_hit
def _compiler(self, dialect, **kw):
"""Return a compiler appropriate for this ClauseElement, given a
@@ -1035,6 +1089,10 @@ class BindParameter(roles.InElementRole, ColumnElement):
_is_bind_parameter = True
_key_is_anon = False
+ # bindparam implements its own _gen_cache_key() method however
+ # we check subclasses for this flag, else no cache key is generated
+ inherit_cache = True
+
def __init__(
self,
key,
@@ -1396,6 +1454,13 @@ class BindParameter(roles.InElementRole, ColumnElement):
return c
def _gen_cache_key(self, anon_map, bindparams):
+ _gen_cache_ok = self.__class__.__dict__.get("inherit_cache", False)
+
+ if not _gen_cache_ok:
+ if anon_map is not None:
+ anon_map[NO_CACHE] = True
+ return None
+
idself = id(self)
if idself in anon_map:
return (anon_map[idself], self.__class__)
@@ -2082,6 +2147,7 @@ class ClauseList(
roles.InElementRole,
roles.OrderByRole,
roles.ColumnsClauseRole,
+ roles.DMLColumnRole,
ClauseElement,
):
"""Describe a list of clauses, separated by an operator.
@@ -2174,6 +2240,7 @@ class ClauseList(
class BooleanClauseList(ClauseList, ColumnElement):
__visit_name__ = "clauselist"
+ inherit_cache = True
_tuple_values = False
@@ -3428,6 +3495,8 @@ class CollectionAggregate(UnaryExpression):
class AsBoolean(WrapsColumnExpression, UnaryExpression):
+ inherit_cache = True
+
def __init__(self, element, operator, negate):
self.element = element
self.type = type_api.BOOLEANTYPE
@@ -3474,6 +3543,7 @@ class BinaryExpression(ColumnElement):
("operator", InternalTraversal.dp_operator),
("negate", InternalTraversal.dp_operator),
("modifiers", InternalTraversal.dp_plain_dict),
+ ("type", InternalTraversal.dp_type,), # affects JSON CAST operators
]
_is_implicitly_boolean = True
@@ -3482,41 +3552,6 @@ class BinaryExpression(ColumnElement):
"""
- def _gen_cache_key(self, anon_map, bindparams):
- # inlined for performance
-
- idself = id(self)
-
- if idself in anon_map:
- return (anon_map[idself], self.__class__)
- else:
- # inline of
- # id_ = anon_map[idself]
- anon_map[idself] = id_ = str(anon_map.index)
- anon_map.index += 1
-
- if self._cache_key_traversal is NO_CACHE:
- anon_map[NO_CACHE] = True
- return None
-
- result = (id_, self.__class__)
-
- return result + (
- ("left", self.left._gen_cache_key(anon_map, bindparams)),
- ("right", self.right._gen_cache_key(anon_map, bindparams)),
- ("operator", self.operator),
- ("negate", self.negate),
- (
- "modifiers",
- tuple(
- (key, self.modifiers[key])
- for key in sorted(self.modifiers)
- )
- if self.modifiers
- else None,
- ),
- )
-
def __init__(
self, left, right, operator, type_=None, negate=None, modifiers=None
):
@@ -3587,15 +3622,30 @@ class Slice(ColumnElement):
__visit_name__ = "slice"
_traverse_internals = [
- ("start", InternalTraversal.dp_plain_obj),
- ("stop", InternalTraversal.dp_plain_obj),
- ("step", InternalTraversal.dp_plain_obj),
+ ("start", InternalTraversal.dp_clauseelement),
+ ("stop", InternalTraversal.dp_clauseelement),
+ ("step", InternalTraversal.dp_clauseelement),
]
- def __init__(self, start, stop, step):
- self.start = start
- self.stop = stop
- self.step = step
+ def __init__(self, start, stop, step, _name=None):
+ self.start = coercions.expect(
+ roles.ExpressionElementRole,
+ start,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.stop = coercions.expect(
+ roles.ExpressionElementRole,
+ stop,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.step = coercions.expect(
+ roles.ExpressionElementRole,
+ step,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
self.type = type_api.NULLTYPE
def self_group(self, against=None):
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 6b1172eba..7b723f371 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -744,6 +744,7 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
coerce_arguments = True
_register = False
+ inherit_cache = True
def __init__(self, *args, **kwargs):
parsed_args = kwargs.pop("_parsed_args", None)
@@ -808,6 +809,8 @@ class next_value(GenericFunction):
class AnsiFunction(GenericFunction):
+ inherit_cache = True
+
def __init__(self, *args, **kwargs):
GenericFunction.__init__(self, *args, **kwargs)
@@ -815,6 +818,8 @@ class AnsiFunction(GenericFunction):
class ReturnTypeFromArgs(GenericFunction):
"""Define a function whose return type is the same as its arguments."""
+ inherit_cache = True
+
def __init__(self, *args, **kwargs):
args = [
coercions.expect(
@@ -832,30 +837,34 @@ class ReturnTypeFromArgs(GenericFunction):
class coalesce(ReturnTypeFromArgs):
_has_args = True
+ inherit_cache = True
class max(ReturnTypeFromArgs): # noqa
- pass
+ inherit_cache = True
class min(ReturnTypeFromArgs): # noqa
- pass
+ inherit_cache = True
class sum(ReturnTypeFromArgs): # noqa
- pass
+ inherit_cache = True
class now(GenericFunction): # noqa
type = sqltypes.DateTime
+ inherit_cache = True
class concat(GenericFunction):
type = sqltypes.String
+ inherit_cache = True
class char_length(GenericFunction):
type = sqltypes.Integer
+ inherit_cache = True
def __init__(self, arg, **kwargs):
GenericFunction.__init__(self, arg, **kwargs)
@@ -863,6 +872,7 @@ class char_length(GenericFunction):
class random(GenericFunction):
_has_args = True
+ inherit_cache = True
class count(GenericFunction):
@@ -887,6 +897,7 @@ class count(GenericFunction):
"""
type = sqltypes.Integer
+ inherit_cache = True
def __init__(self, expression=None, **kwargs):
if expression is None:
@@ -896,38 +907,47 @@ class count(GenericFunction):
class current_date(AnsiFunction):
type = sqltypes.Date
+ inherit_cache = True
class current_time(AnsiFunction):
type = sqltypes.Time
+ inherit_cache = True
class current_timestamp(AnsiFunction):
type = sqltypes.DateTime
+ inherit_cache = True
class current_user(AnsiFunction):
type = sqltypes.String
+ inherit_cache = True
class localtime(AnsiFunction):
type = sqltypes.DateTime
+ inherit_cache = True
class localtimestamp(AnsiFunction):
type = sqltypes.DateTime
+ inherit_cache = True
class session_user(AnsiFunction):
type = sqltypes.String
+ inherit_cache = True
class sysdate(AnsiFunction):
type = sqltypes.DateTime
+ inherit_cache = True
class user(AnsiFunction):
type = sqltypes.String
+ inherit_cache = True
class array_agg(GenericFunction):
@@ -951,6 +971,7 @@ class array_agg(GenericFunction):
"""
type = sqltypes.ARRAY
+ inherit_cache = True
def __init__(self, *args, **kwargs):
args = [
@@ -978,6 +999,7 @@ class OrderedSetAgg(GenericFunction):
:meth:`.FunctionElement.within_group` method."""
array_for_multi_clause = False
+ inherit_cache = True
def within_group_type(self, within_group):
func_clauses = self.clause_expr.element
@@ -1000,6 +1022,8 @@ class mode(OrderedSetAgg):
"""
+ inherit_cache = True
+
class percentile_cont(OrderedSetAgg):
"""implement the ``percentile_cont`` ordered-set aggregate function.
@@ -1016,6 +1040,7 @@ class percentile_cont(OrderedSetAgg):
"""
array_for_multi_clause = True
+ inherit_cache = True
class percentile_disc(OrderedSetAgg):
@@ -1033,6 +1058,7 @@ class percentile_disc(OrderedSetAgg):
"""
array_for_multi_clause = True
+ inherit_cache = True
class rank(GenericFunction):
@@ -1048,6 +1074,7 @@ class rank(GenericFunction):
"""
type = sqltypes.Integer()
+ inherit_cache = True
class dense_rank(GenericFunction):
@@ -1063,6 +1090,7 @@ class dense_rank(GenericFunction):
"""
type = sqltypes.Integer()
+ inherit_cache = True
class percent_rank(GenericFunction):
@@ -1078,6 +1106,7 @@ class percent_rank(GenericFunction):
"""
type = sqltypes.Numeric()
+ inherit_cache = True
class cume_dist(GenericFunction):
@@ -1093,6 +1122,7 @@ class cume_dist(GenericFunction):
"""
type = sqltypes.Numeric()
+ inherit_cache = True
class cube(GenericFunction):
@@ -1109,6 +1139,7 @@ class cube(GenericFunction):
"""
_has_args = True
+ inherit_cache = True
class rollup(GenericFunction):
@@ -1125,6 +1156,7 @@ class rollup(GenericFunction):
"""
_has_args = True
+ inherit_cache = True
class grouping_sets(GenericFunction):
@@ -1158,3 +1190,4 @@ class grouping_sets(GenericFunction):
"""
_has_args = True
+ inherit_cache = True
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index ee411174c..29ca81d26 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -1013,6 +1013,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
__visit_name__ = "column"
+ inherit_cache = True
+
def __init__(self, *args, **kwargs):
r"""
Construct a new ``Column`` object.
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index a95fc561a..54f293967 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -61,6 +61,8 @@ if util.TYPE_CHECKING:
class _OffsetLimitParam(BindParameter):
+ inherit_cache = True
+
@property
def _limit_offset_value(self):
return self.effective_value
@@ -1426,6 +1428,8 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows):
__visit_name__ = "alias"
+ inherit_cache = True
+
@classmethod
def _factory(cls, selectable, name=None, flat=False):
"""Return an :class:`_expression.Alias` object.
@@ -1500,6 +1504,8 @@ class Lateral(AliasedReturnsRows):
__visit_name__ = "lateral"
_is_lateral = True
+ inherit_cache = True
+
@classmethod
def _factory(cls, selectable, name=None):
"""Return a :class:`_expression.Lateral` object.
@@ -1626,7 +1632,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
AliasedReturnsRows._traverse_internals
+ [
("_cte_alias", InternalTraversal.dp_clauseelement),
- ("_restates", InternalTraversal.dp_clauseelement_unordered_set),
+ ("_restates", InternalTraversal.dp_clauseelement_list),
("recursive", InternalTraversal.dp_boolean),
]
+ HasPrefixes._has_prefixes_traverse_internals
@@ -1651,7 +1657,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
name=None,
recursive=False,
_cte_alias=None,
- _restates=frozenset(),
+ _restates=(),
_prefixes=None,
_suffixes=None,
):
@@ -1692,7 +1698,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
self.element.union(other),
name=self.name,
recursive=self.recursive,
- _restates=self._restates.union([self]),
+ _restates=self._restates + (self,),
_prefixes=self._prefixes,
_suffixes=self._suffixes,
)
@@ -1702,7 +1708,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
self.element.union_all(other),
name=self.name,
recursive=self.recursive,
- _restates=self._restates.union([self]),
+ _restates=self._restates + (self,),
_prefixes=self._prefixes,
_suffixes=self._suffixes,
)
@@ -1918,6 +1924,8 @@ class Subquery(AliasedReturnsRows):
_is_subquery = True
+ inherit_cache = True
+
@classmethod
def _factory(cls, selectable, name=None):
"""Return a :class:`.Subquery` object.
@@ -3783,15 +3791,15 @@ class Select(
("_group_by_clauses", InternalTraversal.dp_clauseelement_list,),
("_setup_joins", InternalTraversal.dp_setup_join_tuple,),
("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,),
- ("_correlate", InternalTraversal.dp_clauseelement_unordered_set),
- (
- "_correlate_except",
- InternalTraversal.dp_clauseelement_unordered_set,
- ),
+ ("_correlate", InternalTraversal.dp_clauseelement_list),
+ ("_correlate_except", InternalTraversal.dp_clauseelement_list,),
+ ("_limit_clause", InternalTraversal.dp_clauseelement),
+ ("_offset_clause", InternalTraversal.dp_clauseelement),
("_for_update_arg", InternalTraversal.dp_clauseelement),
("_distinct", InternalTraversal.dp_boolean),
("_distinct_on", InternalTraversal.dp_clauseelement_list),
("_label_style", InternalTraversal.dp_plain_obj),
+ ("_is_future", InternalTraversal.dp_boolean),
]
+ HasPrefixes._has_prefixes_traverse_internals
+ HasSuffixes._has_suffixes_traverse_internals
@@ -4522,7 +4530,7 @@ class Select(
if fromclauses and fromclauses[0] is None:
self._correlate = ()
else:
- self._correlate = set(self._correlate).union(
+ self._correlate = self._correlate + tuple(
coercions.expect(roles.FromClauseRole, f) for f in fromclauses
)
@@ -4560,7 +4568,7 @@ class Select(
if fromclauses and fromclauses[0] is None:
self._correlate_except = ()
else:
- self._correlate_except = set(self._correlate_except or ()).union(
+ self._correlate_except = (self._correlate_except or ()) + tuple(
coercions.expect(roles.FromClauseRole, f) for f in fromclauses
)
@@ -4866,6 +4874,7 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
_from_objects = []
_is_from_container = True
_is_implicitly_boolean = False
+ inherit_cache = True
def __init__(self, element):
self.element = element
@@ -4899,6 +4908,7 @@ class Exists(UnaryExpression):
"""
_from_objects = []
+ inherit_cache = True
def __init__(self, *args, **kwargs):
"""Construct a new :class:`_expression.Exists` against an existing
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 732b775f6..9cd9d5058 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -2616,26 +2616,10 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
return_type = self.type
if self.type.zero_indexes:
index = slice(index.start + 1, index.stop + 1, index.step)
- index = Slice(
- coercions.expect(
- roles.ExpressionElementRole,
- index.start,
- name=self.expr.key,
- type_=type_api.INTEGERTYPE,
- ),
- coercions.expect(
- roles.ExpressionElementRole,
- index.stop,
- name=self.expr.key,
- type_=type_api.INTEGERTYPE,
- ),
- coercions.expect(
- roles.ExpressionElementRole,
- index.step,
- name=self.expr.key,
- type_=type_api.INTEGERTYPE,
- ),
+ slice_ = Slice(
+ index.start, index.stop, index.step, _name=self.expr.key
)
+ return operators.getitem, slice_, return_type
else:
if self.type.zero_indexes:
index += 1
@@ -2647,7 +2631,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
self.type.__class__, **adapt_kw
)
- return operators.getitem, index, return_type
+ return operators.getitem, index, return_type
def contains(self, *arg, **kw):
raise NotImplementedError(
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 68281f33d..ed0bfa27a 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -19,6 +19,7 @@ NO_CACHE = util.symbol("no_cache")
CACHE_IN_PLACE = util.symbol("cache_in_place")
CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key")
STATIC_CACHE_KEY = util.symbol("static_cache_key")
+PROPAGATE_ATTRS = util.symbol("propagate_attrs")
ANON_NAME = util.symbol("anon_name")
@@ -31,10 +32,74 @@ def compare(obj1, obj2, **kw):
return strategy.compare(obj1, obj2, **kw)
+def _preconfigure_traversals(target_hierarchy):
+
+ stack = [target_hierarchy]
+ while stack:
+ cls = stack.pop()
+ stack.extend(cls.__subclasses__())
+
+ if hasattr(cls, "_traverse_internals"):
+ cls._generate_cache_attrs()
+ _copy_internals.generate_dispatch(
+ cls,
+ cls._traverse_internals,
+ "_generated_copy_internals_traversal",
+ )
+ _get_children.generate_dispatch(
+ cls,
+ cls._traverse_internals,
+ "_generated_get_children_traversal",
+ )
+
+
class HasCacheKey(object):
_cache_key_traversal = NO_CACHE
__slots__ = ()
+ @classmethod
+ def _generate_cache_attrs(cls):
+ """generate cache key dispatcher for a new class.
+
+ This sets the _generated_cache_key_traversal attribute once called
+ so should only be called once per class.
+
+ """
+ inherit = cls.__dict__.get("inherit_cache", False)
+
+ if inherit:
+ _cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
+ if _cache_key_traversal is None:
+ try:
+ _cache_key_traversal = cls._traverse_internals
+ except AttributeError:
+ cls._generated_cache_key_traversal = NO_CACHE
+ return NO_CACHE
+
+ # TODO: wouldn't we instead get this from our superclass?
+ # also, our superclass may not have this yet, but in any case,
+ # we'd generate for the superclass that has it. this is a little
+ # more complicated, so for the moment this is a little less
+ # efficient on startup but simpler.
+ return _cache_key_traversal_visitor.generate_dispatch(
+ cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ )
+ else:
+ _cache_key_traversal = cls.__dict__.get(
+ "_cache_key_traversal", None
+ )
+ if _cache_key_traversal is None:
+ _cache_key_traversal = cls.__dict__.get(
+ "_traverse_internals", None
+ )
+ if _cache_key_traversal is None:
+ cls._generated_cache_key_traversal = NO_CACHE
+ return NO_CACHE
+
+ return _cache_key_traversal_visitor.generate_dispatch(
+ cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ )
+
@util.preload_module("sqlalchemy.sql.elements")
def _gen_cache_key(self, anon_map, bindparams):
"""return an optional cache key.
@@ -72,14 +137,18 @@ class HasCacheKey(object):
else:
id_ = None
- _cache_key_traversal = self._cache_key_traversal
- if _cache_key_traversal is None:
- try:
- _cache_key_traversal = self._traverse_internals
- except AttributeError:
- _cache_key_traversal = NO_CACHE
+ try:
+ dispatcher = self.__class__.__dict__[
+ "_generated_cache_key_traversal"
+ ]
+ except KeyError:
+ # most of the dispatchers are generated up front
+ # in sqlalchemy/sql/__init__.py ->
+ # traversals.py-> _preconfigure_traversals().
+ # this block will generate any remaining dispatchers.
+ dispatcher = self.__class__._generate_cache_attrs()
- if _cache_key_traversal is NO_CACHE:
+ if dispatcher is NO_CACHE:
if anon_map is not None:
anon_map[NO_CACHE] = True
return None
@@ -87,19 +156,13 @@ class HasCacheKey(object):
result = (id_, self.__class__)
# inline of _cache_key_traversal_visitor.run_generated_dispatch()
- try:
- dispatcher = self.__class__.__dict__[
- "_generated_cache_key_traversal"
- ]
- except KeyError:
- dispatcher = _cache_key_traversal_visitor.generate_dispatch(
- self, _cache_key_traversal, "_generated_cache_key_traversal"
- )
for attrname, obj, meth in dispatcher(
self, _cache_key_traversal_visitor
):
if obj is not None:
+ # TODO: see if C code can help here as Python lacks an
+ # efficient switch construct
if meth is CACHE_IN_PLACE:
# cache in place is always going to be a Python
# tuple, dict, list, etc. so we can do a boolean check
@@ -116,6 +179,15 @@ class HasCacheKey(object):
attrname,
obj._gen_cache_key(anon_map, bindparams),
)
+ elif meth is PROPAGATE_ATTRS:
+ if obj:
+ result += (
+ attrname,
+ obj["compile_state_plugin"],
+ obj["plugin_subject"]._gen_cache_key(
+ anon_map, bindparams
+ ),
+ )
elif meth is InternalTraversal.dp_annotations_key:
# obj is here is the _annotations dict. however,
# we want to use the memoized cache key version of it.
@@ -332,6 +404,8 @@ class _CacheKey(ExtendedInternalTraversal):
visit_type = STATIC_CACHE_KEY
visit_anon_name = ANON_NAME
+ visit_propagate_attrs = PROPAGATE_ATTRS
+
def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
@@ -445,10 +519,16 @@ class _CacheKey(ExtendedInternalTraversal):
def visit_setup_join_tuple(
self, attrname, obj, parent, anon_map, bindparams
):
+ is_legacy = "legacy" in attrname
+
return tuple(
(
- target._gen_cache_key(anon_map, bindparams),
- onclause._gen_cache_key(anon_map, bindparams)
+ target
+ if is_legacy and isinstance(target, str)
+ else target._gen_cache_key(anon_map, bindparams),
+ onclause
+ if is_legacy and isinstance(onclause, str)
+ else onclause._gen_cache_key(anon_map, bindparams)
if onclause is not None
else None,
from_._gen_cache_key(anon_map, bindparams)
@@ -711,6 +791,11 @@ class _CopyInternals(InternalTraversal):
for sequence in element
]
+ def visit_propagate_attrs(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return element
+
_copy_internals = _CopyInternals()
@@ -782,6 +867,9 @@ class _GetChildren(InternalTraversal):
def visit_dml_multi_values(self, element, **kw):
return ()
+ def visit_propagate_attrs(self, element, **kw):
+ return ()
+
_get_children = _GetChildren()
@@ -916,6 +1004,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
):
return COMPARE_FAILED
+ def visit_propagate_attrs(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return self.compare_inner(
+ left.get("plugin_subject", None), right.get("plugin_subject", None)
+ )
+
def visit_has_cache_key_list(
self, attrname, left_parent, left, right_parent, right, **kw
):
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index ccda21e11..fe3634bad 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -555,7 +555,12 @@ class TypeEngine(Traversible):
def _static_cache_key(self):
names = util.get_cls_kwargs(self.__class__)
return (self.__class__,) + tuple(
- (k, self.__dict__[k])
+ (
+ k,
+ self.__dict__[k]._static_cache_key
+ if isinstance(self.__dict__[k], TypeEngine)
+ else self.__dict__[k],
+ )
for k in names
if k in self.__dict__ and not k.startswith("_")
)
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 5de68f504..904702003 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -217,18 +217,23 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
try:
dispatcher = target.__class__.__dict__[generate_dispatcher_name]
except KeyError:
+ # most of the dispatchers are generated up front
+ # in sqlalchemy/sql/__init__.py ->
+ # traversals.py-> _preconfigure_traversals().
+ # this block will generate any remaining dispatchers.
dispatcher = self.generate_dispatch(
- target, internal_dispatch, generate_dispatcher_name
+ target.__class__, internal_dispatch, generate_dispatcher_name
)
return dispatcher(target, self)
def generate_dispatch(
- self, target, internal_dispatch, generate_dispatcher_name
+ self, target_cls, internal_dispatch, generate_dispatcher_name
):
dispatcher = _generate_dispatcher(
self, internal_dispatch, generate_dispatcher_name
)
- setattr(target.__class__, generate_dispatcher_name, dispatcher)
+ # assert isinstance(target_cls, type)
+ setattr(target_cls, generate_dispatcher_name, dispatcher)
return dispatcher
dp_has_cache_key = symbol("HC")
@@ -263,10 +268,6 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
- dp_clauseelement_unordered_set = symbol("CU")
- """Visit an unordered set of :class:`_expression.ClauseElement`
- objects. """
-
dp_fromclause_ordered_set = symbol("CO")
"""Visit an ordered set of :class:`_expression.FromClause` objects. """
@@ -414,6 +415,10 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
+ dp_propagate_attrs = symbol("PA")
+ """Visit the propagate attrs dict. this hardcodes to the particular
+ elements we care about right now."""
+
class ExtendedInternalTraversal(InternalTraversal):
"""defines additional symbols that are useful in caching applications.