summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-04-29 19:46:43 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-05-31 21:41:52 -0400
commit4ecd352a9fbb9dbac7b428fe0f098f665c1f0cb1 (patch)
tree323868c9f18fffdbfef6168622010c7d19367b12 /lib/sqlalchemy/sql
parentcbfa1363d7201848a56e7209146e81b9c51aa8af (diff)
downloadsqlalchemy-4ecd352a9fbb9dbac7b428fe0f098f665c1f0cb1.tar.gz
Improve rendering of core statements w/ ORM elements
This patch contains a variety of ORM and expression layer tweaks to support ORM constructs in select() statements, without the 1.3.x requiremnt in Query that a full _compile_context() + new select() is needed in order to get a working statement object. Includes such tweaks as the ability to implement aliased class of an aliased class, as we are looking to fully support ACs against subqueries, as well as the ability to access anonymously-labeled ColumnProperty expressions within subqueries by naming the ".key" of the label after the property key. Some tuning to query.join() as well as ORMJoin internals to allow things to work more smoothly. Change-Id: Id810f485c5f7ed971529489b84694e02a3356d6d
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/base.py32
-rw-r--r--lib/sqlalchemy/sql/elements.py23
-rw-r--r--lib/sqlalchemy/sql/selectable.py32
-rw-r--r--lib/sqlalchemy/sql/traversals.py18
-rw-r--r--lib/sqlalchemy/sql/util.py7
-rw-r--r--lib/sqlalchemy/sql/visitors.py11
6 files changed, 92 insertions, 31 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 6415d4b37..f14319089 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -522,7 +522,12 @@ class _MetaOptions(type):
def __init__(cls, classname, bases, dict_):
cls._cache_attrs = tuple(
- sorted(d for d in dict_ if not d.startswith("__"))
+ sorted(
+ d
+ for d in dict_
+ if not d.startswith("__")
+ and d not in ("_cache_key_traversal",)
+ )
)
type.__init__(cls, classname, bases, dict_)
@@ -561,6 +566,31 @@ class Options(util.with_metaclass(_MetaOptions)):
def _state_dict(cls):
return cls._state_dict_const
+ @classmethod
+ def safe_merge(cls, other):
+ d = other._state_dict()
+
+ # only support a merge with another object of our class
+ # and which does not have attrs that we dont. otherwise
+ # we risk having state that might not be part of our cache
+ # key strategy
+
+ if (
+ cls is not other.__class__
+ and other._cache_attrs
+ and set(other._cache_attrs).difference(cls._cache_attrs)
+ ):
+ raise TypeError(
+ "other element %r is not empty, is not of type %s, "
+ "and contains attributes not covered here %r"
+ % (
+ other,
+ cls,
+ set(other._cache_attrs).difference(cls._cache_attrs),
+ )
+ )
+ return cls + d
+
class CacheableOptions(Options, HasCacheKey):
@hybridmethod
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 287e53724..fa2888a23 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -878,6 +878,7 @@ class ColumnElement(
key = self._proxy_key
else:
key = name
+
co = ColumnClause(
coercions.expect(roles.TruncatedLabelRole, name)
if name_is_truncatable
@@ -885,6 +886,7 @@ class ColumnElement(
type_=getattr(self, "type", None),
_selectable=selectable,
)
+
co._propagate_attrs = selectable._propagate_attrs
co._proxies = [self]
if selectable._is_clone_of is not None:
@@ -1284,6 +1286,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
"""
+
if required is NO_ARG:
required = value is NO_ARG and callable_ is None
if value is NO_ARG:
@@ -1302,6 +1305,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
id(self),
re.sub(r"[%\(\) \$]+", "_", key).strip("_")
if key is not None
+ and not isinstance(key, _anonymous_label)
else "param",
)
)
@@ -4182,16 +4186,27 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
return self.element._from_objects
def _make_proxy(self, selectable, name=None, **kw):
+ name = self.name if not name else name
+
key, e = self.element._make_proxy(
selectable,
- name=name if name else self.name,
+ name=name,
disallow_is_literal=True,
+ name_is_truncatable=isinstance(name, _truncated_label),
)
+ # TODO: want to remove this assertion at some point. all
+ # _make_proxy() implementations will give us back the key that
+ # is our "name" in the first place. based on this we can
+ # safely return our "self.key" as the key here, to support a new
+ # case where the key and name are separate.
+ assert key == self.name
+
e._propagate_attrs = selectable._propagate_attrs
e._proxies.append(self)
if self._type is not None:
e.type = self._type
- return key, e
+
+ return self.key, e
class ColumnClause(
@@ -4240,7 +4255,7 @@ class ColumnClause(
__visit_name__ = "column"
_traverse_internals = [
- ("name", InternalTraversal.dp_string),
+ ("name", InternalTraversal.dp_anon_name),
("type", InternalTraversal.dp_type),
("table", InternalTraversal.dp_clauseelement),
("is_literal", InternalTraversal.dp_boolean),
@@ -4410,10 +4425,8 @@ class ColumnClause(
def _gen_label(self, name, dedupe_on_key=True):
t = self.table
-
if self.is_literal:
return None
-
elif t is not None and t.named_with_column:
if getattr(t, "schema", None):
label = t.schema.replace(".", "_") + "_" + t.name + "_" + name
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 170e016a5..d6845e05f 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -3451,8 +3451,8 @@ class SelectState(util.MemoizedSlots, CompileState):
self.columns_plus_names = statement._generate_columns_plus_names(True)
def _get_froms(self, statement):
- froms = []
seen = set()
+ froms = []
for item in itertools.chain(
itertools.chain.from_iterable(
@@ -3474,6 +3474,16 @@ class SelectState(util.MemoizedSlots, CompileState):
froms.append(item)
seen.update(item._cloned_set)
+ toremove = set(
+ itertools.chain.from_iterable(
+ [_expand_cloned(f._hide_froms) for f in froms]
+ )
+ )
+ if toremove:
+ # filter out to FROM clauses not in the list,
+ # using a list to maintain ordering
+ froms = [f for f in froms if f not in toremove]
+
return froms
def _get_display_froms(
@@ -3490,16 +3500,6 @@ class SelectState(util.MemoizedSlots, CompileState):
froms = self.froms
- toremove = set(
- itertools.chain.from_iterable(
- [_expand_cloned(f._hide_froms) for f in froms]
- )
- )
- if toremove:
- # filter out to FROM clauses not in the list,
- # using a list to maintain ordering
- froms = [f for f in froms if f not in toremove]
-
if self.statement._correlate:
to_correlate = self.statement._correlate
if to_correlate:
@@ -3557,7 +3557,7 @@ class SelectState(util.MemoizedSlots, CompileState):
def _memoized_attr__label_resolve_dict(self):
with_cols = dict(
(c._resolve_label or c._label or c.key, c)
- for c in _select_iterables(self.statement._raw_columns)
+ for c in self.statement._exported_columns_iterator()
if c._allow_label_resolve
)
only_froms = dict(
@@ -3578,6 +3578,10 @@ class SelectState(util.MemoizedSlots, CompileState):
else:
return None
+ @classmethod
+ def exported_columns_iterator(cls, statement):
+ return _select_iterables(statement._raw_columns)
+
def _setup_joins(self, args):
for (right, onclause, left, flags) in args:
isouter = flags["isouter"]
@@ -4599,7 +4603,7 @@ class Select(
pa = None
collection = []
- for c in _select_iterables(self._raw_columns):
+ for c in self._exported_columns_iterator():
# we use key_label since this name is intended for targeting
# within the ColumnCollection only, it's not related to SQL
# rendering which always uses column name for SQL label names
@@ -4630,7 +4634,7 @@ class Select(
return self
def _generate_columns_plus_names(self, anon_for_dupe_key):
- cols = _select_iterables(self._raw_columns)
+ cols = self._exported_columns_iterator()
# when use_labels is on:
# in all cases == if we see the same label name, use _label_anon_label
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index a38088a27..388097e45 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -18,6 +18,7 @@ NO_CACHE = util.symbol("no_cache")
CACHE_IN_PLACE = util.symbol("cache_in_place")
CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key")
STATIC_CACHE_KEY = util.symbol("static_cache_key")
+ANON_NAME = util.symbol("anon_name")
def compare(obj1, obj2, **kw):
@@ -33,6 +34,7 @@ class HasCacheKey(object):
_cache_key_traversal = NO_CACHE
__slots__ = ()
+ @util.preload_module("sqlalchemy.sql.elements")
def _gen_cache_key(self, anon_map, bindparams):
"""return an optional cache key.
@@ -54,6 +56,8 @@ class HasCacheKey(object):
"""
+ elements = util.preloaded.sql_elements
+
idself = id(self)
if anon_map is not None:
@@ -102,6 +106,10 @@ class HasCacheKey(object):
result += (attrname, obj)
elif meth is STATIC_CACHE_KEY:
result += (attrname, obj._static_cache_key)
+ elif meth is ANON_NAME:
+ if elements._anonymous_label in obj.__class__.__mro__:
+ obj = obj.apply_map(anon_map)
+ result += (attrname, obj)
elif meth is CALL_GEN_CACHE_KEY:
result += (
attrname,
@@ -321,6 +329,7 @@ class _CacheKey(ExtendedInternalTraversal):
) = visit_operator = visit_plain_obj = CACHE_IN_PLACE
visit_statement_hint_list = CACHE_IN_PLACE
visit_type = STATIC_CACHE_KEY
+ visit_anon_name = ANON_NAME
def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
@@ -387,15 +396,6 @@ class _CacheKey(ExtendedInternalTraversal):
attrname, obj, parent, anon_map, bindparams
)
- def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams):
- from . import elements
-
- name = obj
- if isinstance(name, elements._anonymous_label):
- name = name.apply_map(anon_map)
-
- return (attrname, name)
-
def visit_fromclause_ordered_set(
self, attrname, obj, parent, anon_map, bindparams
):
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 377aa4fe0..e8726000b 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -822,9 +822,14 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
# is another join or selectable that contains a table which our
# selectable derives from, that we want to process
return None
+
elif not isinstance(col, ColumnElement):
return None
- elif self.include_fn and not self.include_fn(col):
+
+ if "adapt_column" in col._annotations:
+ col = col._annotations["adapt_column"]
+
+ if self.include_fn and not self.include_fn(col):
return None
elif self.exclude_fn and self.exclude_fn(col):
return None
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 683f545dd..5de68f504 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -50,6 +50,13 @@ def _generate_compiler_dispatch(cls):
"""
visit_name = cls.__visit_name__
+ if "_compiler_dispatch" in cls.__dict__:
+ # class has a fixed _compiler_dispatch() method.
+ # copy it to "original" so that we can get it back if
+ # sqlalchemy.ext.compiles overrides it.
+ cls._original_compiler_dispatch = cls._compiler_dispatch
+ return
+
if not isinstance(visit_name, util.compat.string_types):
raise exc.InvalidRequestError(
"__visit_name__ on class %s must be a string at the class level"
@@ -76,7 +83,9 @@ def _generate_compiler_dispatch(cls):
+ self.__visit_name__ on the visitor, and call it with the same
kw params.
"""
- cls._compiler_dispatch = _compiler_dispatch
+ cls._compiler_dispatch = (
+ cls._original_compiler_dispatch
+ ) = _compiler_dispatch
class TraversibleType(type):