diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/__init__.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/annotation.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 54 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/ddl.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 138 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/functions.py | 39 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 32 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 24 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 129 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 19 |
14 files changed, 338 insertions, 137 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 78de80734..a25c1b083 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -100,6 +100,7 @@ def __go(lcls): from .elements import AnnotatedColumnElement from .elements import ClauseList # noqa from .selectable import AnnotatedFromClause # noqa + from .traversals import _preconfigure_traversals from . import base from . import coercions @@ -122,6 +123,8 @@ def __go(lcls): _prepare_annotations(FromClause, AnnotatedFromClause) _prepare_annotations(ClauseList, Annotated) + _preconfigure_traversals(ClauseElement) + _sa_util.preloaded.import_prefix("sqlalchemy.sql") from . import naming # noqa diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 08ed121d3..8a0d6ec28 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -338,6 +338,15 @@ def _new_annotation_type(cls, base_cls): anno_cls._traverse_internals = list(cls._traverse_internals) + [ ("_annotations", InternalTraversal.dp_annotations_key) ] + elif cls.__dict__.get("inherit_cache", False): + anno_cls._traverse_internals = list(cls._traverse_internals) + [ + ("_annotations", InternalTraversal.dp_annotations_key) + ] + + # some classes include this even if they have traverse_internals + # e.g. BindParameter, add it if present. + if cls.__dict__.get("inherit_cache", False): + anno_cls.inherit_cache = True anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 5dd3b519a..5f2ce8f14 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -624,19 +624,14 @@ class Executable(Generative): _bind = None _with_options = () _with_context_options = () - _cache_enable = True _executable_traverse_internals = [ ("_with_options", ExtendedInternalTraversal.dp_has_cache_key_list), ("_with_context_options", ExtendedInternalTraversal.dp_plain_obj), - ("_cache_enable", ExtendedInternalTraversal.dp_plain_obj), + ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs), ] @_generative - def _disable_caching(self): - self._cache_enable = HasCacheKey() - - @_generative def options(self, *options): """Apply options to this statement. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 2519438d1..61178291a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -373,6 +373,8 @@ class Compiled(object): _cached_metadata = None + _result_columns = None + schema_translate_map = None execution_options = util.immutabledict() @@ -433,7 +435,6 @@ class Compiled(object): self, dialect, statement, - bind=None, schema_translate_map=None, render_schema_translate=False, compile_kwargs=util.immutabledict(), @@ -463,7 +464,6 @@ class Compiled(object): """ self.dialect = dialect - self.bind = bind self.preparer = self.dialect.identifier_preparer if schema_translate_map: self.schema_translate_map = schema_translate_map @@ -527,24 +527,6 @@ class Compiled(object): """Return the bind params for this compiled object.""" return self.construct_params() - def execute(self, *multiparams, **params): - """Execute this compiled object.""" - - e = self.bind - if e is None: - raise exc.UnboundExecutionError( - "This Compiled object is not bound to any Engine " - "or Connection.", - code="2afi", - ) - return e._execute_compiled(self, multiparams, params) - - def scalar(self, *multiparams, **params): - """Execute this compiled object and return the result's - scalar value.""" - - return self.execute(*multiparams, **params).scalar() - class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): """Produces DDL specification for TypeEngine objects.""" @@ -687,6 +669,13 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () + _cache_key_bind_match = None + """a mapping that will relate the BindParameter object we compile + to those that are part of the extracted collection of parameters + in the cache key, if we were given a cache key. + + """ + def __init__( self, dialect, @@ -717,6 +706,9 @@ class SQLCompiler(Compiled): self.cache_key = cache_key + if cache_key: + self._cache_key_bind_match = {b: b for b in cache_key[1]} + # compile INSERT/UPDATE defaults/sequences inlined (no pre- # execute) self.inline = inline or getattr(statement, "_inline", False) @@ -875,8 +867,9 @@ class SQLCompiler(Compiled): replace_context=err, ) + ckbm = self._cache_key_bind_match resolved_extracted = { - b.key: extracted + ckbm[b]: extracted for b, extracted in zip(orig_extracted, extracted_parameters) } else: @@ -907,7 +900,7 @@ class SQLCompiler(Compiled): else: if resolved_extracted: value_param = resolved_extracted.get( - bindparam.key, bindparam + bindparam, bindparam ) else: value_param = bindparam @@ -936,9 +929,7 @@ class SQLCompiler(Compiled): ) if resolved_extracted: - value_param = resolved_extracted.get( - bindparam.key, bindparam - ) + value_param = resolved_extracted.get(bindparam, bindparam) else: value_param = bindparam @@ -2021,6 +2012,19 @@ class SQLCompiler(Compiled): ) self.binds[bindparam.key] = self.binds[name] = bindparam + + # if we are given a cache key that we're going to match against, + # relate the bindparam here to one that is most likely present + # in the "extracted params" portion of the cache key. this is used + # to set up a positional mapping that is used to determine the + # correct parameters for a subsequent use of this compiled with + # a different set of parameter values. here, we accommodate for + # parameters that may have been cloned both before and after the cache + # key was been generated. + ckbm = self._cache_key_bind_match + if ckbm: + ckbm.update({bp: bindparam for bp in bindparam._cloned_set}) + if bindparam.isoutparam: self.has_out_parameters = True diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 569030651..d3730b124 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -28,6 +28,9 @@ class _DDLCompiles(ClauseElement): return dialect.ddl_compiler(dialect, self, **kw) + def _compile_w_cache(self, *arg, **kw): + raise NotImplementedError() + class DDLElement(roles.DDLRole, Executable, _DDLCompiles): """Base class for DDL expression constructs. diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index a82641d77..50b2a935a 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -641,7 +641,7 @@ class ValuesBase(UpdateBase): if self._preserve_parameter_order: arg = [ ( - k, + coercions.expect(roles.DMLColumnRole, k), coercions.expect( roles.ExpressionElementRole, v, @@ -654,7 +654,7 @@ class ValuesBase(UpdateBase): self._ordered_values = arg else: arg = { - k: coercions.expect( + coercions.expect(roles.DMLColumnRole, k): coercions.expect( roles.ExpressionElementRole, v, type_=NullType(), @@ -772,6 +772,7 @@ class Insert(ValuesBase): ] + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals ) @ValuesBase._constructor_20_deprecations( @@ -997,6 +998,7 @@ class Update(DMLWhereBase, ValuesBase): ] + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals ) @ValuesBase._constructor_20_deprecations( @@ -1187,7 +1189,7 @@ class Update(DMLWhereBase, ValuesBase): ) arg = [ ( - k, + coercions.expect(roles.DMLColumnRole, k), coercions.expect( roles.ExpressionElementRole, v, @@ -1238,6 +1240,7 @@ class Delete(DMLWhereBase, UpdateBase): ] + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals ) @ValuesBase._constructor_20_deprecations( diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 8e1b623a7..60c816ee6 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -381,6 +381,7 @@ class ClauseElement( try: traverse_internals = self._traverse_internals except AttributeError: + # user-defined classes may not have a _traverse_internals return for attrname, obj, meth in _copy_internals.run_generated_dispatch( @@ -410,6 +411,7 @@ class ClauseElement( try: traverse_internals = self._traverse_internals except AttributeError: + # user-defined classes may not have a _traverse_internals return [] return itertools.chain.from_iterable( @@ -516,10 +518,62 @@ class ClauseElement( dialect = bind.dialect elif self.bind: dialect = self.bind.dialect - bind = self.bind else: dialect = default.StrCompileDialect() - return self._compiler(dialect, bind=bind, **kw) + + return self._compiler(dialect, **kw) + + def _compile_w_cache( + self, + dialect, + compiled_cache=None, + column_keys=None, + inline=False, + schema_translate_map=None, + **kw + ): + if compiled_cache is not None: + elem_cache_key = self._generate_cache_key() + else: + elem_cache_key = None + + cache_hit = False + + if elem_cache_key: + cache_key, extracted_params = elem_cache_key + key = ( + dialect, + cache_key, + tuple(column_keys), + bool(schema_translate_map), + inline, + ) + compiled_sql = compiled_cache.get(key) + + if compiled_sql is None: + compiled_sql = self._compiler( + dialect, + cache_key=elem_cache_key, + column_keys=column_keys, + inline=inline, + schema_translate_map=schema_translate_map, + **kw + ) + compiled_cache[key] = compiled_sql + else: + cache_hit = True + else: + extracted_params = None + compiled_sql = self._compiler( + dialect, + cache_key=elem_cache_key, + column_keys=column_keys, + inline=inline, + schema_translate_map=schema_translate_map, + **kw + ) + + return compiled_sql, extracted_params, cache_hit def _compiler(self, dialect, **kw): """Return a compiler appropriate for this ClauseElement, given a @@ -1035,6 +1089,10 @@ class BindParameter(roles.InElementRole, ColumnElement): _is_bind_parameter = True _key_is_anon = False + # bindparam implements its own _gen_cache_key() method however + # we check subclasses for this flag, else no cache key is generated + inherit_cache = True + def __init__( self, key, @@ -1396,6 +1454,13 @@ class BindParameter(roles.InElementRole, ColumnElement): return c def _gen_cache_key(self, anon_map, bindparams): + _gen_cache_ok = self.__class__.__dict__.get("inherit_cache", False) + + if not _gen_cache_ok: + if anon_map is not None: + anon_map[NO_CACHE] = True + return None + idself = id(self) if idself in anon_map: return (anon_map[idself], self.__class__) @@ -2082,6 +2147,7 @@ class ClauseList( roles.InElementRole, roles.OrderByRole, roles.ColumnsClauseRole, + roles.DMLColumnRole, ClauseElement, ): """Describe a list of clauses, separated by an operator. @@ -2174,6 +2240,7 @@ class ClauseList( class BooleanClauseList(ClauseList, ColumnElement): __visit_name__ = "clauselist" + inherit_cache = True _tuple_values = False @@ -3428,6 +3495,8 @@ class CollectionAggregate(UnaryExpression): class AsBoolean(WrapsColumnExpression, UnaryExpression): + inherit_cache = True + def __init__(self, element, operator, negate): self.element = element self.type = type_api.BOOLEANTYPE @@ -3474,6 +3543,7 @@ class BinaryExpression(ColumnElement): ("operator", InternalTraversal.dp_operator), ("negate", InternalTraversal.dp_operator), ("modifiers", InternalTraversal.dp_plain_dict), + ("type", InternalTraversal.dp_type,), # affects JSON CAST operators ] _is_implicitly_boolean = True @@ -3482,41 +3552,6 @@ class BinaryExpression(ColumnElement): """ - def _gen_cache_key(self, anon_map, bindparams): - # inlined for performance - - idself = id(self) - - if idself in anon_map: - return (anon_map[idself], self.__class__) - else: - # inline of - # id_ = anon_map[idself] - anon_map[idself] = id_ = str(anon_map.index) - anon_map.index += 1 - - if self._cache_key_traversal is NO_CACHE: - anon_map[NO_CACHE] = True - return None - - result = (id_, self.__class__) - - return result + ( - ("left", self.left._gen_cache_key(anon_map, bindparams)), - ("right", self.right._gen_cache_key(anon_map, bindparams)), - ("operator", self.operator), - ("negate", self.negate), - ( - "modifiers", - tuple( - (key, self.modifiers[key]) - for key in sorted(self.modifiers) - ) - if self.modifiers - else None, - ), - ) - def __init__( self, left, right, operator, type_=None, negate=None, modifiers=None ): @@ -3587,15 +3622,30 @@ class Slice(ColumnElement): __visit_name__ = "slice" _traverse_internals = [ - ("start", InternalTraversal.dp_plain_obj), - ("stop", InternalTraversal.dp_plain_obj), - ("step", InternalTraversal.dp_plain_obj), + ("start", InternalTraversal.dp_clauseelement), + ("stop", InternalTraversal.dp_clauseelement), + ("step", InternalTraversal.dp_clauseelement), ] - def __init__(self, start, stop, step): - self.start = start - self.stop = stop - self.step = step + def __init__(self, start, stop, step, _name=None): + self.start = coercions.expect( + roles.ExpressionElementRole, + start, + name=_name, + type_=type_api.INTEGERTYPE, + ) + self.stop = coercions.expect( + roles.ExpressionElementRole, + stop, + name=_name, + type_=type_api.INTEGERTYPE, + ) + self.step = coercions.expect( + roles.ExpressionElementRole, + step, + name=_name, + type_=type_api.INTEGERTYPE, + ) self.type = type_api.NULLTYPE def self_group(self, against=None): diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 6b1172eba..7b723f371 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -744,6 +744,7 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)): coerce_arguments = True _register = False + inherit_cache = True def __init__(self, *args, **kwargs): parsed_args = kwargs.pop("_parsed_args", None) @@ -808,6 +809,8 @@ class next_value(GenericFunction): class AnsiFunction(GenericFunction): + inherit_cache = True + def __init__(self, *args, **kwargs): GenericFunction.__init__(self, *args, **kwargs) @@ -815,6 +818,8 @@ class AnsiFunction(GenericFunction): class ReturnTypeFromArgs(GenericFunction): """Define a function whose return type is the same as its arguments.""" + inherit_cache = True + def __init__(self, *args, **kwargs): args = [ coercions.expect( @@ -832,30 +837,34 @@ class ReturnTypeFromArgs(GenericFunction): class coalesce(ReturnTypeFromArgs): _has_args = True + inherit_cache = True class max(ReturnTypeFromArgs): # noqa - pass + inherit_cache = True class min(ReturnTypeFromArgs): # noqa - pass + inherit_cache = True class sum(ReturnTypeFromArgs): # noqa - pass + inherit_cache = True class now(GenericFunction): # noqa type = sqltypes.DateTime + inherit_cache = True class concat(GenericFunction): type = sqltypes.String + inherit_cache = True class char_length(GenericFunction): type = sqltypes.Integer + inherit_cache = True def __init__(self, arg, **kwargs): GenericFunction.__init__(self, arg, **kwargs) @@ -863,6 +872,7 @@ class char_length(GenericFunction): class random(GenericFunction): _has_args = True + inherit_cache = True class count(GenericFunction): @@ -887,6 +897,7 @@ class count(GenericFunction): """ type = sqltypes.Integer + inherit_cache = True def __init__(self, expression=None, **kwargs): if expression is None: @@ -896,38 +907,47 @@ class count(GenericFunction): class current_date(AnsiFunction): type = sqltypes.Date + inherit_cache = True class current_time(AnsiFunction): type = sqltypes.Time + inherit_cache = True class current_timestamp(AnsiFunction): type = sqltypes.DateTime + inherit_cache = True class current_user(AnsiFunction): type = sqltypes.String + inherit_cache = True class localtime(AnsiFunction): type = sqltypes.DateTime + inherit_cache = True class localtimestamp(AnsiFunction): type = sqltypes.DateTime + inherit_cache = True class session_user(AnsiFunction): type = sqltypes.String + inherit_cache = True class sysdate(AnsiFunction): type = sqltypes.DateTime + inherit_cache = True class user(AnsiFunction): type = sqltypes.String + inherit_cache = True class array_agg(GenericFunction): @@ -951,6 +971,7 @@ class array_agg(GenericFunction): """ type = sqltypes.ARRAY + inherit_cache = True def __init__(self, *args, **kwargs): args = [ @@ -978,6 +999,7 @@ class OrderedSetAgg(GenericFunction): :meth:`.FunctionElement.within_group` method.""" array_for_multi_clause = False + inherit_cache = True def within_group_type(self, within_group): func_clauses = self.clause_expr.element @@ -1000,6 +1022,8 @@ class mode(OrderedSetAgg): """ + inherit_cache = True + class percentile_cont(OrderedSetAgg): """implement the ``percentile_cont`` ordered-set aggregate function. @@ -1016,6 +1040,7 @@ class percentile_cont(OrderedSetAgg): """ array_for_multi_clause = True + inherit_cache = True class percentile_disc(OrderedSetAgg): @@ -1033,6 +1058,7 @@ class percentile_disc(OrderedSetAgg): """ array_for_multi_clause = True + inherit_cache = True class rank(GenericFunction): @@ -1048,6 +1074,7 @@ class rank(GenericFunction): """ type = sqltypes.Integer() + inherit_cache = True class dense_rank(GenericFunction): @@ -1063,6 +1090,7 @@ class dense_rank(GenericFunction): """ type = sqltypes.Integer() + inherit_cache = True class percent_rank(GenericFunction): @@ -1078,6 +1106,7 @@ class percent_rank(GenericFunction): """ type = sqltypes.Numeric() + inherit_cache = True class cume_dist(GenericFunction): @@ -1093,6 +1122,7 @@ class cume_dist(GenericFunction): """ type = sqltypes.Numeric() + inherit_cache = True class cube(GenericFunction): @@ -1109,6 +1139,7 @@ class cube(GenericFunction): """ _has_args = True + inherit_cache = True class rollup(GenericFunction): @@ -1125,6 +1156,7 @@ class rollup(GenericFunction): """ _has_args = True + inherit_cache = True class grouping_sets(GenericFunction): @@ -1158,3 +1190,4 @@ class grouping_sets(GenericFunction): """ _has_args = True + inherit_cache = True diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index ee411174c..29ca81d26 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1013,6 +1013,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): __visit_name__ = "column" + inherit_cache = True + def __init__(self, *args, **kwargs): r""" Construct a new ``Column`` object. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index a95fc561a..54f293967 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -61,6 +61,8 @@ if util.TYPE_CHECKING: class _OffsetLimitParam(BindParameter): + inherit_cache = True + @property def _limit_offset_value(self): return self.effective_value @@ -1426,6 +1428,8 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows): __visit_name__ = "alias" + inherit_cache = True + @classmethod def _factory(cls, selectable, name=None, flat=False): """Return an :class:`_expression.Alias` object. @@ -1500,6 +1504,8 @@ class Lateral(AliasedReturnsRows): __visit_name__ = "lateral" _is_lateral = True + inherit_cache = True + @classmethod def _factory(cls, selectable, name=None): """Return a :class:`_expression.Lateral` object. @@ -1626,7 +1632,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows): AliasedReturnsRows._traverse_internals + [ ("_cte_alias", InternalTraversal.dp_clauseelement), - ("_restates", InternalTraversal.dp_clauseelement_unordered_set), + ("_restates", InternalTraversal.dp_clauseelement_list), ("recursive", InternalTraversal.dp_boolean), ] + HasPrefixes._has_prefixes_traverse_internals @@ -1651,7 +1657,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows): name=None, recursive=False, _cte_alias=None, - _restates=frozenset(), + _restates=(), _prefixes=None, _suffixes=None, ): @@ -1692,7 +1698,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows): self.element.union(other), name=self.name, recursive=self.recursive, - _restates=self._restates.union([self]), + _restates=self._restates + (self,), _prefixes=self._prefixes, _suffixes=self._suffixes, ) @@ -1702,7 +1708,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows): self.element.union_all(other), name=self.name, recursive=self.recursive, - _restates=self._restates.union([self]), + _restates=self._restates + (self,), _prefixes=self._prefixes, _suffixes=self._suffixes, ) @@ -1918,6 +1924,8 @@ class Subquery(AliasedReturnsRows): _is_subquery = True + inherit_cache = True + @classmethod def _factory(cls, selectable, name=None): """Return a :class:`.Subquery` object. @@ -3783,15 +3791,15 @@ class Select( ("_group_by_clauses", InternalTraversal.dp_clauseelement_list,), ("_setup_joins", InternalTraversal.dp_setup_join_tuple,), ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,), - ("_correlate", InternalTraversal.dp_clauseelement_unordered_set), - ( - "_correlate_except", - InternalTraversal.dp_clauseelement_unordered_set, - ), + ("_correlate", InternalTraversal.dp_clauseelement_list), + ("_correlate_except", InternalTraversal.dp_clauseelement_list,), + ("_limit_clause", InternalTraversal.dp_clauseelement), + ("_offset_clause", InternalTraversal.dp_clauseelement), ("_for_update_arg", InternalTraversal.dp_clauseelement), ("_distinct", InternalTraversal.dp_boolean), ("_distinct_on", InternalTraversal.dp_clauseelement_list), ("_label_style", InternalTraversal.dp_plain_obj), + ("_is_future", InternalTraversal.dp_boolean), ] + HasPrefixes._has_prefixes_traverse_internals + HasSuffixes._has_suffixes_traverse_internals @@ -4522,7 +4530,7 @@ class Select( if fromclauses and fromclauses[0] is None: self._correlate = () else: - self._correlate = set(self._correlate).union( + self._correlate = self._correlate + tuple( coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) @@ -4560,7 +4568,7 @@ class Select( if fromclauses and fromclauses[0] is None: self._correlate_except = () else: - self._correlate_except = set(self._correlate_except or ()).union( + self._correlate_except = (self._correlate_except or ()) + tuple( coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) @@ -4866,6 +4874,7 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): _from_objects = [] _is_from_container = True _is_implicitly_boolean = False + inherit_cache = True def __init__(self, element): self.element = element @@ -4899,6 +4908,7 @@ class Exists(UnaryExpression): """ _from_objects = [] + inherit_cache = True def __init__(self, *args, **kwargs): """Construct a new :class:`_expression.Exists` against an existing diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 732b775f6..9cd9d5058 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2616,26 +2616,10 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): return_type = self.type if self.type.zero_indexes: index = slice(index.start + 1, index.stop + 1, index.step) - index = Slice( - coercions.expect( - roles.ExpressionElementRole, - index.start, - name=self.expr.key, - type_=type_api.INTEGERTYPE, - ), - coercions.expect( - roles.ExpressionElementRole, - index.stop, - name=self.expr.key, - type_=type_api.INTEGERTYPE, - ), - coercions.expect( - roles.ExpressionElementRole, - index.step, - name=self.expr.key, - type_=type_api.INTEGERTYPE, - ), + slice_ = Slice( + index.start, index.stop, index.step, _name=self.expr.key ) + return operators.getitem, slice_, return_type else: if self.type.zero_indexes: index += 1 @@ -2647,7 +2631,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): self.type.__class__, **adapt_kw ) - return operators.getitem, index, return_type + return operators.getitem, index, return_type def contains(self, *arg, **kw): raise NotImplementedError( diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 68281f33d..ed0bfa27a 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -19,6 +19,7 @@ NO_CACHE = util.symbol("no_cache") CACHE_IN_PLACE = util.symbol("cache_in_place") CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key") STATIC_CACHE_KEY = util.symbol("static_cache_key") +PROPAGATE_ATTRS = util.symbol("propagate_attrs") ANON_NAME = util.symbol("anon_name") @@ -31,10 +32,74 @@ def compare(obj1, obj2, **kw): return strategy.compare(obj1, obj2, **kw) +def _preconfigure_traversals(target_hierarchy): + + stack = [target_hierarchy] + while stack: + cls = stack.pop() + stack.extend(cls.__subclasses__()) + + if hasattr(cls, "_traverse_internals"): + cls._generate_cache_attrs() + _copy_internals.generate_dispatch( + cls, + cls._traverse_internals, + "_generated_copy_internals_traversal", + ) + _get_children.generate_dispatch( + cls, + cls._traverse_internals, + "_generated_get_children_traversal", + ) + + class HasCacheKey(object): _cache_key_traversal = NO_CACHE __slots__ = () + @classmethod + def _generate_cache_attrs(cls): + """generate cache key dispatcher for a new class. + + This sets the _generated_cache_key_traversal attribute once called + so should only be called once per class. + + """ + inherit = cls.__dict__.get("inherit_cache", False) + + if inherit: + _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) + if _cache_key_traversal is None: + try: + _cache_key_traversal = cls._traverse_internals + except AttributeError: + cls._generated_cache_key_traversal = NO_CACHE + return NO_CACHE + + # TODO: wouldn't we instead get this from our superclass? + # also, our superclass may not have this yet, but in any case, + # we'd generate for the superclass that has it. this is a little + # more complicated, so for the moment this is a little less + # efficient on startup but simpler. + return _cache_key_traversal_visitor.generate_dispatch( + cls, _cache_key_traversal, "_generated_cache_key_traversal" + ) + else: + _cache_key_traversal = cls.__dict__.get( + "_cache_key_traversal", None + ) + if _cache_key_traversal is None: + _cache_key_traversal = cls.__dict__.get( + "_traverse_internals", None + ) + if _cache_key_traversal is None: + cls._generated_cache_key_traversal = NO_CACHE + return NO_CACHE + + return _cache_key_traversal_visitor.generate_dispatch( + cls, _cache_key_traversal, "_generated_cache_key_traversal" + ) + @util.preload_module("sqlalchemy.sql.elements") def _gen_cache_key(self, anon_map, bindparams): """return an optional cache key. @@ -72,14 +137,18 @@ class HasCacheKey(object): else: id_ = None - _cache_key_traversal = self._cache_key_traversal - if _cache_key_traversal is None: - try: - _cache_key_traversal = self._traverse_internals - except AttributeError: - _cache_key_traversal = NO_CACHE + try: + dispatcher = self.__class__.__dict__[ + "_generated_cache_key_traversal" + ] + except KeyError: + # most of the dispatchers are generated up front + # in sqlalchemy/sql/__init__.py -> + # traversals.py-> _preconfigure_traversals(). + # this block will generate any remaining dispatchers. + dispatcher = self.__class__._generate_cache_attrs() - if _cache_key_traversal is NO_CACHE: + if dispatcher is NO_CACHE: if anon_map is not None: anon_map[NO_CACHE] = True return None @@ -87,19 +156,13 @@ class HasCacheKey(object): result = (id_, self.__class__) # inline of _cache_key_traversal_visitor.run_generated_dispatch() - try: - dispatcher = self.__class__.__dict__[ - "_generated_cache_key_traversal" - ] - except KeyError: - dispatcher = _cache_key_traversal_visitor.generate_dispatch( - self, _cache_key_traversal, "_generated_cache_key_traversal" - ) for attrname, obj, meth in dispatcher( self, _cache_key_traversal_visitor ): if obj is not None: + # TODO: see if C code can help here as Python lacks an + # efficient switch construct if meth is CACHE_IN_PLACE: # cache in place is always going to be a Python # tuple, dict, list, etc. so we can do a boolean check @@ -116,6 +179,15 @@ class HasCacheKey(object): attrname, obj._gen_cache_key(anon_map, bindparams), ) + elif meth is PROPAGATE_ATTRS: + if obj: + result += ( + attrname, + obj["compile_state_plugin"], + obj["plugin_subject"]._gen_cache_key( + anon_map, bindparams + ), + ) elif meth is InternalTraversal.dp_annotations_key: # obj is here is the _annotations dict. however, # we want to use the memoized cache key version of it. @@ -332,6 +404,8 @@ class _CacheKey(ExtendedInternalTraversal): visit_type = STATIC_CACHE_KEY visit_anon_name = ANON_NAME + visit_propagate_attrs = PROPAGATE_ATTRS + def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) @@ -445,10 +519,16 @@ class _CacheKey(ExtendedInternalTraversal): def visit_setup_join_tuple( self, attrname, obj, parent, anon_map, bindparams ): + is_legacy = "legacy" in attrname + return tuple( ( - target._gen_cache_key(anon_map, bindparams), - onclause._gen_cache_key(anon_map, bindparams) + target + if is_legacy and isinstance(target, str) + else target._gen_cache_key(anon_map, bindparams), + onclause + if is_legacy and isinstance(onclause, str) + else onclause._gen_cache_key(anon_map, bindparams) if onclause is not None else None, from_._gen_cache_key(anon_map, bindparams) @@ -711,6 +791,11 @@ class _CopyInternals(InternalTraversal): for sequence in element ] + def visit_propagate_attrs( + self, attrname, parent, element, clone=_clone, **kw + ): + return element + _copy_internals = _CopyInternals() @@ -782,6 +867,9 @@ class _GetChildren(InternalTraversal): def visit_dml_multi_values(self, element, **kw): return () + def visit_propagate_attrs(self, element, **kw): + return () + _get_children = _GetChildren() @@ -916,6 +1004,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): ): return COMPARE_FAILED + def visit_propagate_attrs( + self, attrname, left_parent, left, right_parent, right, **kw + ): + return self.compare_inner( + left.get("plugin_subject", None), right.get("plugin_subject", None) + ) + def visit_has_cache_key_list( self, attrname, left_parent, left, right_parent, right, **kw ): diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index ccda21e11..fe3634bad 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -555,7 +555,12 @@ class TypeEngine(Traversible): def _static_cache_key(self): names = util.get_cls_kwargs(self.__class__) return (self.__class__,) + tuple( - (k, self.__dict__[k]) + ( + k, + self.__dict__[k]._static_cache_key + if isinstance(self.__dict__[k], TypeEngine) + else self.__dict__[k], + ) for k in names if k in self.__dict__ and not k.startswith("_") ) diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 5de68f504..904702003 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -217,18 +217,23 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): try: dispatcher = target.__class__.__dict__[generate_dispatcher_name] except KeyError: + # most of the dispatchers are generated up front + # in sqlalchemy/sql/__init__.py -> + # traversals.py-> _preconfigure_traversals(). + # this block will generate any remaining dispatchers. dispatcher = self.generate_dispatch( - target, internal_dispatch, generate_dispatcher_name + target.__class__, internal_dispatch, generate_dispatcher_name ) return dispatcher(target, self) def generate_dispatch( - self, target, internal_dispatch, generate_dispatcher_name + self, target_cls, internal_dispatch, generate_dispatcher_name ): dispatcher = _generate_dispatcher( self, internal_dispatch, generate_dispatcher_name ) - setattr(target.__class__, generate_dispatcher_name, dispatcher) + # assert isinstance(target_cls, type) + setattr(target_cls, generate_dispatcher_name, dispatcher) return dispatcher dp_has_cache_key = symbol("HC") @@ -263,10 +268,6 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): """ - dp_clauseelement_unordered_set = symbol("CU") - """Visit an unordered set of :class:`_expression.ClauseElement` - objects. """ - dp_fromclause_ordered_set = symbol("CO") """Visit an ordered set of :class:`_expression.FromClause` objects. """ @@ -414,6 +415,10 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): """ + dp_propagate_attrs = symbol("PA") + """Visit the propagate attrs dict. this hardcodes to the particular + elements we care about right now.""" + class ExtendedInternalTraversal(InternalTraversal): """defines additional symbols that are useful in caching applications. |
