summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGord Thompson <gord@gordthompson.com>2021-02-13 14:43:21 -0700
committerGord Thompson <gord@gordthompson.com>2021-02-15 11:16:38 -0700
commit857adaaf867df54d4a023cf19f618fdf1d0f60c9 (patch)
tree6079150053100063a2b865c7b2c9702dbcab7e3a
parentd642946939416ea2870cf6c6479dcddad795b622 (diff)
downloadsqlalchemy-857adaaf867df54d4a023cf19f618fdf1d0f60c9.tar.gz
Accept ColumnCollection in update_on_conflict(set_=
Fixes: #5939 Change-Id: I21d7125765028e2a98d5ef4c32d8e7e457aa2d12
-rw-r--r--doc/build/changelog/unreleased_14/5939.rst7
-rw-r--r--lib/sqlalchemy/dialects/mysql/dml.py16
-rw-r--r--lib/sqlalchemy/dialects/postgresql/dml.py14
-rw-r--r--lib/sqlalchemy/dialects/sqlite/dml.py14
-rw-r--r--lib/sqlalchemy/sql/functions.py2
-rw-r--r--test/dialect/postgresql/test_on_conflict.py11
-rw-r--r--test/dialect/test_sqlite.py11
7 files changed, 64 insertions, 11 deletions
diff --git a/doc/build/changelog/unreleased_14/5939.rst b/doc/build/changelog/unreleased_14/5939.rst
new file mode 100644
index 000000000..2552cb2c1
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/5939.rst
@@ -0,0 +1,7 @@
+.. change::
+ :tags: sql, usecase, postgresql, sqlite
+ :tickets: 5939
+
+ Enhance ``set_`` keyword of :class:`.OnConflictDoUpdate` to accept a
+ :class:`.ColumnCollection`, such as the ``.c.`` collection from a
+ :class:`Selectable`, or the ``.excluded`` contextual object.
diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py
index 6c50dcca9..d57a89090 100644
--- a/lib/sqlalchemy/dialects/mysql/dml.py
+++ b/lib/sqlalchemy/dialects/mysql/dml.py
@@ -2,6 +2,7 @@ from ... import exc
from ... import util
from ...sql.base import _exclusive_against
from ...sql.base import _generative
+from ...sql.base import ColumnCollection
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
from ...sql.expression import alias
@@ -145,6 +146,17 @@ class OnDuplicateClause(ClauseElement):
self._parameter_ordering = [key for key, value in update]
update = dict(update)
- if not update or not isinstance(update, dict):
- raise ValueError("update parameter must be a non-empty dictionary")
+ if isinstance(update, dict):
+ if not update:
+ raise ValueError(
+ "update parameter dictionary must not be empty"
+ )
+ elif isinstance(update, ColumnCollection):
+ update = dict(update)
+ else:
+ raise ValueError(
+ "update parameter must be a non-empty dictionary "
+ "or a ColumnCollection such as the `.c.` collection "
+ "of a Table object"
+ )
self.update = update
diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py
index bff61e173..b6f5cdf7e 100644
--- a/lib/sqlalchemy/dialects/postgresql/dml.py
+++ b/lib/sqlalchemy/dialects/postgresql/dml.py
@@ -12,6 +12,7 @@ from ...sql import roles
from ...sql import schema
from ...sql.base import _exclusive_against
from ...sql.base import _generative
+from ...sql.base import ColumnCollection
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
from ...sql.expression import alias
@@ -243,8 +244,17 @@ class OnConflictDoUpdate(OnConflictClause):
"but not both, must be specified unless DO NOTHING"
)
- if not isinstance(set_, dict) or not set_:
- raise ValueError("set parameter must be a non-empty dictionary")
+ if isinstance(set_, dict):
+ if not set_:
+ raise ValueError("set parameter dictionary must not be empty")
+ elif isinstance(set_, ColumnCollection):
+ set_ = dict(set_)
+ else:
+ raise ValueError(
+ "set parameter must be a non-empty dictionary "
+ "or a ColumnCollection such as the `.c.` collection "
+ "of a Table object"
+ )
self.update_values_to_set = [
(coercions.expect(roles.DMLColumnRole, key), value)
for key, value in set_.items()
diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py
index be32781c7..4cb819960 100644
--- a/lib/sqlalchemy/dialects/sqlite/dml.py
+++ b/lib/sqlalchemy/dialects/sqlite/dml.py
@@ -9,6 +9,7 @@ from ...sql import coercions
from ...sql import roles
from ...sql.base import _exclusive_against
from ...sql.base import _generative
+from ...sql.base import ColumnCollection
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
from ...sql.expression import alias
@@ -169,8 +170,17 @@ class OnConflictDoUpdate(OnConflictClause):
index_where=index_where,
)
- if not isinstance(set_, dict) or not set_:
- raise ValueError("set parameter must be a non-empty dictionary")
+ if isinstance(set_, dict):
+ if not set_:
+ raise ValueError("set parameter dictionary must not be empty")
+ elif isinstance(set_, ColumnCollection):
+ set_ = dict(set_)
+ else:
+ raise ValueError(
+ "set parameter must be a non-empty dictionary "
+ "or a ColumnCollection such as the `.c.` collection "
+ "of a Table object"
+ )
self.update_values_to_set = [
(coercions.expect(roles.DMLColumnRole, key), value)
for key, value in set_.items()
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 641715327..40af73d7a 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -163,7 +163,7 @@ class FunctionElement(Executable, ColumnElement, FromClause, Generative):
return ScalarFunctionColumn(self, name, type_)
def table_valued(self, *expr, **kw):
- """Return a :class:`_sql.TableValuedAlias` representation of this
+ r"""Return a :class:`_sql.TableValuedAlias` representation of this
:class:`_functions.FunctionElement` with table-valued expressions added.
e.g.::
diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py
index 4e96cc6a2..489084de7 100644
--- a/test/dialect/postgresql/test_on_conflict.py
+++ b/test/dialect/postgresql/test_on_conflict.py
@@ -176,14 +176,21 @@ class OnConflictTest(fixtures.TablesTest):
[(1, "name1")],
)
- def test_on_conflict_do_update_one(self, connection):
+ @testing.combinations(
+ ("with_dict", True),
+ ("issue_5939", False),
+ id_="ia",
+ argnames="with_dict",
+ )
+ def test_on_conflict_do_update_one(self, connection, with_dict):
users = self.tables.users
connection.execute(users.insert(), dict(id=1, name="name1"))
i = insert(users)
i = i.on_conflict_do_update(
- index_elements=[users.c.id], set_=dict(name=i.excluded.name)
+ index_elements=[users.c.id],
+ set_=dict(name=i.excluded.name) if with_dict else i.excluded,
)
result = connection.execute(i, dict(id=1, name="name1"))
diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py
index ad169eebf..aee97e8c6 100644
--- a/test/dialect/test_sqlite.py
+++ b/test/dialect/test_sqlite.py
@@ -2810,7 +2810,13 @@ class OnConflictTest(fixtures.TablesTest):
[(1, "name1")],
)
- def test_on_conflict_do_update_one(self, connection):
+ @testing.combinations(
+ ("with_dict", True),
+ ("issue_5939", False),
+ id_="ia",
+ argnames="with_dict",
+ )
+ def test_on_conflict_do_update_one(self, connection, with_dict):
users = self.tables.users
conn = connection
@@ -2818,7 +2824,8 @@ class OnConflictTest(fixtures.TablesTest):
i = insert(users)
i = i.on_conflict_do_update(
- index_elements=[users.c.id], set_=dict(name=i.excluded.name)
+ index_elements=[users.c.id],
+ set_=dict(name=i.excluded.name) if with_dict else i.excluded,
)
result = conn.execute(i, dict(id=1, name="name1"))