diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2020-02-19 22:59:46 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-02-19 22:59:46 +0000 |
commit | 8c9537d37292459d348214fb8befa85d9cb64059 (patch) | |
tree | 74d9872d2a324058f0e94bf5046fc042225ffe12 | |
parent | 5ad9e9fbb25decff09104b03904cfe00a2b18916 (diff) | |
parent | 60f627cbd0d769e65353e720548efac9d8ab95d9 (diff) | |
download | sqlalchemy-8c9537d37292459d348214fb8befa85d9cb64059.tar.gz |
Merge "Replace engine.execute w/ context manager (step1)"
-rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/fixtures.py | 26 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_insert.py | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_results.py | 114 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_types.py | 71 | ||||
-rw-r--r-- | test/dialect/mssql/test_query.py | 9 | ||||
-rw-r--r-- | test/dialect/mssql/test_types.py | 28 |
7 files changed, 138 insertions, 123 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 1b1bd8c5c..4339551a3 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -182,7 +182,8 @@ execution. Given this example:: Column('x', Integer)) m.create_all(engine) - engine.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2}) + with engine.begin() as conn: + conn.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2}) The above column will be created with IDENTITY, however the INSERT statement we emit is specifying explicit values. In the echo output we can see diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 62bf9fc1f..bae0cee89 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -56,6 +56,32 @@ class TestBase(object): if hasattr(self, "tearDown"): self.tearDown() + @config.fixture() + def connection(self): + conn = config.db.connect() + trans = conn.begin() + try: + yield conn + finally: + trans.rollback() + conn.close() + + # propose a replacement for @testing.provide_metadata. + # the problem with this is that TablesTest below has a ".metadata" + # attribute already which is accessed directly as part of the + # @testing.provide_metadata pattern. Might need to call this _metadata + # for it to be useful. + # @config.fixture() + # def metadata(self): + # """Provide bound MetaData for a single test, dropping afterwards.""" + # + # from . import engines + # metadata = schema.MetaData(config.db) + # try: + # yield metadata + # finally: + # engines.drop_all_tables(metadata, config.db) + class TablesTest(TestBase): diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index 2cc8761b8..931b0ef65 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -109,7 +109,8 @@ class InsertBehaviorTest(fixtures.TablesTest): else: engine = config.db - r = engine.execute(self.tables.autoinc_pk.insert(), data="some data") + with engine.begin() as conn: + r = conn.execute(self.tables.autoinc_pk.insert(), data="some data") assert r._soft_closed assert not r.closed assert r.is_insert @@ -278,9 +279,10 @@ class ReturningTest(fixtures.TablesTest): def test_explicit_returning_pk_autocommit(self): engine = config.db table = self.tables.autoinc_pk - r = engine.execute( - table.insert().returning(table.c.id), data="some data" - ) + with engine.begin() as conn: + r = conn.execute( + table.insert().returning(table.c.id), data="some data" + ) pk = r.first()[0] fetched_pk = config.db.scalar(select([table.c.id])) eq_(fetched_pk, pk) diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index 4fc0bb79d..d77d13efa 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -225,7 +225,7 @@ class ServerSideCursorsTest( def _is_server_side(self, cursor): if self.engine.dialect.driver == "psycopg2": - return cursor.name + return bool(cursor.name) elif self.engine.dialect.driver == "pymysql": sscursor = __import__("pymysql.cursors").cursors.SSCursor return isinstance(cursor, sscursor) @@ -245,43 +245,48 @@ class ServerSideCursorsTest( engines.testing_reaper.close_all() self.engine.dispose() - def test_global_string(self): - engine = self._fixture(True) - result = engine.execute("select 1") - assert self._is_server_side(result.cursor) - - def test_global_text(self): - engine = self._fixture(True) - result = engine.execute(text("select 1")) - assert self._is_server_side(result.cursor) - - def test_global_expr(self): - engine = self._fixture(True) - result = engine.execute(select([1])) - assert self._is_server_side(result.cursor) - - def test_global_off_explicit(self): - engine = self._fixture(False) - result = engine.execute(text("select 1")) - - # It should be off globally ... - - assert not self._is_server_side(result.cursor) - - def test_stmt_option(self): - engine = self._fixture(False) - - s = select([1]).execution_options(stream_results=True) - result = engine.execute(s) - - # ... but enabled for this one. - - assert self._is_server_side(result.cursor) + @testing.combinations( + ("global_string", True, "select 1", True), + ("global_text", True, text("select 1"), True), + ("global_expr", True, select([1]), True), + ("global_off_explicit", False, text("select 1"), False), + ( + "stmt_option", + False, + select([1]).execution_options(stream_results=True), + True, + ), + ( + "stmt_option_disabled", + True, + select([1]).execution_options(stream_results=False), + False, + ), + ("for_update_expr", True, select([1]).with_for_update(), True), + ("for_update_string", True, "SELECT 1 FOR UPDATE", True), + ("text_no_ss", False, text("select 42"), False), + ( + "text_ss_option", + False, + text("select 42").execution_options(stream_results=True), + True, + ), + id_="iaaa", + argnames="engine_ss_arg, statement, cursor_ss_status", + ) + def test_ss_cursor_status( + self, engine_ss_arg, statement, cursor_ss_status + ): + engine = self._fixture(engine_ss_arg) + with engine.begin() as conn: + result = conn.execute(statement) + eq_(self._is_server_side(result.cursor), cursor_ss_status) + result.close() def test_conn_option(self): engine = self._fixture(False) - # and this one + # should be enabled for this one result = ( engine.connect() .execution_options(stream_results=True) @@ -300,46 +305,21 @@ class ServerSideCursorsTest( ) assert not self._is_server_side(result.cursor) - def test_stmt_option_disabled(self): - engine = self._fixture(True) - s = select([1]).execution_options(stream_results=False) - result = engine.execute(s) - assert not self._is_server_side(result.cursor) - def test_aliases_and_ss(self): engine = self._fixture(False) s1 = select([1]).execution_options(stream_results=True).alias() - result = engine.execute(s1) - assert self._is_server_side(result.cursor) + with engine.begin() as conn: + result = conn.execute(s1) + assert self._is_server_side(result.cursor) + result.close() # s1's options shouldn't affect s2 when s2 is used as a # from_obj. s2 = select([1], from_obj=s1) - result = engine.execute(s2) - assert not self._is_server_side(result.cursor) - - def test_for_update_expr(self): - engine = self._fixture(True) - s1 = select([1]).with_for_update() - result = engine.execute(s1) - assert self._is_server_side(result.cursor) - - def test_for_update_string(self): - engine = self._fixture(True) - result = engine.execute("SELECT 1 FOR UPDATE") - assert self._is_server_side(result.cursor) - - def test_text_no_ss(self): - engine = self._fixture(False) - s = text("select 42") - result = engine.execute(s) - assert not self._is_server_side(result.cursor) - - def test_text_ss_option(self): - engine = self._fixture(False) - s = text("select 42").execution_options(stream_results=True) - result = engine.execute(s) - assert self._is_server_side(result.cursor) + with engine.begin() as conn: + result = conn.execute(s2) + assert not self._is_server_side(result.cursor) + result.close() @testing.provide_metadata def test_roundtrip(self): diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 2a5dad2d6..a334b8ebc 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -871,54 +871,51 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): # support sqlite :memory: database... data_table.create(engine, checkfirst=True) - engine.execute( - data_table.insert(), {"name": "row1", "data": data_element} - ) - - row = engine.execute(select([data_table.c.data])).first() + with engine.connect() as conn: + conn.execute( + data_table.insert(), {"name": "row1", "data": data_element} + ) + row = conn.execute(select([data_table.c.data])).first() - eq_(row, (data_element,)) - eq_(js.mock_calls, [mock.call(data_element)]) - eq_(jd.mock_calls, [mock.call(json.dumps(data_element))]) + eq_(row, (data_element,)) + eq_(js.mock_calls, [mock.call(data_element)]) + eq_(jd.mock_calls, [mock.call(json.dumps(data_element))]) - def test_round_trip_none_as_sql_null(self): + def test_round_trip_none_as_sql_null(self, connection): col = self.tables.data_table.c["nulldata"] - with config.db.connect() as conn: - conn.execute( - self.tables.data_table.insert(), {"name": "r1", "data": None} - ) + conn = connection + conn.execute( + self.tables.data_table.insert(), {"name": "r1", "data": None} + ) - eq_( - conn.scalar( - select([self.tables.data_table.c.name]).where( - col.is_(null()) - ) - ), - "r1", - ) + eq_( + conn.scalar( + select([self.tables.data_table.c.name]).where(col.is_(null())) + ), + "r1", + ) - eq_(conn.scalar(select([col])), None) + eq_(conn.scalar(select([col])), None) - def test_round_trip_json_null_as_json_null(self): + def test_round_trip_json_null_as_json_null(self, connection): col = self.tables.data_table.c["data"] - with config.db.connect() as conn: - conn.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": JSON.NULL}, - ) + conn = connection + conn.execute( + self.tables.data_table.insert(), {"name": "r1", "data": JSON.NULL}, + ) - eq_( - conn.scalar( - select([self.tables.data_table.c.name]).where( - cast(col, String) == "null" - ) - ), - "r1", - ) + eq_( + conn.scalar( + select([self.tables.data_table.c.name]).where( + cast(col, String) == "null" + ) + ), + "r1", + ) - eq_(conn.scalar(select([col])), None) + eq_(conn.scalar(select([col])), None) def test_round_trip_none_as_json_null(self): col = self.tables.data_table.c["data"] diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py index 718b18f5b..aa0850222 100644 --- a/test/dialect/mssql/test_query.py +++ b/test/dialect/mssql/test_query.py @@ -356,7 +356,8 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): metadata.create_all(engine) with self.sql_execution_asserter(engine) as asserter: - engine.execute(t1.insert(), {"data": "somedata"}) + with engine.begin() as conn: + conn.execute(t1.insert(), {"data": "somedata"}) # TODO: need a dialect SQL that acts like Cursor SQL asserter.assert_( @@ -381,7 +382,8 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): metadata.create_all(engine) with self.sql_execution_asserter(engine) as asserter: - engine.execute(t1.insert()) + with engine.begin() as conn: + conn.execute(t1.insert()) # even with pyodbc, we don't embed the scope identity on a # DEFAULT VALUES insert @@ -409,7 +411,8 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): metadata.create_all(engine) with self.sql_execution_asserter(engine) as asserter: - engine.execute(t1.insert(), {"data": "somedata"}) + with engine.begin() as conn: + conn.execute(t1.insert(), {"data": "somedata"}) # pyodbc-specific system asserter.assert_( diff --git a/test/dialect/mssql/test_types.py b/test/dialect/mssql/test_types.py index 92d3d9e32..c95ac6e6d 100644 --- a/test/dialect/mssql/test_types.py +++ b/test/dialect/mssql/test_types.py @@ -1026,18 +1026,24 @@ class TypeRoundTripTest( ] for counter, engine in enumerate(eng): - engine.execute(tbl.insert()) - if "int_y" in tbl.c: - assert engine.scalar(select([tbl.c.int_y])) == counter + 1 - assert ( - list(engine.execute(tbl.select()).first()).count( - counter + 1 + with engine.begin() as conn: + conn.execute(tbl.insert()) + if "int_y" in tbl.c: + eq_( + conn.execute(select([tbl.c.int_y])).scalar(), + counter + 1, ) - == 1 - ) - else: - assert 1 not in list(engine.execute(tbl.select()).first()) - engine.execute(tbl.delete()) + assert ( + list(conn.execute(tbl.select()).first()).count( + counter + 1 + ) + == 1 + ) + else: + assert 1 not in list( + conn.execute(tbl.select()).first() + ) + conn.execute(tbl.delete()) class StringTest(fixtures.TestBase, AssertsCompiledSQL): |