diff options
author | Robin Thomas <robin.thomas@livestream.com> | 2016-04-14 02:13:12 -0400 |
---|---|---|
committer | Robin Thomas <robin.thomas@livestream.com> | 2016-04-14 02:13:12 -0400 |
commit | 524b73d7ad2e4db62989ae54500babaa2c83b126 (patch) | |
tree | caec2e7857775308bc38fe97c655ff96fa7db3f5 | |
parent | 1e81462f070c873387d95c67310fb2dfc33a4e67 (diff) | |
download | sqlalchemy-524b73d7ad2e4db62989ae54500babaa2c83b126.tar.gz |
added ON CONFLICT support for UniqueConstratin, PrimaryKeyConstraint,
and Index objects as conflict targets.
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/on_conflict.py | 110 | ||||
-rw-r--r-- | test/dialect/postgresql/test_compiler.py | 29 | ||||
-rw-r--r-- | test/dialect/postgresql/test_on_conflict.py | 44 |
3 files changed, 163 insertions, 20 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/on_conflict.py b/lib/sqlalchemy/dialects/postgresql/on_conflict.py index e69ca582c..c3032e0cc 100644 --- a/lib/sqlalchemy/dialects/postgresql/on_conflict.py +++ b/lib/sqlalchemy/dialects/postgresql/on_conflict.py @@ -1,6 +1,9 @@ from ...sql.expression import ClauseElement, ColumnClause, ColumnElement from ...ext.compiler import compiles from ...exc import CompileError +from ...schema import UniqueConstraint, PrimaryKeyConstraint, Index + +from collections import Iterable __all__ = ('DoUpdate', 'DoNothing') @@ -24,44 +27,112 @@ def resolve_on_conflict_option(option_value, crud_columns): if str(option_value) == 'nothing': return DoNothing() -def resolve_columnish_arg(arg): - for col in (arg if isinstance(arg, (list, tuple)) else (arg,)): - if not isinstance(col, (ColumnClause, str)): - raise ValueError("column arguments must be ColumnClause objects or str object with column name: %r" % col) - return tuple(arg) if isinstance(arg, (list, tuple)) else (arg,) - class OnConflictAction(ClauseElement): def __init__(self, conflict_target): super(OnConflictAction, self).__init__() - if not isinstance(conflict_target, ConflictTarget): - conflict_target = ConflictTarget(conflict_target) self.conflict_target = conflict_target class DoUpdate(OnConflictAction): def __init__(self, conflict_target): - super(DoUpdate, self).__init__(conflict_target) + super(DoUpdate, self).__init__(ConflictTarget(conflict_target)) if not self.conflict_target.contents: raise ValueError("conflict_target may not be None or empty for DoUpdate") self.values_to_set = {} def set_with_excluded(self, *columns): - for col in resolve_columnish_arg(columns): - self.values_to_set[col] = _EXCLUDED + for col in columns: + if not isinstance(col, (ColumnClause, str)): + raise ValueError("column arguments must be ColumnClause objects or str object with column name: %r" % col) + self.values_to_set[col] = _EXCLUDED return self class DoNothing(OnConflictAction): - def __init__(self, conflict_target=[]): - super(DoNothing, self).__init__(conflict_target) + def __init__(self, conflict_target=None): + super(DoNothing, self).__init__(ConflictTarget(conflict_target) if conflict_target else None) class ConflictTarget(ClauseElement): + """ + A ConflictTarget represents the targeted constraint that will be used to determine + when a row proposed for insertion is in conflict and should be handled as specified + in the OnConflictAction. + + A target can be one of the following: + + - A column or list of columns, either column objects or strings, that together + represent a unique or primary key constraint on the table. The compiler + will produce a list like `(col1, col2, col3)` as the conflict target SQL clause. + + - A single PrimaryKeyConstraint or UniqueConstraint object representing the constraint + used to detect the conflict. If the object has a :attr:`.name` attribute, + the compiler will produce `ON CONSTRAINT constraint_name` as the conflict target + SQL clause. If the constraint lacks a `.name` attribute, a list of its + constituent columns, like `(col1, col2, col3)` will be used. + + - An single :class:`Index` object representing the index used to detect the conflict. + Use this in place of the Constraint objects mentioned above if you require + the clauses of a conflict target specific to index definitions -- collation, + opclass used to detect conflict, and WHERE clauses for partial indexes. + """ def __init__(self, contents): - self.contents = resolve_columnish_arg(contents) + if isinstance(contents, (str, ColumnClause)): + self.contents = (contents,) + elif isinstance(contents, (list, tuple)): + if not contents: + raise ValueError("list of column arguments cannot be empty") + for c in contents: + if not isinstance(c, (str, ColumnClause)): + raise ValueError("column arguments must be ColumnClause objects or str object with column name: %r" % c) + self.contents = tuple(contents) + elif isinstance(contents, (PrimaryKeyConstraint, UniqueConstraint, Index)): + self.contents = contents + else: + raise ValueError( + "ConflictTarget contents must be single Column/str, " + "sequence of Column/str; or a PrimaryKeyConsraint, UniqueConstraint, or Index") @compiles(ConflictTarget) def compile_conflict_target(conflict_target, compiler, **kw): - if not conflict_target.contents: - return '' - return "(" + (", ".join(compiler.preparer.format_column(i) for i in conflict_target.contents)) + ")" + target = conflict_target.contents + if isinstance(target, (PrimaryKeyConstraint, UniqueConstraint)): + fmt_cnst = None + if target.name is not None: + fmt_cnst = compiler.preparer.format_constraint(target) + if fmt_cnst is not None: + return "ON CONSTRAINT %s" % fmt_cnst + else: + return "(" + (", ".join(compiler.preparer.format_column(i) for i in target.columns.values())) + ")" + if isinstance(target, (str, ColumnClause)): + return "(" + compiler.preparer.format_column(target) + ")" + if isinstance(target, (list, tuple)): + return "(" + (", ".join(compiler.preparer.format_column(i) for i in target)) + ")" + if isinstance(target, Index): + # columns required first. + ops = target.dialect_options["postgresql"]["ops"] + text = "(%s)" \ + % ( + ', '.join([ + compiler.process( + expr.self_group() + if not isinstance(expr, ColumnClause) + else expr, + include_table=False, literal_binds=True) + + ( + (' ' + ops[expr.key]) + if hasattr(expr, 'key') + and expr.key in ops else '' + ) + for expr in target.expressions + ]) + ) + + whereclause = target.dialect_options["postgresql"]["where"] + + if whereclause is not None: + where_compiled = compiler.process( + whereclause, include_table=False, + literal_binds=True) + text += " WHERE " + where_compiled + return text @compiles(DoUpdate) def compile_do_update(do_update, compiler, **kw): @@ -88,9 +159,8 @@ def compile_do_update(do_update, compiler, **kw): @compiles(DoNothing) def compile_do_nothing(do_nothing, compiler, **kw): - compiled_cf = compiler.process(do_nothing.conflict_target) - if compiled_cf: - return "ON CONFLICT %s DO NOTHING" % compiled_cf + if do_nothing.conflict_target is not None: + return "ON CONFLICT %s DO NOTHING" % compiler.process(do_nothing.conflict_target) else: return "ON CONFLICT DO NOTHING" diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index b2d29c5b3..0b43f0a9e 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -117,6 +117,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): Column('description', String(128)), ) + unique_constr = schema.UniqueConstraint(table1.c.name, name='uq_name') + goofy_index = Index('goofy_index', table1.c.name, postgresql_where=table1.c.name > 'm') i = insert( table1, values=dict( @@ -139,6 +141,33 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): '(%(myid)s, %(name)s) ON CONFLICT (myid) ' 'DO UPDATE SET name = excluded.name', dialect=dialect) + i = insert( + table1, values=dict(name='foo'), + postgresql_on_conflict=DoUpdate(table1.primary_key).set_with_excluded('name') + ) + self.assert_compile(i, + 'INSERT INTO mytable (myid, name) VALUES ' + '(%(myid)s, %(name)s) ON CONFLICT (myid) ' + 'DO UPDATE SET name = excluded.name', + dialect=dialect) + i = insert( + table1, values=dict(name='foo'), + postgresql_on_conflict=DoUpdate(unique_constr).set_with_excluded('myid') + ) + self.assert_compile(i, + 'INSERT INTO mytable (myid, name) VALUES ' + '(%(myid)s, %(name)s) ON CONFLICT ON CONSTRAINT uq_name ' + 'DO UPDATE SET myid = excluded.myid', + dialect=dialect) + i = insert( + table1, values=dict(name='foo'), + postgresql_on_conflict=DoUpdate(goofy_index).set_with_excluded('name') + ) + self.assert_compile(i, + 'INSERT INTO mytable (myid, name) VALUES ' + "(%(myid)s, %(name)s) ON CONFLICT (name) WHERE name > 'm' " + 'DO UPDATE SET name = excluded.name', + dialect=dialect) def test_insert_returning(self): dialect = postgresql.dialect() diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py index 07a0c1f01..c960ead88 100644 --- a/test/dialect/postgresql/test_on_conflict.py +++ b/test/dialect/postgresql/test_on_conflict.py @@ -67,3 +67,47 @@ class OnConflictTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiled .execute().fetchall(), [(1, 'name4')]) finally: users.drop() + + @testing.only_if( + "postgresql >= 9.5", "requires ON CONFLICT clause support") + def test_on_conflict_do_update_exotic_targets(self): + meta = MetaData(testing.db) + users = Table( + 'users', meta, + Column('id', Integer, primary_key=True), + Column('name', String(50)), + Column('login_email', String(50)), + Column('lets_index_this', String(50)), + schema='test_schema') + unique_constraint = schema.UniqueConstraint(users.c.login_email, name='uq_login_email') + bogus_index = schema.Index('idx_special_ops', users.c.lets_index_this, postgresql_where=users.c.lets_index_this > 'm') + users.create() + try: + users.insert().execute(id=1, name='name1', login_email='name1@gmail.com', lets_index_this='not') + users.insert().execute(id=2, name='name2', login_email='name2@gmail.com', lets_index_this='not') + eq_(users.select().where(users.c.id == 1) + .execute().fetchall(), [(1, 'name1', 'name1@gmail.com', 'not')]) + + # try primary key constraint: cause an upsert on unique id column + poc = DoUpdate(users.primary_key).set_with_excluded(users.c.name, users.c.login_email) + users.insert(postgresql_on_conflict=poc).execute(id=1, name='name2', login_email='name1@gmail.com', lets_index_this='not') + eq_(users.select().where(users.c.id == 1) + .execute().fetchall(), [(1, 'name2', 'name1@gmail.com', 'not')]) + + # try unique constraint: cause an upsert on target login_email, not id + poc = DoUpdate(unique_constraint).set_with_excluded(users.c.id, users.c.name, users.c.login_email) + # note: lets_index_this value totally ignored in SET clause. + users.insert(postgresql_on_conflict=poc).execute(id=42, name='nameunique', login_email='name2@gmail.com', lets_index_this='unique') + eq_(users.select().where(users.c.login_email == 'name2@gmail.com') + .execute().fetchall(), [(42, 'nameunique', 'name2@gmail.com', 'not')]) + + # try bogus index + try: + users.insert( + postgresql_on_conflict=DoUpdate(bogus_index).set_with_excluded(users.c.name, users.c.login_email) + ).execute(id=1, name='namebogus', login_email='bogus@gmail.com', lets_index_this='bogus') + raise Exception("Using bogus index should have raised exception") + except exc.ProgrammingError: + pass # expected exception + finally: + users.drop() |