summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2020-12-12 04:04:53 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2020-12-12 04:04:53 +0000
commitd55971119aa48590416193b8b0e0f54aa0e97c82 (patch)
tree0ee97b9ae789cc5e5a4c327e1b9189ad2fbb1974
parenta66ef01e052d8f64b4b9bf90745a8ce84ff86109 (diff)
parented20e2f95f52a072d0c6b09af095b4cda0436d38 (diff)
downloadsqlalchemy-d55971119aa48590416193b8b0e0f54aa0e97c82.tar.gz
Merge "Fixes for lambda expressions and relationship loaders"
-rw-r--r--doc/build/changelog/unreleased_14/5763.rst9
-rw-r--r--doc/build/changelog/unreleased_14/5764.rst9
-rw-r--r--lib/sqlalchemy/orm/query.py2
-rw-r--r--lib/sqlalchemy/orm/session.py2
-rw-r--r--lib/sqlalchemy/orm/strategies.py13
-rw-r--r--lib/sqlalchemy/sql/base.py20
-rw-r--r--lib/sqlalchemy/sql/lambdas.py7
-rw-r--r--test/orm/test_events.py142
-rw-r--r--test/sql/test_lambdas.py45
-rw-r--r--test/sql/test_utils.py29
10 files changed, 258 insertions, 20 deletions
diff --git a/doc/build/changelog/unreleased_14/5763.rst b/doc/build/changelog/unreleased_14/5763.rst
new file mode 100644
index 000000000..e395b6fcf
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/5763.rst
@@ -0,0 +1,9 @@
+.. change::
+ :tags: bug, orm
+ :tickets: 5763
+
+ Fixed bug in lambda SQL feature, used by ORM
+ :meth:`_orm.with_loader_criteria` as well as available generally in the SQL
+ expression language, where assigning a boolean value True/False to a
+ variable would cause the query-time expression calculation to fail, as it
+ would produce a SQL expression not compatible with a bound value. \ No newline at end of file
diff --git a/doc/build/changelog/unreleased_14/5764.rst b/doc/build/changelog/unreleased_14/5764.rst
new file mode 100644
index 000000000..29753fafe
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/5764.rst
@@ -0,0 +1,9 @@
+.. change::
+ :tags: orm, bug
+ :tickets: 5764
+
+ Fixed issue where the :attr:`_orm.ORMExecuteState.is_relationship_load`
+ attribute would not be set correctly for many lazy loads, all
+ selectinloads, etc. The flag is essential in order to test if options
+ should be added to statements or if they would already have been propagated
+ via relationship loads. \ No newline at end of file
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index d7a2cb409..334283bb9 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -1296,7 +1296,6 @@ class Query(
self._set_select_from([fromclause], set_entity_from)
self._compile_options += {
"_enable_single_crit": False,
- "_statement": None,
}
# this enables clause adaptation for non-ORM
@@ -2620,7 +2619,6 @@ class Query(
roles.SelectStatementRole, statement, apply_propagate_attrs=self
)
self._statement = statement
- self._compile_options += {"_statement": statement}
def first(self):
"""Return the first result of this ``Query`` or
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index f6943cc5f..7b5fa2c73 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -275,7 +275,7 @@ class ORMExecuteState(util.MemoizedSlots):
if not self.is_select:
return None
opts = self.statement._compile_options
- if isinstance(opts, context.ORMCompileState.default_compile_options):
+ if opts.isinstance(context.ORMCompileState.default_compile_options):
return opts
else:
return None
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 7f7bab682..98c57149d 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -939,9 +939,14 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
)
stmt += lambda stmt: stmt.options(*opts)
- stmt += lambda stmt: stmt._update_compile_options(
- {"_current_path": effective_path}
- )
+ else:
+ # this path is used if there are not already any options
+ # in the query, but an event may want to add them
+ effective_path = state.mapper._path_registry[self.parent_property]
+
+ stmt += lambda stmt: stmt._update_compile_options(
+ {"_current_path": effective_path}
+ )
if use_get:
if self._raise_on_sql:
@@ -2732,6 +2737,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
orm_util.Bundle("pk", *pk_cols), effective_entity
)
.apply_labels()
+ ._set_compile_options(ORMCompileState.default_compile_options)
._set_propagate_attrs(
{
"compile_state_plugin": "orm",
@@ -2769,7 +2775,6 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
q = q.add_criteria(
lambda q: q.filter(in_expr.in_(sql.bindparam("primary_keys")))
)
-
# a test which exercises what these comments talk about is
# test_selectin_relations.py -> test_twolevel_selectin_w_polymorphic
#
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index ff44ab27c..5178a7ab1 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -553,6 +553,14 @@ class _MetaOptions(type):
def __add__(self, other):
o1 = self()
+
+ if set(other).difference(self._cache_attrs):
+ raise TypeError(
+ "dictionary contains attributes not covered by "
+ "Options class %s: %r"
+ % (self, set(other).difference(self._cache_attrs))
+ )
+
o1.__dict__.update(other)
return o1
@@ -566,6 +574,14 @@ class Options(util.with_metaclass(_MetaOptions)):
def __add__(self, other):
o1 = self.__class__.__new__(self.__class__)
o1.__dict__.update(self.__dict__)
+
+ if set(other).difference(self._cache_attrs):
+ raise TypeError(
+ "dictionary contains attributes not covered by "
+ "Options class %s: %r"
+ % (self, set(other).difference(self._cache_attrs))
+ )
+
o1.__dict__.update(other)
return o1
@@ -589,6 +605,10 @@ class Options(util.with_metaclass(_MetaOptions)):
),
)
+ @classmethod
+ def isinstance(cls, klass):
+ return issubclass(cls, klass)
+
@hybridmethod
def add_to_element(self, name, value):
return self + {name: getattr(self, name) + value}
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
index 676152781..aafdda4ce 100644
--- a/lib/sqlalchemy/sql/lambdas.py
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -1021,7 +1021,12 @@ class PyWrapper(ColumnOperators):
def __getattribute__(self, key):
if key.startswith("_sa_"):
return object.__getattribute__(self, key[4:])
- elif key in ("__clause_element__", "operate", "reverse_operate"):
+ elif key in (
+ "__clause_element__",
+ "operate",
+ "reverse_operate",
+ "__class__",
+ ):
return object.__getattribute__(self, key)
if key.startswith("__"):
diff --git a/test/orm/test_events.py b/test/orm/test_events.py
index bc72d2f21..a046ba34c 100644
--- a/test/orm/test_events.py
+++ b/test/orm/test_events.py
@@ -21,8 +21,10 @@ from sqlalchemy.orm import Mapper
from sqlalchemy.orm import mapper
from sqlalchemy.orm import query
from sqlalchemy.orm import relationship
+from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
+from sqlalchemy.orm import subqueryload
from sqlalchemy.orm.mapper import _mapper_registry
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
@@ -168,14 +170,10 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
},
)
- def test_flags(self):
- User, Address = self.classes("User", "Address")
-
- sess = Session(testing.db, future=True)
-
+ def _flag_fixture(self, session):
canary = Mock()
- @event.listens_for(sess, "do_orm_execute")
+ @event.listens_for(session, "do_orm_execute")
def do_orm_execute(ctx):
if not ctx.is_select:
@@ -197,17 +195,21 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
else None,
)
- u1 = sess.execute(select(User).filter_by(id=7)).scalar_one()
+ return canary
- u1.addresses
+ def test_select_flags(self):
+ User, Address = self.classes("User", "Address")
+
+ sess = Session(testing.db, future=True)
+
+ canary = self._flag_fixture(sess)
+
+ u1 = sess.execute(select(User).filter_by(id=7)).scalar_one()
sess.expire(u1)
eq_(u1.name, "jack")
- sess.execute(delete(User).filter_by(id=18))
- sess.execute(update(User).filter_by(id=18).values(name="eighteen"))
-
eq_(
canary.mock_calls,
[
@@ -226,18 +228,134 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
is_delete=False,
is_orm_statement=True,
is_relationship_load=False,
+ is_column_load=True,
+ lazy_loaded_from=None,
+ ),
+ ],
+ )
+
+ def test_lazyload_flags(self):
+ User, Address = self.classes("User", "Address")
+
+ sess = Session(testing.db, future=True)
+
+ canary = self._flag_fixture(sess)
+
+ u1 = sess.execute(select(User).filter_by(id=7)).scalar_one()
+
+ u1.addresses
+
+ eq_(
+ canary.mock_calls,
+ [
+ call.options(
+ is_select=True,
+ is_update=False,
+ is_delete=False,
+ is_orm_statement=True,
+ is_relationship_load=False,
+ is_column_load=False,
+ lazy_loaded_from=None,
+ ),
+ call.options(
+ is_select=True,
+ is_update=False,
+ is_delete=False,
+ is_orm_statement=True,
+ is_relationship_load=True,
is_column_load=False,
lazy_loaded_from=u1._sa_instance_state,
),
+ ],
+ )
+
+ def test_selectinload_flags(self):
+ User, Address = self.classes("User", "Address")
+
+ sess = Session(testing.db, future=True)
+
+ canary = self._flag_fixture(sess)
+
+ u1 = sess.execute(
+ select(User).filter_by(id=7).options(selectinload(User.addresses))
+ ).scalar_one()
+
+ assert "addresses" in u1.__dict__
+
+ eq_(
+ canary.mock_calls,
+ [
+ call.options(
+ is_select=True,
+ is_update=False,
+ is_delete=False,
+ is_orm_statement=True,
+ is_relationship_load=False,
+ is_column_load=False,
+ lazy_loaded_from=None,
+ ),
+ call.options(
+ is_select=True,
+ is_update=False,
+ is_delete=False,
+ is_orm_statement=True,
+ is_relationship_load=True,
+ is_column_load=False,
+ lazy_loaded_from=None,
+ ),
+ ],
+ )
+
+ def test_subqueryload_flags(self):
+ User, Address = self.classes("User", "Address")
+
+ sess = Session(testing.db, future=True)
+
+ canary = self._flag_fixture(sess)
+
+ u1 = sess.execute(
+ select(User).filter_by(id=7).options(subqueryload(User.addresses))
+ ).scalar_one()
+
+ assert "addresses" in u1.__dict__
+
+ eq_(
+ canary.mock_calls,
+ [
call.options(
is_select=True,
is_update=False,
is_delete=False,
is_orm_statement=True,
is_relationship_load=False,
- is_column_load=True,
+ is_column_load=False,
+ lazy_loaded_from=None,
+ ),
+ call.options(
+ is_select=True,
+ is_update=False,
+ is_delete=False,
+ is_orm_statement=True,
+ is_relationship_load=True,
+ is_column_load=False,
lazy_loaded_from=None,
),
+ ],
+ )
+
+ def test_update_delete_flags(self):
+ User, Address = self.classes("User", "Address")
+
+ sess = Session(testing.db, future=True)
+
+ canary = self._flag_fixture(sess)
+
+ sess.execute(delete(User).filter_by(id=18))
+ sess.execute(update(User).filter_by(id=18).values(name="eighteen"))
+
+ eq_(
+ canary.mock_calls,
+ [
call.options(
is_select=False,
is_update=False,
diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py
index c283e804e..a70dc0511 100644
--- a/test/sql/test_lambdas.py
+++ b/test/sql/test_lambdas.py
@@ -22,6 +22,7 @@ from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import ne_
from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.types import Boolean
from sqlalchemy.types import Integer
from sqlalchemy.types import String
@@ -77,6 +78,41 @@ class DeferredLambdaTest(
checkparams={"global_x_1": 10, "global_y_1": 9},
)
+ def test_boolean_constants(self):
+ t1 = table("t1", column("q"), column("p"))
+
+ def go():
+ xy = True
+ stmt = select(t1).where(lambda: t1.c.q == xy)
+ return stmt
+
+ self.assert_compile(
+ go(), "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :xy_1"
+ )
+
+ def test_execute_boolean(self, boolean_table_fixture, connection):
+ boolean_data = boolean_table_fixture
+
+ connection.execute(
+ boolean_data.insert(),
+ [{"id": 1, "data": True}, {"id": 2, "data": False}],
+ )
+
+ xy = True
+
+ def go():
+ stmt = select(lambda: boolean_data.c.id).where(
+ lambda: boolean_data.c.data == xy
+ )
+ return connection.execute(stmt)
+
+ result = go()
+ eq_(result.all(), [(1,)])
+
+ xy = False
+ result = go()
+ eq_(result.all(), [(2,)])
+
def test_stale_checker_embedded(self):
def go(x):
@@ -761,6 +797,15 @@ class DeferredLambdaTest(
)
return users, addresses
+ @testing.metadata_fixture()
+ def boolean_table_fixture(self, metadata):
+ return Table(
+ "boolean_data",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", Boolean),
+ )
+
def test_adapt_select(self, user_address_fixture):
users, addresses = user_address_fixture
diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py
index a4b76f35d..24a149ece 100644
--- a/test/sql/test_utils.py
+++ b/test/sql/test_utils.py
@@ -15,6 +15,7 @@ from sqlalchemy.sql import util as sql_util
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
@@ -57,6 +58,34 @@ class MiscTest(fixtures.TestBase):
{common, calias, subset_select},
)
+ def test_incompatible_options_add_clslevel(self):
+ class opt1(sql_base.CacheableOptions):
+ _cache_key_traversal = []
+ foo = "bar"
+
+ with expect_raises_message(
+ TypeError,
+ "dictionary contains attributes not covered by "
+ "Options class .*opt1.* .*'bar'.*",
+ ):
+ o1 = opt1
+
+ o1 += {"foo": "f", "bar": "b"}
+
+ def test_incompatible_options_add_instancelevel(self):
+ class opt1(sql_base.CacheableOptions):
+ _cache_key_traversal = []
+ foo = "bar"
+
+ o1 = opt1(foo="bat")
+
+ with expect_raises_message(
+ TypeError,
+ "dictionary contains attributes not covered by "
+ "Options class .*opt1.* .*'bar'.*",
+ ):
+ o1 += {"foo": "f", "bar": "b"}
+
def test_options_merge(self):
class opt1(sql_base.CacheableOptions):
_cache_key_traversal = []