diff options
| -rw-r--r-- | CHANGES | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/schema.py | 38 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 14 | ||||
| -rw-r--r-- | test/orm/unitofwork.py | 39 | ||||
| -rw-r--r-- | test/sql/constraints.py | 105 |
5 files changed, 178 insertions, 20 deletions
@@ -29,6 +29,8 @@ CHANGES - cast() accepts text('something') and other non-literal operands properly [ticket:962] + - Deferrable constraints can now be defined. + - added "autocommit=True" kwarg to select() and text(), as well as generative autocommit() method on select(); for statements which modify the database through some diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 64e9d203d..83f282b24 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -608,7 +608,7 @@ class ForeignKey(SchemaItem): constraint definition. """ - def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None): + def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None, deferrable=None, initially=None): """Construct a new ``ForeignKey`` object. column @@ -629,6 +629,8 @@ class ForeignKey(SchemaItem): self.name = name self.onupdate = onupdate self.ondelete = ondelete + self.deferrable = deferrable + self.initially = initially def __repr__(self): return "ForeignKey(%s)" % repr(self._get_colspec()) @@ -714,7 +716,7 @@ class ForeignKey(SchemaItem): self.parent.table.constraints.remove(fk.constraint) if self.constraint is None and isinstance(self.parent.table, Table): - self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete) + self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, deferrable=self.deferrable, initially=self.initially) self.parent.table.append_constraint(self.constraint) self.constraint._append_fk(self) @@ -855,9 +857,11 @@ class Constraint(SchemaItem): list of underying columns. """ - def __init__(self, name=None): + def __init__(self, name=None, deferrable=None, initially=None): self.name = name self.columns = expression.ColumnCollection() + self.deferrable = deferrable + self.initially = initially def __contains__(self, x): return self.columns.contains_column(x) @@ -878,8 +882,8 @@ class Constraint(SchemaItem): raise NotImplementedError() class CheckConstraint(Constraint): - def __init__(self, sqltext, name=None): - super(CheckConstraint, self).__init__(name) + def __init__(self, sqltext, name=None, deferrable=None, initially=None): + super(CheckConstraint, self).__init__(name, deferrable, initially) self.sqltext = sqltext def __visit_name__(self): @@ -899,8 +903,8 @@ class CheckConstraint(Constraint): class ForeignKeyConstraint(Constraint): """Table-level foreign key constraint, represents a collection of ``ForeignKey`` objects.""" - def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False): - super(ForeignKeyConstraint, self).__init__(name) + def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False, deferrable=None, initially=None): + super(ForeignKeyConstraint, self).__init__(name, deferrable, initially) self.__colnames = columns self.__refcolnames = refcolumns self.elements = util.OrderedSet() @@ -930,7 +934,15 @@ class ForeignKeyConstraint(Constraint): class PrimaryKeyConstraint(Constraint): def __init__(self, *columns, **kwargs): - super(PrimaryKeyConstraint, self).__init__(name=kwargs.pop('name', None)) + constraint_args = dict(name=kwargs.pop('name', None), + deferrable=kwargs.pop('deferrable', None), + initially=kwargs.pop('initially', None)) + if kwargs: + raise exceptions.ArgumentError( + 'Unknown PrimaryKeyConstraint argument(s): %s' % + ', '.join([repr(x) for x in kwargs.keys()])) + + super(PrimaryKeyConstraint, self).__init__(**constraint_args) self.__colnames = list(columns) def _set_parent(self, table): @@ -959,7 +971,15 @@ class PrimaryKeyConstraint(Constraint): class UniqueConstraint(Constraint): def __init__(self, *columns, **kwargs): - super(UniqueConstraint, self).__init__(name=kwargs.pop('name', None)) + constraint_args = dict(name=kwargs.pop('name', None), + deferrable=kwargs.pop('deferrable', None), + initially=kwargs.pop('initially', None)) + if kwargs: + raise exceptions.ArgumentError( + 'Unknown UniqueConstraint argument(s): %s' % + ', '.join([repr(x) for x in kwargs.keys()])) + + super(UniqueConstraint, self).__init__(**constraint_args) self.__colnames = list(columns) def _set_parent(self, table): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 43950a9a6..02f6efce1 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -844,9 +844,11 @@ class SchemaGenerator(DDLBase): self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) self.append(" CHECK (%s)" % constraint.sqltext) + self.define_constraint_deferrability(constraint) def visit_column_check_constraint(self, constraint): self.append(" CHECK (%s)" % constraint.sqltext) + self.define_constraint_deferrability(constraint) def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: @@ -856,6 +858,7 @@ class SchemaGenerator(DDLBase): self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) self.append("PRIMARY KEY ") self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint])) + self.define_constraint_deferrability(constraint) def visit_foreign_key_constraint(self, constraint): if constraint.use_alter and self.dialect.supports_alter: @@ -883,6 +886,7 @@ class SchemaGenerator(DDLBase): self.append(" ON DELETE %s" % constraint.ondelete) if constraint.onupdate is not None: self.append(" ON UPDATE %s" % constraint.onupdate) + self.define_constraint_deferrability(constraint) def visit_unique_constraint(self, constraint): self.append(", \n\t") @@ -890,6 +894,16 @@ class SchemaGenerator(DDLBase): self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint]))) + self.define_constraint_deferrability(constraint) + + def define_constraint_deferrability(self, constraint): + if constraint.deferrable is not None: + if constraint.deferrable: + self.append(" DEFERRABLE") + else: + self.append(" NOT DEFERRABLE") + if constraint.initially is not None: + self.append(" INITIALLY %s" % constraint.initially) def visit_column(self, column): pass diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 22c8bbe8b..ee696cd9d 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -1969,9 +1969,12 @@ class RowSwitchTest(ORMTest): assert list(sess.execute(t2.select(), mapper=T1)) == [(1, 'some other t2', 2)] class TransactionTest(ORMTest): - """This is in fact a core test, but currently the only known way - to make COMMIT repeatably fail is on postgresql with deferrable FKs""" - __only_on__ = 'postgres' + __unsupported_on__ = ('mysql', 'mssql') + + # sqlite doesn't have deferrable constraints, but it allows them to + # be specified. it'll raise immediately post-INSERT, instead of at + # COMMIT. either way, this test should pass. + def define_tables(self, metadata): global t1, T1, t2, T2 @@ -1979,17 +1982,24 @@ class TransactionTest(ORMTest): t1 = Table('t1', metadata, Column('id', Integer, primary_key=True)) - + t2 = Table('t2', metadata, Column('id', Integer, primary_key=True), - Column('t1_id', Integer)) - deferred_constraint = DDL("ALTER TABLE t2 ADD CONSTRAINT t2_t1_id_fk FOREIGN KEY (t1_id) "\ - "REFERENCES t1 (id) DEFERRABLE INITIALLY DEFERRED") - deferred_constraint.execute_at('after-create', t2) - + Column('t1_id', Integer, + ForeignKey('t1.id', deferrable=True, initially='deferred') + )) + + # deferred_constraint = \ + # DDL("ALTER TABLE t2 ADD CONSTRAINT t2_t1_id_fk FOREIGN KEY (t1_id) " + # "REFERENCES t1 (id) DEFERRABLE INITIALLY DEFERRED") + # deferred_constraint.execute_at('after-create', t2) + # t1.create() + # t2.create() + # t2.append_constraint(ForeignKeyConstraint(['t1_id'], ['t1.id'])) + class T1(fixtures.Base): pass - + class T2(fixtures.Base): pass @@ -1999,8 +2009,11 @@ class TransactionTest(ORMTest): def test_close_transaction_on_commit_fail(self): Session = sessionmaker(autoflush=False, transactional=False) sess = Session() - + + # with a deferred constraint, this fails at COMMIT time instead + # of at INSERT time. sess.save(T2(t1_id=123)) + try: sess.flush() assert False @@ -2008,5 +2021,9 @@ class TransactionTest(ORMTest): # Flush needs to rollback also when commit fails assert sess.transaction is None + # todo: on 8.3 at least, the failed commit seems to close the cursor? + # needs investigation. leaving in the DDL above now to help verify + # that the new deferrable support on FK isn't involved in this issue. + t1.bind.engine.dispose() if __name__ == "__main__": testenv.main() diff --git a/test/sql/constraints.py b/test/sql/constraints.py index 142f1ffba..29fffa751 100644 --- a/test/sql/constraints.py +++ b/test/sql/constraints.py @@ -2,6 +2,7 @@ import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy import exceptions from testlib import * +from testlib import config, engines class ConstraintTest(AssertMixin): @@ -201,5 +202,109 @@ class ConstraintTest(AssertMixin): ss = events.select().execute().fetchall() +class ConstraintCompilationTest(AssertMixin): + class accum(object): + def __init__(self): + self.statements = [] + def __call__(self, sql, *a, **kw): + self.statements.append(sql) + def __contains__(self, substring): + for s in self.statements: + if substring in s: + return True + return False + def __str__(self): + return '\n'.join([repr(x) for x in self.statements]) + def clear(self): + del self.statements[:] + + def setUp(self): + self.sql = self.accum() + opts = config.db_opts.copy() + opts['strategy'] = 'mock' + opts['executor'] = self.sql + self.engine = engines.testing_engine(options=opts) + + + def _test_deferrable(self, constraint_factory): + meta = MetaData(self.engine) + t = Table('tbl', meta, + Column('a', Integer), + Column('b', Integer), + constraint_factory(deferrable=True)) + t.create() + assert 'DEFERRABLE' in self.sql, self.sql + assert 'NOT DEFERRABLE' not in self.sql, self.sql + self.sql.clear() + meta.clear() + + t = Table('tbl', meta, + Column('a', Integer), + Column('b', Integer), + constraint_factory(deferrable=False)) + t.create() + assert 'NOT DEFERRABLE' in self.sql + self.sql.clear() + meta.clear() + + t = Table('tbl', meta, + Column('a', Integer), + Column('b', Integer), + constraint_factory(deferrable=True, initially='IMMEDIATE')) + t.create() + assert 'NOT DEFERRABLE' not in self.sql + assert 'INITIALLY IMMEDIATE' in self.sql + self.sql.clear() + meta.clear() + + t = Table('tbl', meta, + Column('a', Integer), + Column('b', Integer), + constraint_factory(deferrable=True, initially='DEFERRED')) + t.create() + + assert 'NOT DEFERRABLE' not in self.sql + assert 'INITIALLY DEFERRED' in self.sql, self.sql + + def test_deferrable_pk(self): + factory = lambda **kw: PrimaryKeyConstraint('a', **kw) + self._test_deferrable(factory) + + def test_deferrable_table_fk(self): + factory = lambda **kw: ForeignKeyConstraint(['b'], ['tbl.a'], **kw) + self._test_deferrable(factory) + + def test_deferrable_column_fk(self): + meta = MetaData(self.engine) + t = Table('tbl', meta, + Column('a', Integer), + Column('b', Integer, + ForeignKey('tbl.a', deferrable=True, + initially='DEFERRED'))) + t.create() + assert 'DEFERRABLE' in self.sql, self.sql + assert 'INITIALLY DEFERRED' in self.sql, self.sql + + def test_deferrable_unique(self): + factory = lambda **kw: UniqueConstraint('b', **kw) + self._test_deferrable(factory) + + def test_deferrable_table_check(self): + factory = lambda **kw: CheckConstraint('a < b', **kw) + self._test_deferrable(factory) + + def test_deferrable_column_check(self): + meta = MetaData(self.engine) + t = Table('tbl', meta, + Column('a', Integer), + Column('b', Integer, + CheckConstraint('a < b', + deferrable=True, + initially='DEFERRED'))) + t.create() + assert 'DEFERRABLE' in self.sql, self.sql + assert 'INITIALLY DEFERRED' in self.sql, self.sql + + if __name__ == "__main__": testenv.main() |
