summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-09-27 17:32:10 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2019-09-30 10:10:58 -0400
commit6ddb62a8ba66b19afd41b967911ce5982250856e (patch)
treeb98eba226618f2dff301366ac7e384f052feb69e
parent2c34d2503a17316cae3282192405b9b9d60df6fe (diff)
downloadsqlalchemy-6ddb62a8ba66b19afd41b967911ce5982250856e.tar.gz
Simplify _ColumnEntity, related
In the interests of making Query much more lightweight up front, rework the calculations done at the top when the entities are constructed to be much less inolved. Use the new coercion system for _ColumnEntity and stop accepting plain strings, this will need to emit a deprecation warning in 1.3.x. Use annotations and other techniques to reduce the decisionmaking and complexity of Query. For the use case of subquery(), .statement, etc. we would like to do minimal work in order to get the columns clause. Change-Id: I7e459bbd3bb10ec71235f75ef4f3b0a969bec590
-rw-r--r--lib/sqlalchemy/orm/attributes.py20
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py21
-rw-r--r--lib/sqlalchemy/orm/interfaces.py12
-rw-r--r--lib/sqlalchemy/orm/properties.py3
-rw-r--r--lib/sqlalchemy/orm/query.py317
-rw-r--r--lib/sqlalchemy/orm/relationships.py4
-rw-r--r--lib/sqlalchemy/orm/util.py9
-rw-r--r--lib/sqlalchemy/sql/coercions.py17
-rw-r--r--lib/sqlalchemy/sql/util.py16
-rw-r--r--test/aaa_profiling/test_orm.py4
-rw-r--r--test/orm/inheritance/test_single.py2
-rw-r--r--test/orm/test_composites.py16
-rw-r--r--test/orm/test_froms.py14
-rw-r--r--test/orm/test_query.py30
-rw-r--r--test/orm/test_subquery_relations.py3
-rw-r--r--test/orm/test_utils.py19
-rw-r--r--test/profiles.txt48
17 files changed, 294 insertions, 261 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 9d404e00d..117dd4cea 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -168,18 +168,18 @@ class QueryableAttribute(
"""
return inspection.inspect(self._parententity)
- @property
+ @util.memoized_property
def expression(self):
- return self.comparator.__clause_element__()
-
- def __clause_element__(self):
- return self.comparator.__clause_element__()
+ return self.comparator.__clause_element__()._annotate(
+ {"orm_key": self.key}
+ )
- def _query_clause_element(self):
- """like __clause_element__(), but called specifically
- by :class:`.Query` to allow special behavior."""
+ @property
+ def _annotations(self):
+ return self.__clause_element__()._annotations
- return self.comparator._query_clause_element()
+ def __clause_element__(self):
+ return self.expression
def _bulk_update_tuples(self, value):
"""Return setter tuples for a bulk UPDATE."""
@@ -207,7 +207,7 @@ class QueryableAttribute(
)
def label(self, name):
- return self._query_clause_element().label(name)
+ return self.__clause_element__().label(name)
def operate(self, op, *other, **kwargs):
return op(self.comparator, *other, **kwargs)
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 28b3bc5db..075638fed 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -413,19 +413,26 @@ class CompositeProperty(DescriptorProperty):
__hash__ = None
- @property
+ @util.memoized_property
def clauses(self):
- return self.__clause_element__()
-
- def __clause_element__(self):
return expression.ClauseList(
group=False, *self._comparable_elements
)
- def _query_clause_element(self):
- return CompositeProperty.CompositeBundle(
- self.prop, self.__clause_element__()
+ def __clause_element__(self):
+ return self.expression
+
+ @util.memoized_property
+ def expression(self):
+ clauses = self.clauses._annotate(
+ {
+ "bundle": True,
+ "parententity": self._parententity,
+ "parentmapper": self._parententity,
+ "orm_key": self.prop.key,
+ }
)
+ return CompositeProperty.CompositeBundle(self.prop, clauses)
def _bulk_update_tuples(self, value):
if value is None:
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 5098a55ce..d6bdfb924 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -363,6 +363,7 @@ class PropComparator(operators.ColumnOperators):
__slots__ = "prop", "property", "_parententity", "_adapt_to_entity"
def __init__(self, prop, parentmapper, adapt_to_entity=None):
+ # type: (MapperProperty, Mapper, Optional(AliasedInsp))
self.prop = self.property = prop
self._parententity = adapt_to_entity or parentmapper
self._adapt_to_entity = adapt_to_entity
@@ -370,10 +371,15 @@ class PropComparator(operators.ColumnOperators):
def __clause_element__(self):
raise NotImplementedError("%r" % self)
- def _query_clause_element(self):
- return self.__clause_element__()
-
def _bulk_update_tuples(self, value):
+ # type: (ColumnOperators) -> List[tuple[ColumnOperators, Any]]
+ """Receive a SQL expression that represents a value in the SET
+ clause of an UPDATE statement.
+
+ Return a tuple that can be passed to a :class:`.Update` construct.
+
+ """
+
return [(self.__clause_element__(), value)]
def adapt_to_entity(self, adapt_to_entity):
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index e2c10e50a..f804d6eed 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -292,7 +292,7 @@ class ColumnProperty(StrategizedProperty):
def _memoized_method___clause_element__(self):
if self.adapter:
- return self.adapter(self.prop.columns[0])
+ return self.adapter(self.prop.columns[0], self.prop.key)
else:
# no adapter, so we aren't aliased
# assert self._parententity is self._parentmapper
@@ -300,6 +300,7 @@ class ColumnProperty(StrategizedProperty):
{
"parententity": self._parententity,
"parentmapper": self._parententity,
+ "orm_key": self.prop.key,
}
)
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 37bd77f63..3d08dce22 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -114,7 +114,6 @@ class Query(Generative):
_from_obj = ()
_join_entities = ()
_select_from_entity = None
- _mapper_adapter_map = {}
_filter_aliases = ()
_from_obj_alias = None
_joinpath = _joinpoint = util.immutabledict()
@@ -177,61 +176,23 @@ class Query(Generative):
self._primary_entity = None
self._has_mapper_entities = False
- # 1. don't run util.to_list() or _set_entity_selectables
- # if no entities were passed - major performance bottleneck
- # from lazy loader implementation when it seeks to use Query
- # class for an identity lookup, causes test_orm.py to fail
- # with thousands of extra function calls, see issue #4228
- # for why this use had to be added
- # 2. can't use classmethod on Query because session.query_cls
- # is an arbitrary callable in some user recipes, not
- # necessarily a class, so we don't have the class available.
- # see issue #4256
- # 3. can't do "if entities is not None" because we usually get here
- # from session.query() which takes in *entities.
- # 4. can't do "if entities" because users make use of undocumented
- # to_list() behavior here and they pass clause expressions that
- # can't be evaluated as boolean. See issue #4269.
- # 5. the empty tuple is a singleton in cPython, take advantage of this
- # so that we can skip for the empty "*entities" case without using
- # any Python overloadable operators.
- #
if entities is not ():
for ent in util.to_list(entities):
entity_wrapper(self, ent)
- self._set_entity_selectables(self._entities)
-
- def _set_entity_selectables(self, entities):
- self._mapper_adapter_map = d = self._mapper_adapter_map.copy()
-
- for ent in entities:
- for entity in ent.entities:
- if entity not in d:
- ext_info = inspect(entity)
- if (
- not ext_info.is_aliased_class
- and ext_info.mapper.with_polymorphic
- ):
- if (
- ext_info.mapper.persist_selectable
- not in self._polymorphic_adapters
- ):
- self._mapper_loads_polymorphically_with(
- ext_info.mapper,
- sql_util.ColumnAdapter(
- ext_info.selectable,
- ext_info.mapper._equivalent_columns,
- ),
- )
- aliased_adapter = None
- elif ext_info.is_aliased_class:
- aliased_adapter = ext_info._adapter
- else:
- aliased_adapter = None
-
- d[entity] = (ext_info, aliased_adapter)
- ent.setup_entity(*d[entity])
+ def _setup_query_adapters(self, entity, ext_info):
+ if not ext_info.is_aliased_class and ext_info.mapper.with_polymorphic:
+ if (
+ ext_info.mapper.persist_selectable
+ not in self._polymorphic_adapters
+ ):
+ self._mapper_loads_polymorphically_with(
+ ext_info.mapper,
+ sql_util.ColumnAdapter(
+ ext_info.selectable,
+ ext_info.mapper._equivalent_columns,
+ ),
+ )
def _mapper_loads_polymorphically_with(self, mapper, adapter):
for m2 in mapper._with_polymorphic_mappers or [mapper]:
@@ -1162,8 +1123,7 @@ class Query(Generative):
entity = aliased(entity, alias)
self._entities = list(self._entities)
- m = _MapperEntity(self, entity)
- self._set_entity_selectables([m])
+ _MapperEntity(self, entity)
@_generative
def with_session(self, session):
@@ -1455,12 +1415,9 @@ class Query(Generative):
of result columns to be returned."""
self._entities = list(self._entities)
- l = len(self._entities)
+
for c in column:
_ColumnEntity(self, c)
- # _ColumnEntity may add many entities if the
- # given arg is a FROM clause
- self._set_entity_selectables(self._entities[l:])
@util.pending_deprecation(
"0.7",
@@ -2464,9 +2421,13 @@ class Query(Generative):
)
else:
# add a new element to the self._from_obj list
-
if use_entity_index is not None:
- # why doesn't this work as .entity_zero_or_selectable?
+ # make use of _MapperEntity selectable, which is usually
+ # entity_zero.selectable, but if with_polymorphic() were used
+ # might be distinct
+ assert isinstance(
+ self._entities[use_entity_index], _MapperEntity
+ )
left_clause = self._entities[use_entity_index].selectable
else:
left_clause = left
@@ -3529,7 +3490,7 @@ class Query(Generative):
# we get just "SELECT 1" without any entities.
return sql.exists(
self.enable_eagerloads(False)
- .add_columns("1")
+ .add_columns(sql.literal_column("1"))
.with_labels()
.statement.with_only_columns([1])
)
@@ -4029,10 +3990,10 @@ class Query(Generative):
"""
- search = set(self._mapper_adapter_map.values())
+ search = set(context.single_inh_entities.values())
if (
self._select_from_entity
- and self._select_from_entity not in self._mapper_adapter_map
+ and self._select_from_entity not in context.single_inh_entities
):
insp = inspect(self._select_from_entity)
if insp.is_aliased_class:
@@ -4110,23 +4071,27 @@ class _MapperEntity(_QueryEntity):
self.entities = [entity]
self.expr = entity
- supports_single_entity = True
-
- use_id_for_hash = True
+ ext_info = self.entity_zero = inspect(entity)
- def setup_entity(self, ext_info, aliased_adapter):
self.mapper = ext_info.mapper
- self.aliased_adapter = aliased_adapter
+
+ if ext_info.is_aliased_class:
+ self._label_name = ext_info.name
+ else:
+ self._label_name = self.mapper.class_.__name__
+
self.selectable = ext_info.selectable
self.is_aliased_class = ext_info.is_aliased_class
self._with_polymorphic = ext_info.with_polymorphic_mappers
self._polymorphic_discriminator = ext_info.polymorphic_on
- self.entity_zero = ext_info
- if ext_info.is_aliased_class:
- self._label_name = self.entity_zero.name
- else:
- self._label_name = self.mapper.class_.__name__
- self.path = self.entity_zero._path_registry
+ self.path = ext_info._path_registry
+
+ if ext_info.mapper.with_polymorphic:
+ query._setup_query_adapters(entity, ext_info)
+
+ supports_single_entity = True
+
+ use_id_for_hash = True
def set_with_polymorphic(
self, query, cls_or_mappers, selectable, polymorphic_on
@@ -4185,7 +4150,7 @@ class _MapperEntity(_QueryEntity):
if query._polymorphic_adapters:
adapter = query._polymorphic_adapters.get(self.mapper, None)
else:
- adapter = self.aliased_adapter
+ adapter = self.entity_zero._adapter
if adapter:
if query._from_obj_alias:
@@ -4235,6 +4200,14 @@ class _MapperEntity(_QueryEntity):
def setup_context(self, query, context):
adapter = self._get_entity_clauses(query, context)
+ single_table_crit = self.mapper._single_table_criterion
+ if single_table_crit is not None:
+ ext_info = self.entity_zero
+ context.single_inh_entities[ext_info] = (
+ ext_info,
+ ext_info._adapter if ext_info.is_aliased_class else None,
+ )
+
# if self._adapted_selectable is None:
context.froms += (self.selectable,)
@@ -4352,7 +4325,9 @@ class Bundle(InspectionAttr):
return cloned
def __clause_element__(self):
- return expression.ClauseList(group=False, *self.exprs)
+ return expression.ClauseList(group=False, *self.exprs)._annotate(
+ {"bundle": True}
+ )
@property
def clauses(self):
@@ -4386,8 +4361,19 @@ class Bundle(InspectionAttr):
class _BundleEntity(_QueryEntity):
use_id_for_hash = False
- def __init__(self, query, bundle, setup_entities=True):
- query._entities.append(self)
+ def __init__(self, query, expr, setup_entities=True, parent_bundle=None):
+ if parent_bundle:
+ parent_bundle._entities.append(self)
+ else:
+ query._entities.append(self)
+
+ if isinstance(
+ expr, (attributes.QueryableAttribute, interfaces.PropComparator)
+ ):
+ bundle = expr.__clause_element__()
+ else:
+ bundle = expr
+
self.bundle = self.expr = bundle
self.type = type(bundle)
self._label_name = bundle.name
@@ -4396,9 +4382,9 @@ class _BundleEntity(_QueryEntity):
if setup_entities:
for expr in bundle.exprs:
if isinstance(expr, Bundle):
- _BundleEntity(self, expr)
+ _BundleEntity(query, expr, parent_bundle=self)
else:
- _ColumnEntity(self, expr)
+ _ColumnEntity(query, expr, parent_bundle=self)
self.supports_single_entity = self.bundle.single_entity
@@ -4448,18 +4434,19 @@ class _BundleEntity(_QueryEntity):
else:
return None
- def adapt_to_selectable(self, query, sel):
- c = _BundleEntity(query, self.bundle, setup_entities=False)
+ def adapt_to_selectable(self, query, sel, parent_bundle=None):
+ c = _BundleEntity(
+ query,
+ self.bundle,
+ setup_entities=False,
+ parent_bundle=parent_bundle,
+ )
# c._label_name = self._label_name
# c.entity_zero = self.entity_zero
# c.entities = self.entities
for ent in self._entities:
- ent.adapt_to_selectable(c, sel)
-
- def setup_entity(self, ext_info, aliased_adapter):
- for ent in self._entities:
- ent.setup_entity(ext_info, aliased_adapter)
+ ent.adapt_to_selectable(query, sel, parent_bundle=c)
def setup_context(self, query, context):
for ent in self._entities:
@@ -4481,76 +4468,52 @@ class _BundleEntity(_QueryEntity):
class _ColumnEntity(_QueryEntity):
"""Column/expression based entity."""
- def __init__(self, query, column, namespace=None):
- self.expr = column
+ froms = frozenset()
+
+ def __init__(self, query, column, namespace=None, parent_bundle=None):
+ self.expr = expr = column
self.namespace = namespace
- search_entities = True
- check_column = False
-
- if isinstance(column, util.string_types):
- column = sql.literal_column(column)
- self._label_name = column.name
- search_entities = False
- check_column = True
- _entity = None
- elif isinstance(
- column, (attributes.QueryableAttribute, interfaces.PropComparator)
- ):
- _entity = getattr(column, "_parententity", None)
- if _entity is not None:
- search_entities = False
- self._label_name = column.key
- column = column._query_clause_element()
- check_column = True
- if isinstance(column, Bundle):
- _BundleEntity(query, column)
- return
+ _label_name = None
- if not isinstance(column, sql.ColumnElement):
- if hasattr(column, "_select_iterable"):
- # break out an object like Table into
- # individual columns
- for c in column._select_iterable:
- if c is column:
- break
- _ColumnEntity(query, c, namespace=column)
- else:
- return
+ column = coercions.expect(roles.ColumnsClauseRole, column)
- raise sa_exc.InvalidRequestError(
- "SQL expression, column, or mapped entity "
- "expected - got '%r'" % (column,)
- )
- elif not check_column:
+ annotations = column._annotations
+
+ if annotations.get("bundle", False):
+ _BundleEntity(query, expr, parent_bundle=parent_bundle)
+ return
+
+ orm_expr = False
+
+ if "parententity" in annotations:
+ _entity = annotations["parententity"]
+ self._label_name = _label_name = annotations.get("orm_key", None)
+ orm_expr = True
+
+ if hasattr(column, "_select_iterable"):
+ # break out an object like Table into
+ # individual columns
+ for c in column._select_iterable:
+ if c is column:
+ break
+ _ColumnEntity(query, c, namespace=column)
+ else:
+ return
+
+ if _label_name is None:
self._label_name = getattr(column, "key", None)
- search_entities = True
self.type = type_ = column.type
self.use_id_for_hash = not type_.hashable
- # If the Column is unnamed, give it a
- # label() so that mutable column expressions
- # can be located in the result even
- # if the expression's identity has been changed
- # due to adaption.
-
- if not column._label and not getattr(column, "is_literal", False):
- column = column.label(self._label_name)
-
- query._entities.append(self)
+ if parent_bundle:
+ parent_bundle._entities.append(self)
+ else:
+ query._entities.append(self)
self.column = column
- self.froms = set()
-
- # look for ORM entities represented within the
- # given expression. Try to count only entities
- # for columns whose FROM object is in the actual list
- # of FROMs for the overall expression - this helps
- # subqueries which were built from ORM constructs from
- # leaking out their entities into the main select construct
- self.actual_froms = set(column._from_objects)
- if not search_entities:
+ if orm_expr:
self.entity_zero = _entity
if _entity:
self.entities = [_entity]
@@ -4559,21 +4522,20 @@ class _ColumnEntity(_QueryEntity):
self.entities = []
self.mapper = None
else:
- all_elements = [
- elem
- for elem in sql_util.surface_column_elements(
- column, include_scalar_selects=False
- )
- if "parententity" in elem._annotations
- ]
- self.entities = util.unique_list(
- [elem._annotations["parententity"] for elem in all_elements]
+ entity = sql_util.extract_first_column_annotation(
+ column, "parententity"
)
+ if entity:
+ self.entities = [entity]
+ else:
+ self.entities = []
+
if self.entities:
self.entity_zero = self.entities[0]
self.mapper = self.entity_zero.mapper
+
elif self.namespace is not None:
self.entity_zero = self.namespace
self.mapper = None
@@ -4581,6 +4543,9 @@ class _ColumnEntity(_QueryEntity):
self.entity_zero = None
self.mapper = None
+ if self.entities and self.entity_zero.mapper.with_polymorphic:
+ query._setup_query_adapters(self.entity_zero, self.entity_zero)
+
supports_single_entity = False
def _deep_entity_zero(self):
@@ -4603,24 +4568,21 @@ class _ColumnEntity(_QueryEntity):
def entity_zero_or_selectable(self):
if self.entity_zero is not None:
return self.entity_zero
- elif self.actual_froms:
- return list(self.actual_froms)[0]
+ elif self.column._from_objects:
+ return self.column._from_objects[0]
else:
return None
- def adapt_to_selectable(self, query, sel):
- c = _ColumnEntity(query, sel.corresponding_column(self.column))
+ def adapt_to_selectable(self, query, sel, parent_bundle=None):
+ c = _ColumnEntity(
+ query,
+ sel.corresponding_column(self.column),
+ parent_bundle=parent_bundle,
+ )
c._label_name = self._label_name
c.entity_zero = self.entity_zero
c.entities = self.entities
- def setup_entity(self, ext_info, aliased_adapter):
- if "selectable" not in self.__dict__:
- self.selectable = ext_info.selectable
-
- if self.actual_froms.intersection(ext_info.selectable._from_objects):
- self.froms.add(ext_info.selectable)
-
def corresponds_to(self, entity):
if self.entity_zero is None:
return False
@@ -4651,13 +4613,32 @@ class _ColumnEntity(_QueryEntity):
def setup_context(self, query, context):
column = query._adapt_clause(self.column, False, True)
+ ezero = self.entity_zero
+
+ if self.mapper:
+ single_table_crit = self.mapper._single_table_criterion
+ if single_table_crit is not None:
+ context.single_inh_entities[ezero] = (
+ ezero,
+ ezero._adapter if ezero.is_aliased_class else None,
+ )
if column._annotations:
# annotated columns perform more slowly in compiler and
# result due to the __eq__() method, so use deannotated
column = column._deannotate()
- context.froms += tuple(self.froms)
+ if ezero is not None:
+ # use entity_zero as the from if we have it. this is necessary
+ # for polymorpic scenarios where our FROM is based on ORM entity,
+ # not the FROM of the column. but also, don't use it if our column
+ # doesn't actually have any FROMs that line up, such as when its
+ # a scalar subquery.
+ if set(self.column._from_objects).intersection(
+ ezero.selectable._from_objects
+ ):
+ context.froms += (ezero.selectable,)
+
context.primary_columns.append(column)
context.attributes[("fetch_column", self)] = column
@@ -4697,6 +4678,7 @@ class QueryContext(object):
"partials",
"post_load_paths",
"identity_token",
+ "single_inh_entities",
)
def __init__(self, query):
@@ -4731,6 +4713,7 @@ class QueryContext(object):
self.secondary_columns = []
self.eager_order_by = []
self.eager_joins = {}
+ self.single_inh_entities = {}
self.create_eager_joins = []
self.propagate_options = set(
o for o in query._with_options if o.propagate_to_loaders
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 63ec21099..731947cba 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -2319,11 +2319,11 @@ class JoinCondition(object):
"""
self.primaryjoin = _deep_deannotate(
- self.primaryjoin, values=("parententity",)
+ self.primaryjoin, values=("parententity", "orm_key")
)
if self.secondaryjoin is not None:
self.secondaryjoin = _deep_deannotate(
- self.secondaryjoin, values=("parententity",)
+ self.secondaryjoin, values=("parententity", "orm_key")
)
def _determine_joins(self):
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 4b4fa4052..747ec7e65 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -668,10 +668,11 @@ class AliasedInsp(InspectionAttr):
state["represents_outer_join"],
)
- def _adapt_element(self, elem):
- return self._adapter.traverse(elem)._annotate(
- {"parententity": self, "parentmapper": self.mapper}
- )
+ def _adapt_element(self, elem, key=None):
+ d = {"parententity": self, "parentmapper": self.mapper}
+ if key:
+ d["orm_key"] = key
+ return self._adapter.traverse(elem)._annotate(d)
def _entity_for_mapper(self, mapper):
self_poly = self.with_polymorphic_mappers
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index a7a856bba..95aee0468 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -57,7 +57,7 @@ def expect(role, element, **kw):
else:
resolved = element
- if issubclass(resolved.__class__, impl._role_class):
+ if impl._role_class in resolved.__class__.__mro__:
if impl._post_coercion:
resolved = impl._post_coercion(resolved, **kw)
return resolved
@@ -102,13 +102,16 @@ class RoleImpl(object):
def _resolve_for_clause_element(self, element, argname=None, **kw):
original_element = element
- is_clause_element = False
+ is_clause_element = hasattr(element, "__clause_element__")
- while hasattr(element, "__clause_element__") and not isinstance(
- element, (elements.ClauseElement, schema.SchemaItem)
- ):
- element = element.__clause_element__()
- is_clause_element = True
+ if is_clause_element:
+ while not isinstance(
+ element, (elements.ClauseElement, schema.SchemaItem)
+ ):
+ try:
+ element = element.__clause_element__()
+ except AttributeError:
+ break
if not is_clause_element:
if self._use_inspection:
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index fe83b163c..3c7f904de 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -364,23 +364,19 @@ def surface_selectables_only(clause):
stack.append(elem.table)
-def surface_column_elements(clause, include_scalar_selects=True):
- """traverse and yield only outer-exposed column elements, such as would
- be addressable in the WHERE clause of a SELECT if this element were
- in the columns clause."""
+def extract_first_column_annotation(column, annotation_name):
+ filter_ = (FromGrouping, SelectBase)
- filter_ = (FromGrouping,)
- if not include_scalar_selects:
- filter_ += (SelectBase,)
-
- stack = deque([clause])
+ stack = deque([column])
while stack:
elem = stack.popleft()
- yield elem
+ if annotation_name in elem._annotations:
+ return elem._annotations[annotation_name]
for sub in elem.get_children():
if isinstance(sub, filter_):
continue
stack.append(sub)
+ return None
def selectables_overlap(left, right):
diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py
index 4e52a7778..632f55937 100644
--- a/test/aaa_profiling/test_orm.py
+++ b/test/aaa_profiling/test_orm.py
@@ -560,6 +560,10 @@ class QueryTest(fixtures.MappedTest):
self._fixture()
sess = Session()
+ # warm up cache
+ for attr in [Parent.data1, Parent.data2, Parent.data3, Parent.data4]:
+ attr.__clause_element__()
+
@profiling.function_call_count()
def go():
for i in range(10):
diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py
index 7b8d413a4..d0db76b21 100644
--- a/test/orm/inheritance/test_single.py
+++ b/test/orm/inheritance/test_single.py
@@ -287,7 +287,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest):
self.assert_compile(
sess.query(literal("1")).select_from(a1),
- "SELECT :param_1 AS param_1 FROM employees AS employees_1 "
+ "SELECT :param_1 AS anon_1 FROM employees AS employees_1 "
"WHERE employees_1.type IN (:type_1, :type_2)",
)
diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py
index 7247c859a..0d679e6db 100644
--- a/test/orm/test_composites.py
+++ b/test/orm/test_composites.py
@@ -20,6 +20,8 @@ from sqlalchemy.testing.schema import Table
class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
+ __dialect__ = "default"
+
@classmethod
def define_tables(cls, metadata):
Table(
@@ -311,6 +313,20 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
[(Point(3, 4), Point(5, 6))],
)
+ def test_cols_as_core_clauseelement(self):
+ Edge = self.classes.Edge
+ Point = self.classes.Point
+
+ start, end = Edge.start, Edge.end
+
+ stmt = select([start, end]).where(start == Point(3, 4))
+ self.assert_compile(
+ stmt,
+ "SELECT edges.x1, edges.y1, edges.x2, edges.y2 "
+ "FROM edges WHERE edges.x1 = :x1_1 AND edges.y1 = :y1_1",
+ checkparams={"x1_1": 3, "y1_1": 4},
+ )
+
def test_query_cols_labeled(self):
Edge = self.classes.Edge
Point = self.classes.Point
diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py
index 498b68057..efa45affa 100644
--- a/test/orm/test_froms.py
+++ b/test/orm/test_froms.py
@@ -7,7 +7,6 @@ from sqlalchemy import exc as sa_exc
from sqlalchemy import exists
from sqlalchemy import ForeignKey
from sqlalchemy import func
-from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import literal_column
from sqlalchemy import select
@@ -2211,7 +2210,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
sess.expunge_all()
assert_raises(
- sa_exc.InvalidRequestError, sess.query(User).add_column, object()
+ sa_exc.ArgumentError, sess.query(User).add_column, object()
)
def test_add_multi_columns(self):
@@ -2270,7 +2269,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
.order_by(User.id)
)
q = sess.query(User)
- result = q.add_column("count").from_statement(s).all()
+ result = q.add_column(s.selected_columns.count).from_statement(s).all()
assert result == expected
def test_raw_columns(self):
@@ -2315,7 +2314,10 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
)
q = create_session().query(User)
result = (
- q.add_column("count").add_column("concat").from_statement(s).all()
+ q.add_column(s.selected_columns.count)
+ .add_column(s.selected_columns.concat)
+ .from_statement(s)
+ .all()
)
assert result == expected
@@ -2399,7 +2401,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
]:
q = s.query(crit)
mzero = q._entity_zero()
- is_(mzero.persist_selectable, q._query_entity_zero().selectable)
+ is_(mzero, q._query_entity_zero().entity_zero)
q = q.join(j)
self.assert_compile(q, exp)
@@ -2429,7 +2431,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
]:
q = s.query(crit)
mzero = q._entity_zero()
- is_(inspect(mzero).selectable, q._query_entity_zero().selectable)
+ is_(mzero, q._query_entity_zero().entity_zero)
q = q.join(j)
self.assert_compile(q, exp)
diff --git a/test/orm/test_query.py b/test/orm/test_query.py
index 4dff6fe56..bcd13e6e2 100644
--- a/test/orm/test_query.py
+++ b/test/orm/test_query.py
@@ -1005,14 +1005,14 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL):
s = create_session()
q = s.query(User)
- assert_raises(sa_exc.InvalidRequestError, q.add_column, object())
+ assert_raises(sa_exc.ArgumentError, q.add_column, object())
def test_invalid_column_tuple(self):
User = self.classes.User
s = create_session()
q = s.query(User)
- assert_raises(sa_exc.InvalidRequestError, q.add_column, (1, 1))
+ assert_raises(sa_exc.ArgumentError, q.add_column, (1, 1))
def test_distinct(self):
"""test that a distinct() call is not valid before 'clauseelement'
@@ -2449,6 +2449,9 @@ class ComparatorTest(QueryTest):
def __clause_element__(self):
return self.expr
+ # this use case isn't exactly needed in this form, however it tests
+ # that we resolve for multiple __clause_element__() calls as is needed
+ # by systems like composites
sess = Session()
eq_(
sess.query(Comparator(User.id))
@@ -3398,11 +3401,11 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL):
q3,
"SELECT anon_1.users_id AS anon_1_users_id, "
"anon_1.users_name AS anon_1_users_name, "
- "anon_1.param_1 AS anon_1_param_1 "
- "FROM (SELECT users.id AS users_id, users.name AS "
- "users_name, :param_1 AS param_1 "
- "FROM users UNION SELECT users.id AS users_id, "
- "users.name AS users_name, 'y' FROM users) AS anon_1",
+ "anon_1.anon_2 AS anon_1_anon_2 FROM "
+ "(SELECT users.id AS users_id, users.name AS users_name, "
+ ":param_1 AS anon_2 FROM users "
+ "UNION SELECT users.id AS users_id, users.name AS users_name, "
+ "'y' FROM users) AS anon_1",
)
def test_union_literal_expressions_results(self):
@@ -3410,7 +3413,8 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL):
s = Session()
- q1 = s.query(User, literal("x"))
+ x_literal = literal("x")
+ q1 = s.query(User, x_literal)
q2 = s.query(User, literal_column("'y'"))
q3 = q1.union(q2)
@@ -3421,7 +3425,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL):
eq_([x["name"] for x in q6.column_descriptions], ["User", "foo"])
for q in (
- q3.order_by(User.id, text("anon_1_param_1")),
+ q3.order_by(User.id, x_literal),
q6.order_by(User.id, "foo"),
):
eq_(
@@ -4231,12 +4235,14 @@ class TextTest(QueryTest, AssertsCompiledSQL):
User = self.classes.User
s = create_session()
- assert_raises(
- sa_exc.InvalidRequestError, s.query, User.id, text("users.name")
+
+ self.assert_compile(
+ s.query(User.id, text("users.name")),
+ "SELECT users.id AS users_id, users.name FROM users",
)
eq_(
- s.query(User.id, "name").order_by(User.id).all(),
+ s.query(User.id, literal_column("name")).order_by(User.id).all(),
[(7, "jack"), (8, "ed"), (9, "fred"), (10, "chuck")],
)
diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py
index b32b6547f..03e17d291 100644
--- a/test/orm/test_subquery_relations.py
+++ b/test/orm/test_subquery_relations.py
@@ -3,6 +3,7 @@ from sqlalchemy import bindparam
from sqlalchemy import ForeignKey
from sqlalchemy import inspect
from sqlalchemy import Integer
+from sqlalchemy import literal_column
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import testing
@@ -792,7 +793,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
sess = create_session()
self.assert_compile(
- sess.query(User, "1"),
+ sess.query(User, literal_column("1")),
"SELECT users.id AS users_id, users.name AS users_name, "
"1 FROM users",
)
diff --git a/test/orm/test_utils.py b/test/orm/test_utils.py
index e47fc3f26..4bc2a5c88 100644
--- a/test/orm/test_utils.py
+++ b/test/orm/test_utils.py
@@ -210,7 +210,24 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL):
eq_(str(alias.x + 1), "point_1.x + :x_1")
eq_(str(alias.x_alone + 1), "point_1.x + :x_1")
- is_(Point.x_alone.__clause_element__(), Point.x.__clause_element__())
+ point_mapper = inspect(Point)
+
+ eq_(
+ Point.x_alone._annotations,
+ {
+ "parententity": point_mapper,
+ "parentmapper": point_mapper,
+ "orm_key": "x_alone",
+ },
+ )
+ eq_(
+ Point.x._annotations,
+ {
+ "parententity": point_mapper,
+ "parentmapper": point_mapper,
+ "orm_key": "x",
+ },
+ )
eq_(str(alias.x_alone == alias.x), "point_1.x = point_1.x")
diff --git a/test/profiles.txt b/test/profiles.txt
index 0750ab767..e4b99ba68 100644
--- a/test/profiles.txt
+++ b/test/profiles.txt
@@ -13,22 +13,22 @@
# TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mssql_pyodbc_dbapiunicode_cextensions 70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mssql_pyodbc_dbapiunicode_nocextensions 70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_mysqldb_dbapiunicode_cextensions 70,70,70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_mysqldb_dbapiunicode_nocextensions 70,70,70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_pymysql_dbapiunicode_cextensions 70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_pymysql_dbapiunicode_nocextensions 70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_oracle_cx_oracle_dbapiunicode_cextensions 70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_oracle_cx_oracle_dbapiunicode_nocextensions 70,70
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mssql_pyodbc_dbapiunicode_cextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mssql_pyodbc_dbapiunicode_nocextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_mysqldb_dbapiunicode_cextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_mysqldb_dbapiunicode_nocextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_pymysql_dbapiunicode_cextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_pymysql_dbapiunicode_nocextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_oracle_cx_oracle_dbapiunicode_cextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_oracle_cx_oracle_dbapiunicode_nocextensions 66
test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_postgresql_psycopg2_dbapiunicode_cextensions 67
test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 67
test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_sqlite_pysqlite_dbapiunicode_cextensions 67
test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 67
-test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_mysqldb_dbapiunicode_cextensions 73,73,73,73
-test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_mysqldb_dbapiunicode_nocextensions 73,73,73,73
-test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_pymysql_dbapiunicode_cextensions 73,73
-test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_pymysql_dbapiunicode_nocextensions 73,73
+test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_mysqldb_dbapiunicode_cextensions 73
+test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_mysqldb_dbapiunicode_nocextensions 73
+test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_pymysql_dbapiunicode_cextensions 73
+test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_pymysql_dbapiunicode_nocextensions 73
test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_oracle_cx_oracle_dbapiunicode_cextensions 73
test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_oracle_cx_oracle_dbapiunicode_nocextensions 73
test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_postgresql_psycopg2_dbapiunicode_cextensions 72
@@ -523,24 +523,14 @@ test.aaa_profiling.test_orm.MergeTest.test_merge_no_load 3.7_sqlite_pysqlite_dba
# TEST: test.aaa_profiling.test_orm.QueryTest.test_query_cols
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_mssql_pyodbc_dbapiunicode_cextensions 6230
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_mssql_pyodbc_dbapiunicode_nocextensions 6780
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_mysql_mysqldb_dbapiunicode_cextensions 6290
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_mysql_mysqldb_dbapiunicode_nocextensions 6840
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_oracle_cx_oracle_dbapiunicode_cextensions 6360
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_oracle_cx_oracle_dbapiunicode_nocextensions 8190
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_postgresql_psycopg2_dbapiunicode_cextensions 6100
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6641
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_sqlite_pysqlite_dbapiunicode_cextensions 6035
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_postgresql_psycopg2_dbapiunicode_cextensions 5900
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6441
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_sqlite_pysqlite_dbapiunicode_cextensions 5900
test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 6585
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_mysql_mysqldb_dbapiunicode_cextensions 6483
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_mysql_mysqldb_dbapiunicode_nocextensions 7143
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_oracle_cx_oracle_dbapiunicode_cextensions 6473
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_oracle_cx_oracle_dbapiunicode_nocextensions 7043
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6464
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 7034
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_sqlite_pysqlite_dbapiunicode_cextensions 6326
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_sqlite_pysqlite_dbapiunicode_nocextensions 6806
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6138
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 6800
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_sqlite_pysqlite_dbapiunicode_cextensions 6226
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_sqlite_pysqlite_dbapiunicode_nocextensions 6506
# TEST: test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results