summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-01-16 12:39:51 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2021-01-16 18:44:21 -0500
commit8860117c9655a4bdeafebab1c6ef12c6a6198e66 (patch)
tree7fc8743f78b6d4f1ae183265abec76e11560232c
parent6137d223be8e596fb2d7c78623ab22162db8ea6e (diff)
downloadsqlalchemy-8860117c9655a4bdeafebab1c6ef12c6a6198e66.tar.gz
introduce generalized decorator to prevent invalid method calls
This introduces the ``_exclusive_against()`` utility decorator that can be used to prevent repeated invocations of methods that typically should only be called once. An informative error message is now raised for a selected set of DML methods (currently all part of :class:`_dml.Insert` constructs) if they are called a second time, which would implicitly cancel out the previous setting. The methods altered include: :class:`_sqlite.Insert.on_conflict_do_update`, :class:`_sqlite.Insert.on_conflict_do_nothing` (SQLite), :class:`_postgresql.Insert.on_conflict_do_update`, :class:`_postgresql.Insert.on_conflict_do_nothing` (PostgreSQL), :class:`_mysql.Insert.on_duplicate_key_update` (MySQL) Fixes: #5169 Change-Id: I9278fa87cd3470dcf296ff96bb0fb17a3236d49d
-rw-r--r--lib/sqlalchemy/dialects/mysql/dml.py8
-rw-r--r--lib/sqlalchemy/dialects/postgresql/dml.py11
-rw-r--r--lib/sqlalchemy/dialects/sqlite/dml.py11
-rw-r--r--lib/sqlalchemy/sql/base.py25
-rw-r--r--lib/sqlalchemy/sql/dml.py30
-rw-r--r--test/dialect/mysql/test_compiler.py16
-rw-r--r--test/dialect/postgresql/test_compiler.py21
-rw-r--r--test/dialect/test_sqlite.py21
-rw-r--r--test/sql/test_update.py2
9 files changed, 131 insertions, 14 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py
index 9f8177c59..6c50dcca9 100644
--- a/lib/sqlalchemy/dialects/mysql/dml.py
+++ b/lib/sqlalchemy/dialects/mysql/dml.py
@@ -1,5 +1,6 @@
from ... import exc
from ... import util
+from ...sql.base import _exclusive_against
from ...sql.base import _generative
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
@@ -49,6 +50,13 @@ class Insert(StandardInsert):
return alias(self.table, name="inserted")
@_generative
+ @_exclusive_against(
+ "_post_values_clause",
+ msgs={
+ "_post_values_clause": "This Insert construct already "
+ "has an ON DUPLICATE KEY clause present"
+ },
+ )
def on_duplicate_key_update(self, *args, **kw):
r"""
Specifies the ON DUPLICATE KEY UPDATE clause.
diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py
index 76dfafd04..bff61e173 100644
--- a/lib/sqlalchemy/dialects/postgresql/dml.py
+++ b/lib/sqlalchemy/dialects/postgresql/dml.py
@@ -10,6 +10,7 @@ from ... import util
from ...sql import coercions
from ...sql import roles
from ...sql import schema
+from ...sql.base import _exclusive_against
from ...sql.base import _generative
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
@@ -50,7 +51,16 @@ class Insert(StandardInsert):
"""
return alias(self.table, name="excluded").columns
+ _on_conflict_exclusive = _exclusive_against(
+ "_post_values_clause",
+ msgs={
+ "_post_values_clause": "This Insert construct already has "
+ "an ON CONFLICT clause established"
+ },
+ )
+
@_generative
+ @_on_conflict_exclusive
def on_conflict_do_update(
self,
constraint=None,
@@ -117,6 +127,7 @@ class Insert(StandardInsert):
)
@_generative
+ @_on_conflict_exclusive
def on_conflict_do_nothing(
self, constraint=None, index_elements=None, index_where=None
):
diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py
index 9c8f10f7b..be32781c7 100644
--- a/lib/sqlalchemy/dialects/sqlite/dml.py
+++ b/lib/sqlalchemy/dialects/sqlite/dml.py
@@ -7,6 +7,7 @@
from ... import util
from ...sql import coercions
from ...sql import roles
+from ...sql.base import _exclusive_against
from ...sql.base import _generative
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
@@ -46,7 +47,16 @@ class Insert(StandardInsert):
"""
return alias(self.table, name="excluded").columns
+ _on_conflict_exclusive = _exclusive_against(
+ "_post_values_clause",
+ msgs={
+ "_post_values_clause": "This Insert construct already has "
+ "an ON CONFLICT clause established"
+ },
+ )
+
@_generative
+ @_on_conflict_exclusive
def on_conflict_do_update(
self,
index_elements=None,
@@ -99,6 +109,7 @@ class Insert(StandardInsert):
)
@_generative
+ @_on_conflict_exclusive
def on_conflict_do_nothing(self, index_elements=None, index_where=None):
"""
Specifies a DO NOTHING action for ON CONFLICT clause.
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 550111020..220bbb115 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -102,6 +102,31 @@ def _generative(fn):
return decorated
+def _exclusive_against(*names, **kw):
+ msgs = kw.pop("msgs", {})
+
+ defaults = kw.pop("defaults", {})
+
+ getters = [
+ (name, operator.attrgetter(name), defaults.get(name, None))
+ for name in names
+ ]
+
+ @util.decorator
+ def check(fn, self, *args, **kw):
+ for name, getter, default_ in getters:
+ if getter(self) is not default_:
+ msg = msgs.get(
+ name,
+ "Method %s() has already been invoked on this %s construct"
+ % (fn.__name__, self.__class__),
+ )
+ raise exc.InvalidRequestError(msg)
+ return fn(self, *args, **kw)
+
+ return check
+
+
def _clone(element, **kw):
return element._clone()
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index c402de121..3f492a490 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -14,6 +14,7 @@ from . import coercions
from . import roles
from . import util as sql_util
from .base import _entity_namespace_key
+from .base import _exclusive_against
from .base import _from_objects
from .base import _generative
from .base import ColumnCollection
@@ -495,6 +496,15 @@ class ValuesBase(UpdateBase):
self._setup_prefixes(prefixes)
@_generative
+ @_exclusive_against(
+ "_select_names",
+ "_ordered_values",
+ msgs={
+ "_select_names": "This construct already inserts from a SELECT",
+ "_ordered_values": "This statement already has ordered "
+ "values present",
+ },
+ )
def values(self, *args, **kwargs):
r"""Specify a fixed VALUES clause for an INSERT statement, or the SET
clause for an UPDATE.
@@ -607,15 +617,6 @@ class ValuesBase(UpdateBase):
"""
- if self._select_names:
- raise exc.InvalidRequestError(
- "This construct already inserts from a SELECT"
- )
- elif self._ordered_values:
- raise exc.ArgumentError(
- "This statement already has ordered values present"
- )
-
if args:
# positional case. this is currently expensive. we don't
# yet have positional-only args so we have to check the length.
@@ -699,6 +700,13 @@ class ValuesBase(UpdateBase):
self._values = util.immutabledict(arg)
@_generative
+ @_exclusive_against(
+ "_returning",
+ msgs={
+ "_returning": "RETURNING is already configured on this statement"
+ },
+ defaults={"_returning": _returning},
+ )
def return_defaults(self, *cols):
"""Make use of a :term:`RETURNING` clause for the purpose
of fetching server-side expressions and defaults.
@@ -783,10 +791,6 @@ class ValuesBase(UpdateBase):
:attr:`_engine.CursorResult.inserted_primary_key_rows`
"""
- if self._returning:
- raise exc.InvalidRequestError(
- "RETURNING is already configured on this statement"
- )
self._return_defaults = cols or True
diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py
index 7fd24e8b5..84646d380 100644
--- a/test/dialect/mysql/test_compiler.py
+++ b/test/dialect/mysql/test_compiler.py
@@ -1000,6 +1000,22 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
Column("baz", String(10)),
)
+ def test_no_call_twice(self):
+ stmt = insert(self.table).values(
+ [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}]
+ )
+ stmt = stmt.on_duplicate_key_update(
+ bar=stmt.inserted.bar, baz=stmt.inserted.baz
+ )
+ with testing.expect_raises_message(
+ exc.InvalidRequestError,
+ "This Insert construct already has an "
+ "ON DUPLICATE KEY clause present",
+ ):
+ stmt = stmt.on_duplicate_key_update(
+ bar=stmt.inserted.bar, baz=stmt.inserted.baz
+ )
+
def test_from_values(self):
stmt = insert(self.table).values(
[{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}]
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index b3a0b9bbd..eb39091ae 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -1842,6 +1842,27 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
"goofy_index", table1.c.name, postgresql_where=table1.c.name > "m"
)
+ def test_on_conflict_do_no_call_twice(self):
+ users = self.table1
+
+ for stmt in (
+ insert(users).on_conflict_do_nothing(),
+ insert(users).on_conflict_do_update(
+ index_elements=[users.c.myid], set_=dict(name="foo")
+ ),
+ ):
+ for meth in (
+ stmt.on_conflict_do_nothing,
+ stmt.on_conflict_do_update,
+ ):
+
+ with testing.expect_raises_message(
+ exc.InvalidRequestError,
+ "This Insert construct already has an "
+ "ON CONFLICT clause established",
+ ):
+ meth()
+
def test_do_nothing_no_target(self):
i = insert(
diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py
index 0500a20bd..23ceb88b3 100644
--- a/test/dialect/test_sqlite.py
+++ b/test/dialect/test_sqlite.py
@@ -2717,6 +2717,27 @@ class OnConflictTest(fixtures.TablesTest):
ValueError, insert(self.tables.users).on_conflict_do_update
)
+ def test_on_conflict_do_no_call_twice(self):
+ users = self.tables.users
+
+ for stmt in (
+ insert(users).on_conflict_do_nothing(),
+ insert(users).on_conflict_do_update(
+ index_elements=[users.c.id], set_=dict(name="foo")
+ ),
+ ):
+ for meth in (
+ stmt.on_conflict_do_nothing,
+ stmt.on_conflict_do_update,
+ ):
+
+ with testing.expect_raises_message(
+ exc.InvalidRequestError,
+ "This Insert construct already has an "
+ "ON CONFLICT clause established",
+ ):
+ meth()
+
def test_on_conflict_do_nothing(self, connection):
users = self.tables.users
diff --git a/test/sql/test_update.py b/test/sql/test_update.py
index 946a01651..26b0f6217 100644
--- a/test/sql/test_update.py
+++ b/test/sql/test_update.py
@@ -672,7 +672,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
stmt = table1.update().ordered_values(("myid", 1), ("name", "d1"))
assert_raises_message(
- exc.ArgumentError,
+ exc.InvalidRequestError,
"This statement already has ordered values present",
stmt.values,
{"myid": 2, "name": "d2"},