diff options
| author | Gord Thompson <gord@gordthompson.com> | 2021-02-13 14:43:21 -0700 |
|---|---|---|
| committer | Gord Thompson <gord@gordthompson.com> | 2021-02-15 11:16:38 -0700 |
| commit | 857adaaf867df54d4a023cf19f618fdf1d0f60c9 (patch) | |
| tree | 6079150053100063a2b865c7b2c9702dbcab7e3a | |
| parent | d642946939416ea2870cf6c6479dcddad795b622 (diff) | |
| download | sqlalchemy-857adaaf867df54d4a023cf19f618fdf1d0f60c9.tar.gz | |
Accept ColumnCollection in update_on_conflict(set_=
Fixes: #5939
Change-Id: I21d7125765028e2a98d5ef4c32d8e7e457aa2d12
| -rw-r--r-- | doc/build/changelog/unreleased_14/5939.rst | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/dml.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/dml.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/dml.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/functions.py | 2 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_on_conflict.py | 11 | ||||
| -rw-r--r-- | test/dialect/test_sqlite.py | 11 |
7 files changed, 64 insertions, 11 deletions
diff --git a/doc/build/changelog/unreleased_14/5939.rst b/doc/build/changelog/unreleased_14/5939.rst new file mode 100644 index 000000000..2552cb2c1 --- /dev/null +++ b/doc/build/changelog/unreleased_14/5939.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: sql, usecase, postgresql, sqlite + :tickets: 5939 + + Enhance ``set_`` keyword of :class:`.OnConflictDoUpdate` to accept a + :class:`.ColumnCollection`, such as the ``.c.`` collection from a + :class:`Selectable`, or the ``.excluded`` contextual object. diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py index 6c50dcca9..d57a89090 100644 --- a/lib/sqlalchemy/dialects/mysql/dml.py +++ b/lib/sqlalchemy/dialects/mysql/dml.py @@ -2,6 +2,7 @@ from ... import exc from ... import util from ...sql.base import _exclusive_against from ...sql.base import _generative +from ...sql.base import ColumnCollection from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement from ...sql.expression import alias @@ -145,6 +146,17 @@ class OnDuplicateClause(ClauseElement): self._parameter_ordering = [key for key, value in update] update = dict(update) - if not update or not isinstance(update, dict): - raise ValueError("update parameter must be a non-empty dictionary") + if isinstance(update, dict): + if not update: + raise ValueError( + "update parameter dictionary must not be empty" + ) + elif isinstance(update, ColumnCollection): + update = dict(update) + else: + raise ValueError( + "update parameter must be a non-empty dictionary " + "or a ColumnCollection such as the `.c.` collection " + "of a Table object" + ) self.update = update diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index bff61e173..b6f5cdf7e 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -12,6 +12,7 @@ from ...sql import roles from ...sql import schema from ...sql.base import _exclusive_against from ...sql.base import _generative +from ...sql.base import ColumnCollection from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement from ...sql.expression import alias @@ -243,8 +244,17 @@ class OnConflictDoUpdate(OnConflictClause): "but not both, must be specified unless DO NOTHING" ) - if not isinstance(set_, dict) or not set_: - raise ValueError("set parameter must be a non-empty dictionary") + if isinstance(set_, dict): + if not set_: + raise ValueError("set parameter dictionary must not be empty") + elif isinstance(set_, ColumnCollection): + set_ = dict(set_) + else: + raise ValueError( + "set parameter must be a non-empty dictionary " + "or a ColumnCollection such as the `.c.` collection " + "of a Table object" + ) self.update_values_to_set = [ (coercions.expect(roles.DMLColumnRole, key), value) for key, value in set_.items() diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py index be32781c7..4cb819960 100644 --- a/lib/sqlalchemy/dialects/sqlite/dml.py +++ b/lib/sqlalchemy/dialects/sqlite/dml.py @@ -9,6 +9,7 @@ from ...sql import coercions from ...sql import roles from ...sql.base import _exclusive_against from ...sql.base import _generative +from ...sql.base import ColumnCollection from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement from ...sql.expression import alias @@ -169,8 +170,17 @@ class OnConflictDoUpdate(OnConflictClause): index_where=index_where, ) - if not isinstance(set_, dict) or not set_: - raise ValueError("set parameter must be a non-empty dictionary") + if isinstance(set_, dict): + if not set_: + raise ValueError("set parameter dictionary must not be empty") + elif isinstance(set_, ColumnCollection): + set_ = dict(set_) + else: + raise ValueError( + "set parameter must be a non-empty dictionary " + "or a ColumnCollection such as the `.c.` collection " + "of a Table object" + ) self.update_values_to_set = [ (coercions.expect(roles.DMLColumnRole, key), value) for key, value in set_.items() diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 641715327..40af73d7a 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -163,7 +163,7 @@ class FunctionElement(Executable, ColumnElement, FromClause, Generative): return ScalarFunctionColumn(self, name, type_) def table_valued(self, *expr, **kw): - """Return a :class:`_sql.TableValuedAlias` representation of this + r"""Return a :class:`_sql.TableValuedAlias` representation of this :class:`_functions.FunctionElement` with table-valued expressions added. e.g.:: diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py index 4e96cc6a2..489084de7 100644 --- a/test/dialect/postgresql/test_on_conflict.py +++ b/test/dialect/postgresql/test_on_conflict.py @@ -176,14 +176,21 @@ class OnConflictTest(fixtures.TablesTest): [(1, "name1")], ) - def test_on_conflict_do_update_one(self, connection): + @testing.combinations( + ("with_dict", True), + ("issue_5939", False), + id_="ia", + argnames="with_dict", + ) + def test_on_conflict_do_update_one(self, connection, with_dict): users = self.tables.users connection.execute(users.insert(), dict(id=1, name="name1")) i = insert(users) i = i.on_conflict_do_update( - index_elements=[users.c.id], set_=dict(name=i.excluded.name) + index_elements=[users.c.id], + set_=dict(name=i.excluded.name) if with_dict else i.excluded, ) result = connection.execute(i, dict(id=1, name="name1")) diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index ad169eebf..aee97e8c6 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -2810,7 +2810,13 @@ class OnConflictTest(fixtures.TablesTest): [(1, "name1")], ) - def test_on_conflict_do_update_one(self, connection): + @testing.combinations( + ("with_dict", True), + ("issue_5939", False), + id_="ia", + argnames="with_dict", + ) + def test_on_conflict_do_update_one(self, connection, with_dict): users = self.tables.users conn = connection @@ -2818,7 +2824,8 @@ class OnConflictTest(fixtures.TablesTest): i = insert(users) i = i.on_conflict_do_update( - index_elements=[users.c.id], set_=dict(name=i.excluded.name) + index_elements=[users.c.id], + set_=dict(name=i.excluded.name) if with_dict else i.excluded, ) result = conn.execute(i, dict(id=1, name="name1")) |
