summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/dialects/mysql/dml.py1
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py17
-rw-r--r--lib/sqlalchemy/dialects/postgresql/dml.py2
-rw-r--r--lib/sqlalchemy/ext/baked.py2
-rw-r--r--lib/sqlalchemy/ext/compiler.py2
-rw-r--r--lib/sqlalchemy/orm/attributes.py10
-rw-r--r--lib/sqlalchemy/orm/base.py1
-rw-r--r--lib/sqlalchemy/orm/interfaces.py13
-rw-r--r--lib/sqlalchemy/orm/mapper.py6
-rw-r--r--lib/sqlalchemy/orm/path_registry.py12
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py38
-rw-r--r--lib/sqlalchemy/orm/util.py10
-rw-r--r--lib/sqlalchemy/sql/annotation.py69
-rw-r--r--lib/sqlalchemy/sql/base.py36
-rw-r--r--lib/sqlalchemy/sql/clause_compare.py334
-rw-r--r--lib/sqlalchemy/sql/compiler.py29
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py3
-rw-r--r--lib/sqlalchemy/sql/elements.py515
-rw-r--r--lib/sqlalchemy/sql/expression.py2
-rw-r--r--lib/sqlalchemy/sql/functions.py70
-rw-r--r--lib/sqlalchemy/sql/schema.py18
-rw-r--r--lib/sqlalchemy/sql/selectable.py396
-rw-r--r--lib/sqlalchemy/sql/traversals.py768
-rw-r--r--lib/sqlalchemy/sql/type_api.py17
-rw-r--r--lib/sqlalchemy/sql/util.py2
-rw-r--r--lib/sqlalchemy/sql/visitors.py447
26 files changed, 1733 insertions, 1087 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py
index 293aa426d..b43b364fa 100644
--- a/lib/sqlalchemy/dialects/mysql/dml.py
+++ b/lib/sqlalchemy/dialects/mysql/dml.py
@@ -103,7 +103,6 @@ class Insert(StandardInsert):
inserted_alias = getattr(self, "inserted_alias", None)
self._post_values_clause = OnDuplicateClause(inserted_alias, values)
- return self
insert = public_factory(Insert, ".dialects.mysql.insert")
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 909d568a7..e94f9913c 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1658,23 +1658,20 @@ class PGCompiler(compiler.SQLCompiler):
return "ONLY " + sqltext
def get_select_precolumns(self, select, **kw):
- if select._distinct is not False:
- if select._distinct is True:
- return "DISTINCT "
- elif isinstance(select._distinct, (list, tuple)):
+ if select._distinct or select._distinct_on:
+ if select._distinct_on:
return (
"DISTINCT ON ("
+ ", ".join(
- [self.process(col, **kw) for col in select._distinct]
+ [
+ self.process(col, **kw)
+ for col in select._distinct_on
+ ]
)
+ ") "
)
else:
- return (
- "DISTINCT ON ("
- + self.process(select._distinct, **kw)
- + ") "
- )
+ return "DISTINCT "
else:
return ""
diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py
index 4e77f5a4c..f4467976a 100644
--- a/lib/sqlalchemy/dialects/postgresql/dml.py
+++ b/lib/sqlalchemy/dialects/postgresql/dml.py
@@ -103,7 +103,6 @@ class Insert(StandardInsert):
self._post_values_clause = OnConflictDoUpdate(
constraint, index_elements, index_where, set_, where
)
- return self
@_generative
def on_conflict_do_nothing(
@@ -138,7 +137,6 @@ class Insert(StandardInsert):
self._post_values_clause = OnConflictDoNothing(
constraint, index_elements, index_where
)
- return self
insert = public_factory(Insert, ".dialects.postgresql.insert")
diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py
index d18a35a40..8e137f141 100644
--- a/lib/sqlalchemy/ext/baked.py
+++ b/lib/sqlalchemy/ext/baked.py
@@ -198,7 +198,7 @@ class BakedQuery(object):
self.spoil()
else:
for opt in options:
- cache_key = opt._generate_cache_key(cache_path)
+ cache_key = opt._generate_path_cache_key(cache_path)
if cache_key is False:
self.spoil()
elif cache_key is not None:
diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py
index 4a5a8ba9c..c2b234758 100644
--- a/lib/sqlalchemy/ext/compiler.py
+++ b/lib/sqlalchemy/ext/compiler.py
@@ -455,7 +455,7 @@ def deregister(class_):
if hasattr(class_, "_compiler_dispatcher"):
# regenerate default _compiler_dispatch
- visitors._generate_dispatch(class_)
+ visitors._generate_compiler_dispatch(class_)
# remove custom directive
del class_._compiler_dispatcher
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 83069f113..aa2986205 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -47,6 +47,8 @@ from .base import state_str
from .. import event
from .. import inspection
from .. import util
+from ..sql import base as sql_base
+from ..sql import visitors
@inspection._self_inspects
@@ -54,6 +56,7 @@ class QueryableAttribute(
interfaces._MappedAttribute,
interfaces.InspectionAttr,
interfaces.PropComparator,
+ sql_base.HasCacheKey,
):
"""Base class for :term:`descriptor` objects that intercept
attribute events on behalf of a :class:`.MapperProperty`
@@ -102,6 +105,13 @@ class QueryableAttribute(
if base[key].dispatch._active_history:
self.dispatch._active_history = True
+ _cache_key_traversal = [
+ # ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("key", visitors.ExtendedInternalTraversal.dp_string),
+ ("_parententity", visitors.ExtendedInternalTraversal.dp_multi),
+ ("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+ ]
+
@util.memoized_property
def _supports_population(self):
return self.impl.supports_population
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
index 6f8d19293..a3dea6b0e 100644
--- a/lib/sqlalchemy/orm/base.py
+++ b/lib/sqlalchemy/orm/base.py
@@ -216,7 +216,6 @@ def _assertions(*assertions):
for assertion in assertions:
assertion(self, fn.__name__)
fn(self, *args[1:], **kw)
- return self
return generate
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index e94a81fed..704ce9df7 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -36,6 +36,8 @@ from .. import inspect
from .. import inspection
from .. import util
from ..sql import operators
+from ..sql import visitors
+from ..sql.traversals import HasCacheKey
__all__ = (
@@ -54,7 +56,9 @@ __all__ = (
)
-class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots):
+class MapperProperty(
+ HasCacheKey, _MappedAttribute, InspectionAttr, util.MemoizedSlots
+):
"""Represent a particular class attribute mapped by :class:`.Mapper`.
The most common occurrences of :class:`.MapperProperty` are the
@@ -74,6 +78,11 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots):
"info",
)
+ _cache_key_traversal = [
+ ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+ ("key", visitors.ExtendedInternalTraversal.dp_string),
+ ]
+
cascade = frozenset()
"""The set of 'cascade' attribute names.
@@ -647,7 +656,7 @@ class MapperOption(object):
self.process_query(query)
- def _generate_cache_key(self, path):
+ def _generate_path_cache_key(self, path):
"""Used by the "baked lazy loader" to see if this option can be cached.
The "baked lazy loader" refers to the :class:`.Query` that is
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 376ad1923..548eca58d 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -71,7 +71,7 @@ _CONFIGURE_MUTEX = util.threading.RLock()
@inspection._self_inspects
@log.class_logger
-class Mapper(InspectionAttr):
+class Mapper(sql_base.HasCacheKey, InspectionAttr):
"""Define the correlation of class attributes to database table
columns.
@@ -729,6 +729,10 @@ class Mapper(InspectionAttr):
"""
return self
+ _cache_key_traversal = [
+ ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj)
+ ]
+
@property
def entity(self):
r"""Part of the inspection API.
diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py
index 2f680a3a1..585cb80bc 100644
--- a/lib/sqlalchemy/orm/path_registry.py
+++ b/lib/sqlalchemy/orm/path_registry.py
@@ -15,7 +15,8 @@ from .base import class_mapper
from .. import exc
from .. import inspection
from .. import util
-
+from ..sql import visitors
+from ..sql.traversals import HasCacheKey
log = logging.getLogger(__name__)
@@ -28,7 +29,7 @@ _WILDCARD_TOKEN = "*"
_DEFAULT_TOKEN = "_sa_default"
-class PathRegistry(object):
+class PathRegistry(HasCacheKey):
"""Represent query load paths and registry functions.
Basically represents structures like:
@@ -57,6 +58,10 @@ class PathRegistry(object):
is_token = False
is_root = False
+ _cache_key_traversal = [
+ ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list)
+ ]
+
def __eq__(self, other):
return other is not None and self.path == other.path
@@ -78,6 +83,9 @@ class PathRegistry(object):
def __len__(self):
return len(self.path)
+ def __hash__(self):
+ return id(self)
+
@property
def length(self):
return len(self.path)
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
index 26f47f616..99bbbe37c 100644
--- a/lib/sqlalchemy/orm/strategy_options.py
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -26,11 +26,13 @@ from .. import inspect
from .. import util
from ..sql import coercions
from ..sql import roles
+from ..sql import visitors
from ..sql.base import _generative
from ..sql.base import Generative
+from ..sql.traversals import HasCacheKey
-class Load(Generative, MapperOption):
+class Load(HasCacheKey, Generative, MapperOption):
"""Represents loader options which modify the state of a
:class:`.Query` in order to affect how various mapped attributes are
loaded.
@@ -70,6 +72,17 @@ class Load(Generative, MapperOption):
"""
+ _cache_key_traversal = [
+ ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+ ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+ (
+ "_context_cache_key",
+ visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples,
+ ),
+ ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict),
+ ]
+
def __init__(self, entity):
insp = inspect(entity)
self.path = insp._path_registry
@@ -89,7 +102,16 @@ class Load(Generative, MapperOption):
load._of_type = None
return load
- def _generate_cache_key(self, path):
+ @property
+ def _context_cache_key(self):
+ serialized = []
+ for (key, loader_path), obj in self.context.items():
+ if key != "loader":
+ continue
+ serialized.append(loader_path + (obj,))
+ return serialized
+
+ def _generate_path_cache_key(self, path):
if path.path[0].is_aliased_class:
return False
@@ -522,9 +544,16 @@ class _UnboundLoad(Load):
self._to_bind = []
self.local_opts = {}
+ _cache_key_traversal = [
+ ("path", visitors.ExtendedInternalTraversal.dp_multi_list),
+ ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("_to_bind", visitors.ExtendedInternalTraversal.dp_has_cache_key_list),
+ ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict),
+ ]
+
_is_chain_link = False
- def _generate_cache_key(self, path):
+ def _generate_path_cache_key(self, path):
serialized = ()
for val in self._to_bind:
for local_elem, val_elem in zip(self.path, val.path):
@@ -533,7 +562,7 @@ class _UnboundLoad(Load):
else:
opt = val._bind_loader([path.path[0]], None, None, False)
if opt:
- c_key = opt._generate_cache_key(path)
+ c_key = opt._generate_path_cache_key(path)
if c_key is False:
return False
elif c_key:
@@ -660,7 +689,6 @@ class _UnboundLoad(Load):
opt = meth(opt, all_tokens[-1], **kw)
opt._is_chain_link = False
-
return opt
def _chop_path(self, to_chop, path):
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 5f0f41e8d..c86993678 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -30,10 +30,12 @@ from .. import exc as sa_exc
from .. import inspection
from .. import sql
from .. import util
+from ..sql import base as sql_base
from ..sql import coercions
from ..sql import expression
from ..sql import roles
from ..sql import util as sql_util
+from ..sql import visitors
all_cascades = frozenset(
@@ -530,7 +532,7 @@ class AliasedClass(object):
return str(self._aliased_insp)
-class AliasedInsp(InspectionAttr):
+class AliasedInsp(sql_base.HasCacheKey, InspectionAttr):
"""Provide an inspection interface for an
:class:`.AliasedClass` object.
@@ -627,6 +629,12 @@ class AliasedInsp(InspectionAttr):
def __clause_element__(self):
return self.selectable
+ _cache_key_traversal = [
+ ("name", visitors.ExtendedInternalTraversal.dp_string),
+ ("_adapt_on_names", visitors.ExtendedInternalTraversal.dp_boolean),
+ ("selectable", visitors.ExtendedInternalTraversal.dp_clauseelement),
+ ]
+
@property
def class_(self):
"""Return the mapped class ultimately represented by this
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index a0264845e..0d995ec8a 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -12,12 +12,32 @@ associations.
"""
from . import operators
+from .base import HasCacheKey
+from .visitors import InternalTraversal
from .. import util
-class SupportsCloneAnnotations(object):
+class SupportsAnnotations(object):
+ @util.memoized_property
+ def _annotation_traversals(self):
+ return [
+ (
+ key,
+ InternalTraversal.dp_has_cache_key
+ if isinstance(value, HasCacheKey)
+ else InternalTraversal.dp_plain_obj,
+ )
+ for key, value in self._annotations.items()
+ ]
+
+
+class SupportsCloneAnnotations(SupportsAnnotations):
_annotations = util.immutabledict()
+ _traverse_internals = [
+ ("_annotations", InternalTraversal.dp_annotations_state)
+ ]
+
def _annotate(self, values):
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
@@ -25,6 +45,7 @@ class SupportsCloneAnnotations(object):
"""
new = self._clone()
new._annotations = new._annotations.union(values)
+ new.__dict__.pop("_annotation_traversals", None)
return new
def _with_annotations(self, values):
@@ -34,6 +55,7 @@ class SupportsCloneAnnotations(object):
"""
new = self._clone()
new._annotations = util.immutabledict(values)
+ new.__dict__.pop("_annotation_traversals", None)
return new
def _deannotate(self, values=None, clone=False):
@@ -49,12 +71,13 @@ class SupportsCloneAnnotations(object):
# the expression for a deep deannotation
new = self._clone()
new._annotations = {}
+ new.__dict__.pop("_annotation_traversals", None)
return new
else:
return self
-class SupportsWrappingAnnotations(object):
+class SupportsWrappingAnnotations(SupportsAnnotations):
def _annotate(self, values):
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
@@ -123,6 +146,7 @@ class Annotated(object):
def __init__(self, element, values):
self.__dict__ = element.__dict__.copy()
+ self.__dict__.pop("_annotation_traversals", None)
self.__element = element
self._annotations = values
self._hash = hash(element)
@@ -135,6 +159,7 @@ class Annotated(object):
def _with_annotations(self, values):
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
+ clone.__dict__.pop("_annotation_traversals", None)
clone._annotations = values
return clone
@@ -192,7 +217,17 @@ def _deep_annotate(element, annotations, exclude=None):
"""
- def clone(elem):
+ # annotated objects hack the __hash__() method so if we want to
+ # uniquely process them we have to use id()
+
+ cloned_ids = {}
+
+ def clone(elem, **kw):
+ id_ = id(elem)
+
+ if id_ in cloned_ids:
+ return cloned_ids[id_]
+
if (
exclude
and hasattr(elem, "proxy_set")
@@ -204,6 +239,7 @@ def _deep_annotate(element, annotations, exclude=None):
else:
newelem = elem
newelem._copy_internals(clone=clone)
+ cloned_ids[id_] = newelem
return newelem
if element is not None:
@@ -214,23 +250,21 @@ def _deep_annotate(element, annotations, exclude=None):
def _deep_deannotate(element, values=None):
"""Deep copy the given element, removing annotations."""
- cloned = util.column_dict()
+ cloned = {}
- def clone(elem):
- # if a values dict is given,
- # the elem must be cloned each time it appears,
- # as there may be different annotations in source
- # elements that are remaining. if totally
- # removing all annotations, can assume the same
- # slate...
- if values or elem not in cloned:
+ def clone(elem, **kw):
+ if values:
+ key = id(elem)
+ else:
+ key = elem
+
+ if key not in cloned:
newelem = elem._deannotate(values=values, clone=True)
newelem._copy_internals(clone=clone)
- if not values:
- cloned[elem] = newelem
+ cloned[key] = newelem
return newelem
else:
- return cloned[elem]
+ return cloned[key]
if element is not None:
element = clone(element)
@@ -268,6 +302,11 @@ def _new_annotation_type(cls, base_cls):
"Annotated%s" % cls.__name__, (base_cls, cls), {}
)
globals()["Annotated%s" % cls.__name__] = anno_cls
+
+ if "_traverse_internals" in cls.__dict__:
+ anno_cls._traverse_internals = list(cls._traverse_internals) + [
+ ("_annotations", InternalTraversal.dp_annotations_state)
+ ]
return anno_cls
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 7e9199bfa..d11a3a313 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -14,6 +14,7 @@ import itertools
import operator
import re
+from .traversals import HasCacheKey # noqa
from .visitors import ClauseVisitor
from .. import exc
from .. import util
@@ -38,18 +39,41 @@ class Immutable(object):
def _clone(self):
return self
+ def _copy_internals(self, **kw):
+ pass
+
+
+class HasMemoized(object):
+ def _reset_memoizations(self):
+ self._memoized_property.expire_instance(self)
+
+ def _reset_exported(self):
+ self._memoized_property.expire_instance(self)
+
+ def _copy_internals(self, **kw):
+ super(HasMemoized, self)._copy_internals(**kw)
+ self._reset_memoizations()
+
def _from_objects(*elements):
return itertools.chain(*[element._from_objects for element in elements])
def _generative(fn):
+ """non-caching _generative() decorator.
+
+ This is basically the legacy decorator that copies the object and
+ runs a method on the new copy.
+
+ """
+
@util.decorator
- def _generative(fn, *args, **kw):
+ def _generative(fn, self, *args, **kw):
"""Mark a method as generative."""
- self = args[0]._generate()
- fn(self, *args[1:], **kw)
+ self = self._generate()
+ x = fn(self, *args, **kw)
+ assert x is None, "generative methods must have no return value"
return self
decorated = _generative(fn)
@@ -357,10 +381,8 @@ class DialectKWArgs(object):
class Generative(object):
- """Allow a ClauseElement to generate itself via the
- @_generative decorator.
-
- """
+ """Provide a method-chaining pattern in conjunction with the
+ @_generative decorator."""
def _generate(self):
s = self.__class__.__new__(self.__class__)
diff --git a/lib/sqlalchemy/sql/clause_compare.py b/lib/sqlalchemy/sql/clause_compare.py
deleted file mode 100644
index 30a90348c..000000000
--- a/lib/sqlalchemy/sql/clause_compare.py
+++ /dev/null
@@ -1,334 +0,0 @@
-from collections import deque
-
-from . import operators
-from .. import util
-
-
-SKIP_TRAVERSE = util.symbol("skip_traverse")
-
-
-def compare(obj1, obj2, **kw):
- if kw.get("use_proxies", False):
- strategy = ColIdentityComparatorStrategy()
- else:
- strategy = StructureComparatorStrategy()
-
- return strategy.compare(obj1, obj2, **kw)
-
-
-class StructureComparatorStrategy(object):
- __slots__ = "compare_stack", "cache"
-
- def __init__(self):
- self.compare_stack = deque()
- self.cache = set()
-
- def compare(self, obj1, obj2, **kw):
- stack = self.compare_stack
- cache = self.cache
-
- stack.append((obj1, obj2))
-
- while stack:
- left, right = stack.popleft()
-
- if left is right:
- continue
- elif left is None or right is None:
- # we know they are different so no match
- return False
- elif (left, right) in cache:
- continue
- cache.add((left, right))
-
- visit_name = left.__visit_name__
-
- # we're not exactly looking for identical types, because
- # there are things like Column and AnnotatedColumn. So the
- # visit_name has to at least match up
- if visit_name != right.__visit_name__:
- return False
-
- meth = getattr(self, "compare_%s" % visit_name, None)
-
- if meth:
- comparison = meth(left, right, **kw)
- if comparison is False:
- return False
- elif comparison is SKIP_TRAVERSE:
- continue
-
- for c1, c2 in util.zip_longest(
- left.get_children(column_collections=False),
- right.get_children(column_collections=False),
- fillvalue=None,
- ):
- if c1 is None or c2 is None:
- # collections are different sizes, comparison fails
- return False
- stack.append((c1, c2))
-
- return True
-
- def compare_inner(self, obj1, obj2, **kw):
- stack = self.compare_stack
- try:
- self.compare_stack = deque()
- return self.compare(obj1, obj2, **kw)
- finally:
- self.compare_stack = stack
-
- def _compare_unordered_sequences(self, seq1, seq2, **kw):
- if seq1 is None:
- return seq2 is None
-
- completed = set()
- for clause in seq1:
- for other_clause in set(seq2).difference(completed):
- if self.compare_inner(clause, other_clause, **kw):
- completed.add(other_clause)
- break
- return len(completed) == len(seq1) == len(seq2)
-
- def compare_bindparam(self, left, right, **kw):
- # note the ".key" is often generated from id(self) so can't
- # be compared, as far as determining structure.
- return (
- left.type._compare_type_affinity(right.type)
- and left.value == right.value
- and left.callable == right.callable
- and left._orig_key == right._orig_key
- )
-
- def compare_clauselist(self, left, right, **kw):
- if left.operator is right.operator:
- if operators.is_associative(left.operator):
- if self._compare_unordered_sequences(
- left.clauses, right.clauses
- ):
- return SKIP_TRAVERSE
- else:
- return False
- else:
- # normal ordered traversal
- return True
- else:
- return False
-
- def compare_unary(self, left, right, **kw):
- if left.operator:
- disp = self._get_operator_dispatch(
- left.operator, "unary", "operator"
- )
- if disp is not None:
- result = disp(left, right, left.operator, **kw)
- if result is not True:
- return result
- elif left.modifier:
- disp = self._get_operator_dispatch(
- left.modifier, "unary", "modifier"
- )
- if disp is not None:
- result = disp(left, right, left.operator, **kw)
- if result is not True:
- return result
- return (
- left.operator == right.operator and left.modifier == right.modifier
- )
-
- def compare_binary(self, left, right, **kw):
- disp = self._get_operator_dispatch(left.operator, "binary", None)
- if disp:
- result = disp(left, right, left.operator, **kw)
- if result is not True:
- return result
-
- if left.operator == right.operator:
- if operators.is_commutative(left.operator):
- if (
- compare(left.left, right.left, **kw)
- and compare(left.right, right.right, **kw)
- ) or (
- compare(left.left, right.right, **kw)
- and compare(left.right, right.left, **kw)
- ):
- return SKIP_TRAVERSE
- else:
- return False
- else:
- return True
- else:
- return False
-
- def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
- # used by compare_binary, compare_unary
- attrname = "visit_%s_%s%s" % (
- operator_.__name__,
- qualifier1,
- "_" + qualifier2 if qualifier2 else "",
- )
- return getattr(self, attrname, None)
-
- def visit_function_as_comparison_op_binary(
- self, left, right, operator, **kw
- ):
- return (
- left.left_index == right.left_index
- and left.right_index == right.right_index
- )
-
- def compare_function(self, left, right, **kw):
- return left.name == right.name
-
- def compare_column(self, left, right, **kw):
- if left.table is not None:
- self.compare_stack.appendleft((left.table, right.table))
- return (
- left.key == right.key
- and left.name == right.name
- and (
- left.type._compare_type_affinity(right.type)
- if left.type is not None
- else right.type is None
- )
- and left.is_literal == right.is_literal
- )
-
- def compare_collation(self, left, right, **kw):
- return left.collation == right.collation
-
- def compare_type_coerce(self, left, right, **kw):
- return left.type._compare_type_affinity(right.type)
-
- @util.dependencies("sqlalchemy.sql.elements")
- def compare_alias(self, elements, left, right, **kw):
- return (
- left.name == right.name
- if not isinstance(left.name, elements._anonymous_label)
- else isinstance(right.name, elements._anonymous_label)
- )
-
- def compare_cte(self, elements, left, right, **kw):
- raise NotImplementedError("TODO")
-
- def compare_extract(self, left, right, **kw):
- return left.field == right.field
-
- def compare_textual_label_reference(self, left, right, **kw):
- return left.element == right.element
-
- def compare_slice(self, left, right, **kw):
- return (
- left.start == right.start
- and left.stop == right.stop
- and left.step == right.step
- )
-
- def compare_over(self, left, right, **kw):
- return left.range_ == right.range_ and left.rows == right.rows
-
- @util.dependencies("sqlalchemy.sql.elements")
- def compare_label(self, elements, left, right, **kw):
- return left._type._compare_type_affinity(right._type) and (
- left.name == right.name
- if not isinstance(left.name, elements._anonymous_label)
- else isinstance(right.name, elements._anonymous_label)
- )
-
- def compare_typeclause(self, left, right, **kw):
- return left.type._compare_type_affinity(right.type)
-
- def compare_join(self, left, right, **kw):
- return left.isouter == right.isouter and left.full == right.full
-
- def compare_table(self, left, right, **kw):
- if left.name != right.name:
- return False
-
- self.compare_stack.extendleft(
- util.zip_longest(left.columns, right.columns)
- )
-
- def compare_compound_select(self, left, right, **kw):
-
- if not self._compare_unordered_sequences(
- left.selects, right.selects, **kw
- ):
- return False
-
- if left.keyword != right.keyword:
- return False
-
- if left._for_update_arg != right._for_update_arg:
- return False
-
- if not self.compare_inner(
- left._order_by_clause, right._order_by_clause, **kw
- ):
- return False
-
- if not self.compare_inner(
- left._group_by_clause, right._group_by_clause, **kw
- ):
- return False
-
- return SKIP_TRAVERSE
-
- def compare_select(self, left, right, **kw):
- if not self._compare_unordered_sequences(
- left._correlate, right._correlate
- ):
- return False
- if not self._compare_unordered_sequences(
- left._correlate_except, right._correlate_except
- ):
- return False
-
- if not self._compare_unordered_sequences(
- left._from_obj, right._from_obj
- ):
- return False
-
- if left._for_update_arg != right._for_update_arg:
- return False
-
- return True
-
- def compare_textual_select(self, left, right, **kw):
- self.compare_stack.extendleft(
- util.zip_longest(left.column_args, right.column_args)
- )
- return left.positional == right.positional
-
-
-class ColIdentityComparatorStrategy(StructureComparatorStrategy):
- def compare_column_element(
- self, left, right, use_proxies=True, equivalents=(), **kw
- ):
- """Compare ColumnElements using proxies and equivalent collections.
-
- This is a comparison strategy specific to the ORM.
- """
-
- to_compare = (right,)
- if equivalents and right in equivalents:
- to_compare = equivalents[right].union(to_compare)
-
- for oth in to_compare:
- if use_proxies and left.shares_lineage(oth):
- return True
- elif hash(left) == hash(right):
- return True
- else:
- return False
-
- def compare_column(self, left, right, **kw):
- return self.compare_column_element(left, right, **kw)
-
- def compare_label(self, left, right, **kw):
- return self.compare_column_element(left, right, **kw)
-
- def compare_table(self, left, right, **kw):
- # tables compare on identity, since it's not really feasible to
- # compare them column by column with the above rules
- return left is right
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 5ecec7d6c..546fffc6c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -434,6 +434,27 @@ class _CompileLabel(elements.ColumnElement):
return self
+class prefix_anon_map(dict):
+ """A map that creates new keys for missing key access.
+
+ Considers keys of the form "<ident> <name>" to produce
+ new symbols "<name>_<index>", where "index" is an incrementing integer
+ corresponding to <name>.
+
+ Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+ is otherwise usually used for this type of operation.
+
+ """
+
+ def __missing__(self, key):
+ (ident, derived) = key.split(" ", 1)
+ anonymous_counter = self.get(derived, 1)
+ self[derived] = anonymous_counter + 1
+ value = derived + "_" + str(anonymous_counter)
+ self[key] = value
+ return value
+
+
class SQLCompiler(Compiled):
"""Default implementation of :class:`.Compiled`.
@@ -574,7 +595,7 @@ class SQLCompiler(Compiled):
# a map which tracks "anonymous" identifiers that are created on
# the fly here
- self.anon_map = util.PopulateDict(self._process_anon)
+ self.anon_map = prefix_anon_map()
# a map which tracks "truncated" names based on
# dialect.label_length or dialect.max_identifier_length
@@ -1712,12 +1733,6 @@ class SQLCompiler(Compiled):
def _anonymize(self, name):
return name % self.anon_map
- def _process_anon(self, key):
- (ident, derived) = key.split(" ", 1)
- anonymous_counter = self.anon_map.get(derived, 1)
- self.anon_map[derived] = anonymous_counter + 1
- return derived + "_" + str(anonymous_counter)
-
def bindparam_string(
self,
name,
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
index 918f7524e..c0baa8555 100644
--- a/lib/sqlalchemy/sql/default_comparator.py
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -178,6 +178,9 @@ def _unsupported_impl(expr, op, *arg, **kw):
def _inv_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.__inv__`."""
+
+ # undocumented element currently used by the ORM for
+ # relationship.contains()
if hasattr(expr, "negation_clause"):
return expr.negation_clause
else:
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index e6f57b8d1..ba615bc3f 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -16,23 +16,29 @@ import itertools
import operator
import re
-from . import clause_compare
from . import coercions
from . import operators
from . import roles
+from . import traversals
from . import type_api
from .annotation import Annotated
from .annotation import SupportsWrappingAnnotations
from .base import _clone
from .base import _generative
from .base import Executable
+from .base import HasCacheKey
+from .base import HasMemoized
from .base import Immutable
from .base import NO_ARG
from .base import PARSE_AUTOCOMMIT
from .coercions import _document_text_coercion
+from .traversals import _copy_internals
+from .traversals import _get_children
+from .traversals import NO_CACHE
from .visitors import cloned_traverse
+from .visitors import InternalTraversal
from .visitors import traverse
-from .visitors import Visitable
+from .visitors import Traversible
from .. import exc
from .. import inspection
from .. import util
@@ -162,7 +168,9 @@ def not_(clause):
@inspection._self_inspects
-class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
+class ClauseElement(
+ roles.SQLRole, SupportsWrappingAnnotations, HasCacheKey, Traversible
+):
"""Base class for elements of a programmatically constructed SQL
expression.
@@ -190,6 +198,13 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
_order_by_label_element = None
+ @property
+ def _cache_key_traversal(self):
+ try:
+ return self._traverse_internals
+ except AttributeError:
+ return NO_CACHE
+
def _clone(self):
"""Create a shallow copy of this ClauseElement.
@@ -221,28 +236,6 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
"""
return self
- def _cache_key(self, **kw):
- """return an optional cache key.
-
- The cache key is a tuple which can contain any series of
- objects that are hashable and also identifies
- this object uniquely within the presence of a larger SQL expression
- or statement, for the purposes of caching the resulting query.
-
- The cache key should be based on the SQL compiled structure that would
- ultimately be produced. That is, two structures that are composed in
- exactly the same way should produce the same cache key; any difference
- in the strucures that would affect the SQL string or the type handlers
- should result in a different cache key.
-
- If a structure cannot produce a useful cache key, it should raise
- NotImplementedError, which will result in the entire structure
- for which it's part of not being useful as a cache key.
-
-
- """
- raise NotImplementedError()
-
@property
def _constructor(self):
"""return the 'constructor' for this ClauseElement.
@@ -336,9 +329,9 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
(see :class:`.ColumnElement`)
"""
- return clause_compare.compare(self, other, **kw)
+ return traversals.compare(self, other, **kw)
- def _copy_internals(self, clone=_clone, **kw):
+ def _copy_internals(self, **kw):
"""Reassign internal elements to be clones of themselves.
Called during a copy-and-traverse operation on newly
@@ -349,21 +342,46 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
traversal, cloned traversal, annotations).
"""
- pass
- def get_children(self, **kwargs):
- r"""Return immediate child elements of this :class:`.ClauseElement`.
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ return
+
+ for attrname, obj, meth in _copy_internals.run_generated_dispatch(
+ self, traverse_internals, "_generated_copy_internals_traversal"
+ ):
+ if obj is not None:
+ result = meth(self, obj, **kw)
+ if result is not None:
+ setattr(self, attrname, result)
+
+ def get_children(self, omit_attrs=None, **kw):
+ r"""Return immediate child :class:`.Traversible` elements of this
+ :class:`.Traversible`.
This is used for visit traversal.
- \**kwargs may contain flags that change the collection that is
+ \**kw may contain flags that change the collection that is
returned, for example to return a subset of items in order to
cut down on larger traversals, or to return child items from a
different context (such as schema-level collections instead of
clause-level).
"""
- return []
+ result = []
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ return result
+
+ for attrname, obj, meth in _get_children.run_generated_dispatch(
+ self, traverse_internals, "_generated_get_children_traversal"
+ ):
+ if obj is None or omit_attrs and attrname in omit_attrs:
+ continue
+ result.extend(meth(obj, **kw))
+ return result
def self_group(self, against=None):
# type: (Optional[Any]) -> ClauseElement
@@ -501,6 +519,8 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
return or_(self, other)
def __invert__(self):
+ # undocumented element currently used by the ORM for
+ # relationship.contains()
if hasattr(self, "negation_clause"):
return self.negation_clause
else:
@@ -508,9 +528,7 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
def _negate(self):
return UnaryExpression(
- self.self_group(against=operators.inv),
- operator=operators.inv,
- negate=None,
+ self.self_group(against=operators.inv), operator=operators.inv
)
def __bool__(self):
@@ -731,9 +749,6 @@ class ColumnElement(
else:
return comparator_factory(self)
- def _cache_key(self, **kw):
- raise NotImplementedError(self.__class__)
-
def __getattr__(self, key):
try:
return getattr(self.comparator, key)
@@ -969,6 +984,13 @@ class BindParameter(roles.InElementRole, ColumnElement):
__visit_name__ = "bindparam"
+ _traverse_internals = [
+ ("key", InternalTraversal.dp_anon_name),
+ ("type", InternalTraversal.dp_type),
+ ("callable", InternalTraversal.dp_plain_dict),
+ ("value", InternalTraversal.dp_plain_obj),
+ ]
+
_is_crud = False
_expanding_in_types = ()
@@ -1321,26 +1343,19 @@ class BindParameter(roles.InElementRole, ColumnElement):
)
return c
- def _cache_key(self, bindparams=None, **kw):
- if bindparams is None:
- # even though _cache_key is a private method, we would like to
- # be super paranoid about this point. You can't include the
- # "value" or "callable" in the cache key, because the value is
- # not part of the structure of a statement and is likely to
- # change every time. However you cannot *throw it away* either,
- # because you can't invoke the statement without the parameter
- # values that were explicitly placed. So require that they
- # are collected here to make sure this happens.
- if self._value_required_for_cache:
- raise NotImplementedError(
- "bindparams collection argument required for _cache_key "
- "implementation. Bound parameter cache keys are not safe "
- "to use without accommodating for the value or callable "
- "within the parameter itself."
- )
- else:
- bindparams.append(self)
- return (BindParameter, self.type._cache_key, self._orig_key)
+ def _gen_cache_key(self, anon_map, bindparams):
+ if self in anon_map:
+ return (anon_map[self], self.__class__)
+
+ id_ = anon_map[self]
+ bindparams.append(self)
+
+ return (
+ id_,
+ self.__class__,
+ self.type._gen_cache_key,
+ traversals._resolve_name_for_compare(self, self.key, anon_map),
+ )
def _convert_to_unique(self):
if not self.unique:
@@ -1377,12 +1392,11 @@ class TypeClause(ClauseElement):
__visit_name__ = "typeclause"
+ _traverse_internals = [("type", InternalTraversal.dp_type)]
+
def __init__(self, type_):
self.type = type_
- def _cache_key(self, **kw):
- return (TypeClause, self.type._cache_key)
-
class TextClause(
roles.DDLConstraintColumnRole,
@@ -1419,6 +1433,11 @@ class TextClause(
__visit_name__ = "textclause"
+ _traverse_internals = [
+ ("_bindparams", InternalTraversal.dp_string_clauseelement_dict),
+ ("text", InternalTraversal.dp_string),
+ ]
+
_is_text_clause = True
_is_textual = True
@@ -1861,19 +1880,6 @@ class TextClause(
else:
return self
- def _copy_internals(self, clone=_clone, **kw):
- self._bindparams = dict(
- (b.key, clone(b, **kw)) for b in self._bindparams.values()
- )
-
- def get_children(self, **kwargs):
- return list(self._bindparams.values())
-
- def _cache_key(self, **kw):
- return (self.text,) + tuple(
- bind._cache_key for bind in self._bindparams.values()
- )
-
class Null(roles.ConstExprRole, ColumnElement):
"""Represent the NULL keyword in a SQL statement.
@@ -1885,6 +1891,8 @@ class Null(roles.ConstExprRole, ColumnElement):
__visit_name__ = "null"
+ _traverse_internals = []
+
@util.memoized_property
def type(self):
return type_api.NULLTYPE
@@ -1895,9 +1903,6 @@ class Null(roles.ConstExprRole, ColumnElement):
return Null()
- def _cache_key(self, **kw):
- return (Null,)
-
class False_(roles.ConstExprRole, ColumnElement):
"""Represent the ``false`` keyword, or equivalent, in a SQL statement.
@@ -1908,6 +1913,7 @@ class False_(roles.ConstExprRole, ColumnElement):
"""
__visit_name__ = "false"
+ _traverse_internals = []
@util.memoized_property
def type(self):
@@ -1954,9 +1960,6 @@ class False_(roles.ConstExprRole, ColumnElement):
return False_()
- def _cache_key(self, **kw):
- return (False_,)
-
class True_(roles.ConstExprRole, ColumnElement):
"""Represent the ``true`` keyword, or equivalent, in a SQL statement.
@@ -1968,6 +1971,8 @@ class True_(roles.ConstExprRole, ColumnElement):
__visit_name__ = "true"
+ _traverse_internals = []
+
@util.memoized_property
def type(self):
return type_api.BOOLEANTYPE
@@ -2020,9 +2025,6 @@ class True_(roles.ConstExprRole, ColumnElement):
return True_()
- def _cache_key(self, **kw):
- return (True_,)
-
class ClauseList(
roles.InElementRole,
@@ -2038,6 +2040,11 @@ class ClauseList(
__visit_name__ = "clauselist"
+ _traverse_internals = [
+ ("clauses", InternalTraversal.dp_clauseelement_list),
+ ("operator", InternalTraversal.dp_operator),
+ ]
+
def __init__(self, *clauses, **kwargs):
self.operator = kwargs.pop("operator", operators.comma_op)
self.group = kwargs.pop("group", True)
@@ -2082,17 +2089,6 @@ class ClauseList(
coercions.expect(self._text_converter_role, clause)
)
- def _copy_internals(self, clone=_clone, **kw):
- self.clauses = [clone(clause, **kw) for clause in self.clauses]
-
- def get_children(self, **kwargs):
- return self.clauses
-
- def _cache_key(self, **kw):
- return (ClauseList, self.operator) + tuple(
- clause._cache_key(**kw) for clause in self.clauses
- )
-
@property
def _from_objects(self):
return list(itertools.chain(*[c._from_objects for c in self.clauses]))
@@ -2115,11 +2111,6 @@ class BooleanClauseList(ClauseList, ColumnElement):
"BooleanClauseList has a private constructor"
)
- def _cache_key(self, **kw):
- return (BooleanClauseList, self.operator) + tuple(
- clause._cache_key(**kw) for clause in self.clauses
- )
-
@classmethod
def _construct(cls, operator, continue_on, skip_on, *clauses, **kw):
convert_clauses = []
@@ -2250,6 +2241,8 @@ or_ = BooleanClauseList.or_
class Tuple(ClauseList, ColumnElement):
"""Represent a SQL tuple."""
+ _traverse_internals = ClauseList._traverse_internals + []
+
def __init__(self, *clauses, **kw):
"""Return a :class:`.Tuple`.
@@ -2289,11 +2282,6 @@ class Tuple(ClauseList, ColumnElement):
def _select_iterable(self):
return (self,)
- def _cache_key(self, **kw):
- return (Tuple,) + tuple(
- clause._cache_key(**kw) for clause in self.clauses
- )
-
def _bind_param(self, operator, obj, type_=None):
return Tuple(
*[
@@ -2339,6 +2327,12 @@ class Case(ColumnElement):
__visit_name__ = "case"
+ _traverse_internals = [
+ ("value", InternalTraversal.dp_clauseelement),
+ ("whens", InternalTraversal.dp_clauseelement_tuples),
+ ("else_", InternalTraversal.dp_clauseelement),
+ ]
+
def __init__(self, whens, value=None, else_=None):
r"""Produce a ``CASE`` expression.
@@ -2501,40 +2495,6 @@ class Case(ColumnElement):
else:
self.else_ = None
- def _copy_internals(self, clone=_clone, **kw):
- if self.value is not None:
- self.value = clone(self.value, **kw)
- self.whens = [(clone(x, **kw), clone(y, **kw)) for x, y in self.whens]
- if self.else_ is not None:
- self.else_ = clone(self.else_, **kw)
-
- def get_children(self, **kwargs):
- if self.value is not None:
- yield self.value
- for x, y in self.whens:
- yield x
- yield y
- if self.else_ is not None:
- yield self.else_
-
- def _cache_key(self, **kw):
- return (
- (
- Case,
- self.value._cache_key(**kw)
- if self.value is not None
- else None,
- )
- + tuple(
- (x._cache_key(**kw), y._cache_key(**kw)) for x, y in self.whens
- )
- + (
- self.else_._cache_key(**kw)
- if self.else_ is not None
- else None,
- )
- )
-
@property
def _from_objects(self):
return list(
@@ -2603,6 +2563,11 @@ class Cast(WrapsColumnExpression, ColumnElement):
__visit_name__ = "cast"
+ _traverse_internals = [
+ ("clause", InternalTraversal.dp_clauseelement),
+ ("typeclause", InternalTraversal.dp_clauseelement),
+ ]
+
def __init__(self, expression, type_):
r"""Produce a ``CAST`` expression.
@@ -2662,20 +2627,6 @@ class Cast(WrapsColumnExpression, ColumnElement):
)
self.typeclause = TypeClause(self.type)
- def _copy_internals(self, clone=_clone, **kw):
- self.clause = clone(self.clause, **kw)
- self.typeclause = clone(self.typeclause, **kw)
-
- def get_children(self, **kwargs):
- return self.clause, self.typeclause
-
- def _cache_key(self, **kw):
- return (
- Cast,
- self.clause._cache_key(**kw),
- self.typeclause._cache_key(**kw),
- )
-
@property
def _from_objects(self):
return self.clause._from_objects
@@ -2685,7 +2636,7 @@ class Cast(WrapsColumnExpression, ColumnElement):
return self.clause
-class TypeCoerce(WrapsColumnExpression, ColumnElement):
+class TypeCoerce(HasMemoized, WrapsColumnExpression, ColumnElement):
"""Represent a Python-side type-coercion wrapper.
:class:`.TypeCoerce` supplies the :func:`.expression.type_coerce`
@@ -2705,6 +2656,13 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement):
__visit_name__ = "type_coerce"
+ _traverse_internals = [
+ ("clause", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ]
+
+ _memoized_property = util.group_expirable_memoized_property()
+
def __init__(self, expression, type_):
r"""Associate a SQL expression with a particular type, without rendering
``CAST``.
@@ -2773,21 +2731,11 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement):
roles.ExpressionElementRole, expression, type_=self.type
)
- def _copy_internals(self, clone=_clone, **kw):
- self.clause = clone(self.clause, **kw)
- self.__dict__.pop("typed_expression", None)
-
- def get_children(self, **kwargs):
- return (self.clause,)
-
- def _cache_key(self, **kw):
- return (TypeCoerce, self.type._cache_key, self.clause._cache_key(**kw))
-
@property
def _from_objects(self):
return self.clause._from_objects
- @util.memoized_property
+ @_memoized_property
def typed_expression(self):
if isinstance(self.clause, BindParameter):
bp = self.clause._clone()
@@ -2806,6 +2754,11 @@ class Extract(ColumnElement):
__visit_name__ = "extract"
+ _traverse_internals = [
+ ("expr", InternalTraversal.dp_clauseelement),
+ ("field", InternalTraversal.dp_string),
+ ]
+
def __init__(self, field, expr, **kwargs):
"""Return a :class:`.Extract` construct.
@@ -2818,15 +2771,6 @@ class Extract(ColumnElement):
self.field = field
self.expr = coercions.expect(roles.ExpressionElementRole, expr)
- def _copy_internals(self, clone=_clone, **kw):
- self.expr = clone(self.expr, **kw)
-
- def get_children(self, **kwargs):
- return (self.expr,)
-
- def _cache_key(self, **kw):
- return (Extract, self.field, self.expr._cache_key(**kw))
-
@property
def _from_objects(self):
return self.expr._from_objects
@@ -2847,18 +2791,11 @@ class _label_reference(ColumnElement):
__visit_name__ = "label_reference"
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
def __init__(self, element):
self.element = element
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def _cache_key(self, **kw):
- return (_label_reference, self.element._cache_key(**kw))
-
- def get_children(self, **kwargs):
- return [self.element]
-
@property
def _from_objects(self):
return ()
@@ -2867,6 +2804,8 @@ class _label_reference(ColumnElement):
class _textual_label_reference(ColumnElement):
__visit_name__ = "textual_label_reference"
+ _traverse_internals = [("element", InternalTraversal.dp_string)]
+
def __init__(self, element):
self.element = element
@@ -2874,9 +2813,6 @@ class _textual_label_reference(ColumnElement):
def _text_clause(self):
return TextClause._create_text(self.element)
- def _cache_key(self, **kw):
- return (_textual_label_reference, self.element)
-
class UnaryExpression(ColumnElement):
"""Define a 'unary' expression.
@@ -2894,13 +2830,18 @@ class UnaryExpression(ColumnElement):
__visit_name__ = "unary"
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("operator", InternalTraversal.dp_operator),
+ ("modifier", InternalTraversal.dp_operator),
+ ]
+
def __init__(
self,
element,
operator=None,
modifier=None,
type_=None,
- negate=None,
wraps_column_expression=False,
):
self.operator = operator
@@ -2909,7 +2850,6 @@ class UnaryExpression(ColumnElement):
against=self.operator or self.modifier
)
self.type = type_api.to_instance(type_)
- self.negate = negate
self.wraps_column_expression = wraps_column_expression
@classmethod
@@ -3135,37 +3075,13 @@ class UnaryExpression(ColumnElement):
def _from_objects(self):
return self.element._from_objects
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def _cache_key(self, **kw):
- return (
- UnaryExpression,
- self.element._cache_key(**kw),
- self.operator,
- self.modifier,
- )
-
- def get_children(self, **kwargs):
- return (self.element,)
-
def _negate(self):
- if self.negate is not None:
- return UnaryExpression(
- self.element,
- operator=self.negate,
- negate=self.operator,
- modifier=self.modifier,
- type_=self.type,
- wraps_column_expression=self.wraps_column_expression,
- )
- elif self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
+ if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
return UnaryExpression(
self.self_group(against=operators.inv),
operator=operators.inv,
type_=type_api.BOOLEANTYPE,
wraps_column_expression=self.wraps_column_expression,
- negate=None,
)
else:
return ClauseElement._negate(self)
@@ -3286,15 +3202,6 @@ class AsBoolean(WrapsColumnExpression, UnaryExpression):
# type: (Optional[Any]) -> ClauseElement
return self
- def _cache_key(self, **kw):
- return (
- self.element._cache_key(**kw),
- self.type._cache_key,
- self.operator,
- self.negate,
- self.modifier,
- )
-
def _negate(self):
if isinstance(self.element, (True_, False_)):
return self.element._negate()
@@ -3318,6 +3225,14 @@ class BinaryExpression(ColumnElement):
__visit_name__ = "binary"
+ _traverse_internals = [
+ ("left", InternalTraversal.dp_clauseelement),
+ ("right", InternalTraversal.dp_clauseelement),
+ ("operator", InternalTraversal.dp_operator),
+ ("negate", InternalTraversal.dp_operator),
+ ("modifiers", InternalTraversal.dp_plain_dict),
+ ]
+
_is_implicitly_boolean = True
"""Indicates that any database will know this is a boolean expression
even if the database does not have an explicit boolean datatype.
@@ -3360,20 +3275,6 @@ class BinaryExpression(ColumnElement):
def _from_objects(self):
return self.left._from_objects + self.right._from_objects
- def _copy_internals(self, clone=_clone, **kw):
- self.left = clone(self.left, **kw)
- self.right = clone(self.right, **kw)
-
- def get_children(self, **kwargs):
- return self.left, self.right
-
- def _cache_key(self, **kw):
- return (
- BinaryExpression,
- self.left._cache_key(**kw),
- self.right._cache_key(**kw),
- )
-
def self_group(self, against=None):
# type: (Optional[Any]) -> ClauseElement
@@ -3406,6 +3307,12 @@ class Slice(ColumnElement):
__visit_name__ = "slice"
+ _traverse_internals = [
+ ("start", InternalTraversal.dp_plain_obj),
+ ("stop", InternalTraversal.dp_plain_obj),
+ ("step", InternalTraversal.dp_plain_obj),
+ ]
+
def __init__(self, start, stop, step):
self.start = start
self.stop = stop
@@ -3417,9 +3324,6 @@ class Slice(ColumnElement):
assert against is operator.getitem
return self
- def _cache_key(self, **kw):
- return (Slice, self.start, self.stop, self.step)
-
class IndexExpression(BinaryExpression):
"""Represent the class of expressions that are like an "index" operation.
@@ -3444,6 +3348,11 @@ class GroupedElement(ClauseElement):
class Grouping(GroupedElement, ColumnElement):
"""Represent a grouping within a column expression"""
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ]
+
def __init__(self, element):
self.element = element
self.type = getattr(element, "type", type_api.NULLTYPE)
@@ -3460,15 +3369,6 @@ class Grouping(GroupedElement, ColumnElement):
def _label(self):
return getattr(self.element, "_label", None) or self.anon_label
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def get_children(self, **kwargs):
- return (self.element,)
-
- def _cache_key(self, **kw):
- return (Grouping, self.element._cache_key(**kw))
-
@property
def _from_objects(self):
return self.element._from_objects
@@ -3501,6 +3401,14 @@ class Over(ColumnElement):
__visit_name__ = "over"
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("order_by", InternalTraversal.dp_clauseelement),
+ ("partition_by", InternalTraversal.dp_clauseelement),
+ ("range_", InternalTraversal.dp_plain_obj),
+ ("rows", InternalTraversal.dp_plain_obj),
+ ]
+
order_by = None
partition_by = None
@@ -3667,30 +3575,6 @@ class Over(ColumnElement):
def type(self):
return self.element.type
- def get_children(self, **kwargs):
- return [
- c
- for c in (self.element, self.partition_by, self.order_by)
- if c is not None
- ]
-
- def _cache_key(self, **kw):
- return (
- (Over,)
- + tuple(
- e._cache_key(**kw) if e is not None else None
- for e in (self.element, self.partition_by, self.order_by)
- )
- + (self.range_, self.rows)
- )
-
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
- if self.partition_by is not None:
- self.partition_by = clone(self.partition_by, **kw)
- if self.order_by is not None:
- self.order_by = clone(self.order_by, **kw)
-
@property
def _from_objects(self):
return list(
@@ -3723,6 +3607,11 @@ class WithinGroup(ColumnElement):
__visit_name__ = "withingroup"
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("order_by", InternalTraversal.dp_clauseelement),
+ ]
+
order_by = None
def __init__(self, element, *order_by):
@@ -3791,25 +3680,6 @@ class WithinGroup(ColumnElement):
else:
return self.element.type
- def get_children(self, **kwargs):
- return [c for c in (self.element, self.order_by) if c is not None]
-
- def _cache_key(self, **kw):
- return (
- WithinGroup,
- self.element._cache_key(**kw)
- if self.element is not None
- else None,
- self.order_by._cache_key(**kw)
- if self.order_by is not None
- else None,
- )
-
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
- if self.order_by is not None:
- self.order_by = clone(self.order_by, **kw)
-
@property
def _from_objects(self):
return list(
@@ -3845,6 +3715,11 @@ class FunctionFilter(ColumnElement):
__visit_name__ = "funcfilter"
+ _traverse_internals = [
+ ("func", InternalTraversal.dp_clauseelement),
+ ("criterion", InternalTraversal.dp_clauseelement),
+ ]
+
criterion = None
def __init__(self, func, *criterion):
@@ -3932,23 +3807,6 @@ class FunctionFilter(ColumnElement):
def type(self):
return self.func.type
- def get_children(self, **kwargs):
- return [c for c in (self.func, self.criterion) if c is not None]
-
- def _copy_internals(self, clone=_clone, **kw):
- self.func = clone(self.func, **kw)
- if self.criterion is not None:
- self.criterion = clone(self.criterion, **kw)
-
- def _cache_key(self, **kw):
- return (
- FunctionFilter,
- self.func._cache_key(**kw),
- self.criterion._cache_key(**kw)
- if self.criterion is not None
- else None,
- )
-
@property
def _from_objects(self):
return list(
@@ -3962,7 +3820,7 @@ class FunctionFilter(ColumnElement):
)
-class Label(roles.LabeledColumnExprRole, ColumnElement):
+class Label(HasMemoized, roles.LabeledColumnExprRole, ColumnElement):
"""Represents a column label (AS).
Represent a label, as typically applied to any column-level
@@ -3972,6 +3830,14 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
__visit_name__ = "label"
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_anon_name),
+ ("_type", InternalTraversal.dp_type),
+ ("_element", InternalTraversal.dp_clauseelement),
+ ]
+
+ _memoized_property = util.group_expirable_memoized_property()
+
def __init__(self, name, element, type_=None):
"""Return a :class:`Label` object for the
given :class:`.ColumnElement`.
@@ -4010,14 +3876,11 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
def __reduce__(self):
return self.__class__, (self.name, self._element, self._type)
- def _cache_key(self, **kw):
- return (Label, self.element._cache_key(**kw), self._resolve_label)
-
@util.memoized_property
def _is_implicitly_boolean(self):
return self.element._is_implicitly_boolean
- @util.memoized_property
+ @_memoized_property
def _allow_label_resolve(self):
return self.element._allow_label_resolve
@@ -4031,7 +3894,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
self._type or getattr(self._element, "type", None)
)
- @util.memoized_property
+ @_memoized_property
def element(self):
return self._element.self_group(against=operators.as_)
@@ -4057,13 +3920,9 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
def foreign_keys(self):
return self.element.foreign_keys
- def get_children(self, **kwargs):
- return (self.element,)
-
def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
+ self._reset_memoizations()
self._element = clone(self._element, **kw)
- self.__dict__.pop("element", None)
- self.__dict__.pop("_allow_label_resolve", None)
if anonymize_labels:
self.name = self._resolve_label = _anonymous_label(
"%%(%d %s)s"
@@ -4124,6 +3983,13 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
__visit_name__ = "column"
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_string),
+ ("type", InternalTraversal.dp_type),
+ ("table", InternalTraversal.dp_clauseelement),
+ ("is_literal", InternalTraversal.dp_boolean),
+ ]
+
onupdate = default = server_default = server_onupdate = None
_is_multiparam_column = False
@@ -4254,14 +4120,6 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
table = property(_get_table, _set_table)
- def _cache_key(self, **kw):
- return (
- self.name,
- self.table.name if self.table is not None else None,
- self.is_literal,
- self.type._cache_key,
- )
-
@_memoized_property
def _from_objects(self):
t = self.table
@@ -4395,12 +4253,11 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
class CollationClause(ColumnElement):
__visit_name__ = "collation"
+ _traverse_internals = [("collation", InternalTraversal.dp_string)]
+
def __init__(self, collation):
self.collation = collation
- def _cache_key(self, **kw):
- return (CollationClause, self.collation)
-
class _IdentifiedClause(Executable, ClauseElement):
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 7ce822669..08e69f075 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -86,7 +86,6 @@ __all__ = [
from .base import _from_objects # noqa
from .base import ColumnCollection # noqa
from .base import Executable # noqa
-from .base import Generative # noqa
from .base import PARSE_AUTOCOMMIT # noqa
from .dml import Delete # noqa
from .dml import Insert # noqa
@@ -242,7 +241,6 @@ _UnaryExpression = UnaryExpression
_Case = Case
_Tuple = Tuple
_Over = Over
-_Generative = Generative
_TypeClause = TypeClause
_Extract = Extract
_Exists = Exists
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index cbc8e539f..96e64dc28 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -17,7 +17,6 @@ from . import sqltypes
from . import util as sqlutil
from .base import ColumnCollection
from .base import Executable
-from .elements import _clone
from .elements import _type_from_args
from .elements import BinaryExpression
from .elements import BindParameter
@@ -33,7 +32,8 @@ from .elements import WithinGroup
from .selectable import Alias
from .selectable import FromClause
from .selectable import Select
-from .visitors import VisitableType
+from .visitors import InternalTraversal
+from .visitors import TraversibleType
from .. import util
@@ -78,10 +78,14 @@ class FunctionElement(Executable, ColumnElement, FromClause):
"""
+ _traverse_internals = [("clause_expr", InternalTraversal.dp_clauseelement)]
+
packagenames = ()
_has_args = False
+ _memoized_property = FromClause._memoized_property
+
def __init__(self, *clauses, **kwargs):
r"""Construct a :class:`.FunctionElement`.
@@ -136,7 +140,7 @@ class FunctionElement(Executable, ColumnElement, FromClause):
col = self.label(None)
return ColumnCollection(columns=[(col.key, col)])
- @util.memoized_property
+ @_memoized_property
def clauses(self):
"""Return the underlying :class:`.ClauseList` which contains
the arguments for this :class:`.FunctionElement`.
@@ -283,17 +287,6 @@ class FunctionElement(Executable, ColumnElement, FromClause):
def _from_objects(self):
return self.clauses._from_objects
- def get_children(self, **kwargs):
- return (self.clause_expr,)
-
- def _cache_key(self, **kw):
- return (FunctionElement, self.clause_expr._cache_key(**kw))
-
- def _copy_internals(self, clone=_clone, **kw):
- self.clause_expr = clone(self.clause_expr, **kw)
- self._reset_exported()
- FunctionElement.clauses._reset(self)
-
def within_group_type(self, within_group):
"""For types that define their return type as based on the criteria
within a WITHIN GROUP (ORDER BY) expression, called by the
@@ -404,6 +397,13 @@ class FunctionElement(Executable, ColumnElement, FromClause):
class FunctionAsBinary(BinaryExpression):
+ _traverse_internals = [
+ ("sql_function", InternalTraversal.dp_clauseelement),
+ ("left_index", InternalTraversal.dp_plain_obj),
+ ("right_index", InternalTraversal.dp_plain_obj),
+ ("modifiers", InternalTraversal.dp_plain_dict),
+ ]
+
def __init__(self, fn, left_index, right_index):
self.sql_function = fn
self.left_index = left_index
@@ -431,20 +431,6 @@ class FunctionAsBinary(BinaryExpression):
def right(self, value):
self.sql_function.clauses.clauses[self.right_index - 1] = value
- def _copy_internals(self, clone=_clone, **kw):
- self.sql_function = clone(self.sql_function, **kw)
-
- def get_children(self, **kw):
- yield self.sql_function
-
- def _cache_key(self, **kw):
- return (
- FunctionAsBinary,
- self.sql_function._cache_key(**kw),
- self.left_index,
- self.right_index,
- )
-
class _FunctionGenerator(object):
"""Generate SQL function expressions.
@@ -606,6 +592,12 @@ class Function(FunctionElement):
__visit_name__ = "function"
+ _traverse_internals = FunctionElement._traverse_internals + [
+ ("packagenames", InternalTraversal.dp_plain_obj),
+ ("name", InternalTraversal.dp_string),
+ ("type", InternalTraversal.dp_type),
+ ]
+
def __init__(self, name, *clauses, **kw):
"""Construct a :class:`.Function`.
@@ -630,15 +622,8 @@ class Function(FunctionElement):
unique=True,
)
- def _cache_key(self, **kw):
- return (
- (Function,) + tuple(self.packagenames)
- if self.packagenames
- else () + (self.name, self.clause_expr._cache_key(**kw))
- )
-
-class _GenericMeta(VisitableType):
+class _GenericMeta(TraversibleType):
def __init__(cls, clsname, bases, clsdict):
if annotation.Annotated not in cls.__mro__:
cls.name = name = clsdict.get("name", clsname)
@@ -764,6 +749,10 @@ class next_value(GenericFunction):
type = sqltypes.Integer()
name = "next_value"
+ _traverse_internals = [
+ ("sequence", InternalTraversal.dp_named_ddl_element)
+ ]
+
def __init__(self, seq, **kw):
assert isinstance(
seq, schema.Sequence
@@ -771,21 +760,12 @@ class next_value(GenericFunction):
self._bind = kw.get("bind", None)
self.sequence = seq
- def _cache_key(self, **kw):
- return (next_value, self.sequence.name)
-
def compare(self, other, **kw):
return (
isinstance(other, next_value)
and self.sequence.name == other.sequence.name
)
- def get_children(self, **kwargs):
- return []
-
- def _copy_internals(self, **kw):
- pass
-
@property
def _from_objects(self):
return []
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 4e8f4a397..ee7dc61ce 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -50,6 +50,7 @@ from .elements import ColumnElement
from .elements import quoted_name
from .elements import TextClause
from .selectable import TableClause
+from .visitors import InternalTraversal
from .. import event
from .. import exc
from .. import inspection
@@ -425,6 +426,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
__visit_name__ = "table"
+ _traverse_internals = TableClause._traverse_internals + [
+ ("schema", InternalTraversal.dp_string)
+ ]
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ return (self,)
+
+ @util.deprecated_params(
+ useexisting=(
+ "0.7",
+ "The :paramref:`.Table.useexisting` parameter is deprecated and "
+ "will be removed in a future release. Please use "
+ ":paramref:`.Table.extend_existing`.",
+ )
+ )
def __new__(cls, *args, **kw):
if not args:
# python3k pickle seems to call this
@@ -763,6 +779,8 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
def get_children(
self, column_collections=True, schema_visitor=False, **kw
):
+ # TODO: consider that we probably don't need column_collections=True
+ # at all, it does not seem to impact anything
if not schema_visitor:
return TableClause.get_children(
self, column_collections=column_collections, **kw
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 6a7413fc0..4b3844eec 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -31,6 +31,7 @@ from .base import ColumnSet
from .base import DedupeColumnCollection
from .base import Executable
from .base import Generative
+from .base import HasMemoized
from .base import Immutable
from .coercions import _document_text_coercion
from .elements import _anonymous_label
@@ -39,11 +40,13 @@ from .elements import and_
from .elements import BindParameter
from .elements import ClauseElement
from .elements import ClauseList
+from .elements import ColumnClause
from .elements import GroupedElement
from .elements import Grouping
from .elements import literal_column
from .elements import True_
from .elements import UnaryExpression
+from .visitors import InternalTraversal
from .. import exc
from .. import util
@@ -201,6 +204,8 @@ class Selectable(ReturnsRows):
class HasPrefixes(object):
_prefixes = ()
+ _traverse_internals = [("_prefixes", InternalTraversal.dp_prefix_sequence)]
+
@_generative
@_document_text_coercion(
"expr",
@@ -252,6 +257,8 @@ class HasPrefixes(object):
class HasSuffixes(object):
_suffixes = ()
+ _traverse_internals = [("_suffixes", InternalTraversal.dp_prefix_sequence)]
+
@_generative
@_document_text_coercion(
"expr",
@@ -295,7 +302,7 @@ class HasSuffixes(object):
)
-class FromClause(roles.AnonymizedFromClauseRole, Selectable):
+class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable):
"""Represent an element that can be used within the ``FROM``
clause of a ``SELECT`` statement.
@@ -529,11 +536,6 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
return getattr(self, "name", self.__class__.__name__ + " object")
- def _reset_exported(self):
- """delete memoized collections when a FromClause is cloned."""
-
- self._memoized_property.expire_instance(self)
-
def _generate_fromclause_column_proxies(self, fromclause):
fromclause._columns._populate_separate_keys(
col._make_proxy(fromclause) for col in self.c
@@ -668,6 +670,14 @@ class Join(FromClause):
__visit_name__ = "join"
+ _traverse_internals = [
+ ("left", InternalTraversal.dp_clauseelement),
+ ("right", InternalTraversal.dp_clauseelement),
+ ("onclause", InternalTraversal.dp_clauseelement),
+ ("isouter", InternalTraversal.dp_boolean),
+ ("full", InternalTraversal.dp_boolean),
+ ]
+
_is_join = True
def __init__(self, left, right, onclause=None, isouter=False, full=False):
@@ -805,25 +815,6 @@ class Join(FromClause):
self.left._refresh_for_new_column(column)
self.right._refresh_for_new_column(column)
- def _copy_internals(self, clone=_clone, **kw):
- self._reset_exported()
- self.left = clone(self.left, **kw)
- self.right = clone(self.right, **kw)
- self.onclause = clone(self.onclause, **kw)
-
- def get_children(self, **kwargs):
- return self.left, self.right, self.onclause
-
- def _cache_key(self, **kw):
- return (
- Join,
- self.isouter,
- self.full,
- self.left._cache_key(**kw),
- self.right._cache_key(**kw),
- self.onclause._cache_key(**kw),
- )
-
def _match_primaries(self, left, right):
if isinstance(left, Join):
left_right = left.right
@@ -1175,6 +1166,11 @@ class AliasedReturnsRows(FromClause):
_is_from_container = True
named_with_column = True
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("name", InternalTraversal.dp_anon_name),
+ ]
+
def __init__(self, *arg, **kw):
raise NotImplementedError(
"The %s class is not intended to be constructed "
@@ -1243,18 +1239,13 @@ class AliasedReturnsRows(FromClause):
def _copy_internals(self, clone=_clone, **kw):
element = clone(self.element, **kw)
+
+ # the element clone is usually against a Table that returns the
+ # same object. don't reset exported .c. collections and other
+ # memoized details if nothing changed
if element is not self.element:
self._reset_exported()
- self.element = element
-
- def get_children(self, column_collections=True, **kw):
- if column_collections:
- for c in self.c:
- yield c
- yield self.element
-
- def _cache_key(self, **kw):
- return (self.__class__, self.element._cache_key(**kw), self._orig_name)
+ self.element = element
@property
def _from_objects(self):
@@ -1396,6 +1387,11 @@ class TableSample(AliasedReturnsRows):
__visit_name__ = "tablesample"
+ _traverse_internals = AliasedReturnsRows._traverse_internals + [
+ ("sampling", InternalTraversal.dp_clauseelement),
+ ("seed", InternalTraversal.dp_clauseelement),
+ ]
+
@classmethod
def _factory(cls, selectable, sampling, name=None, seed=None):
"""Return a :class:`.TableSample` object.
@@ -1466,6 +1462,16 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows):
__visit_name__ = "cte"
+ _traverse_internals = (
+ AliasedReturnsRows._traverse_internals
+ + [
+ ("_cte_alias", InternalTraversal.dp_clauseelement),
+ ("_restates", InternalTraversal.dp_clauseelement_unordered_set),
+ ("recursive", InternalTraversal.dp_boolean),
+ ]
+ + HasSuffixes._traverse_internals
+ )
+
@classmethod
def _factory(cls, selectable, name=None, recursive=False):
r"""Return a new :class:`.CTE`, or Common Table Expression instance.
@@ -1495,15 +1501,13 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows):
def _copy_internals(self, clone=_clone, **kw):
super(CTE, self)._copy_internals(clone, **kw)
+ # TODO: I don't like that we can't use the traversal data here
if self._cte_alias is not None:
self._cte_alias = clone(self._cte_alias, **kw)
self._restates = frozenset(
[clone(elem, **kw) for elem in self._restates]
)
- def _cache_key(self, *arg, **kw):
- raise NotImplementedError("TODO")
-
def alias(self, name=None, flat=False):
"""Return an :class:`.Alias` of this :class:`.CTE`.
@@ -1764,6 +1768,8 @@ class Subquery(AliasedReturnsRows):
class FromGrouping(GroupedElement, FromClause):
"""Represent a grouping of a FROM clause"""
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
def __init__(self, element):
self.element = coercions.expect(roles.FromClauseRole, element)
@@ -1792,15 +1798,6 @@ class FromGrouping(GroupedElement, FromClause):
def _hide_froms(self):
return self.element._hide_froms
- def get_children(self, **kwargs):
- return (self.element,)
-
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def _cache_key(self, **kw):
- return (FromGrouping, self.element._cache_key(**kw))
-
@property
def _from_objects(self):
return self.element._from_objects
@@ -1843,6 +1840,14 @@ class TableClause(Immutable, FromClause):
__visit_name__ = "table"
+ _traverse_internals = [
+ (
+ "columns",
+ InternalTraversal.dp_fromclause_canonical_column_collection,
+ ),
+ ("name", InternalTraversal.dp_string),
+ ]
+
named_with_column = True
implicit_returning = False
@@ -1895,17 +1900,6 @@ class TableClause(Immutable, FromClause):
self._columns.add(c)
c.table = self
- def get_children(self, column_collections=True, **kwargs):
- if column_collections:
- return [c for c in self.c]
- else:
- return []
-
- def _cache_key(self, **kw):
- return (TableClause, self.name) + tuple(
- col._cache_key(**kw) for col in self._columns
- )
-
@util.dependencies("sqlalchemy.sql.dml")
def insert(self, dml, values=None, inline=False, **kwargs):
"""Generate an :func:`.insert` construct against this
@@ -1965,6 +1959,13 @@ class TableClause(Immutable, FromClause):
class ForUpdateArg(ClauseElement):
+ _traverse_internals = [
+ ("of", InternalTraversal.dp_clauseelement_list),
+ ("nowait", InternalTraversal.dp_boolean),
+ ("read", InternalTraversal.dp_boolean),
+ ("skip_locked", InternalTraversal.dp_boolean),
+ ]
+
@classmethod
def parse_legacy_select(self, arg):
"""Parse the for_update argument of :func:`.select`.
@@ -2029,19 +2030,6 @@ class ForUpdateArg(ClauseElement):
def __hash__(self):
return id(self)
- def _copy_internals(self, clone=_clone, **kw):
- if self.of is not None:
- self.of = [clone(col, **kw) for col in self.of]
-
- def _cache_key(self, **kw):
- return (
- ForUpdateArg,
- self.nowait,
- self.read,
- self.skip_locked,
- self.of._cache_key(**kw) if self.of is not None else None,
- )
-
def __init__(
self,
nowait=False,
@@ -2074,6 +2062,7 @@ class SelectBase(
roles.DMLSelectRole,
roles.CompoundElementRole,
roles.InElementRole,
+ HasMemoized,
HasCTE,
Executable,
SupportsCloneAnnotations,
@@ -2092,9 +2081,6 @@ class SelectBase(
_memoized_property = util.group_expirable_memoized_property()
- def _reset_memoizations(self):
- self._memoized_property.expire_instance(self)
-
def _generate_fromclause_column_proxies(self, fromclause):
# type: (FromClause)
raise NotImplementedError()
@@ -2339,6 +2325,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
"""
__visit_name__ = "grouping"
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
_is_select_container = True
@@ -2350,9 +2337,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
def select_statement(self):
return self.element
- def get_children(self, **kwargs):
- return (self.element,)
-
def self_group(self, against=None):
# type: (Optional[Any]) -> FromClause
return self
@@ -2377,12 +2361,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
"""
return self.element.selected_columns
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def _cache_key(self, **kw):
- return (SelectStatementGrouping, self.element._cache_key(**kw))
-
@property
def _from_objects(self):
return self.element._from_objects
@@ -2758,9 +2736,6 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
def _label_resolve_dict(self):
raise NotImplementedError()
- def _copy_internals(self, clone=_clone, **kw):
- raise NotImplementedError()
-
class CompoundSelect(GenerativeSelect):
"""Forms the basis of ``UNION``, ``UNION ALL``, and other
@@ -2785,6 +2760,16 @@ class CompoundSelect(GenerativeSelect):
__visit_name__ = "compound_select"
+ _traverse_internals = [
+ ("selects", InternalTraversal.dp_clauseelement_list),
+ ("_limit_clause", InternalTraversal.dp_clauseelement),
+ ("_offset_clause", InternalTraversal.dp_clauseelement),
+ ("_order_by_clause", InternalTraversal.dp_clauseelement),
+ ("_group_by_clause", InternalTraversal.dp_clauseelement),
+ ("_for_update_arg", InternalTraversal.dp_clauseelement),
+ ("keyword", InternalTraversal.dp_string),
+ ] + SupportsCloneAnnotations._traverse_internals
+
UNION = util.symbol("UNION")
UNION_ALL = util.symbol("UNION ALL")
EXCEPT = util.symbol("EXCEPT")
@@ -3004,47 +2989,6 @@ class CompoundSelect(GenerativeSelect):
"""
return self.selects[0].selected_columns
- def _copy_internals(self, clone=_clone, **kw):
- self._reset_memoizations()
- self.selects = [clone(s, **kw) for s in self.selects]
- if hasattr(self, "_col_map"):
- del self._col_map
- for attr in (
- "_limit_clause",
- "_offset_clause",
- "_order_by_clause",
- "_group_by_clause",
- "_for_update_arg",
- ):
- if getattr(self, attr) is not None:
- setattr(self, attr, clone(getattr(self, attr), **kw))
-
- def get_children(self, **kwargs):
- return [self._order_by_clause, self._group_by_clause] + list(
- self.selects
- )
-
- def _cache_key(self, **kw):
- return (
- (CompoundSelect, self.keyword)
- + tuple(stmt._cache_key(**kw) for stmt in self.selects)
- + (
- self._order_by_clause._cache_key(**kw)
- if self._order_by_clause is not None
- else None,
- )
- + (
- self._group_by_clause._cache_key(**kw)
- if self._group_by_clause is not None
- else None,
- )
- + (
- self._for_update_arg._cache_key(**kw)
- if self._for_update_arg is not None
- else None,
- )
- )
-
def bind(self):
if self._bind:
return self._bind
@@ -3193,11 +3137,35 @@ class Select(
_hints = util.immutabledict()
_statement_hints = ()
_distinct = False
- _from_cloned = None
+ _distinct_on = ()
_correlate = ()
_correlate_except = None
_memoized_property = SelectBase._memoized_property
+ _traverse_internals = (
+ [
+ ("_from_obj", InternalTraversal.dp_fromclause_ordered_set),
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ ("_whereclause", InternalTraversal.dp_clauseelement),
+ ("_having", InternalTraversal.dp_clauseelement),
+ ("_order_by_clause", InternalTraversal.dp_clauseelement_list),
+ ("_group_by_clause", InternalTraversal.dp_clauseelement_list),
+ ("_correlate", InternalTraversal.dp_clauseelement_unordered_set),
+ (
+ "_correlate_except",
+ InternalTraversal.dp_clauseelement_unordered_set,
+ ),
+ ("_for_update_arg", InternalTraversal.dp_clauseelement),
+ ("_statement_hints", InternalTraversal.dp_statement_hint_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ("_distinct", InternalTraversal.dp_boolean),
+ ("_distinct_on", InternalTraversal.dp_clauseelement_list),
+ ]
+ + HasPrefixes._traverse_internals
+ + HasSuffixes._traverse_internals
+ + SupportsCloneAnnotations._traverse_internals
+ )
+
@util.deprecated_params(
autocommit=(
"0.6",
@@ -3416,13 +3384,14 @@ class Select(
"""
self._auto_correlate = correlate
if distinct is not False:
- if distinct is True:
- self._distinct = True
- else:
- self._distinct = [
- coercions.expect(roles.WhereHavingRole, e)
- for e in util.to_list(distinct)
- ]
+ self._distinct = True
+ if not isinstance(distinct, bool):
+ self._distinct_on = tuple(
+ [
+ coercions.expect(roles.WhereHavingRole, e)
+ for e in util.to_list(distinct)
+ ]
+ )
if from_obj is not None:
self._from_obj = util.OrderedSet(
@@ -3472,15 +3441,17 @@ class Select(
GenerativeSelect.__init__(self, **kwargs)
+ # @_memoized_property
@property
def _froms(self):
- # would love to cache this,
- # but there's just enough edge cases, particularly now that
- # declarative encourages construction of SQL expressions
- # without tables present, to just regen this each time.
+ # current roadblock to caching is two tests that test that the
+ # SELECT can be compiled to a string, then a Table is created against
+ # columns, then it can be compiled again and works. this is somewhat
+ # valid as people make select() against declarative class where
+ # columns don't have their Table yet and perhaps some operations
+ # call upon _froms and cache it too soon.
froms = []
seen = set()
- translate = self._from_cloned
for item in itertools.chain(
_from_objects(*self._raw_columns),
@@ -3493,8 +3464,6 @@ class Select(
raise exc.InvalidRequestError(
"select() construct refers to itself as a FROM"
)
- if translate and item in translate:
- item = translate[item]
if not seen.intersection(item._cloned_set):
froms.append(item)
seen.update(item._cloned_set)
@@ -3518,15 +3487,6 @@ class Select(
itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms])
)
if toremove:
- # if we're maintaining clones of froms,
- # add the copies out to the toremove list. only include
- # clones that are lexical equivalents.
- if self._from_cloned:
- toremove.update(
- self._from_cloned[f]
- for f in toremove.intersection(self._from_cloned)
- if self._from_cloned[f]._is_lexical_equivalent(f)
- )
# filter out to FROM clauses not in the list,
# using a list to maintain ordering
froms = [f for f in froms if f not in toremove]
@@ -3707,7 +3667,6 @@ class Select(
return False
def _copy_internals(self, clone=_clone, **kw):
-
# Select() object has been cloned and probably adapted by the
# given clone function. Apply the cloning function to internal
# objects
@@ -3719,37 +3678,42 @@ class Select(
# as of 0.7.4 we also put the current version of _froms, which
# gets cleared on each generation. previously we were "baking"
# _froms into self._from_obj.
- self._from_cloned = from_cloned = dict(
- (f, clone(f, **kw)) for f in self._from_obj.union(self._froms)
- )
- # 3. update persistent _from_obj with the cloned versions.
- self._from_obj = util.OrderedSet(
- from_cloned[f] for f in self._from_obj
+ all_the_froms = list(
+ itertools.chain(
+ _from_objects(*self._raw_columns),
+ _from_objects(self._whereclause)
+ if self._whereclause is not None
+ else (),
+ )
)
+ new_froms = {f: clone(f, **kw) for f in all_the_froms}
+ # copy FROM collections
- # the _correlate collection is done separately, what can happen
- # here is the same item is _correlate as in _from_obj but the
- # _correlate version has an annotation on it - (specifically
- # RelationshipProperty.Comparator._criterion_exists() does
- # this). Also keep _correlate liberally open with its previous
- # contents, as this set is used for matching, not rendering.
- self._correlate = set(clone(f) for f in self._correlate).union(
- self._correlate
- )
+ self._from_obj = util.OrderedSet(
+ clone(f, **kw) for f in self._from_obj
+ ).union(f for f in new_froms.values() if isinstance(f, Join))
- # do something similar for _correlate_except - this is a more
- # unusual case but same idea applies
+ self._correlate = set(clone(f) for f in self._correlate)
if self._correlate_except:
self._correlate_except = set(
clone(f) for f in self._correlate_except
- ).union(self._correlate_except)
+ )
# 4. clone other things. The difficulty here is that Column
- # objects are not actually cloned, and refer to their original
- # .table, resulting in the wrong "from" parent after a clone
- # operation. Hence _from_cloned and _from_obj supersede what is
- # present here.
+ # objects are usually not altered by a straight clone because they
+ # are dependent on the FROM cloning we just did above in order to
+ # be targeted correctly, or a new FROM we have might be a JOIN
+ # object which doesn't have its own columns. so give the cloner a
+ # hint.
+ def replace(obj, **kw):
+ if isinstance(obj, ColumnClause) and obj.table in new_froms:
+ newelem = new_froms[obj.table].corresponding_column(obj)
+ return newelem
+
+ kw["replace"] = replace
+
+ # TODO: I'd still like to try to leverage the traversal data
self._raw_columns = [clone(c, **kw) for c in self._raw_columns]
for attr in (
"_limit_clause",
@@ -3763,67 +3727,12 @@ class Select(
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr), **kw))
- # erase _froms collection,
- # etc.
self._reset_memoizations()
def get_children(self, **kwargs):
- """return child elements as per the ClauseElement specification."""
-
- return (
- self._raw_columns
- + list(self._froms)
- + [
- x
- for x in (
- self._whereclause,
- self._having,
- self._order_by_clause,
- self._group_by_clause,
- )
- if x is not None
- ]
- )
-
- def _cache_key(self, **kw):
- return (
- (Select,)
- + ("raw_columns",)
- + tuple(elem._cache_key(**kw) for elem in self._raw_columns)
- + ("elements",)
- + tuple(
- elem._cache_key(**kw) if elem is not None else None
- for elem in (
- self._whereclause,
- self._having,
- self._order_by_clause,
- self._group_by_clause,
- )
- )
- + ("from_obj",)
- + tuple(elem._cache_key(**kw) for elem in self._from_obj)
- + ("correlate",)
- + tuple(
- elem._cache_key(**kw)
- for elem in (
- self._correlate if self._correlate is not None else ()
- )
- )
- + ("correlate_except",)
- + tuple(
- elem._cache_key(**kw)
- for elem in (
- self._correlate_except
- if self._correlate_except is not None
- else ()
- )
- )
- + ("for_update",),
- (
- self._for_update_arg._cache_key(**kw)
- if self._for_update_arg is not None
- else None,
- ),
+ # TODO: define "get_children" traversal items separately?
+ return self._froms + super(Select, self).get_children(
+ omit_attrs=["_from_obj", "_correlate", "_correlate_except"]
)
@_generative
@@ -3987,10 +3896,8 @@ class Select(
"""
if expr:
expr = [coercions.expect(roles.ByOfRole, e) for e in expr]
- if isinstance(self._distinct, list):
- self._distinct = self._distinct + expr
- else:
- self._distinct = expr
+ self._distinct = True
+ self._distinct_on = self._distinct_on + tuple(expr)
else:
self._distinct = True
@@ -4489,6 +4396,11 @@ class TextualSelect(SelectBase):
__visit_name__ = "textual_select"
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("column_args", InternalTraversal.dp_clauseelement_list),
+ ] + SupportsCloneAnnotations._traverse_internals
+
_is_textual = True
def __init__(self, text, columns, positional=False):
@@ -4534,18 +4446,6 @@ class TextualSelect(SelectBase):
c._make_proxy(fromclause) for c in self.column_args
)
- def _copy_internals(self, clone=_clone, **kw):
- self._reset_memoizations()
- self.element = clone(self.element, **kw)
-
- def get_children(self, **kw):
- return [self.element]
-
- def _cache_key(self, **kw):
- return (TextualSelect, self.element._cache_key(**kw)) + tuple(
- col._cache_key(**kw) for col in self.column_args
- )
-
def _scalar_type(self):
return self.column_args[0].type
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
new file mode 100644
index 000000000..c0782ce48
--- /dev/null
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -0,0 +1,768 @@
+from collections import deque
+from collections import namedtuple
+
+from . import operators
+from .visitors import ExtendedInternalTraversal
+from .visitors import InternalTraversal
+from .. import inspect
+from .. import util
+
+SKIP_TRAVERSE = util.symbol("skip_traverse")
+COMPARE_FAILED = False
+COMPARE_SUCCEEDED = True
+NO_CACHE = util.symbol("no_cache")
+
+
+def compare(obj1, obj2, **kw):
+ if kw.get("use_proxies", False):
+ strategy = ColIdentityComparatorStrategy()
+ else:
+ strategy = TraversalComparatorStrategy()
+
+ return strategy.compare(obj1, obj2, **kw)
+
+
+class HasCacheKey(object):
+ _cache_key_traversal = NO_CACHE
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ """return an optional cache key.
+
+ The cache key is a tuple which can contain any series of
+ objects that are hashable and also identifies
+ this object uniquely within the presence of a larger SQL expression
+ or statement, for the purposes of caching the resulting query.
+
+ The cache key should be based on the SQL compiled structure that would
+ ultimately be produced. That is, two structures that are composed in
+ exactly the same way should produce the same cache key; any difference
+ in the strucures that would affect the SQL string or the type handlers
+ should result in a different cache key.
+
+ If a structure cannot produce a useful cache key, it should raise
+ NotImplementedError, which will result in the entire structure
+ for which it's part of not being useful as a cache key.
+
+
+ """
+
+ if self in anon_map:
+ return (anon_map[self], self.__class__)
+
+ id_ = anon_map[self]
+
+ if self._cache_key_traversal is NO_CACHE:
+ anon_map[NO_CACHE] = True
+ return None
+
+ result = (id_, self.__class__)
+
+ for attrname, obj, meth in _cache_key_traversal.run_generated_dispatch(
+ self, self._cache_key_traversal, "_generated_cache_key_traversal"
+ ):
+ if obj is not None:
+ result += meth(attrname, obj, self, anon_map, bindparams)
+ return result
+
+ def _generate_cache_key(self):
+ """return a cache key.
+
+ The cache key is a tuple which can contain any series of
+ objects that are hashable and also identifies
+ this object uniquely within the presence of a larger SQL expression
+ or statement, for the purposes of caching the resulting query.
+
+ The cache key should be based on the SQL compiled structure that would
+ ultimately be produced. That is, two structures that are composed in
+ exactly the same way should produce the same cache key; any difference
+ in the strucures that would affect the SQL string or the type handlers
+ should result in a different cache key.
+
+ The cache key returned by this method is an instance of
+ :class:`.CacheKey`, which consists of a tuple representing the
+ cache key, as well as a list of :class:`.BindParameter` objects
+ which are extracted from the expression. While two expressions
+ that produce identical cache key tuples will themselves generate
+ identical SQL strings, the list of :class:`.BindParameter` objects
+ indicates the bound values which may have different values in
+ each one; these bound parameters must be consulted in order to
+ execute the statement with the correct parameters.
+
+ a :class:`.ClauseElement` structure that does not implement
+ a :meth:`._gen_cache_key` method and does not implement a
+ :attr:`.traverse_internals` attribute will not be cacheable; when
+ such an element is embedded into a larger structure, this method
+ will return None, indicating no cache key is available.
+
+ """
+ bindparams = []
+
+ _anon_map = anon_map()
+ key = self._gen_cache_key(_anon_map, bindparams)
+ if NO_CACHE in _anon_map:
+ return None
+ else:
+ return CacheKey(key, bindparams)
+
+
+class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
+ def __hash__(self):
+ return hash(self.key)
+
+ def __eq__(self, other):
+ return self.key == other.key
+
+
+def _clone(element, **kw):
+ return element._clone()
+
+
+class _CacheKey(ExtendedInternalTraversal):
+ def visit_has_cache_key(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj._gen_cache_key(anon_map, bindparams))
+
+ def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
+ return self.visit_has_cache_key(
+ attrname, inspect(obj), parent, anon_map, bindparams
+ )
+
+ def visit_clauseelement(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj._gen_cache_key(anon_map, bindparams))
+
+ def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
+ return (
+ attrname,
+ obj._gen_cache_key(anon_map, bindparams)
+ if isinstance(obj, HasCacheKey)
+ else obj,
+ )
+
+ def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
+ return (
+ attrname,
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ if isinstance(elem, HasCacheKey)
+ else elem
+ for elem in obj
+ ),
+ )
+
+ def visit_has_cache_key_tuples(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in tup_elem
+ )
+ for tup_elem in obj
+ ),
+ )
+
+ def visit_has_cache_key_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+ )
+
+ def visit_inspectable_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return self.visit_has_cache_key_list(
+ attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
+ )
+
+ def visit_clauseelement_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+ )
+
+ def visit_clauseelement_tuples(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return self.visit_has_cache_key_tuples(
+ attrname, obj, parent, anon_map, bindparams
+ )
+
+ def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams):
+ from . import elements
+
+ name = obj
+ if isinstance(name, elements._anonymous_label):
+ name = name.apply_map(anon_map)
+
+ return (attrname, name)
+
+ def visit_fromclause_ordered_set(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+ )
+
+ def visit_clauseelement_unordered_set(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ cache_keys = [
+ elem._gen_cache_key(anon_map, bindparams) for elem in obj
+ ]
+ return (
+ attrname,
+ tuple(
+ sorted(cache_keys)
+ ), # cache keys all start with (id_, class)
+ )
+
+ def visit_named_ddl_element(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (attrname, obj.name)
+
+ def visit_prefix_sequence(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (clause._gen_cache_key(anon_map, bindparams), strval)
+ for clause, strval in obj
+ ),
+ )
+
+ def visit_statement_hint_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (attrname, obj)
+
+ def visit_table_hint_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ clause._gen_cache_key(anon_map, bindparams),
+ dialect_name,
+ text,
+ )
+ for (clause, dialect_name), text in obj.items()
+ ),
+ )
+
+ def visit_type(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj._gen_cache_key)
+
+ def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, tuple((key, obj[key]) for key in sorted(obj)))
+
+ def visit_string_clauseelement_dict(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (key, obj[key]._gen_cache_key(anon_map, bindparams))
+ for key in sorted(obj)
+ ),
+ )
+
+ def visit_string_multi_dict(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key,
+ value._gen_cache_key(anon_map, bindparams)
+ if isinstance(value, HasCacheKey)
+ else value,
+ )
+ for key, value in [(key, obj[key]) for key in sorted(obj)]
+ ),
+ )
+
+ def visit_string(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj)
+
+ def visit_boolean(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj)
+
+ def visit_operator(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj)
+
+ def visit_plain_obj(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj)
+
+ def visit_fromclause_canonical_column_collection(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(col._gen_cache_key(anon_map, bindparams) for col in obj),
+ )
+
+ def visit_annotations_state(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key,
+ self.dispatch(sym)(
+ key, obj[key], obj, anon_map, bindparams
+ ),
+ )
+ for key, sym in parent._annotation_traversals
+ ),
+ )
+
+ def visit_unknown_structure(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ anon_map[NO_CACHE] = True
+ return ()
+
+
+_cache_key_traversal = _CacheKey()
+
+
+class _CopyInternals(InternalTraversal):
+ """Generate a _copy_internals internal traversal dispatch for classes
+ with a _traverse_internals collection."""
+
+ def visit_clauseelement(self, parent, element, clone=_clone, **kw):
+ return clone(element, **kw)
+
+ def visit_clauseelement_list(self, parent, element, clone=_clone, **kw):
+ return [clone(clause, **kw) for clause in element]
+
+ def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw):
+ return [
+ tuple(clone(tup_elem, **kw) for tup_elem in elem)
+ for elem in element
+ ]
+
+ def visit_string_clauseelement_dict(
+ self, parent, element, clone=_clone, **kw
+ ):
+ return dict(
+ (key, clone(value, **kw)) for key, value in element.items()
+ )
+
+
+_copy_internals = _CopyInternals()
+
+
+class _GetChildren(InternalTraversal):
+ """Generate a _children_traversal internal traversal dispatch for classes
+ with a _traverse_internals collection."""
+
+ def visit_has_cache_key(self, element, **kw):
+ return (element,)
+
+ def visit_clauseelement(self, element, **kw):
+ return (element,)
+
+ def visit_clauseelement_list(self, element, **kw):
+ return tuple(element)
+
+ def visit_clauseelement_tuples(self, element, **kw):
+ tup = ()
+ for elem in element:
+ tup += elem
+ return tup
+
+ def visit_fromclause_canonical_column_collection(self, element, **kw):
+ if kw.get("column_collections", False):
+ return tuple(element)
+ else:
+ return ()
+
+ def visit_string_clauseelement_dict(self, element, **kw):
+ return tuple(element.values())
+
+ def visit_fromclause_ordered_set(self, element, **kw):
+ return tuple(element)
+
+ def visit_clauseelement_unordered_set(self, element, **kw):
+ return tuple(element)
+
+
+_get_children = _GetChildren()
+
+
+@util.dependencies("sqlalchemy.sql.elements")
+def _resolve_name_for_compare(elements, element, name, anon_map, **kw):
+ if isinstance(name, elements._anonymous_label):
+ name = name.apply_map(anon_map)
+
+ return name
+
+
+class anon_map(dict):
+ """A map that creates new keys for missing key access.
+
+ Produces an incrementing sequence given a series of unique keys.
+
+ This is similar to the compiler prefix_anon_map class although simpler.
+
+ Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+ is otherwise usually used for this type of operation.
+
+ """
+
+ def __init__(self):
+ self.index = 0
+
+ def __missing__(self, key):
+ self[key] = val = str(self.index)
+ self.index += 1
+ return val
+
+
+class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
+ __slots__ = "stack", "cache", "anon_map"
+
+ def __init__(self):
+ self.stack = deque()
+ self.cache = set()
+
+ def _memoized_attr_anon_map(self):
+ return (anon_map(), anon_map())
+
+ def compare(self, obj1, obj2, **kw):
+ stack = self.stack
+ cache = self.cache
+
+ compare_annotations = kw.get("compare_annotations", False)
+
+ stack.append((obj1, obj2))
+
+ while stack:
+ left, right = stack.popleft()
+
+ if left is right:
+ continue
+ elif left is None or right is None:
+ # we know they are different so no match
+ return False
+ elif (left, right) in cache:
+ continue
+ cache.add((left, right))
+
+ visit_name = left.__visit_name__
+ if visit_name != right.__visit_name__:
+ return False
+
+ meth = getattr(self, "compare_%s" % visit_name, None)
+
+ if meth:
+ attributes_compared = meth(left, right, **kw)
+ if attributes_compared is COMPARE_FAILED:
+ return False
+ elif attributes_compared is SKIP_TRAVERSE:
+ continue
+
+ # attributes_compared is returned as a list of attribute
+ # names that were "handled" by the comparison method above.
+ # remaining attribute names in the _traverse_internals
+ # will be compared.
+ else:
+ attributes_compared = ()
+
+ for (
+ (left_attrname, left_visit_sym),
+ (right_attrname, right_visit_sym),
+ ) in util.zip_longest(
+ left._traverse_internals,
+ right._traverse_internals,
+ fillvalue=(None, None),
+ ):
+ if (
+ left_attrname != right_attrname
+ or left_visit_sym is not right_visit_sym
+ ):
+ if not compare_annotations and (
+ (
+ left_visit_sym
+ is InternalTraversal.dp_annotations_state,
+ )
+ or (
+ right_visit_sym
+ is InternalTraversal.dp_annotations_state,
+ )
+ ):
+ continue
+
+ return False
+ elif left_attrname in attributes_compared:
+ continue
+
+ dispatch = self.dispatch(left_visit_sym)
+ left_child = getattr(left, left_attrname)
+ right_child = getattr(right, right_attrname)
+ if left_child is None:
+ if right_child is not None:
+ return False
+ else:
+ continue
+
+ comparison = dispatch(
+ left, left_child, right, right_child, **kw
+ )
+ if comparison is COMPARE_FAILED:
+ return False
+
+ return True
+
+ def compare_inner(self, obj1, obj2, **kw):
+ comparator = self.__class__()
+ return comparator.compare(obj1, obj2, **kw)
+
+ def visit_has_cache_key(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
+ self.anon_map[1], []
+ ):
+ return COMPARE_FAILED
+
+ def visit_clauseelement(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ self.stack.append((left, right))
+
+ def visit_fromclause_canonical_column_collection(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for lcol, rcol in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((lcol, rcol))
+
+ def visit_fromclause_derived_column_collection(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ pass
+
+ def visit_string_clauseelement_dict(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for lstr, rstr in util.zip_longest(
+ sorted(left), sorted(right), fillvalue=None
+ ):
+ if lstr != rstr:
+ return COMPARE_FAILED
+ self.stack.append((left[lstr], right[rstr]))
+
+ def visit_annotations_state(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ if not kw.get("compare_annotations", False):
+ return
+
+ for (lstr, lmeth), (rstr, rmeth) in util.zip_longest(
+ left_parent._annotation_traversals,
+ right_parent._annotation_traversals,
+ fillvalue=(None, None),
+ ):
+ if lstr != rstr or (lmeth is not rmeth):
+ return COMPARE_FAILED
+
+ dispatch = self.dispatch(lmeth)
+ left_child = left[lstr]
+ right_child = right[rstr]
+ if left_child is None:
+ if right_child is not None:
+ return False
+ else:
+ continue
+
+ comparison = dispatch(None, left_child, None, right_child, **kw)
+ if comparison is COMPARE_FAILED:
+ return comparison
+
+ def visit_clauseelement_tuples(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for ltup, rtup in util.zip_longest(left, right, fillvalue=None):
+ if ltup is None or rtup is None:
+ return COMPARE_FAILED
+
+ for l, r in util.zip_longest(ltup, rtup, fillvalue=None):
+ self.stack.append((l, r))
+
+ def visit_clauseelement_list(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
+ def _compare_unordered_sequences(self, seq1, seq2, **kw):
+ if seq1 is None:
+ return seq2 is None
+
+ completed = set()
+ for clause in seq1:
+ for other_clause in set(seq2).difference(completed):
+ if self.compare_inner(clause, other_clause, **kw):
+ completed.add(other_clause)
+ break
+ return len(completed) == len(seq1) == len(seq2)
+
+ def visit_clauseelement_unordered_set(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ return self._compare_unordered_sequences(left, right, **kw)
+
+ def visit_fromclause_ordered_set(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
+ def visit_string(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
+ def visit_anon_name(self, left_parent, left, right_parent, right, **kw):
+ return _resolve_name_for_compare(
+ left_parent, left, self.anon_map[0], **kw
+ ) == _resolve_name_for_compare(
+ right_parent, right, self.anon_map[1], **kw
+ )
+
+ def visit_boolean(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
+ def visit_operator(self, left_parent, left, right_parent, right, **kw):
+ return left is right
+
+ def visit_type(self, left_parent, left, right_parent, right, **kw):
+ return left._compare_type_affinity(right)
+
+ def visit_plain_dict(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
+ def visit_plain_obj(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
+ def visit_named_ddl_element(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ if left is None:
+ if right is not None:
+ return COMPARE_FAILED
+
+ return left.name == right.name
+
+ def visit_prefix_sequence(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for (l_clause, l_str), (r_clause, r_str) in util.zip_longest(
+ left, right, fillvalue=(None, None)
+ ):
+ if l_str != r_str:
+ return COMPARE_FAILED
+ else:
+ self.stack.append((l_clause, r_clause))
+
+ def visit_table_hint_list(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
+ right_keys = sorted(
+ right, key=lambda elem: (elem[0].fullname, elem[1])
+ )
+ for (ltable, ldialect), (rtable, rdialect) in util.zip_longest(
+ left_keys, right_keys, fillvalue=(None, None)
+ ):
+ if ldialect != rdialect:
+ return COMPARE_FAILED
+ elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
+ return COMPARE_FAILED
+ else:
+ self.stack.append((ltable, rtable))
+
+ def visit_statement_hint_list(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_unknown_structure(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ raise NotImplementedError()
+
+ def compare_clauselist(self, left, right, **kw):
+ if left.operator is right.operator:
+ if operators.is_associative(left.operator):
+ if self._compare_unordered_sequences(
+ left.clauses, right.clauses, **kw
+ ):
+ return ["operator", "clauses"]
+ else:
+ return COMPARE_FAILED
+ else:
+ return ["operator"]
+ else:
+ return COMPARE_FAILED
+
+ def compare_binary(self, left, right, **kw):
+ if left.operator == right.operator:
+ if operators.is_commutative(left.operator):
+ if (
+ compare(left.left, right.left, **kw)
+ and compare(left.right, right.right, **kw)
+ ) or (
+ compare(left.left, right.right, **kw)
+ and compare(left.right, right.left, **kw)
+ ):
+ return ["operator", "negate", "left", "right"]
+ else:
+ return COMPARE_FAILED
+ else:
+ return ["operator", "negate"]
+ else:
+ return COMPARE_FAILED
+
+
+class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
+ def compare_column_element(
+ self, left, right, use_proxies=True, equivalents=(), **kw
+ ):
+ """Compare ColumnElements using proxies and equivalent collections.
+
+ This is a comparison strategy specific to the ORM.
+ """
+
+ to_compare = (right,)
+ if equivalents and right in equivalents:
+ to_compare = equivalents[right].union(to_compare)
+
+ for oth in to_compare:
+ if use_proxies and left.shares_lineage(oth):
+ return SKIP_TRAVERSE
+ elif hash(left) == hash(right):
+ return SKIP_TRAVERSE
+ else:
+ return COMPARE_FAILED
+
+ def compare_column(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_label(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_table(self, left, right, **kw):
+ # tables compare on identity, since it's not really feasible to
+ # compare them column by column with the above rules
+ return SKIP_TRAVERSE if left is right else COMPARE_FAILED
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 9c5f5dd47..d09bb28bb 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -12,8 +12,8 @@
from . import operators
from .base import SchemaEventTarget
-from .visitors import Visitable
-from .visitors import VisitableType
+from .visitors import Traversible
+from .visitors import TraversibleType
from .. import exc
from .. import util
@@ -28,7 +28,7 @@ INDEXABLE = None
_resolve_value_to_type = None
-class TypeEngine(Visitable):
+class TypeEngine(Traversible):
"""The ultimate base class for all SQL datatypes.
Common subclasses of :class:`.TypeEngine` include
@@ -535,8 +535,13 @@ class TypeEngine(Visitable):
return dialect.type_descriptor(self)
@util.memoized_property
- def _cache_key(self):
- return util.constructor_key(self, self.__class__)
+ def _gen_cache_key(self):
+ names = util.get_cls_kwargs(self.__class__)
+ return (self.__class__,) + tuple(
+ (k, self.__dict__[k])
+ for k in names
+ if k in self.__dict__ and not k.startswith("_")
+ )
def adapt(self, cls, **kw):
"""Produce an "adapted" form of this type, given an "impl" class
@@ -617,7 +622,7 @@ class TypeEngine(Visitable):
return util.generic_repr(self)
-class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType):
+class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType):
pass
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index e109852a2..8539f4845 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -734,7 +734,7 @@ def criterion_as_pairs(
return pairs
-class ClauseAdapter(visitors.ReplacingCloningVisitor):
+class ClauseAdapter(visitors.ReplacingExternalTraversal):
"""Clones and modifies clauses based on column correspondence.
E.g.::
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 7b2ac285a..8c06eb8af 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -28,14 +28,10 @@ import operator
from .. import exc
from .. import util
-
+from ..util import langhelpers
+from ..util import symbol
__all__ = [
- "VisitableType",
- "Visitable",
- "ClauseVisitor",
- "CloningVisitor",
- "ReplacingCloningVisitor",
"iterate",
"iterate_depthfirst",
"traverse_using",
@@ -43,85 +39,382 @@ __all__ = [
"traverse_depthfirst",
"cloned_traverse",
"replacement_traverse",
+ "Traversible",
+ "TraversibleType",
+ "ExternalTraversal",
+ "InternalTraversal",
]
-class VisitableType(type):
- """Metaclass which assigns a ``_compiler_dispatch`` method to classes
- having a ``__visit_name__`` attribute.
+def _generate_compiler_dispatch(cls):
+ """Generate a _compiler_dispatch() external traversal on classes with a
+ __visit_name__ attribute.
+
+ """
+ visit_name = cls.__visit_name__
+
+ if isinstance(visit_name, util.compat.string_types):
+ # There is an optimization opportunity here because the
+ # the string name of the class's __visit_name__ is known at
+ # this early stage (import time) so it can be pre-constructed.
+ getter = operator.attrgetter("visit_%s" % visit_name)
+
+ def _compiler_dispatch(self, visitor, **kw):
+ try:
+ meth = getter(visitor)
+ except AttributeError:
+ raise exc.UnsupportedCompilationError(visitor, cls)
+ else:
+ return meth(self, **kw)
+
+ else:
+ # The optimization opportunity is lost for this case because the
+ # __visit_name__ is not yet a string. As a result, the visit
+ # string has to be recalculated with each compilation.
+ def _compiler_dispatch(self, visitor, **kw):
+ visit_attr = "visit_%s" % self.__visit_name__
+ try:
+ meth = getattr(visitor, visit_attr)
+ except AttributeError:
+ raise exc.UnsupportedCompilationError(visitor, cls)
+ else:
+ return meth(self, **kw)
+
+ _compiler_dispatch.__doc__ = """Look for an attribute named "visit_"
+ + self.__visit_name__ on the visitor, and call it with the same
+ kw params.
+ """
+ cls._compiler_dispatch = _compiler_dispatch
+
+
+class TraversibleType(type):
+ """Metaclass which assigns dispatch attributes to various kinds of
+ "visitable" classes.
- The ``_compiler_dispatch`` attribute becomes an instance method which
- looks approximately like the following::
+ Attributes include:
- def _compiler_dispatch (self, visitor, **kw):
- '''Look for an attribute named "visit_" + self.__visit_name__
- on the visitor, and call it with the same kw params.'''
- visit_attr = 'visit_%s' % self.__visit_name__
- return getattr(visitor, visit_attr)(self, **kw)
+ * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``.
+ This is called "external traversal" because the caller of each visit()
+ method is responsible for sub-traversing the inner elements of each
+ object. This is appropriate for string compilers and other traversals
+ that need to call upon the inner elements in a specific pattern.
- Classes having no ``__visit_name__`` attribute will remain unaffected.
+ * internal traversal collections ``_children_traversal``,
+ ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from
+ an optional ``_traverse_internals`` collection of symbols which comes
+ from the :class:`.InternalTraversal` list of symbols. This is called
+ "internal traversal" MARKMARK
"""
def __init__(cls, clsname, bases, clsdict):
- if clsname != "Visitable" and hasattr(cls, "__visit_name__"):
- _generate_dispatch(cls)
+ if clsname != "Traversible":
+ if "__visit_name__" in clsdict:
+ _generate_compiler_dispatch(cls)
+
+ super(TraversibleType, cls).__init__(clsname, bases, clsdict)
- super(VisitableType, cls).__init__(clsname, bases, clsdict)
+class Traversible(util.with_metaclass(TraversibleType)):
+ """Base class for visitable objects, applies the
+ :class:`.visitors.TraversibleType` metaclass.
-def _generate_dispatch(cls):
- """Return an optimized visit dispatch function for the cls
- for use by the compiler.
"""
- if "__visit_name__" in cls.__dict__:
- visit_name = cls.__visit_name__
- if isinstance(visit_name, util.compat.string_types):
- # There is an optimization opportunity here because the
- # the string name of the class's __visit_name__ is known at
- # this early stage (import time) so it can be pre-constructed.
- getter = operator.attrgetter("visit_%s" % visit_name)
- def _compiler_dispatch(self, visitor, **kw):
- try:
- meth = getter(visitor)
- except AttributeError:
- raise exc.UnsupportedCompilationError(visitor, cls)
- else:
- return meth(self, **kw)
+class _InternalTraversalType(type):
+ def __init__(cls, clsname, bases, clsdict):
+ if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"):
+ lookup = {}
+ for key, sym in clsdict.items():
+ if key.startswith("dp_"):
+ visit_key = key.replace("dp_", "visit_")
+ sym_name = sym.name
+ assert sym_name not in lookup, sym_name
+ lookup[sym] = lookup[sym_name] = visit_key
+ if hasattr(cls, "_dispatch_lookup"):
+ lookup.update(cls._dispatch_lookup)
+ cls._dispatch_lookup = lookup
+
+ super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict)
+
+
+def _generate_dispatcher(visitor, internal_dispatch, method_name):
+ names = []
+ for attrname, visit_sym in internal_dispatch:
+ meth = visitor.dispatch(visit_sym)
+ if meth:
+ visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym]
+ names.append((attrname, visit_name))
+
+ code = (
+ (" return [\n")
+ + (
+ ", \n".join(
+ " (%r, self.%s, visitor.%s)"
+ % (attrname, attrname, visit_name)
+ for attrname, visit_name in names
+ )
+ )
+ + ("\n ]\n")
+ )
+ meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
+ # print(meth_text)
+ return langhelpers._exec_code_in_env(meth_text, {}, method_name)
- else:
- # The optimization opportunity is lost for this case because the
- # __visit_name__ is not yet a string. As a result, the visit
- # string has to be recalculated with each compilation.
- def _compiler_dispatch(self, visitor, **kw):
- visit_attr = "visit_%s" % self.__visit_name__
- try:
- meth = getattr(visitor, visit_attr)
- except AttributeError:
- raise exc.UnsupportedCompilationError(visitor, cls)
- else:
- return meth(self, **kw)
-
- _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__
- on the visitor, and call it with the same kw params.
- """
- cls._compiler_dispatch = _compiler_dispatch
-
-
-class Visitable(util.with_metaclass(VisitableType, object)):
- """Base class for visitable objects, applies the
- :class:`.visitors.VisitableType` metaclass.
- The :class:`.Visitable` class is essentially at the base of the
- :class:`.ClauseElement` hierarchy.
+class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
+ r"""Defines visitor symbols used for internal traversal.
+
+ The :class:`.InternalTraversal` class is used in two ways. One is that
+ it can serve as the superclass for an object that implements the
+ various visit methods of the class. The other is that the symbols
+ themselves of :class:`.InternalTraversal` are used within
+ the ``_traverse_internals`` collection. Such as, the :class:`.Case`
+ object defines ``_travserse_internals`` as ::
+
+ _traverse_internals = [
+ ("value", InternalTraversal.dp_clauseelement),
+ ("whens", InternalTraversal.dp_clauseelement_tuples),
+ ("else_", InternalTraversal.dp_clauseelement),
+ ]
+
+ Above, the :class:`.Case` class indicates its internal state as the
+ attribtues named ``value``, ``whens``, and ``else\_``. They each
+ link to an :class:`.InternalTraversal` method which indicates the type
+ of datastructure referred towards.
+
+ Using the ``_traverse_internals`` structure, objects of type
+ :class:`.InternalTraversible` will have the following methods automatically
+ implemented:
+
+ * :meth:`.Traversible.get_children`
+
+ * :meth:`.Traversible._copy_internals`
+
+ * :meth:`.Traversible._gen_cache_key`
+
+ Subclasses can also implement these methods directly, particularly for the
+ :meth:`.Traversible._copy_internals` method, when special steps
+ are needed.
+
+ .. versionadded:: 1.4
"""
+ def dispatch(self, visit_symbol):
+ """Given a method from :class:`.InternalTraversal`, return the
+ corresponding method on a subclass.
-class ClauseVisitor(object):
- """Base class for visitor objects which can traverse using
+ """
+ name = self._dispatch_lookup[visit_symbol]
+ return getattr(self, name, None)
+
+ def run_generated_dispatch(
+ self, target, internal_dispatch, generate_dispatcher_name
+ ):
+ try:
+ dispatcher = target.__class__.__dict__[generate_dispatcher_name]
+ except KeyError:
+ dispatcher = _generate_dispatcher(
+ self, internal_dispatch, generate_dispatcher_name
+ )
+ setattr(target.__class__, generate_dispatcher_name, dispatcher)
+ return dispatcher(target, self)
+
+ dp_has_cache_key = symbol("HC")
+ """Visit a :class:`.HasCacheKey` object."""
+
+ dp_clauseelement = symbol("CE")
+ """Visit a :class:`.ClauseElement` object."""
+
+ dp_fromclause_canonical_column_collection = symbol("FC")
+ """Visit a :class:`.FromClause` object in the context of the
+ ``columns`` attribute.
+
+ The column collection is "canonical", meaning it is the originally
+ defined location of the :class:`.ColumnClause` objects. Right now
+ this means that the object being visited is a :class:`.TableClause`
+ or :class:`.Table` object only.
+
+ """
+
+ dp_clauseelement_tuples = symbol("CT")
+ """Visit a list of tuples which contain :class:`.ClauseElement`
+ objects.
+
+ """
+
+ dp_clauseelement_list = symbol("CL")
+ """Visit a list of :class:`.ClauseElement` objects.
+
+ """
+
+ dp_clauseelement_unordered_set = symbol("CU")
+ """Visit an unordered set of :class:`.ClauseElement` objects. """
+
+ dp_fromclause_ordered_set = symbol("CO")
+ """Visit an ordered set of :class:`.FromClause` objects. """
+
+ dp_string = symbol("S")
+ """Visit a plain string value.
+
+ Examples include table and column names, bound parameter keys, special
+ keywords such as "UNION", "UNION ALL".
+
+ The string value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_anon_name = symbol("AN")
+ """Visit a potentially "anonymized" string value.
+
+ The string value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_boolean = symbol("B")
+ """Visit a boolean value.
+
+ The boolean value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_operator = symbol("O")
+ """Visit an operator.
+
+ The operator is a function from the :mod:`sqlalchemy.sql.operators`
+ module.
+
+ The operator value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_type = symbol("T")
+ """Visit a :class:`.TypeEngine` object
+
+ The type object is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_plain_dict = symbol("PD")
+ """Visit a dictionary with string keys.
+
+ The keys of the dictionary should be strings, the values should
+ be immutable and hashable. The dictionary is considered to be
+ significant for cache key generation.
+
+ """
+
+ dp_string_clauseelement_dict = symbol("CD")
+ """Visit a dictionary of string keys to :class:`.ClauseElement`
+ objects.
+
+ """
+
+ dp_string_multi_dict = symbol("MD")
+ """Visit a dictionary of string keys to values which may either be
+ plain immutable/hashable or :class:`.HasCacheKey` objects.
+
+ """
+
+ dp_plain_obj = symbol("PO")
+ """Visit a plain python object.
+
+ The value should be immutable and hashable, such as an integer.
+ The value is considered to be significant for cache key generation.
+
+ """
+
+ dp_annotations_state = symbol("A")
+ """Visit the state of the :class:`.Annotatated` version of an object.
+
+ """
+
+ dp_named_ddl_element = symbol("DD")
+ """Visit a simple named DDL element.
+
+ The current object used by this method is the :class:`.Sequence`.
+
+ The object is only considered to be important for cache key generation
+ as far as its name, but not any other aspects of it.
+
+ """
+
+ dp_prefix_sequence = symbol("PS")
+ """Visit the sequence represented by :class:`.HasPrefixes`
+ or :class:`.HasSuffixes`.
+
+ """
+
+ dp_table_hint_list = symbol("TH")
+ """Visit the ``_hints`` collection of a :class:`.Select` object.
+
+ """
+
+ dp_statement_hint_list = symbol("SH")
+ """Visit the ``_statement_hints`` collection of a :class:`.Select`
+ object.
+
+ """
+
+ dp_unknown_structure = symbol("UK")
+ """Visit an unknown structure.
+
+ """
+
+
+class ExtendedInternalTraversal(InternalTraversal):
+ """defines additional symbols that are useful in caching applications.
+
+ Traversals for :class:`.ClauseElement` objects only need to use
+ those symbols present in :class:`.InternalTraversal`. However, for
+ additional caching use cases within the ORM, symbols dealing with the
+ :class:`.HasCacheKey` class are added here.
+
+ """
+
+ dp_ignore = symbol("IG")
+ """Specify an object that should be ignored entirely.
+
+ This currently applies function call argument caching where some
+ arguments should not be considered to be part of a cache key.
+
+ """
+
+ dp_inspectable = symbol("IS")
+ """Visit an inspectable object where the return value is a HasCacheKey`
+ object."""
+
+ dp_multi = symbol("M")
+ """Visit an object that may be a :class:`.HasCacheKey` or may be a
+ plain hashable object."""
+
+ dp_multi_list = symbol("MT")
+ """Visit a tuple containing elements that may be :class:`.HasCacheKey` or
+ may be a plain hashable object."""
+
+ dp_has_cache_key_tuples = symbol("HT")
+ """Visit a list of tuples which contain :class:`.HasCacheKey`
+ objects.
+
+ """
+
+ dp_has_cache_key_list = symbol("HL")
+ """Visit a list of :class:`.HasCacheKey` objects."""
+
+ dp_inspectable_list = symbol("IL")
+ """Visit a list of inspectable objects which upon inspection are
+ HasCacheKey objects."""
+
+
+class ExternalTraversal(object):
+ """Base class for visitor objects which can traverse externally using
the :func:`.visitors.traverse` function.
Direct usage of the :func:`.visitors.traverse` function is usually
@@ -178,7 +471,7 @@ class ClauseVisitor(object):
return self
-class CloningVisitor(ClauseVisitor):
+class CloningExternalTraversal(ExternalTraversal):
"""Base class for visitor objects which can traverse using
the :func:`.visitors.cloned_traverse` function.
@@ -203,7 +496,7 @@ class CloningVisitor(ClauseVisitor):
)
-class ReplacingCloningVisitor(CloningVisitor):
+class ReplacingExternalTraversal(CloningExternalTraversal):
"""Base class for visitor objects which can traverse using
the :func:`.visitors.replacement_traverse` function.
@@ -233,6 +526,14 @@ class ReplacingCloningVisitor(CloningVisitor):
return replacement_traverse(obj, self.__traverse_options__, replace)
+# backwards compatibility
+Visitable = Traversible
+VisitableType = TraversibleType
+ClauseVisitor = ExternalTraversal
+CloningVisitor = CloningExternalTraversal
+ReplacingCloningVisitor = ReplacingExternalTraversal
+
+
def iterate(obj, opts):
r"""traverse the given expression structure, returning an iterator.
@@ -405,11 +706,18 @@ def cloned_traverse(obj, opts, visitors):
cloned = {}
stop_on = set(opts.get("stop_on", []))
- def clone(elem):
+ def clone(elem, **kw):
if elem in stop_on:
return elem
else:
if id(elem) not in cloned:
+
+ if "replace" in kw:
+ newelem = kw["replace"](elem)
+ if newelem is not None:
+ cloned[id(elem)] = newelem
+ return newelem
+
cloned[id(elem)] = newelem = elem._clone()
newelem._copy_internals(clone=clone)
meth = visitors.get(newelem.__visit_name__, None)
@@ -461,7 +769,14 @@ def replacement_traverse(obj, opts, replace):
stop_on.add(id(newelem))
return newelem
else:
+
if elem not in cloned:
+ if "replace" in kw:
+ newelem = kw["replace"](elem)
+ if newelem is not None:
+ cloned[elem] = newelem
+ return newelem
+
cloned[elem] = newelem = elem._clone()
newelem._copy_internals(clone=clone, **kw)
return cloned[elem]