summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_13/4715.rst6
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py12
-rw-r--r--test/dialect/mysql/test_on_duplicate.py15
3 files changed, 27 insertions, 6 deletions
diff --git a/doc/build/changelog/unreleased_13/4715.rst b/doc/build/changelog/unreleased_13/4715.rst
new file mode 100644
index 000000000..190adc1e1
--- /dev/null
+++ b/doc/build/changelog/unreleased_13/4715.rst
@@ -0,0 +1,6 @@
+.. change::
+ :tags: bug, mysql
+ :tickets: 4715
+
+ Fixed bug where MySQL ON DUPLICATE KEY UPDATE would not accommodate setting
+ a column to the value NULL. Pull request courtesy Lukáš Banič. \ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index ad5ab288c..9aa80a5f4 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1232,15 +1232,15 @@ class MySQLCompiler(compiler.SQLCompiler):
c for c in self.statement.table.c if c.key not in ordered_keys
]
else:
- # traverse in table column order
cols = self.statement.table.c
clauses = []
- for column in cols:
- val = on_duplicate.update.get(column.key)
- if val is None:
- continue
- elif coercions._is_literal(val):
+ # traverses through all table columns to preserve table column order
+ for column in (col for col in cols if col.key in on_duplicate.update):
+
+ val = on_duplicate.update[column.key]
+
+ if coercions._is_literal(val):
val = elements.BindParameter(None, val, type_=column.type)
value_text = self.process(val.self_group(), use_schema=False)
elif isinstance(val, elements.BindParameter) and val.type._isnull:
diff --git a/test/dialect/mysql/test_on_duplicate.py b/test/dialect/mysql/test_on_duplicate.py
index 0c6f47929..077e5ba98 100644
--- a/test/dialect/mysql/test_on_duplicate.py
+++ b/test/dialect/mysql/test_on_duplicate.py
@@ -62,6 +62,21 @@ class OnDuplicateTest(fixtures.TablesTest):
[(1, "ab", "bz", False)],
)
+ def test_on_duplicate_key_update_null(self):
+ foos = self.tables.foos
+ with testing.db.connect() as conn:
+ conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
+ stmt = insert(foos).values(
+ [dict(id=1, bar="ab"), dict(id=2, bar="b")]
+ )
+ stmt = stmt.on_duplicate_key_update(updated_once=None)
+ result = conn.execute(stmt)
+ eq_(result.inserted_primary_key, [2])
+ eq_(
+ conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+ [(1, "b", "bz", None)],
+ )
+
def test_on_duplicate_key_update_preserve_order(self):
foos = self.tables.foos
with testing.db.connect() as conn: