summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2019-08-30 22:23:44 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2019-08-30 22:23:44 +0000
commitb83c41c44bad0b166ad9a2355d10641b0310e2fe (patch)
tree9870ea0f1195da751a2c08d33288f74cd3c663e8
parent520f8579d1785e6f906947ff103aaa8db8330621 (diff)
parentf6c9b20a04d183d86078252048563b14e27fb6d2 (diff)
downloadsqlalchemy-b83c41c44bad0b166ad9a2355d10641b0310e2fe.tar.gz
Merge "Annotate session-bind-lookup entity in Query-produced selectables"
-rw-r--r--doc/build/changelog/unreleased_14/4829.rst12
-rw-r--r--lib/sqlalchemy/orm/query.py77
-rw-r--r--lib/sqlalchemy/sql/annotation.py74
-rw-r--r--lib/sqlalchemy/sql/elements.py40
-rw-r--r--lib/sqlalchemy/sql/selectable.py2
-rw-r--r--test/orm/test_query.py110
-rw-r--r--test/sql/test_selectable.py66
7 files changed, 312 insertions, 69 deletions
diff --git a/doc/build/changelog/unreleased_14/4829.rst b/doc/build/changelog/unreleased_14/4829.rst
new file mode 100644
index 000000000..93c582fa2
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/4829.rst
@@ -0,0 +1,12 @@
+.. change::
+ :tags: bug, orm
+ :tickets: 4829
+
+ Added new entity-targeting capabilities to the :class:`.Query` object to
+ help with the case where the :class:`.Session` is using a bind dictionary
+ against mapped classes, rather than a single bind, and the :class:`.Query`
+ is against a Core statement that was ultimately generated from a method
+ such as :meth:`.Query.subquery`; a deep search is performed to locate
+ any ORM entity related to the query in order to locate a mapper if
+ one is not otherwise present.
+
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 936929703..d4ff35d2e 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -384,6 +384,25 @@ class Query(object):
else self._query_entity_zero().entity_zero
)
+ def _deep_entity_zero(self):
+ """Return a 'deep' entity; this is any entity we can find associated
+ with the first entity / column experssion. this is used only for
+ session.get_bind().
+
+ """
+
+ if (
+ self._select_from_entity is not None
+ and not self._select_from_entity.is_clause_element
+ ):
+ return self._select_from_entity.mapper
+ for ent in self._entities:
+ ezero = ent._deep_entity_zero()
+ if ezero is not None:
+ return ezero.mapper
+ else:
+ return None
+
@property
def _mapper_entities(self):
for ent in self._entities:
@@ -394,13 +413,7 @@ class Query(object):
return self._joinpoint.get("_joinpoint_entity", self._entity_zero())
def _bind_mapper(self):
- ezero = self._entity_zero()
- if ezero is not None:
- insp = inspect(ezero)
- if not insp.is_clause_element:
- return insp.mapper
-
- return None
+ return self._deep_entity_zero()
def _only_full_mapper_zero(self, methname):
if self._entities != [self._primary_entity]:
@@ -3900,6 +3913,12 @@ class Query(object):
else:
context.statement = self._simple_statement(context)
+ if for_statement:
+ ezero = self._mapper_zero()
+ if ezero is not None:
+ context.statement = context.statement._annotate(
+ {"deepentity": ezero}
+ )
return context
def _compound_eager_statement(self, context):
@@ -4161,6 +4180,9 @@ class _MapperEntity(_QueryEntity):
def entity_zero_or_selectable(self):
return self.entity_zero
+ def _deep_entity_zero(self):
+ return self.entity_zero
+
def corresponds_to(self, entity):
return _entity_corresponds_to(self.entity_zero, entity)
@@ -4430,6 +4452,14 @@ class _BundleEntity(_QueryEntity):
else:
return None
+ def _deep_entity_zero(self):
+ for ent in self._entities:
+ ezero = ent._deep_entity_zero()
+ if ezero is not None:
+ return ezero
+ else:
+ return None
+
def adapt_to_selectable(self, query, sel):
c = _BundleEntity(query, self.bundle, setup_entities=False)
# c._label_name = self._label_name
@@ -4530,7 +4560,7 @@ class _ColumnEntity(_QueryEntity):
# of FROMs for the overall expression - this helps
# subqueries which were built from ORM constructs from
# leaking out their entities into the main select construct
- self.actual_froms = actual_froms = set(column._from_objects)
+ self.actual_froms = set(column._from_objects)
if not search_entities:
self.entity_zero = _entity
@@ -4540,7 +4570,6 @@ class _ColumnEntity(_QueryEntity):
else:
self.entities = []
self.mapper = None
- self._from_entities = set(self.entities)
else:
all_elements = [
elem
@@ -4551,21 +4580,9 @@ class _ColumnEntity(_QueryEntity):
]
self.entities = util.unique_list(
- [
- elem._annotations["parententity"]
- for elem in all_elements
- if "parententity" in elem._annotations
- ]
+ [elem._annotations["parententity"] for elem in all_elements]
)
- self._from_entities = set(
- [
- elem._annotations["parententity"]
- for elem in all_elements
- if "parententity" in elem._annotations
- and actual_froms.intersection(elem._from_objects)
- ]
- )
if self.entities:
self.entity_zero = self.entities[0]
self.mapper = self.entity_zero.mapper
@@ -4578,6 +4595,22 @@ class _ColumnEntity(_QueryEntity):
supports_single_entity = False
+ def _deep_entity_zero(self):
+ if self.mapper is not None:
+ return self.mapper
+
+ else:
+ for obj in visitors.iterate(
+ self.column,
+ {"column_tables": True, "column_collections": False},
+ ):
+ if "parententity" in obj._annotations:
+ return obj._annotations["parententity"]
+ elif "deepentity" in obj._annotations:
+ return obj._annotations["deepentity"]
+ else:
+ return None
+
@property
def entity_zero_or_selectable(self):
if self.entity_zero is not None:
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index 7fc9245ab..a0264845e 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -15,8 +15,80 @@ from . import operators
from .. import util
+class SupportsCloneAnnotations(object):
+ _annotations = util.immutabledict()
+
+ def _annotate(self, values):
+ """return a copy of this ClauseElement with annotations
+ updated by the given dictionary.
+
+ """
+ new = self._clone()
+ new._annotations = new._annotations.union(values)
+ return new
+
+ def _with_annotations(self, values):
+ """return a copy of this ClauseElement with annotations
+ replaced by the given dictionary.
+
+ """
+ new = self._clone()
+ new._annotations = util.immutabledict(values)
+ return new
+
+ def _deannotate(self, values=None, clone=False):
+ """return a copy of this :class:`.ClauseElement` with annotations
+ removed.
+
+ :param values: optional tuple of individual values
+ to remove.
+
+ """
+ if clone or self._annotations:
+ # clone is used when we are also copying
+ # the expression for a deep deannotation
+ new = self._clone()
+ new._annotations = {}
+ return new
+ else:
+ return self
+
+
+class SupportsWrappingAnnotations(object):
+ def _annotate(self, values):
+ """return a copy of this ClauseElement with annotations
+ updated by the given dictionary.
+
+ """
+ return Annotated(self, values)
+
+ def _with_annotations(self, values):
+ """return a copy of this ClauseElement with annotations
+ replaced by the given dictionary.
+
+ """
+ return Annotated(self, values)
+
+ def _deannotate(self, values=None, clone=False):
+ """return a copy of this :class:`.ClauseElement` with annotations
+ removed.
+
+ :param values: optional tuple of individual values
+ to remove.
+
+ """
+ if clone:
+ # clone is used when we are also copying
+ # the expression for a deep deannotation
+ return self._clone()
+ else:
+ # if no clone, since we have no annotations we return
+ # self
+ return self
+
+
class Annotated(object):
- """clones a ClauseElement and applies an 'annotations' dictionary.
+ """clones a SupportsAnnotated and applies an 'annotations' dictionary.
Unlike regular clones, this clone also mimics __hash__() and
__cmp__() of the original element so that it takes its place
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 42e7522ae..bc6f51b8c 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -22,6 +22,7 @@ from . import operators
from . import roles
from . import type_api
from .annotation import Annotated
+from .annotation import SupportsWrappingAnnotations
from .base import _clone
from .base import _generative
from .base import Executable
@@ -161,7 +162,7 @@ def not_(clause):
@inspection._self_inspects
-class ClauseElement(roles.SQLRole, Visitable):
+class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
"""Base class for elements of a programmatically constructed SQL
expression.
@@ -276,37 +277,6 @@ class ClauseElement(roles.SQLRole, Visitable):
d.pop("_is_clone_of", None)
return d
- def _annotate(self, values):
- """return a copy of this ClauseElement with annotations
- updated by the given dictionary.
-
- """
- return Annotated(self, values)
-
- def _with_annotations(self, values):
- """return a copy of this ClauseElement with annotations
- replaced by the given dictionary.
-
- """
- return Annotated(self, values)
-
- def _deannotate(self, values=None, clone=False):
- """return a copy of this :class:`.ClauseElement` with annotations
- removed.
-
- :param values: optional tuple of individual values
- to remove.
-
- """
- if clone:
- # clone is used when we are also copying
- # the expression for a deep deannotation
- return self._clone()
- else:
- # if no clone, since we have no annotations we return
- # self
- return self
-
def _execute_on_connection(self, connection, multiparams, params):
if self.supports_execution:
return connection._execute_clauseelement(self, multiparams, params)
@@ -4230,6 +4200,12 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
self._memoized_property.expire_instance(self)
self.__dict__["table"] = table
+ def get_children(self, column_tables=False, **kw):
+ if column_tables and self.table is not None:
+ return [self.table]
+ else:
+ return []
+
table = property(_get_table, _set_table)
def _cache_key(self, **kw):
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 03dbcd449..97c49f8fc 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -19,6 +19,7 @@ from . import operators
from . import roles
from . import type_api
from .annotation import Annotated
+from .annotation import SupportsCloneAnnotations
from .base import _clone
from .base import _cloned_difference
from .base import _cloned_intersection
@@ -2068,6 +2069,7 @@ class SelectBase(
roles.InElementRole,
HasCTE,
Executable,
+ SupportsCloneAnnotations,
Selectable,
):
"""Base class for SELECT statements.
diff --git a/test/orm/test_query.py b/test/orm/test_query.py
index f5283ff44..4dff6fe56 100644
--- a/test/orm/test_query.py
+++ b/test/orm/test_query.py
@@ -139,6 +139,23 @@ class RowTupleTest(QueryTest):
assert row.id == 7
assert row.uname == "jack"
+ def test_deep_entity(self):
+ users, User = (self.tables.users, self.classes.User)
+
+ mapper(User, users)
+
+ sess = create_session()
+ bundle = Bundle("b1", User.id, User.name)
+ subq1 = sess.query(User.id).subquery()
+ subq2 = sess.query(bundle).subquery()
+ cte = sess.query(User.id).cte()
+ ex = sess.query(User).exists()
+
+ is_(sess.query(subq1)._deep_entity_zero(), inspect(User))
+ is_(sess.query(subq2)._deep_entity_zero(), inspect(User))
+ is_(sess.query(cte)._deep_entity_zero(), inspect(User))
+ is_(sess.query(ex)._deep_entity_zero(), inspect(User))
+
def test_column_metadata(self):
users, Address, addresses, User = (
self.tables.users,
@@ -156,6 +173,8 @@ class RowTupleTest(QueryTest):
fn = func.count(User.id)
name_label = User.name.label("uname")
bundle = Bundle("b1", User.id, User.name)
+ subq1 = sess.query(User.id).subquery()
+ subq2 = sess.query(bundle).subquery()
cte = sess.query(User.id).cte()
for q, asserted in [
(
@@ -276,6 +295,30 @@ class RowTupleTest(QueryTest):
],
),
(
+ sess.query(subq1.c.id),
+ [
+ {
+ "aliased": False,
+ "expr": subq1.c.id,
+ "type": subq1.c.id.type,
+ "name": "id",
+ "entity": None,
+ }
+ ],
+ ),
+ (
+ sess.query(subq2.c.id),
+ [
+ {
+ "aliased": False,
+ "expr": subq2.c.id,
+ "type": subq2.c.id.type,
+ "name": "id",
+ "entity": None,
+ }
+ ],
+ ),
+ (
sess.query(users),
[
{
@@ -5518,12 +5561,15 @@ class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL):
class SessionBindTest(QueryTest):
@contextlib.contextmanager
- def _assert_bind_args(self, session):
+ def _assert_bind_args(self, session, expect_mapped_bind=True):
get_bind = mock.Mock(side_effect=session.get_bind)
with mock.patch.object(session, "get_bind", get_bind):
yield
for call_ in get_bind.mock_calls:
- is_(call_[1][0], inspect(self.classes.User))
+ if expect_mapped_bind:
+ is_(call_[1][0], inspect(self.classes.User))
+ else:
+ is_(call_[1][0], None)
is_not_(call_[2]["clause"], None)
def test_single_entity_q(self):
@@ -5532,12 +5578,43 @@ class SessionBindTest(QueryTest):
with self._assert_bind_args(session):
session.query(User).all()
+ def test_aliased_entity_q(self):
+ User = self.classes.User
+ u = aliased(User)
+ session = Session()
+ with self._assert_bind_args(session):
+ session.query(u).all()
+
def test_sql_expr_entity_q(self):
User = self.classes.User
session = Session()
with self._assert_bind_args(session):
session.query(User.id).all()
+ def test_sql_expr_subquery_from_entity(self):
+ User = self.classes.User
+ session = Session()
+ with self._assert_bind_args(session):
+ subq = session.query(User.id).subquery()
+ session.query(subq).all()
+
+ def test_sql_expr_cte_from_entity(self):
+ User = self.classes.User
+ session = Session()
+ with self._assert_bind_args(session):
+ cte = session.query(User.id).cte()
+ subq = session.query(cte).subquery()
+ session.query(subq).all()
+
+ def test_sql_expr_bundle_cte_from_entity(self):
+ User = self.classes.User
+ session = Session()
+ with self._assert_bind_args(session):
+ cte = session.query(User.id, User.name).cte()
+ subq = session.query(cte).subquery()
+ bundle = Bundle(subq.c.id, subq.c.name)
+ session.query(bundle).all()
+
def test_count(self):
User = self.classes.User
session = Session()
@@ -5594,6 +5671,35 @@ class SessionBindTest(QueryTest):
with self._assert_bind_args(session):
session.query(func.max(User.score)).scalar()
+ def test_plain_table(self):
+ User = self.classes.User
+
+ session = Session()
+ with self._assert_bind_args(session, expect_mapped_bind=False):
+ session.query(inspect(User).local_table).all()
+
+ def test_plain_table_from_self(self):
+ User = self.classes.User
+
+ session = Session()
+ with self._assert_bind_args(session, expect_mapped_bind=False):
+ session.query(inspect(User).local_table).from_self().all()
+
+ def test_plain_table_count(self):
+ User = self.classes.User
+
+ session = Session()
+ with self._assert_bind_args(session, expect_mapped_bind=False):
+ session.query(inspect(User).local_table).count()
+
+ def test_plain_table_select_from(self):
+ User = self.classes.User
+
+ table = inspect(User).local_table
+ session = Session()
+ with self._assert_bind_args(session, expect_mapped_bind=False):
+ session.query(table).select_from(table).all()
+
@testing.requires.nested_aggregates
def test_column_property_select(self):
User = self.classes.User
diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py
index 189436192..c54f27c23 100644
--- a/test/sql/test_selectable.py
+++ b/test/sql/test_selectable.py
@@ -41,7 +41,10 @@ from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import AssertsExecutionResults
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import in_
from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_not_
+from sqlalchemy.testing import ne_
metadata = MetaData()
@@ -2196,12 +2199,21 @@ class AnnotationsTest(fixtures.TestBase):
t = table("t", column("x"))
a = t.alias()
+
+ for obj in [t, t.c.x, a, t.c.x > 1, (t.c.x > 1).label(None)]:
+ annot = obj._annotate({})
+ eq_(set([obj]), set([annot]))
+
+ def test_clone_annotations_dont_hash(self):
+ t = table("t", column("x"))
+
s = t.select()
+ a = t.alias()
s2 = a.select()
- for obj in [t, t.c.x, a, s, s2, t.c.x > 1, (t.c.x > 1).label(None)]:
+ for obj in [s, s2]:
annot = obj._annotate({})
- eq_(set([obj]), set([annot]))
+ ne_(set([obj]), set([annot]))
def test_compare(self):
t = table("t", column("x"), column("y"))
@@ -2423,7 +2435,7 @@ class AnnotationsTest(fixtures.TestBase):
expected,
)
- def test_deannotate(self):
+ def test_deannotate_wrapping(self):
table1 = table("table1", column("col1"), column("col2"))
bin_ = table1.c.col1 == bindparam("foo", value=None)
@@ -2433,7 +2445,7 @@ class AnnotationsTest(fixtures.TestBase):
b4 = sql_util._deep_deannotate(bin_)
for elem in (b2._annotations, b2.left._annotations):
- assert "_orm_adapt" in elem
+ in_("_orm_adapt", elem)
for elem in (
b3._annotations,
@@ -2441,17 +2453,47 @@ class AnnotationsTest(fixtures.TestBase):
b4._annotations,
b4.left._annotations,
):
- assert elem == {}
+ eq_(elem, {})
- assert b2.left is not bin_.left
- assert b3.left is not b2.left and b2.left is not bin_.left
- assert b4.left is bin_.left # since column is immutable
+ is_not_(b2.left, bin_.left)
+ is_not_(b3.left, b2.left)
+ is_not_(b2.left, bin_.left)
+ is_(b4.left, bin_.left) # since column is immutable
# deannotate copies the element
- assert (
- bin_.right is not b2.right
- and b2.right is not b3.right
- and b3.right is not b4.right
+ is_not_(bin_.right, b2.right)
+ is_not_(b2.right, b3.right)
+ is_not_(b3.right, b4.right)
+
+ def test_deannotate_clone(self):
+ table1 = table("table1", column("col1"), column("col2"))
+
+ subq = (
+ select([table1])
+ .where(table1.c.col1 == bindparam("foo"))
+ .subquery()
)
+ stmt = select([subq])
+
+ s2 = sql_util._deep_annotate(stmt, {"_orm_adapt": True})
+ s3 = sql_util._deep_deannotate(s2)
+ s4 = sql_util._deep_deannotate(s3)
+
+ eq_(stmt._annotations, {})
+ eq_(subq._annotations, {})
+
+ eq_(s2._annotations, {"_orm_adapt": True})
+ eq_(s3._annotations, {})
+ eq_(s4._annotations, {})
+
+ # select._raw_columns[0] is the subq object
+ eq_(s2._raw_columns[0]._annotations, {"_orm_adapt": True})
+ eq_(s3._raw_columns[0]._annotations, {})
+ eq_(s4._raw_columns[0]._annotations, {})
+
+ is_not_(s3, s2)
+ is_not_(s4, s3) # deep deannotate makes a clone unconditionally
+
+ is_(s3._deannotate(), s3) # regular deannotate returns same object
def test_annotate_unique_traversal(self):
"""test that items are copied only once during