summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_14/5722.rst13
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py32
-rw-r--r--lib/sqlalchemy/dialects/postgresql/dml.py21
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py32
-rw-r--r--lib/sqlalchemy/dialects/sqlite/dml.py21
-rw-r--r--test/dialect/postgresql/test_compiler.py25
-rw-r--r--test/dialect/postgresql/test_on_conflict.py102
-rw-r--r--test/dialect/test_sqlite.py47
8 files changed, 253 insertions, 40 deletions
diff --git a/doc/build/changelog/unreleased_14/5722.rst b/doc/build/changelog/unreleased_14/5722.rst
new file mode 100644
index 000000000..e756f8eb9
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/5722.rst
@@ -0,0 +1,13 @@
+.. change::
+ :tags: bug, postgresql
+ :tickets: 5722
+ :versions: 1.4.0b2
+
+ Established support for :class:`_schema.Column` objects as well as ORM
+ instrumented attributes as keys in the ``set_`` dictionary passed to the
+ :meth:`_postgresql.Insert.on_conflict_do_update` and
+ :meth:`_sqlite.Insert.on_conflict_do_update` methods, which match to the
+ :class:`_schema.Column` objects in the ``.c`` collection of the target
+ :class:`_schema.Table`. Previously, only string column names were
+ expected; a column expression would be assumed to be an out-of-table
+ expression that would render fully along with a warning. \ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 3c33d9ee8..3a458ebed 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -2147,22 +2147,28 @@ class PGCompiler(compiler.SQLCompiler):
cols = insert_statement.table.c
for c in cols:
col_key = c.key
+
if col_key in set_parameters:
value = set_parameters.pop(col_key)
- if coercions._is_literal(value):
- value = elements.BindParameter(None, value, type_=c.type)
+ elif c in set_parameters:
+ value = set_parameters.pop(c)
+ else:
+ continue
- else:
- if (
- isinstance(value, elements.BindParameter)
- and value.type._isnull
- ):
- value = value._clone()
- value.type = c.type
- value_text = self.process(value.self_group(), use_schema=False)
-
- key_text = self.preparer.quote(col_key)
- action_set_ops.append("%s = %s" % (key_text, value_text))
+ if coercions._is_literal(value):
+ value = elements.BindParameter(None, value, type_=c.type)
+
+ else:
+ if (
+ isinstance(value, elements.BindParameter)
+ and value.type._isnull
+ ):
+ value = value._clone()
+ value.type = c.type
+ value_text = self.process(value.self_group(), use_schema=False)
+
+ key_text = self.preparer.quote(col_key)
+ action_set_ops.append("%s = %s" % (key_text, value_text))
# check for names that don't match columns
if set_parameters:
diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py
index 50fd09528..78cad974f 100644
--- a/lib/sqlalchemy/dialects/postgresql/dml.py
+++ b/lib/sqlalchemy/dialects/postgresql/dml.py
@@ -7,6 +7,8 @@
from . import ext
from ... import util
+from ...sql import coercions
+from ...sql import roles
from ...sql import schema
from ...sql.base import _generative
from ...sql.dml import Insert as StandardInsert
@@ -77,12 +79,16 @@ class Insert(StandardInsert):
conditional target index.
:param set\_:
- Required argument. A dictionary or other mapping object
- with column names as keys and expressions or literals as values,
- specifying the ``SET`` actions to take.
- If the target :class:`_schema.Column` specifies a ".
- key" attribute distinct
- from the column name, that key should be used.
+ A dictionary or other mapping object
+ where the keys are either names of columns in the target table,
+ or :class:`_schema.Column` objects or other ORM-mapped columns
+ matching that of the target table, and expressions or literals
+ as values, specifying the ``SET`` actions to take.
+
+ .. versionadded:: 1.4 The
+ :paramref:`_postgresql.Insert.on_conflict_do_update.set_`
+ parameter supports :class:`_schema.Column` objects from the target
+ :class:`_schema.Table` as keys.
.. warning:: This dictionary does **not** take into account
Python-specified default UPDATE values or generation functions,
@@ -229,6 +235,7 @@ class OnConflictDoUpdate(OnConflictClause):
if not isinstance(set_, dict) or not set_:
raise ValueError("set parameter must be a non-empty dictionary")
self.update_values_to_set = [
- (key, value) for key, value in set_.items()
+ (coercions.expect(roles.DMLColumnRole, key), value)
+ for key, value in set_.items()
]
self.update_whereclause = where
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index 404a215b6..7c1bbb18e 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -1355,22 +1355,28 @@ class SQLiteCompiler(compiler.SQLCompiler):
cols = insert_statement.table.c
for c in cols:
col_key = c.key
+
if col_key in set_parameters:
value = set_parameters.pop(col_key)
- if coercions._is_literal(value):
- value = elements.BindParameter(None, value, type_=c.type)
+ elif c in set_parameters:
+ value = set_parameters.pop(c)
+ else:
+ continue
- else:
- if (
- isinstance(value, elements.BindParameter)
- and value.type._isnull
- ):
- value = value._clone()
- value.type = c.type
- value_text = self.process(value.self_group(), use_schema=False)
-
- key_text = self.preparer.quote(col_key)
- action_set_ops.append("%s = %s" % (key_text, value_text))
+ if coercions._is_literal(value):
+ value = elements.BindParameter(None, value, type_=c.type)
+
+ else:
+ if (
+ isinstance(value, elements.BindParameter)
+ and value.type._isnull
+ ):
+ value = value._clone()
+ value.type = c.type
+ value_text = self.process(value.self_group(), use_schema=False)
+
+ key_text = self.preparer.quote(col_key)
+ action_set_ops.append("%s = %s" % (key_text, value_text))
# check for names that don't match columns
if set_parameters:
diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py
index a4d4d560c..2d7ea6e4a 100644
--- a/lib/sqlalchemy/dialects/sqlite/dml.py
+++ b/lib/sqlalchemy/dialects/sqlite/dml.py
@@ -5,6 +5,8 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from ... import util
+from ...sql import coercions
+from ...sql import roles
from ...sql.base import _generative
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
@@ -65,12 +67,16 @@ class Insert(StandardInsert):
conditional target index.
:param set\_:
- Required argument. A dictionary or other mapping object
- with column names as keys and expressions or literals as values,
- specifying the ``SET`` actions to take.
- If the target :class:`_schema.Column` specifies a ".
- key" attribute distinct
- from the column name, that key should be used.
+ A dictionary or other mapping object
+ where the keys are either names of columns in the target table,
+ or :class:`_schema.Column` objects or other ORM-mapped columns
+ matching that of the target table, and expressions or literals
+ as values, specifying the ``SET`` actions to take.
+
+ .. versionadded:: 1.4 The
+ :paramref:`_sqlite.Insert.on_conflict_do_update.set_`
+ parameter supports :class:`_schema.Column` objects from the target
+ :class:`_schema.Table` as keys.
.. warning:: This dictionary does **not** take into account
Python-specified default UPDATE values or generation functions,
@@ -155,6 +161,7 @@ class OnConflictDoUpdate(OnConflictClause):
if not isinstance(set_, dict) or not set_:
raise ValueError("set parameter must be a non-empty dictionary")
self.update_values_to_set = [
- (key, value) for key, value in set_.items()
+ (coercions.expect(roles.DMLColumnRole, key), value)
+ for key, value in set_.items()
]
self.update_whereclause = where
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index a031c3df9..9651f7bd9 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -1864,6 +1864,31 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
},
)
+ def test_do_update_set_clause_column_keys(self):
+ i = insert(self.table_with_metadata).values(myid=1, name="foo")
+ i = i.on_conflict_do_update(
+ index_elements=["myid"],
+ set_=OrderedDict(
+ [
+ (self.table_with_metadata.c.name, "I'm a name"),
+ (self.table_with_metadata.c.description, None),
+ ]
+ ),
+ )
+ self.assert_compile(
+ i,
+ "INSERT INTO mytable (myid, name) VALUES "
+ "(%(myid)s, %(name)s) ON CONFLICT (myid) "
+ "DO UPDATE SET name = %(param_1)s, "
+ "description = %(param_2)s",
+ {
+ "myid": 1,
+ "name": "foo",
+ "param_1": "I'm a name",
+ "param_2": None,
+ },
+ )
+
def test_do_update_set_clause_literal(self):
i = insert(self.table_with_metadata).values(myid=1, name="foo")
i = i.on_conflict_do_update(
diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py
index 7a9bfd75d..760487842 100644
--- a/test/dialect/postgresql/test_on_conflict.py
+++ b/test/dialect/postgresql/test_on_conflict.py
@@ -10,6 +10,7 @@ from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql import insert
+from sqlalchemy.testing import config
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import assert_raises
from sqlalchemy.testing.assertions import eq_
@@ -30,6 +31,14 @@ class OnConflictTest(fixtures.TablesTest):
Column("name", String(50)),
)
+ Table(
+ "users_schema",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50)),
+ schema=config.test_schema,
+ )
+
class SpecialType(sqltypes.TypeDecorator):
impl = String
@@ -185,6 +194,99 @@ class OnConflictTest(fixtures.TablesTest):
[(1, "name1")],
)
+ def test_on_conflict_do_update_schema(self):
+ users = self.tables.get("%s.users_schema" % config.test_schema)
+
+ with testing.db.connect() as conn:
+ conn.execute(users.insert(), dict(id=1, name="name1"))
+
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=[users.c.id], set_=dict(name=i.excluded.name)
+ )
+ result = conn.execute(i, dict(id=1, name="name1"))
+
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
+
+ eq_(
+ conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+ [(1, "name1")],
+ )
+
+ def test_on_conflict_do_update_column_as_key_set(self):
+ users = self.tables.users
+
+ with testing.db.connect() as conn:
+ conn.execute(users.insert(), dict(id=1, name="name1"))
+
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=[users.c.id],
+ set_={users.c.name: i.excluded.name},
+ )
+ result = conn.execute(i, dict(id=1, name="name1"))
+
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
+
+ eq_(
+ conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+ [(1, "name1")],
+ )
+
+ def test_on_conflict_do_update_clauseelem_as_key_set(self):
+ users = self.tables.users
+
+ class MyElem(object):
+ def __init__(self, expr):
+ self.expr = expr
+
+ def __clause_element__(self):
+ return self.expr
+
+ with testing.db.connect() as conn:
+ conn.execute(
+ users.insert(),
+ {"id": 1, "name": "name1"},
+ )
+
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=[users.c.id],
+ set_={MyElem(users.c.name): i.excluded.name},
+ ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name1"})
+ result = conn.execute(i)
+
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
+
+ eq_(
+ conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+ [(1, "name1")],
+ )
+
+ def test_on_conflict_do_update_column_as_key_set_schema(self):
+ users = self.tables.get("%s.users_schema" % config.test_schema)
+
+ with testing.db.connect() as conn:
+ conn.execute(users.insert(), dict(id=1, name="name1"))
+
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=[users.c.id],
+ set_={users.c.name: i.excluded.name},
+ )
+ result = conn.execute(i, dict(id=1, name="name1"))
+
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
+
+ eq_(
+ conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+ [(1, "name1")],
+ )
+
def test_on_conflict_do_update_two(self):
users = self.tables.users
diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py
index 456bad7bd..f8b50f888 100644
--- a/test/dialect/test_sqlite.py
+++ b/test/dialect/test_sqlite.py
@@ -2921,6 +2921,53 @@ class OnConflictTest(fixtures.TablesTest):
[(10, "I'm a name")],
)
+ def test_on_conflict_do_update_column_keys(self, connection):
+ users = self.tables.users
+
+ conn = connection
+ conn.execute(users.insert(), dict(id=1, name="name1"))
+
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=users.primary_key.columns,
+ set_={users.c.id: 10, users.c.name: "I'm a name"},
+ ).values(id=1, name="name4")
+
+ result = conn.execute(i)
+ eq_(result.inserted_primary_key, (1,))
+
+ eq_(
+ conn.execute(users.select().where(users.c.id == 10)).fetchall(),
+ [(10, "I'm a name")],
+ )
+
+ def test_on_conflict_do_update_clauseelem_keys(self, connection):
+ users = self.tables.users
+
+ class MyElem(object):
+ def __init__(self, expr):
+ self.expr = expr
+
+ def __clause_element__(self):
+ return self.expr
+
+ conn = connection
+ conn.execute(users.insert(), dict(id=1, name="name1"))
+
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=users.primary_key.columns,
+ set_={MyElem(users.c.id): 10, MyElem(users.c.name): "I'm a name"},
+ ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name4"})
+
+ result = conn.execute(i)
+ eq_(result.inserted_primary_key, (1,))
+
+ eq_(
+ conn.execute(users.select().where(users.c.id == 10)).fetchall(),
+ [(10, "I'm a name")],
+ )
+
def test_on_conflict_do_update_multivalues(self, connection):
users = self.tables.users