diff options
Diffstat (limited to 'test/dialect/postgresql/test_query.py')
-rw-r--r-- | test/dialect/postgresql/test_query.py | 161 |
1 files changed, 141 insertions, 20 deletions
diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index b488b146c..fdce643f8 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -14,6 +14,7 @@ from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import JSON from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import MetaData @@ -29,6 +30,7 @@ from sqlalchemy import true from sqlalchemy import tuple_ from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.sql.expression import type_coerce from sqlalchemy.testing import assert_raises from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import AssertsExecutionResults @@ -40,6 +42,17 @@ from sqlalchemy.testing.assertsql import CursorSQL from sqlalchemy.testing.assertsql import DialectSQL +class FunctionTypingTest(fixtures.TestBase, AssertsExecutionResults): + __only_on__ = "postgresql" + __backend__ = True + + def test_count_star(self, connection): + eq_(connection.scalar(func.count("*")), 1) + + def test_count_int(self, connection): + eq_(connection.scalar(func.count(1)), 1) + + class InsertTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "postgresql" @@ -956,23 +969,42 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL): ], ) + def _strs_render_bind_casts(self, connection): + + return ( + connection.dialect._bind_typing_render_casts + and String().dialect_impl(connection.dialect).render_bind_cast + ) + @testing.requires.pyformat_paramstyle - def test_expression_pyformat(self): + def test_expression_pyformat(self, connection): matchtable = self.tables.matchtable - self.assert_compile( - matchtable.c.title.match("somstr"), - "matchtable.title @@ to_tsquery(%(title_1)s" ")", - ) + + if self._strs_render_bind_casts(connection): + self.assert_compile( + matchtable.c.title.match("somstr"), + "matchtable.title @@ to_tsquery(%(title_1)s::VARCHAR(200))", + ) + else: + self.assert_compile( + matchtable.c.title.match("somstr"), + "matchtable.title @@ to_tsquery(%(title_1)s)", + ) @testing.requires.format_paramstyle - def test_expression_positional(self): + def test_expression_positional(self, connection): matchtable = self.tables.matchtable - self.assert_compile( - matchtable.c.title.match("somstr"), - # note we assume current tested DBAPIs use emulated setinputsizes - # here, the cast is not strictly necessary - "matchtable.title @@ to_tsquery(%s::VARCHAR(200))", - ) + + if self._strs_render_bind_casts(connection): + self.assert_compile( + matchtable.c.title.match("somstr"), + "matchtable.title @@ to_tsquery(%s::VARCHAR(200))", + ) + else: + self.assert_compile( + matchtable.c.title.match("somstr"), + "matchtable.title @@ to_tsquery(%s)", + ) def test_simple_match(self, connection): matchtable = self.tables.matchtable @@ -1551,17 +1583,106 @@ class TableValuedRoundTripTest(fixtures.TestBase): [(14, 1), (41, 2), (7, 3), (54, 4), (9, 5), (49, 6)], ) - @testing.only_on( - "postgresql+psycopg2", - "I cannot get this to run at all on other drivers, " - "even selecting from a table", + @testing.combinations( + ( + type_coerce, + testing.fails("fails on all drivers"), + ), + ( + cast, + testing.fails("fails on all drivers"), + ), + ( + None, + testing.fails_on_everything_except( + ["postgresql+psycopg2"], + "I cannot get this to run at all on other drivers, " + "even selecting from a table", + ), + ), + argnames="cast_fn", ) - def test_render_derived_quoting(self, connection): + def test_render_derived_quoting_text(self, connection, cast_fn): + + value = ( + '[{"CaseSensitive":1,"the % value":"foo"}, ' + '{"CaseSensitive":"2","the % value":"bar"}]' + ) + + if cast_fn: + value = cast_fn(value, JSON) + fn = ( - func.json_to_recordset( # noqa - '[{"CaseSensitive":1,"the % value":"foo"}, ' - '{"CaseSensitive":"2","the % value":"bar"}]' + func.json_to_recordset(value) + .table_valued( + column("CaseSensitive", Integer), column("the % value", String) ) + .render_derived(with_types=True) + ) + + stmt = select(fn.c.CaseSensitive, fn.c["the % value"]) + + eq_(connection.execute(stmt).all(), [(1, "foo"), (2, "bar")]) + + @testing.combinations( + ( + type_coerce, + testing.fails("fails on all drivers"), + ), + ( + cast, + testing.fails("fails on all drivers"), + ), + ( + None, + testing.fails("Fails on all drivers"), + ), + argnames="cast_fn", + ) + def test_render_derived_quoting_text_to_json(self, connection, cast_fn): + + value = ( + '[{"CaseSensitive":1,"the % value":"foo"}, ' + '{"CaseSensitive":"2","the % value":"bar"}]' + ) + + if cast_fn: + value = cast_fn(value, JSON) + + # why wont this work?!?!? + # should be exactly json_to_recordset(to_json('string'::text)) + # + fn = ( + func.json_to_recordset(func.to_json(value)) + .table_valued( + column("CaseSensitive", Integer), column("the % value", String) + ) + .render_derived(with_types=True) + ) + + stmt = select(fn.c.CaseSensitive, fn.c["the % value"]) + + eq_(connection.execute(stmt).all(), [(1, "foo"), (2, "bar")]) + + @testing.combinations( + (type_coerce,), + (cast,), + (None, testing.fails("Fails on all PG backends")), + argnames="cast_fn", + ) + def test_render_derived_quoting_straight_json(self, connection, cast_fn): + # these all work + + value = [ + {"CaseSensitive": 1, "the % value": "foo"}, + {"CaseSensitive": "2", "the % value": "bar"}, + ] + + if cast_fn: + value = cast_fn(value, JSON) + + fn = ( + func.json_to_recordset(value) # noqa .table_valued( column("CaseSensitive", Integer), column("the % value", String) ) |