diff options
Diffstat (limited to 'test/dialect/postgresql')
| -rw-r--r-- | test/dialect/postgresql/test_dialect.py | 291 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_on_conflict.py | 894 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_query.py | 220 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_reflection.py | 180 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_types.py | 2 |
5 files changed, 796 insertions, 791 deletions
diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 5cea604d6..3bd8e9da0 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -36,6 +36,7 @@ from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_VALUES from sqlalchemy.engine import cursor as _cursor from sqlalchemy.engine import engine_from_config from sqlalchemy.engine import url +from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -51,7 +52,7 @@ from sqlalchemy.testing.assertions import eq_regex from sqlalchemy.testing.assertions import ne_ from sqlalchemy.util import u from sqlalchemy.util import ue -from ...engine import test_execute +from ...engine import test_deprecations if True: from sqlalchemy.dialects.postgresql.psycopg2 import ( @@ -195,6 +196,20 @@ class ExecuteManyMode(object): options = None + @config.fixture() + def connection(self): + eng = engines.testing_engine(options=self.options) + + conn = eng.connect() + trans = conn.begin() + try: + yield conn + finally: + if trans.is_active: + trans.rollback() + conn.close() + eng.dispose() + @classmethod def define_tables(cls, metadata): Table( @@ -213,20 +228,12 @@ class ExecuteManyMode(object): Column(ue("\u6e2c\u8a66"), Integer), ) - def setup(self): - super(ExecuteManyMode, self).setup() - self.engine = engines.testing_engine(options=self.options) - - def teardown(self): - self.engine.dispose() - super(ExecuteManyMode, self).teardown() - - def test_insert(self): + def test_insert(self, connection): from psycopg2 import extras - values_page_size = self.engine.dialect.executemany_values_page_size - batch_page_size = self.engine.dialect.executemany_batch_page_size - if self.engine.dialect.executemany_mode & EXECUTEMANY_VALUES: + values_page_size = connection.dialect.executemany_values_page_size + batch_page_size = connection.dialect.executemany_batch_page_size + if connection.dialect.executemany_mode & EXECUTEMANY_VALUES: meth = extras.execute_values stmt = "INSERT INTO data (x, y) VALUES %s" expected_kwargs = { @@ -234,7 +241,7 @@ class ExecuteManyMode(object): "page_size": values_page_size, "fetch": False, } - elif self.engine.dialect.executemany_mode & EXECUTEMANY_BATCH: + elif connection.dialect.executemany_mode & EXECUTEMANY_BATCH: meth = extras.execute_batch stmt = "INSERT INTO data (x, y) VALUES (%(x)s, %(y)s)" expected_kwargs = {"page_size": batch_page_size} @@ -244,24 +251,23 @@ class ExecuteManyMode(object): with mock.patch.object( extras, meth.__name__, side_effect=meth ) as mock_exec: - with self.engine.connect() as conn: - conn.execute( - self.tables.data.insert(), - [ - {"x": "x1", "y": "y1"}, - {"x": "x2", "y": "y2"}, - {"x": "x3", "y": "y3"}, - ], - ) + connection.execute( + self.tables.data.insert(), + [ + {"x": "x1", "y": "y1"}, + {"x": "x2", "y": "y2"}, + {"x": "x3", "y": "y3"}, + ], + ) - eq_( - conn.execute(select(self.tables.data)).fetchall(), - [ - (1, "x1", "y1", 5), - (2, "x2", "y2", 5), - (3, "x3", "y3", 5), - ], - ) + eq_( + connection.execute(select(self.tables.data)).fetchall(), + [ + (1, "x1", "y1", 5), + (2, "x2", "y2", 5), + (3, "x3", "y3", 5), + ], + ) eq_( mock_exec.mock_calls, [ @@ -278,14 +284,13 @@ class ExecuteManyMode(object): ], ) - def test_insert_no_page_size(self): + def test_insert_no_page_size(self, connection): from psycopg2 import extras - values_page_size = self.engine.dialect.executemany_values_page_size - batch_page_size = self.engine.dialect.executemany_batch_page_size + values_page_size = connection.dialect.executemany_values_page_size + batch_page_size = connection.dialect.executemany_batch_page_size - eng = self.engine - if eng.dialect.executemany_mode & EXECUTEMANY_VALUES: + if connection.dialect.executemany_mode & EXECUTEMANY_VALUES: meth = extras.execute_values stmt = "INSERT INTO data (x, y) VALUES %s" expected_kwargs = { @@ -293,7 +298,7 @@ class ExecuteManyMode(object): "page_size": values_page_size, "fetch": False, } - elif eng.dialect.executemany_mode & EXECUTEMANY_BATCH: + elif connection.dialect.executemany_mode & EXECUTEMANY_BATCH: meth = extras.execute_batch stmt = "INSERT INTO data (x, y) VALUES (%(x)s, %(y)s)" expected_kwargs = {"page_size": batch_page_size} @@ -303,15 +308,14 @@ class ExecuteManyMode(object): with mock.patch.object( extras, meth.__name__, side_effect=meth ) as mock_exec: - with eng.connect() as conn: - conn.execute( - self.tables.data.insert(), - [ - {"x": "x1", "y": "y1"}, - {"x": "x2", "y": "y2"}, - {"x": "x3", "y": "y3"}, - ], - ) + connection.execute( + self.tables.data.insert(), + [ + {"x": "x1", "y": "y1"}, + {"x": "x2", "y": "y2"}, + {"x": "x3", "y": "y3"}, + ], + ) eq_( mock_exec.mock_calls, @@ -356,7 +360,7 @@ class ExecuteManyMode(object): with mock.patch.object( extras, meth.__name__, side_effect=meth ) as mock_exec: - with eng.connect() as conn: + with eng.begin() as conn: conn.execute( self.tables.data.insert(), [ @@ -398,11 +402,10 @@ class ExecuteManyMode(object): eq_(connection.execute(table.select()).all(), [(1, 1), (2, 2), (3, 3)]) - def test_update_fallback(self): + def test_update_fallback(self, connection): from psycopg2 import extras - batch_page_size = self.engine.dialect.executemany_batch_page_size - eng = self.engine + batch_page_size = connection.dialect.executemany_batch_page_size meth = extras.execute_batch stmt = "UPDATE data SET y=%(yval)s WHERE data.x = %(xval)s" expected_kwargs = {"page_size": batch_page_size} @@ -410,18 +413,17 @@ class ExecuteManyMode(object): with mock.patch.object( extras, meth.__name__, side_effect=meth ) as mock_exec: - with eng.connect() as conn: - conn.execute( - self.tables.data.update() - .where(self.tables.data.c.x == bindparam("xval")) - .values(y=bindparam("yval")), - [ - {"xval": "x1", "yval": "y5"}, - {"xval": "x3", "yval": "y6"}, - ], - ) + connection.execute( + self.tables.data.update() + .where(self.tables.data.c.x == bindparam("xval")) + .values(y=bindparam("yval")), + [ + {"xval": "x1", "yval": "y5"}, + {"xval": "x3", "yval": "y6"}, + ], + ) - if eng.dialect.executemany_mode & EXECUTEMANY_BATCH: + if connection.dialect.executemany_mode & EXECUTEMANY_BATCH: eq_( mock_exec.mock_calls, [ @@ -439,36 +441,34 @@ class ExecuteManyMode(object): else: eq_(mock_exec.mock_calls, []) - def test_not_sane_rowcount(self): - self.engine.connect().close() - if self.engine.dialect.executemany_mode & EXECUTEMANY_BATCH: - assert not self.engine.dialect.supports_sane_multi_rowcount + def test_not_sane_rowcount(self, connection): + if connection.dialect.executemany_mode & EXECUTEMANY_BATCH: + assert not connection.dialect.supports_sane_multi_rowcount else: - assert self.engine.dialect.supports_sane_multi_rowcount + assert connection.dialect.supports_sane_multi_rowcount - def test_update(self): - with self.engine.connect() as conn: - conn.execute( - self.tables.data.insert(), - [ - {"x": "x1", "y": "y1"}, - {"x": "x2", "y": "y2"}, - {"x": "x3", "y": "y3"}, - ], - ) + def test_update(self, connection): + connection.execute( + self.tables.data.insert(), + [ + {"x": "x1", "y": "y1"}, + {"x": "x2", "y": "y2"}, + {"x": "x3", "y": "y3"}, + ], + ) - conn.execute( - self.tables.data.update() - .where(self.tables.data.c.x == bindparam("xval")) - .values(y=bindparam("yval")), - [{"xval": "x1", "yval": "y5"}, {"xval": "x3", "yval": "y6"}], - ) - eq_( - conn.execute( - select(self.tables.data).order_by(self.tables.data.c.id) - ).fetchall(), - [(1, "x1", "y5", 5), (2, "x2", "y2", 5), (3, "x3", "y6", 5)], - ) + connection.execute( + self.tables.data.update() + .where(self.tables.data.c.x == bindparam("xval")) + .values(y=bindparam("yval")), + [{"xval": "x1", "yval": "y5"}, {"xval": "x3", "yval": "y6"}], + ) + eq_( + connection.execute( + select(self.tables.data).order_by(self.tables.data.c.id) + ).fetchall(), + [(1, "x1", "y5", 5), (2, "x2", "y2", 5), (3, "x3", "y6", 5)], + ) class ExecutemanyBatchModeTest(ExecuteManyMode, fixtures.TablesTest): @@ -578,7 +578,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): [(pk,) for pk in range(1 + first_pk, total_rows + first_pk)], ) - def test_insert_w_newlines(self): + def test_insert_w_newlines(self, connection): from psycopg2 import extras t = self.tables.data @@ -606,15 +606,14 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): extras, "execute_values", side_effect=meth ) as mock_exec: - with self.engine.connect() as conn: - conn.execute( - ins, - [ - {"id": 1, "y": "y1", "z": 1}, - {"id": 2, "y": "y2", "z": 2}, - {"id": 3, "y": "y3", "z": 3}, - ], - ) + connection.execute( + ins, + [ + {"id": 1, "y": "y1", "z": 1}, + {"id": 2, "y": "y2", "z": 2}, + {"id": 3, "y": "y3", "z": 3}, + ], + ) eq_( mock_exec.mock_calls, @@ -629,12 +628,12 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): ), template="(%(id)s, (SELECT 5 \nFROM data), %(y)s, %(z)s)", fetch=False, - page_size=conn.dialect.executemany_values_page_size, + page_size=connection.dialect.executemany_values_page_size, ) ], ) - def test_insert_modified_by_event(self): + def test_insert_modified_by_event(self, connection): from psycopg2 import extras t = self.tables.data @@ -664,33 +663,33 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): extras, "execute_batch", side_effect=meth ) as mock_batch: - with self.engine.connect() as conn: - - # create an event hook that will change the statement to - # something else, meaning the dialect has to detect that - # insert_single_values_expr is no longer useful - @event.listens_for(conn, "before_cursor_execute", retval=True) - def before_cursor_execute( - conn, cursor, statement, parameters, context, executemany - ): - statement = ( - "INSERT INTO data (id, y, z) VALUES " - "(%(id)s, %(y)s, %(z)s)" - ) - return statement, parameters - - conn.execute( - ins, - [ - {"id": 1, "y": "y1", "z": 1}, - {"id": 2, "y": "y2", "z": 2}, - {"id": 3, "y": "y3", "z": 3}, - ], + # create an event hook that will change the statement to + # something else, meaning the dialect has to detect that + # insert_single_values_expr is no longer useful + @event.listens_for( + connection, "before_cursor_execute", retval=True + ) + def before_cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): + statement = ( + "INSERT INTO data (id, y, z) VALUES " + "(%(id)s, %(y)s, %(z)s)" ) + return statement, parameters + + connection.execute( + ins, + [ + {"id": 1, "y": "y1", "z": 1}, + {"id": 2, "y": "y2", "z": 2}, + {"id": 3, "y": "y3", "z": 3}, + ], + ) eq_(mock_values.mock_calls, []) - if self.engine.dialect.executemany_mode & EXECUTEMANY_BATCH: + if connection.dialect.executemany_mode & EXECUTEMANY_BATCH: eq_( mock_batch.mock_calls, [ @@ -727,10 +726,10 @@ class ExecutemanyFlagOptionsTest(fixtures.TablesTest): ("values_only", EXECUTEMANY_VALUES), ("values_plus_batch", EXECUTEMANY_VALUES_PLUS_BATCH), ]: - self.engine = engines.testing_engine( + connection = engines.testing_engine( options={"executemany_mode": opt} ) - is_(self.engine.dialect.executemany_mode, expected) + is_(connection.dialect.executemany_mode, expected) def test_executemany_wrong_flag_options(self): for opt in [1, True, "batch_insert"]: @@ -1082,7 +1081,7 @@ $$ LANGUAGE plpgsql; t.create(connection, checkfirst=True) @testing.provide_metadata - def test_schema_roundtrips(self): + def test_schema_roundtrips(self, connection): meta = self.metadata users = Table( "users", @@ -1091,33 +1090,37 @@ $$ LANGUAGE plpgsql; Column("name", String(50)), schema="test_schema", ) - users.create() - users.insert().execute(id=1, name="name1") - users.insert().execute(id=2, name="name2") - users.insert().execute(id=3, name="name3") - users.insert().execute(id=4, name="name4") + users.create(connection) + connection.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=2, name="name2")) + connection.execute(users.insert(), dict(id=3, name="name3")) + connection.execute(users.insert(), dict(id=4, name="name4")) eq_( - users.select().where(users.c.name == "name2").execute().fetchall(), + connection.execute( + users.select().where(users.c.name == "name2") + ).fetchall(), [(2, "name2")], ) eq_( - users.select(use_labels=True) - .where(users.c.name == "name2") - .execute() - .fetchall(), + connection.execute( + users.select().apply_labels().where(users.c.name == "name2") + ).fetchall(), [(2, "name2")], ) - users.delete().where(users.c.id == 3).execute() + connection.execute(users.delete().where(users.c.id == 3)) eq_( - users.select().where(users.c.name == "name3").execute().fetchall(), + connection.execute( + users.select().where(users.c.name == "name3") + ).fetchall(), [], ) - users.update().where(users.c.name == "name4").execute(name="newname") + connection.execute( + users.update().where(users.c.name == "name4"), dict(name="newname") + ) eq_( - users.select(use_labels=True) - .where(users.c.id == 4) - .execute() - .fetchall(), + connection.execute( + users.select().apply_labels().where(users.c.id == 4) + ).fetchall(), [(4, "newname")], ) @@ -1233,7 +1236,7 @@ $$ LANGUAGE plpgsql; ne_(conn.connection.status, STATUS_IN_TRANSACTION) -class AutocommitTextTest(test_execute.AutocommitTextTest): +class AutocommitTextTest(test_deprecations.AutocommitTextTest): __only_on__ = "postgresql" def test_grant(self): diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py index 760487842..4e96cc6a2 100644 --- a/test/dialect/postgresql/test_on_conflict.py +++ b/test/dialect/postgresql/test_on_conflict.py @@ -99,28 +99,29 @@ class OnConflictTest(fixtures.TablesTest): ValueError, insert(self.tables.users).on_conflict_do_update ) - def test_on_conflict_do_nothing(self): + def test_on_conflict_do_nothing(self, connection): users = self.tables.users - with testing.db.connect() as conn: - result = conn.execute( - insert(users).on_conflict_do_nothing(), - dict(id=1, name="name1"), - ) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) - - result = conn.execute( - insert(users).on_conflict_do_nothing(), - dict(id=1, name="name2"), - ) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) - - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) + result = connection.execute( + insert(users).on_conflict_do_nothing(), + dict(id=1, name="name1"), + ) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + result = connection.execute( + insert(users).on_conflict_do_nothing(), + dict(id=1, name="name2"), + ) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) def test_on_conflict_do_nothing_connectionless(self, connection): users = self.tables.users_xtra @@ -147,95 +148,99 @@ class OnConflictTest(fixtures.TablesTest): ) @testing.provide_metadata - def test_on_conflict_do_nothing_target(self): + def test_on_conflict_do_nothing_target(self, connection): users = self.tables.users - with testing.db.connect() as conn: - result = conn.execute( - insert(users).on_conflict_do_nothing( - index_elements=users.primary_key.columns - ), - dict(id=1, name="name1"), - ) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) - - result = conn.execute( - insert(users).on_conflict_do_nothing( - index_elements=users.primary_key.columns - ), - dict(id=1, name="name2"), - ) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) - - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) - - def test_on_conflict_do_update_one(self): + result = connection.execute( + insert(users).on_conflict_do_nothing( + index_elements=users.primary_key.columns + ), + dict(id=1, name="name1"), + ) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + result = connection.execute( + insert(users).on_conflict_do_nothing( + index_elements=users.primary_key.columns + ), + dict(id=1, name="name2"), + ) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) + + def test_on_conflict_do_update_one(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.id], set_=dict(name=i.excluded.name) - ) - result = conn.execute(i, dict(id=1, name="name1")) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], set_=dict(name=i.excluded.name) + ) + result = connection.execute(i, dict(id=1, name="name1")) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) - def test_on_conflict_do_update_schema(self): + def test_on_conflict_do_update_schema(self, connection): users = self.tables.get("%s.users_schema" % config.test_schema) - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.id], set_=dict(name=i.excluded.name) - ) - result = conn.execute(i, dict(id=1, name="name1")) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], set_=dict(name=i.excluded.name) + ) + result = connection.execute(i, dict(id=1, name="name1")) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) - def test_on_conflict_do_update_column_as_key_set(self): + def test_on_conflict_do_update_column_as_key_set(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.id], - set_={users.c.name: i.excluded.name}, - ) - result = conn.execute(i, dict(id=1, name="name1")) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: i.excluded.name}, + ) + result = connection.execute(i, dict(id=1, name="name1")) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) - def test_on_conflict_do_update_clauseelem_as_key_set(self): + def test_on_conflict_do_update_clauseelem_as_key_set(self, connection): users = self.tables.users class MyElem(object): @@ -245,162 +250,165 @@ class OnConflictTest(fixtures.TablesTest): def __clause_element__(self): return self.expr - with testing.db.connect() as conn: - conn.execute( - users.insert(), - {"id": 1, "name": "name1"}, - ) + connection.execute( + users.insert(), + {"id": 1, "name": "name1"}, + ) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.id], - set_={MyElem(users.c.name): i.excluded.name}, - ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name1"}) - result = conn.execute(i) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], + set_={MyElem(users.c.name): i.excluded.name}, + ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name1"}) + result = connection.execute(i) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) - def test_on_conflict_do_update_column_as_key_set_schema(self): + def test_on_conflict_do_update_column_as_key_set_schema(self, connection): users = self.tables.get("%s.users_schema" % config.test_schema) - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.id], - set_={users.c.name: i.excluded.name}, - ) - result = conn.execute(i, dict(id=1, name="name1")) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: i.excluded.name}, + ) + result = connection.execute(i, dict(id=1, name="name1")) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) - def test_on_conflict_do_update_two(self): + def test_on_conflict_do_update_two(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.id], - set_=dict(id=i.excluded.id, name=i.excluded.name), - ) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], + set_=dict(id=i.excluded.id, name=i.excluded.name), + ) - result = conn.execute(i, dict(id=1, name="name2")) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + result = connection.execute(i, dict(id=1, name="name2")) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name2")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name2")], + ) - def test_on_conflict_do_update_three(self): + def test_on_conflict_do_update_three(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=users.primary_key.columns, - set_=dict(name=i.excluded.name), - ) - result = conn.execute(i, dict(id=1, name="name3")) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=users.primary_key.columns, + set_=dict(name=i.excluded.name), + ) + result = connection.execute(i, dict(id=1, name="name3")) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name3")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name3")], + ) - def test_on_conflict_do_update_four(self): + def test_on_conflict_do_update_four(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=users.primary_key.columns, - set_=dict(id=i.excluded.id, name=i.excluded.name), - ).values(id=1, name="name4") + i = insert(users) + i = i.on_conflict_do_update( + index_elements=users.primary_key.columns, + set_=dict(id=i.excluded.id, name=i.excluded.name), + ).values(id=1, name="name4") - result = conn.execute(i) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + result = connection.execute(i) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name4")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name4")], + ) - def test_on_conflict_do_update_five(self): + def test_on_conflict_do_update_five(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=users.primary_key.columns, - set_=dict(id=10, name="I'm a name"), - ).values(id=1, name="name4") + i = insert(users) + i = i.on_conflict_do_update( + index_elements=users.primary_key.columns, + set_=dict(id=10, name="I'm a name"), + ).values(id=1, name="name4") - result = conn.execute(i) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + result = connection.execute(i) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute( - users.select().where(users.c.id == 10) - ).fetchall(), - [(10, "I'm a name")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 10) + ).fetchall(), + [(10, "I'm a name")], + ) - def test_on_conflict_do_update_multivalues(self): + def test_on_conflict_do_update_multivalues(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) - conn.execute(users.insert(), dict(id=2, name="name2")) - - i = insert(users) - i = i.on_conflict_do_update( - index_elements=users.primary_key.columns, - set_=dict(name="updated"), - where=(i.excluded.name != "name12"), - ).values( - [ - dict(id=1, name="name11"), - dict(id=2, name="name12"), - dict(id=3, name="name13"), - dict(id=4, name="name14"), - ] - ) - - result = conn.execute(i) - eq_(result.inserted_primary_key, (None,)) - eq_(result.returned_defaults, None) - - eq_( - conn.execute(users.select().order_by(users.c.id)).fetchall(), - [(1, "updated"), (2, "name2"), (3, "name13"), (4, "name14")], - ) + connection.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=2, name="name2")) + + i = insert(users) + i = i.on_conflict_do_update( + index_elements=users.primary_key.columns, + set_=dict(name="updated"), + where=(i.excluded.name != "name12"), + ).values( + [ + dict(id=1, name="name11"), + dict(id=2, name="name12"), + dict(id=3, name="name13"), + dict(id=4, name="name14"), + ] + ) + + result = connection.execute(i) + eq_(result.inserted_primary_key, (None,)) + eq_(result.returned_defaults, None) + + eq_( + connection.execute(users.select().order_by(users.c.id)).fetchall(), + [(1, "updated"), (2, "name2"), (3, "name13"), (4, "name14")], + ) def _exotic_targets_fixture(self, conn): users = self.tables.users_xtra @@ -429,260 +437,250 @@ class OnConflictTest(fixtures.TablesTest): [(1, "name1", "name1@gmail.com", "not")], ) - def test_on_conflict_do_update_exotic_targets_two(self): + def test_on_conflict_do_update_exotic_targets_two(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - self._exotic_targets_fixture(conn) - # try primary key constraint: cause an upsert on unique id column - i = insert(users) - i = i.on_conflict_do_update( - index_elements=users.primary_key.columns, - set_=dict( - name=i.excluded.name, login_email=i.excluded.login_email - ), - ) - result = conn.execute( - i, - dict( - id=1, - name="name2", - login_email="name1@gmail.com", - lets_index_this="not", - ), - ) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) - - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name2", "name1@gmail.com", "not")], - ) - - def test_on_conflict_do_update_exotic_targets_three(self): + self._exotic_targets_fixture(connection) + # try primary key constraint: cause an upsert on unique id column + i = insert(users) + i = i.on_conflict_do_update( + index_elements=users.primary_key.columns, + set_=dict( + name=i.excluded.name, login_email=i.excluded.login_email + ), + ) + result = connection.execute( + i, + dict( + id=1, + name="name2", + login_email="name1@gmail.com", + lets_index_this="not", + ), + ) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name2", "name1@gmail.com", "not")], + ) + + def test_on_conflict_do_update_exotic_targets_three(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - self._exotic_targets_fixture(conn) - # try unique constraint: cause an upsert on target - # login_email, not id - i = insert(users) - i = i.on_conflict_do_update( - constraint=self.unique_constraint, - set_=dict( - id=i.excluded.id, - name=i.excluded.name, - login_email=i.excluded.login_email, - ), - ) - # note: lets_index_this value totally ignored in SET clause. - result = conn.execute( - i, - dict( - id=42, - name="nameunique", - login_email="name2@gmail.com", - lets_index_this="unique", - ), - ) - eq_(result.inserted_primary_key, (42,)) - eq_(result.returned_defaults, None) - - eq_( - conn.execute( - users.select().where( - users.c.login_email == "name2@gmail.com" - ) - ).fetchall(), - [(42, "nameunique", "name2@gmail.com", "not")], - ) - - def test_on_conflict_do_update_exotic_targets_four(self): + self._exotic_targets_fixture(connection) + # try unique constraint: cause an upsert on target + # login_email, not id + i = insert(users) + i = i.on_conflict_do_update( + constraint=self.unique_constraint, + set_=dict( + id=i.excluded.id, + name=i.excluded.name, + login_email=i.excluded.login_email, + ), + ) + # note: lets_index_this value totally ignored in SET clause. + result = connection.execute( + i, + dict( + id=42, + name="nameunique", + login_email="name2@gmail.com", + lets_index_this="unique", + ), + ) + eq_(result.inserted_primary_key, (42,)) + eq_(result.returned_defaults, None) + + eq_( + connection.execute( + users.select().where(users.c.login_email == "name2@gmail.com") + ).fetchall(), + [(42, "nameunique", "name2@gmail.com", "not")], + ) + + def test_on_conflict_do_update_exotic_targets_four(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - self._exotic_targets_fixture(conn) - # try unique constraint by name: cause an - # upsert on target login_email, not id - i = insert(users) - i = i.on_conflict_do_update( - constraint=self.unique_constraint.name, - set_=dict( - id=i.excluded.id, - name=i.excluded.name, - login_email=i.excluded.login_email, - ), - ) - # note: lets_index_this value totally ignored in SET clause. - - result = conn.execute( - i, - dict( - id=43, - name="nameunique2", - login_email="name2@gmail.com", - lets_index_this="unique", - ), - ) - eq_(result.inserted_primary_key, (43,)) - eq_(result.returned_defaults, None) - - eq_( - conn.execute( - users.select().where( - users.c.login_email == "name2@gmail.com" - ) - ).fetchall(), - [(43, "nameunique2", "name2@gmail.com", "not")], - ) - - def test_on_conflict_do_update_exotic_targets_four_no_pk(self): + self._exotic_targets_fixture(connection) + # try unique constraint by name: cause an + # upsert on target login_email, not id + i = insert(users) + i = i.on_conflict_do_update( + constraint=self.unique_constraint.name, + set_=dict( + id=i.excluded.id, + name=i.excluded.name, + login_email=i.excluded.login_email, + ), + ) + # note: lets_index_this value totally ignored in SET clause. + + result = connection.execute( + i, + dict( + id=43, + name="nameunique2", + login_email="name2@gmail.com", + lets_index_this="unique", + ), + ) + eq_(result.inserted_primary_key, (43,)) + eq_(result.returned_defaults, None) + + eq_( + connection.execute( + users.select().where(users.c.login_email == "name2@gmail.com") + ).fetchall(), + [(43, "nameunique2", "name2@gmail.com", "not")], + ) + + def test_on_conflict_do_update_exotic_targets_four_no_pk(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - self._exotic_targets_fixture(conn) - # try unique constraint by name: cause an - # upsert on target login_email, not id - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.login_email], - set_=dict( - id=i.excluded.id, - name=i.excluded.name, - login_email=i.excluded.login_email, - ), - ) - - result = conn.execute( - i, dict(name="name3", login_email="name1@gmail.com") - ) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, (1,)) - - eq_( - conn.execute(users.select().order_by(users.c.id)).fetchall(), - [ - (1, "name3", "name1@gmail.com", "not"), - (2, "name2", "name2@gmail.com", "not"), - ], - ) - - def test_on_conflict_do_update_exotic_targets_five(self): + self._exotic_targets_fixture(connection) + # try unique constraint by name: cause an + # upsert on target login_email, not id + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.login_email], + set_=dict( + id=i.excluded.id, + name=i.excluded.name, + login_email=i.excluded.login_email, + ), + ) + + result = connection.execute( + i, dict(name="name3", login_email="name1@gmail.com") + ) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, (1,)) + + eq_( + connection.execute(users.select().order_by(users.c.id)).fetchall(), + [ + (1, "name3", "name1@gmail.com", "not"), + (2, "name2", "name2@gmail.com", "not"), + ], + ) + + def test_on_conflict_do_update_exotic_targets_five(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - self._exotic_targets_fixture(conn) - # try bogus index - i = insert(users) - i = i.on_conflict_do_update( - index_elements=self.bogus_index.columns, - index_where=self.bogus_index.dialect_options["postgresql"][ - "where" - ], - set_=dict( - name=i.excluded.name, login_email=i.excluded.login_email - ), - ) - - assert_raises( - exc.ProgrammingError, - conn.execute, - i, - dict( - id=1, - name="namebogus", - login_email="bogus@gmail.com", - lets_index_this="bogus", - ), - ) - - def test_on_conflict_do_update_exotic_targets_six(self): + self._exotic_targets_fixture(connection) + # try bogus index + i = insert(users) + i = i.on_conflict_do_update( + index_elements=self.bogus_index.columns, + index_where=self.bogus_index.dialect_options["postgresql"][ + "where" + ], + set_=dict( + name=i.excluded.name, login_email=i.excluded.login_email + ), + ) + + assert_raises( + exc.ProgrammingError, + connection.execute, + i, + dict( + id=1, + name="namebogus", + login_email="bogus@gmail.com", + lets_index_this="bogus", + ), + ) + + def test_on_conflict_do_update_exotic_targets_six(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - conn.execute( - insert(users), + connection.execute( + insert(users), + dict( + id=1, + name="name1", + login_email="mail1@gmail.com", + lets_index_this="unique_name", + ), + ) + + i = insert(users) + i = i.on_conflict_do_update( + index_elements=self.unique_partial_index.columns, + index_where=self.unique_partial_index.dialect_options[ + "postgresql" + ]["where"], + set_=dict( + name=i.excluded.name, login_email=i.excluded.login_email + ), + ) + + connection.execute( + i, + [ dict( - id=1, name="name1", - login_email="mail1@gmail.com", + login_email="mail2@gmail.com", lets_index_this="unique_name", - ), - ) - - i = insert(users) - i = i.on_conflict_do_update( - index_elements=self.unique_partial_index.columns, - index_where=self.unique_partial_index.dialect_options[ - "postgresql" - ]["where"], - set_=dict( - name=i.excluded.name, login_email=i.excluded.login_email - ), - ) - - conn.execute( - i, - [ - dict( - name="name1", - login_email="mail2@gmail.com", - lets_index_this="unique_name", - ) - ], - ) - - eq_( - conn.execute(users.select()).fetchall(), - [(1, "name1", "mail2@gmail.com", "unique_name")], - ) - - def test_on_conflict_do_update_no_row_actually_affected(self): + ) + ], + ) + + eq_( + connection.execute(users.select()).fetchall(), + [(1, "name1", "mail2@gmail.com", "unique_name")], + ) + + def test_on_conflict_do_update_no_row_actually_affected(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - self._exotic_targets_fixture(conn) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.login_email], - set_=dict(name="new_name"), - where=(i.excluded.name == "other_name"), - ) - result = conn.execute( - i, dict(name="name2", login_email="name1@gmail.com") - ) - - eq_(result.returned_defaults, None) - eq_(result.inserted_primary_key, None) - - eq_( - conn.execute(users.select()).fetchall(), - [ - (1, "name1", "name1@gmail.com", "not"), - (2, "name2", "name2@gmail.com", "not"), - ], - ) - - def test_on_conflict_do_update_special_types_in_set(self): + self._exotic_targets_fixture(connection) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.login_email], + set_=dict(name="new_name"), + where=(i.excluded.name == "other_name"), + ) + result = connection.execute( + i, dict(name="name2", login_email="name1@gmail.com") + ) + + eq_(result.returned_defaults, None) + eq_(result.inserted_primary_key, None) + + eq_( + connection.execute(users.select()).fetchall(), + [ + (1, "name1", "name1@gmail.com", "not"), + (2, "name2", "name2@gmail.com", "not"), + ], + ) + + def test_on_conflict_do_update_special_types_in_set(self, connection): bind_targets = self.tables.bind_targets - with testing.db.connect() as conn: - i = insert(bind_targets) - conn.execute(i, {"id": 1, "data": "initial data"}) - - eq_( - conn.scalar(sql.select(bind_targets.c.data)), - "initial data processed", - ) - - i = insert(bind_targets) - i = i.on_conflict_do_update( - index_elements=[bind_targets.c.id], - set_=dict(data="new updated data"), - ) - conn.execute(i, {"id": 1, "data": "new inserted data"}) - - eq_( - conn.scalar(sql.select(bind_targets.c.data)), - "new updated data processed", - ) + i = insert(bind_targets) + connection.execute(i, {"id": 1, "data": "initial data"}) + + eq_( + connection.scalar(sql.select(bind_targets.c.data)), + "initial data processed", + ) + + i = insert(bind_targets) + i = i.on_conflict_do_update( + index_elements=[bind_targets.c.id], + set_=dict(data="new updated data"), + ) + connection.execute(i, {"id": 1, "data": "new inserted data"}) + + eq_( + connection.scalar(sql.select(bind_targets.c.data)), + "new updated data processed", + ) diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index c959acf35..94af168ee 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -35,30 +35,32 @@ from sqlalchemy.testing.assertsql import CursorSQL from sqlalchemy.testing.assertsql import DialectSQL -matchtable = cattable = None - - class InsertTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "postgresql" __backend__ = True - @classmethod - def setup_class(cls): - cls.metadata = MetaData(testing.db) + def setup(self): + self.metadata = MetaData() def teardown(self): - self.metadata.drop_all() - self.metadata.clear() + with testing.db.begin() as conn: + self.metadata.drop_all(conn) + + @testing.combinations((False,), (True,)) + def test_foreignkey_missing_insert(self, implicit_returning): + engine = engines.testing_engine( + options={"implicit_returning": implicit_returning} + ) - def test_foreignkey_missing_insert(self): Table("t1", self.metadata, Column("id", Integer, primary_key=True)) t2 = Table( "t2", self.metadata, Column("id", Integer, ForeignKey("t1.id"), primary_key=True), ) - self.metadata.create_all() + + self.metadata.create_all(engine) # want to ensure that "null value in column "id" violates not- # null constraint" is raised (IntegrityError on psycoopg2, but @@ -67,19 +69,13 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): # the latter corresponds to autoincrement behavior, which is not # the case here due to the foreign key. - for eng in [ - engines.testing_engine(options={"implicit_returning": False}), - engines.testing_engine(options={"implicit_returning": True}), - ]: - with expect_warnings( - ".*has no Python-side or server-side default.*" - ): - with eng.connect() as conn: - assert_raises( - (exc.IntegrityError, exc.ProgrammingError), - conn.execute, - t2.insert(), - ) + with expect_warnings(".*has no Python-side or server-side default.*"): + with engine.begin() as conn: + assert_raises( + (exc.IntegrityError, exc.ProgrammingError), + conn.execute, + t2.insert(), + ) def test_sequence_insert(self): table = Table( @@ -88,7 +84,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): Column("id", Integer, Sequence("my_seq"), primary_key=True), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_with_sequence(table, "my_seq") @testing.requires.returning @@ -99,7 +95,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): Column("id", Integer, Sequence("my_seq"), primary_key=True), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_with_sequence_returning(table, "my_seq") def test_opt_sequence_insert(self): @@ -114,7 +110,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_autoincrement(table) @testing.requires.returning @@ -130,7 +126,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_autoincrement_returning(table) def test_autoincrement_insert(self): @@ -140,7 +136,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): Column("id", Integer, primary_key=True), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_autoincrement(table) @testing.requires.returning @@ -151,7 +147,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): Column("id", Integer, primary_key=True), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_autoincrement_returning(table) def test_noautoincrement_insert(self): @@ -161,7 +157,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): Column("id", Integer, primary_key=True, autoincrement=False), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_noautoincrement(table) def _assert_data_autoincrement(self, table): @@ -169,7 +165,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): with self.sql_execution_asserter(engine) as asserter: - with engine.connect() as conn: + with engine.begin() as conn: # execute with explicit id r = conn.execute(table.insert(), {"id": 30, "data": "d1"}) @@ -226,7 +222,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ), ) - with engine.connect() as conn: + with engine.begin() as conn: eq_( conn.execute(table.select()).fetchall(), [ @@ -250,7 +246,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table = Table(table.name, m2, autoload_with=engine) with self.sql_execution_asserter(engine) as asserter: - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(table.insert(), {"id": 30, "data": "d1"}) r = conn.execute(table.insert(), {"data": "d2"}) eq_(r.inserted_primary_key, (5,)) @@ -288,7 +284,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): "INSERT INTO testtable (data) VALUES (:data)", [{"data": "d8"}] ), ) - with engine.connect() as conn: + with engine.begin() as conn: eq_( conn.execute(table.select()).fetchall(), [ @@ -308,7 +304,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): engine = engines.testing_engine(options={"implicit_returning": True}) with self.sql_execution_asserter(engine) as asserter: - with engine.connect() as conn: + with engine.begin() as conn: # execute with explicit id @@ -367,7 +363,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ), ) - with engine.connect() as conn: + with engine.begin() as conn: eq_( conn.execute(table.select()).fetchall(), [ @@ -390,7 +386,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table = Table(table.name, m2, autoload_with=engine) with self.sql_execution_asserter(engine) as asserter: - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(table.insert(), {"id": 30, "data": "d1"}) r = conn.execute(table.insert(), {"data": "d2"}) eq_(r.inserted_primary_key, (5,)) @@ -430,7 +426,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ), ) - with engine.connect() as conn: + with engine.begin() as conn: eq_( conn.execute(table.select()).fetchall(), [ @@ -450,7 +446,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): engine = engines.testing_engine(options={"implicit_returning": False}) with self.sql_execution_asserter(engine) as asserter: - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(table.insert(), {"id": 30, "data": "d1"}) conn.execute(table.insert(), {"data": "d2"}) conn.execute( @@ -491,7 +487,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): [{"data": "d8"}], ), ) - with engine.connect() as conn: + with engine.begin() as conn: eq_( conn.execute(table.select()).fetchall(), [ @@ -513,7 +509,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): engine = engines.testing_engine(options={"implicit_returning": True}) with self.sql_execution_asserter(engine) as asserter: - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(table.insert(), {"id": 30, "data": "d1"}) conn.execute(table.insert(), {"data": "d2"}) conn.execute( @@ -555,7 +551,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ), ) - with engine.connect() as conn: + with engine.begin() as conn: eq_( conn.execute(table.select()).fetchall(), [ @@ -578,9 +574,12 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): # turning off the cache because we are checking for compile-time # warnings - with engine.connect().execution_options(compiled_cache=None) as conn: + engine = engine.execution_options(compiled_cache=None) + + with engine.begin() as conn: conn.execute(table.insert(), {"id": 30, "data": "d1"}) + with engine.begin() as conn: with expect_warnings( ".*has no Python-side or server-side default.*" ): @@ -590,6 +589,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table.insert(), {"data": "d2"}, ) + + with engine.begin() as conn: with expect_warnings( ".*has no Python-side or server-side default.*" ): @@ -599,6 +600,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table.insert(), [{"data": "d2"}, {"data": "d3"}], ) + + with engine.begin() as conn: with expect_warnings( ".*has no Python-side or server-side default.*" ): @@ -608,6 +611,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table.insert(), {"data": "d2"}, ) + + with engine.begin() as conn: with expect_warnings( ".*has no Python-side or server-side default.*" ): @@ -618,6 +623,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): [{"data": "d2"}, {"data": "d3"}], ) + with engine.begin() as conn: conn.execute( table.insert(), [{"id": 31, "data": "d2"}, {"id": 32, "data": "d3"}], @@ -634,9 +640,10 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): m2 = MetaData() table = Table(table.name, m2, autoload_with=engine) - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(table.insert(), {"id": 30, "data": "d1"}) + with engine.begin() as conn: with expect_warnings( ".*has no Python-side or server-side default.*" ): @@ -646,6 +653,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table.insert(), {"data": "d2"}, ) + + with engine.begin() as conn: with expect_warnings( ".*has no Python-side or server-side default.*" ): @@ -655,6 +664,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table.insert(), [{"data": "d2"}, {"data": "d3"}], ) + + with engine.begin() as conn: conn.execute( table.insert(), [{"id": 31, "data": "d2"}, {"id": 32, "data": "d3"}], @@ -666,36 +677,40 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ) -class MatchTest(fixtures.TestBase, AssertsCompiledSQL): +class MatchTest(fixtures.TablesTest, AssertsCompiledSQL): __only_on__ = "postgresql >= 8.3" __backend__ = True @classmethod - def setup_class(cls): - global metadata, cattable, matchtable - metadata = MetaData(testing.db) - cattable = Table( + def define_tables(cls, metadata): + Table( "cattable", metadata, Column("id", Integer, primary_key=True), Column("description", String(50)), ) - matchtable = Table( + Table( "matchtable", metadata, Column("id", Integer, primary_key=True), Column("title", String(200)), Column("category_id", Integer, ForeignKey("cattable.id")), ) - metadata.create_all() - cattable.insert().execute( + + @classmethod + def insert_data(cls, connection): + cattable, matchtable = cls.tables("cattable", "matchtable") + + connection.execute( + cattable.insert(), [ {"id": 1, "description": "Python"}, {"id": 2, "description": "Ruby"}, - ] + ], ) - matchtable.insert().execute( + connection.execute( + matchtable.insert(), [ { "id": 1, @@ -714,15 +729,12 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): "category_id": 1, }, {"id": 5, "title": "Python in a Nutshell", "category_id": 1}, - ] + ], ) - @classmethod - def teardown_class(cls): - metadata.drop_all() - @testing.requires.pyformat_paramstyle def test_expression_pyformat(self): + matchtable = self.tables.matchtable self.assert_compile( matchtable.c.title.match("somstr"), "matchtable.title @@ to_tsquery(%(title_1)s" ")", @@ -730,51 +742,47 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): @testing.requires.format_paramstyle def test_expression_positional(self): + matchtable = self.tables.matchtable self.assert_compile( matchtable.c.title.match("somstr"), "matchtable.title @@ to_tsquery(%s)", ) - def test_simple_match(self): - results = ( + def test_simple_match(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( matchtable.select() .where(matchtable.c.title.match("python")) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([2, 5], [r.id for r in results]) - def test_not_match(self): - results = ( + def test_not_match(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( matchtable.select() .where(~matchtable.c.title.match("python")) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([1, 3, 4], [r.id for r in results]) - def test_simple_match_with_apostrophe(self): - results = ( - matchtable.select() - .where(matchtable.c.title.match("Matz's")) - .execute() - .fetchall() - ) + def test_simple_match_with_apostrophe(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( + matchtable.select().where(matchtable.c.title.match("Matz's")) + ).fetchall() eq_([3], [r.id for r in results]) - def test_simple_derivative_match(self): - results = ( - matchtable.select() - .where(matchtable.c.title.match("nutshells")) - .execute() - .fetchall() - ) + def test_simple_derivative_match(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( + matchtable.select().where(matchtable.c.title.match("nutshells")) + ).fetchall() eq_([5], [r.id for r in results]) - def test_or_match(self): - results1 = ( + def test_or_match(self, connection): + matchtable = self.tables.matchtable + results1 = connection.execute( matchtable.select() .where( or_( @@ -783,42 +791,36 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): ) ) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([3, 5], [r.id for r in results1]) - results2 = ( + results2 = connection.execute( matchtable.select() .where(matchtable.c.title.match("nutshells | rubies")) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([3, 5], [r.id for r in results2]) - def test_and_match(self): - results1 = ( - matchtable.select() - .where( + def test_and_match(self, connection): + matchtable = self.tables.matchtable + results1 = connection.execute( + matchtable.select().where( and_( matchtable.c.title.match("python"), matchtable.c.title.match("nutshells"), ) ) - .execute() - .fetchall() - ) + ).fetchall() eq_([5], [r.id for r in results1]) - results2 = ( - matchtable.select() - .where(matchtable.c.title.match("python & nutshells")) - .execute() - .fetchall() - ) + results2 = connection.execute( + matchtable.select().where( + matchtable.c.title.match("python & nutshells") + ) + ).fetchall() eq_([5], [r.id for r in results2]) - def test_match_across_joins(self): - results = ( + def test_match_across_joins(self, connection): + cattable, matchtable = self.tables("cattable", "matchtable") + results = connection.execute( matchtable.select() .where( and_( @@ -830,9 +832,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): ) ) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([1, 3, 5], [r.id for r in results]) diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index 4de4d88e3..824f6cd36 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -291,63 +291,64 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): @classmethod def setup_class(cls): - con = testing.db.connect() - for ddl in [ - 'CREATE SCHEMA "SomeSchema"', - "CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42", - "CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0", - "CREATE TYPE testtype AS ENUM ('test')", - "CREATE DOMAIN enumdomain AS testtype", - "CREATE DOMAIN arraydomain AS INTEGER[]", - 'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0', - ]: - try: - con.exec_driver_sql(ddl) - except exc.DBAPIError as e: - if "already exists" not in str(e): - raise e - con.exec_driver_sql( - "CREATE TABLE testtable (question integer, answer " "testdomain)" - ) - con.exec_driver_sql( - "CREATE TABLE test_schema.testtable(question " - "integer, answer test_schema.testdomain, anything " - "integer)" - ) - con.exec_driver_sql( - "CREATE TABLE crosschema (question integer, answer " - "test_schema.testdomain)" - ) + with testing.db.begin() as con: + for ddl in [ + 'CREATE SCHEMA "SomeSchema"', + "CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42", + "CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0", + "CREATE TYPE testtype AS ENUM ('test')", + "CREATE DOMAIN enumdomain AS testtype", + "CREATE DOMAIN arraydomain AS INTEGER[]", + 'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0', + ]: + try: + con.exec_driver_sql(ddl) + except exc.DBAPIError as e: + if "already exists" not in str(e): + raise e + con.exec_driver_sql( + "CREATE TABLE testtable (question integer, answer " + "testdomain)" + ) + con.exec_driver_sql( + "CREATE TABLE test_schema.testtable(question " + "integer, answer test_schema.testdomain, anything " + "integer)" + ) + con.exec_driver_sql( + "CREATE TABLE crosschema (question integer, answer " + "test_schema.testdomain)" + ) - con.exec_driver_sql( - "CREATE TABLE enum_test (id integer, data enumdomain)" - ) + con.exec_driver_sql( + "CREATE TABLE enum_test (id integer, data enumdomain)" + ) - con.exec_driver_sql( - "CREATE TABLE array_test (id integer, data arraydomain)" - ) + con.exec_driver_sql( + "CREATE TABLE array_test (id integer, data arraydomain)" + ) - con.exec_driver_sql( - "CREATE TABLE quote_test " - '(id integer, data "SomeSchema"."Quoted.Domain")' - ) + con.exec_driver_sql( + "CREATE TABLE quote_test " + '(id integer, data "SomeSchema"."Quoted.Domain")' + ) @classmethod def teardown_class(cls): - con = testing.db.connect() - con.exec_driver_sql("DROP TABLE testtable") - con.exec_driver_sql("DROP TABLE test_schema.testtable") - con.exec_driver_sql("DROP TABLE crosschema") - con.exec_driver_sql("DROP TABLE quote_test") - con.exec_driver_sql("DROP DOMAIN testdomain") - con.exec_driver_sql("DROP DOMAIN test_schema.testdomain") - con.exec_driver_sql("DROP TABLE enum_test") - con.exec_driver_sql("DROP DOMAIN enumdomain") - con.exec_driver_sql("DROP TYPE testtype") - con.exec_driver_sql("DROP TABLE array_test") - con.exec_driver_sql("DROP DOMAIN arraydomain") - con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"') - con.exec_driver_sql('DROP SCHEMA "SomeSchema"') + with testing.db.begin() as con: + con.exec_driver_sql("DROP TABLE testtable") + con.exec_driver_sql("DROP TABLE test_schema.testtable") + con.exec_driver_sql("DROP TABLE crosschema") + con.exec_driver_sql("DROP TABLE quote_test") + con.exec_driver_sql("DROP DOMAIN testdomain") + con.exec_driver_sql("DROP DOMAIN test_schema.testdomain") + con.exec_driver_sql("DROP TABLE enum_test") + con.exec_driver_sql("DROP DOMAIN enumdomain") + con.exec_driver_sql("DROP TYPE testtype") + con.exec_driver_sql("DROP TABLE array_test") + con.exec_driver_sql("DROP DOMAIN arraydomain") + con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"') + con.exec_driver_sql('DROP SCHEMA "SomeSchema"') def test_table_is_reflected(self): metadata = MetaData() @@ -486,7 +487,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("ref", Integer, ForeignKey("subject.id$")), ) - meta1.create_all() + meta1.create_all(testing.db) meta2 = MetaData() subject = Table("subject", meta2, autoload_with=testing.db) referer = Table("referer", meta2, autoload_with=testing.db) @@ -523,9 +524,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): with testing.db.begin() as conn: r = conn.execute(t2.insert()) eq_(r.inserted_primary_key, (1,)) - testing.db.connect().execution_options( - autocommit=True - ).exec_driver_sql("alter table t_id_seq rename to foobar_id_seq") + + with testing.db.begin() as conn: + conn.exec_driver_sql( + "alter table t_id_seq rename to foobar_id_seq" + ) m3 = MetaData() t3 = Table("t", m3, autoload_with=testing.db, implicit_returning=False) eq_( @@ -545,10 +548,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("x", Integer), ) - metadata.create_all() - testing.db.connect().execution_options( - autocommit=True - ).exec_driver_sql("alter table t alter column id type varchar(50)") + metadata.create_all(testing.db) + + with testing.db.begin() as conn: + conn.exec_driver_sql( + "alter table t alter column id type varchar(50)" + ) m2 = MetaData() t2 = Table("t", m2, autoload_with=testing.db) eq_(t2.c.id.autoincrement, False) @@ -558,10 +563,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): def test_renamed_pk_reflection(self): metadata = self.metadata Table("t", metadata, Column("id", Integer, primary_key=True)) - metadata.create_all() - testing.db.connect().execution_options( - autocommit=True - ).exec_driver_sql("alter table t rename id to t_id") + metadata.create_all(testing.db) + with testing.db.begin() as conn: + conn.exec_driver_sql("alter table t rename id to t_id") m2 = MetaData() t2 = Table("t", m2, autoload_with=testing.db) eq_([c.name for c in t2.primary_key], ["t_id"]) @@ -936,13 +940,13 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("name", String(20), index=True), Column("aname", String(20)), ) - metadata.create_all() - with testing.db.connect() as c: - c.exec_driver_sql("create index idx1 on party ((id || name))") - c.exec_driver_sql( + metadata.create_all(testing.db) + with testing.db.begin() as conn: + conn.exec_driver_sql("create index idx1 on party ((id || name))") + conn.exec_driver_sql( "create unique index idx2 on party (id) where name = 'test'" ) - c.exec_driver_sql( + conn.exec_driver_sql( """ create index idx3 on party using btree (lower(name::text), lower(aname::text)) @@ -1029,7 +1033,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("aname", String(20)), ) - with testing.db.connect() as conn: + with testing.db.begin() as conn: t1.create(conn) @@ -1109,18 +1113,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("x", Integer), ) - metadata.create_all() - conn = testing.db.connect().execution_options(autocommit=True) - conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)") - conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y") + metadata.create_all(testing.db) + with testing.db.begin() as conn: + conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)") + conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y") - ind = testing.db.dialect.get_indexes(conn, "t", None) - expected = [{"name": "idx1", "unique": False, "column_names": ["y"]}] - if testing.requires.index_reflects_included_columns.enabled: - expected[0]["include_columns"] = [] + ind = testing.db.dialect.get_indexes(conn, "t", None) + expected = [ + {"name": "idx1", "unique": False, "column_names": ["y"]} + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] - eq_(ind, expected) - conn.close() + eq_(ind, expected) @testing.fails_if("postgresql < 8.2", "reloptions not supported") @testing.provide_metadata @@ -1135,9 +1140,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("x", Integer), ) - metadata.create_all() + metadata.create_all(testing.db) - with testing.db.connect().execution_options(autocommit=True) as conn: + with testing.db.begin() as conn: conn.exec_driver_sql( "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)" ) @@ -1177,8 +1182,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("x", ARRAY(Integer)), ) - metadata.create_all() - with testing.db.connect().execution_options(autocommit=True) as conn: + metadata.create_all(testing.db) + with testing.db.begin() as conn: conn.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)") ind = testing.db.dialect.get_indexes(conn, "t", None) @@ -1215,7 +1220,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("name", String(20)), ) metadata.create_all() - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.exec_driver_sql("CREATE INDEX idx1 ON t (x) INCLUDE (name)") # prior to #5205, this would return: @@ -1312,8 +1317,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): eq_(fk, fk_ref[fk["name"]]) @testing.provide_metadata - def test_inspect_enums_schema(self): - conn = testing.db.connect() + def test_inspect_enums_schema(self, connection): enum_type = postgresql.ENUM( "sad", "ok", @@ -1322,8 +1326,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): schema="test_schema", metadata=self.metadata, ) - enum_type.create(conn) - inspector = inspect(conn) + enum_type.create(connection) + inspector = inspect(connection) eq_( inspector.get_enums("test_schema"), [ diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index e7174f234..ae7a65a3a 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -206,7 +206,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ), schema=symbol_name, ) - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn = conn.execution_options( schema_translate_map={symbol_name: testing.config.test_schema} ) |
