summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py32
-rw-r--r--lib/sqlalchemy/dialects/postgresql/dml.py21
2 files changed, 33 insertions, 20 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 3c33d9ee8..3a458ebed 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -2147,22 +2147,28 @@ class PGCompiler(compiler.SQLCompiler):
cols = insert_statement.table.c
for c in cols:
col_key = c.key
+
if col_key in set_parameters:
value = set_parameters.pop(col_key)
- if coercions._is_literal(value):
- value = elements.BindParameter(None, value, type_=c.type)
+ elif c in set_parameters:
+ value = set_parameters.pop(c)
+ else:
+ continue
- else:
- if (
- isinstance(value, elements.BindParameter)
- and value.type._isnull
- ):
- value = value._clone()
- value.type = c.type
- value_text = self.process(value.self_group(), use_schema=False)
-
- key_text = self.preparer.quote(col_key)
- action_set_ops.append("%s = %s" % (key_text, value_text))
+ if coercions._is_literal(value):
+ value = elements.BindParameter(None, value, type_=c.type)
+
+ else:
+ if (
+ isinstance(value, elements.BindParameter)
+ and value.type._isnull
+ ):
+ value = value._clone()
+ value.type = c.type
+ value_text = self.process(value.self_group(), use_schema=False)
+
+ key_text = self.preparer.quote(col_key)
+ action_set_ops.append("%s = %s" % (key_text, value_text))
# check for names that don't match columns
if set_parameters:
diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py
index 50fd09528..78cad974f 100644
--- a/lib/sqlalchemy/dialects/postgresql/dml.py
+++ b/lib/sqlalchemy/dialects/postgresql/dml.py
@@ -7,6 +7,8 @@
from . import ext
from ... import util
+from ...sql import coercions
+from ...sql import roles
from ...sql import schema
from ...sql.base import _generative
from ...sql.dml import Insert as StandardInsert
@@ -77,12 +79,16 @@ class Insert(StandardInsert):
conditional target index.
:param set\_:
- Required argument. A dictionary or other mapping object
- with column names as keys and expressions or literals as values,
- specifying the ``SET`` actions to take.
- If the target :class:`_schema.Column` specifies a ".
- key" attribute distinct
- from the column name, that key should be used.
+ A dictionary or other mapping object
+ where the keys are either names of columns in the target table,
+ or :class:`_schema.Column` objects or other ORM-mapped columns
+ matching that of the target table, and expressions or literals
+ as values, specifying the ``SET`` actions to take.
+
+ .. versionadded:: 1.4 The
+ :paramref:`_postgresql.Insert.on_conflict_do_update.set_`
+ parameter supports :class:`_schema.Column` objects from the target
+ :class:`_schema.Table` as keys.
.. warning:: This dictionary does **not** take into account
Python-specified default UPDATE values or generation functions,
@@ -229,6 +235,7 @@ class OnConflictDoUpdate(OnConflictClause):
if not isinstance(set_, dict) or not set_:
raise ValueError("set parameter must be a non-empty dictionary")
self.update_values_to_set = [
- (key, value) for key, value in set_.items()
+ (coercions.expect(roles.DMLColumnRole, key), value)
+ for key, value in set_.items()
]
self.update_whereclause = where