summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRobin Thomas <robin.thomas@livestream.com>2016-04-14 02:13:12 -0400
committerRobin Thomas <robin.thomas@livestream.com>2016-04-14 02:13:12 -0400
commit524b73d7ad2e4db62989ae54500babaa2c83b126 (patch)
treecaec2e7857775308bc38fe97c655ff96fa7db3f5
parent1e81462f070c873387d95c67310fb2dfc33a4e67 (diff)
downloadsqlalchemy-524b73d7ad2e4db62989ae54500babaa2c83b126.tar.gz
added ON CONFLICT support for UniqueConstratin, PrimaryKeyConstraint,
and Index objects as conflict targets.
-rw-r--r--lib/sqlalchemy/dialects/postgresql/on_conflict.py110
-rw-r--r--test/dialect/postgresql/test_compiler.py29
-rw-r--r--test/dialect/postgresql/test_on_conflict.py44
3 files changed, 163 insertions, 20 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/on_conflict.py b/lib/sqlalchemy/dialects/postgresql/on_conflict.py
index e69ca582c..c3032e0cc 100644
--- a/lib/sqlalchemy/dialects/postgresql/on_conflict.py
+++ b/lib/sqlalchemy/dialects/postgresql/on_conflict.py
@@ -1,6 +1,9 @@
from ...sql.expression import ClauseElement, ColumnClause, ColumnElement
from ...ext.compiler import compiles
from ...exc import CompileError
+from ...schema import UniqueConstraint, PrimaryKeyConstraint, Index
+
+from collections import Iterable
__all__ = ('DoUpdate', 'DoNothing')
@@ -24,44 +27,112 @@ def resolve_on_conflict_option(option_value, crud_columns):
if str(option_value) == 'nothing':
return DoNothing()
-def resolve_columnish_arg(arg):
- for col in (arg if isinstance(arg, (list, tuple)) else (arg,)):
- if not isinstance(col, (ColumnClause, str)):
- raise ValueError("column arguments must be ColumnClause objects or str object with column name: %r" % col)
- return tuple(arg) if isinstance(arg, (list, tuple)) else (arg,)
-
class OnConflictAction(ClauseElement):
def __init__(self, conflict_target):
super(OnConflictAction, self).__init__()
- if not isinstance(conflict_target, ConflictTarget):
- conflict_target = ConflictTarget(conflict_target)
self.conflict_target = conflict_target
class DoUpdate(OnConflictAction):
def __init__(self, conflict_target):
- super(DoUpdate, self).__init__(conflict_target)
+ super(DoUpdate, self).__init__(ConflictTarget(conflict_target))
if not self.conflict_target.contents:
raise ValueError("conflict_target may not be None or empty for DoUpdate")
self.values_to_set = {}
def set_with_excluded(self, *columns):
- for col in resolve_columnish_arg(columns):
- self.values_to_set[col] = _EXCLUDED
+ for col in columns:
+ if not isinstance(col, (ColumnClause, str)):
+ raise ValueError("column arguments must be ColumnClause objects or str object with column name: %r" % col)
+ self.values_to_set[col] = _EXCLUDED
return self
class DoNothing(OnConflictAction):
- def __init__(self, conflict_target=[]):
- super(DoNothing, self).__init__(conflict_target)
+ def __init__(self, conflict_target=None):
+ super(DoNothing, self).__init__(ConflictTarget(conflict_target) if conflict_target else None)
class ConflictTarget(ClauseElement):
+ """
+ A ConflictTarget represents the targeted constraint that will be used to determine
+ when a row proposed for insertion is in conflict and should be handled as specified
+ in the OnConflictAction.
+
+ A target can be one of the following:
+
+ - A column or list of columns, either column objects or strings, that together
+ represent a unique or primary key constraint on the table. The compiler
+ will produce a list like `(col1, col2, col3)` as the conflict target SQL clause.
+
+ - A single PrimaryKeyConstraint or UniqueConstraint object representing the constraint
+ used to detect the conflict. If the object has a :attr:`.name` attribute,
+ the compiler will produce `ON CONSTRAINT constraint_name` as the conflict target
+ SQL clause. If the constraint lacks a `.name` attribute, a list of its
+ constituent columns, like `(col1, col2, col3)` will be used.
+
+ - An single :class:`Index` object representing the index used to detect the conflict.
+ Use this in place of the Constraint objects mentioned above if you require
+ the clauses of a conflict target specific to index definitions -- collation,
+ opclass used to detect conflict, and WHERE clauses for partial indexes.
+ """
def __init__(self, contents):
- self.contents = resolve_columnish_arg(contents)
+ if isinstance(contents, (str, ColumnClause)):
+ self.contents = (contents,)
+ elif isinstance(contents, (list, tuple)):
+ if not contents:
+ raise ValueError("list of column arguments cannot be empty")
+ for c in contents:
+ if not isinstance(c, (str, ColumnClause)):
+ raise ValueError("column arguments must be ColumnClause objects or str object with column name: %r" % c)
+ self.contents = tuple(contents)
+ elif isinstance(contents, (PrimaryKeyConstraint, UniqueConstraint, Index)):
+ self.contents = contents
+ else:
+ raise ValueError(
+ "ConflictTarget contents must be single Column/str, "
+ "sequence of Column/str; or a PrimaryKeyConsraint, UniqueConstraint, or Index")
@compiles(ConflictTarget)
def compile_conflict_target(conflict_target, compiler, **kw):
- if not conflict_target.contents:
- return ''
- return "(" + (", ".join(compiler.preparer.format_column(i) for i in conflict_target.contents)) + ")"
+ target = conflict_target.contents
+ if isinstance(target, (PrimaryKeyConstraint, UniqueConstraint)):
+ fmt_cnst = None
+ if target.name is not None:
+ fmt_cnst = compiler.preparer.format_constraint(target)
+ if fmt_cnst is not None:
+ return "ON CONSTRAINT %s" % fmt_cnst
+ else:
+ return "(" + (", ".join(compiler.preparer.format_column(i) for i in target.columns.values())) + ")"
+ if isinstance(target, (str, ColumnClause)):
+ return "(" + compiler.preparer.format_column(target) + ")"
+ if isinstance(target, (list, tuple)):
+ return "(" + (", ".join(compiler.preparer.format_column(i) for i in target)) + ")"
+ if isinstance(target, Index):
+ # columns required first.
+ ops = target.dialect_options["postgresql"]["ops"]
+ text = "(%s)" \
+ % (
+ ', '.join([
+ compiler.process(
+ expr.self_group()
+ if not isinstance(expr, ColumnClause)
+ else expr,
+ include_table=False, literal_binds=True) +
+ (
+ (' ' + ops[expr.key])
+ if hasattr(expr, 'key')
+ and expr.key in ops else ''
+ )
+ for expr in target.expressions
+ ])
+ )
+
+ whereclause = target.dialect_options["postgresql"]["where"]
+
+ if whereclause is not None:
+ where_compiled = compiler.process(
+ whereclause, include_table=False,
+ literal_binds=True)
+ text += " WHERE " + where_compiled
+ return text
@compiles(DoUpdate)
def compile_do_update(do_update, compiler, **kw):
@@ -88,9 +159,8 @@ def compile_do_update(do_update, compiler, **kw):
@compiles(DoNothing)
def compile_do_nothing(do_nothing, compiler, **kw):
- compiled_cf = compiler.process(do_nothing.conflict_target)
- if compiled_cf:
- return "ON CONFLICT %s DO NOTHING" % compiled_cf
+ if do_nothing.conflict_target is not None:
+ return "ON CONFLICT %s DO NOTHING" % compiler.process(do_nothing.conflict_target)
else:
return "ON CONFLICT DO NOTHING"
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index b2d29c5b3..0b43f0a9e 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -117,6 +117,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
Column('description', String(128)),
)
+ unique_constr = schema.UniqueConstraint(table1.c.name, name='uq_name')
+ goofy_index = Index('goofy_index', table1.c.name, postgresql_where=table1.c.name > 'm')
i = insert(
table1,
values=dict(
@@ -139,6 +141,33 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
'(%(myid)s, %(name)s) ON CONFLICT (myid) '
'DO UPDATE SET name = excluded.name',
dialect=dialect)
+ i = insert(
+ table1, values=dict(name='foo'),
+ postgresql_on_conflict=DoUpdate(table1.primary_key).set_with_excluded('name')
+ )
+ self.assert_compile(i,
+ 'INSERT INTO mytable (myid, name) VALUES '
+ '(%(myid)s, %(name)s) ON CONFLICT (myid) '
+ 'DO UPDATE SET name = excluded.name',
+ dialect=dialect)
+ i = insert(
+ table1, values=dict(name='foo'),
+ postgresql_on_conflict=DoUpdate(unique_constr).set_with_excluded('myid')
+ )
+ self.assert_compile(i,
+ 'INSERT INTO mytable (myid, name) VALUES '
+ '(%(myid)s, %(name)s) ON CONFLICT ON CONSTRAINT uq_name '
+ 'DO UPDATE SET myid = excluded.myid',
+ dialect=dialect)
+ i = insert(
+ table1, values=dict(name='foo'),
+ postgresql_on_conflict=DoUpdate(goofy_index).set_with_excluded('name')
+ )
+ self.assert_compile(i,
+ 'INSERT INTO mytable (myid, name) VALUES '
+ "(%(myid)s, %(name)s) ON CONFLICT (name) WHERE name > 'm' "
+ 'DO UPDATE SET name = excluded.name',
+ dialect=dialect)
def test_insert_returning(self):
dialect = postgresql.dialect()
diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py
index 07a0c1f01..c960ead88 100644
--- a/test/dialect/postgresql/test_on_conflict.py
+++ b/test/dialect/postgresql/test_on_conflict.py
@@ -67,3 +67,47 @@ class OnConflictTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiled
.execute().fetchall(), [(1, 'name4')])
finally:
users.drop()
+
+ @testing.only_if(
+ "postgresql >= 9.5", "requires ON CONFLICT clause support")
+ def test_on_conflict_do_update_exotic_targets(self):
+ meta = MetaData(testing.db)
+ users = Table(
+ 'users', meta,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(50)),
+ Column('login_email', String(50)),
+ Column('lets_index_this', String(50)),
+ schema='test_schema')
+ unique_constraint = schema.UniqueConstraint(users.c.login_email, name='uq_login_email')
+ bogus_index = schema.Index('idx_special_ops', users.c.lets_index_this, postgresql_where=users.c.lets_index_this > 'm')
+ users.create()
+ try:
+ users.insert().execute(id=1, name='name1', login_email='name1@gmail.com', lets_index_this='not')
+ users.insert().execute(id=2, name='name2', login_email='name2@gmail.com', lets_index_this='not')
+ eq_(users.select().where(users.c.id == 1)
+ .execute().fetchall(), [(1, 'name1', 'name1@gmail.com', 'not')])
+
+ # try primary key constraint: cause an upsert on unique id column
+ poc = DoUpdate(users.primary_key).set_with_excluded(users.c.name, users.c.login_email)
+ users.insert(postgresql_on_conflict=poc).execute(id=1, name='name2', login_email='name1@gmail.com', lets_index_this='not')
+ eq_(users.select().where(users.c.id == 1)
+ .execute().fetchall(), [(1, 'name2', 'name1@gmail.com', 'not')])
+
+ # try unique constraint: cause an upsert on target login_email, not id
+ poc = DoUpdate(unique_constraint).set_with_excluded(users.c.id, users.c.name, users.c.login_email)
+ # note: lets_index_this value totally ignored in SET clause.
+ users.insert(postgresql_on_conflict=poc).execute(id=42, name='nameunique', login_email='name2@gmail.com', lets_index_this='unique')
+ eq_(users.select().where(users.c.login_email == 'name2@gmail.com')
+ .execute().fetchall(), [(42, 'nameunique', 'name2@gmail.com', 'not')])
+
+ # try bogus index
+ try:
+ users.insert(
+ postgresql_on_conflict=DoUpdate(bogus_index).set_with_excluded(users.c.name, users.c.login_email)
+ ).execute(id=1, name='namebogus', login_email='bogus@gmail.com', lets_index_this='bogus')
+ raise Exception("Using bogus index should have raised exception")
+ except exc.ProgrammingError:
+ pass # expected exception
+ finally:
+ users.drop()