summaryrefslogtreecommitdiff
path: root/test/dialect/postgresql/test_query.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/dialect/postgresql/test_query.py')
-rw-r--r--test/dialect/postgresql/test_query.py161
1 files changed, 141 insertions, 20 deletions
diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py
index b488b146c..fdce643f8 100644
--- a/test/dialect/postgresql/test_query.py
+++ b/test/dialect/postgresql/test_query.py
@@ -14,6 +14,7 @@ from sqlalchemy import Float
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
+from sqlalchemy import JSON
from sqlalchemy import literal
from sqlalchemy import literal_column
from sqlalchemy import MetaData
@@ -29,6 +30,7 @@ from sqlalchemy import true
from sqlalchemy import tuple_
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.sql.expression import type_coerce
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import AssertsExecutionResults
@@ -40,6 +42,17 @@ from sqlalchemy.testing.assertsql import CursorSQL
from sqlalchemy.testing.assertsql import DialectSQL
+class FunctionTypingTest(fixtures.TestBase, AssertsExecutionResults):
+ __only_on__ = "postgresql"
+ __backend__ = True
+
+ def test_count_star(self, connection):
+ eq_(connection.scalar(func.count("*")), 1)
+
+ def test_count_int(self, connection):
+ eq_(connection.scalar(func.count(1)), 1)
+
+
class InsertTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql"
@@ -956,23 +969,42 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
],
)
+ def _strs_render_bind_casts(self, connection):
+
+ return (
+ connection.dialect._bind_typing_render_casts
+ and String().dialect_impl(connection.dialect).render_bind_cast
+ )
+
@testing.requires.pyformat_paramstyle
- def test_expression_pyformat(self):
+ def test_expression_pyformat(self, connection):
matchtable = self.tables.matchtable
- self.assert_compile(
- matchtable.c.title.match("somstr"),
- "matchtable.title @@ to_tsquery(%(title_1)s" ")",
- )
+
+ if self._strs_render_bind_casts(connection):
+ self.assert_compile(
+ matchtable.c.title.match("somstr"),
+ "matchtable.title @@ to_tsquery(%(title_1)s::VARCHAR(200))",
+ )
+ else:
+ self.assert_compile(
+ matchtable.c.title.match("somstr"),
+ "matchtable.title @@ to_tsquery(%(title_1)s)",
+ )
@testing.requires.format_paramstyle
- def test_expression_positional(self):
+ def test_expression_positional(self, connection):
matchtable = self.tables.matchtable
- self.assert_compile(
- matchtable.c.title.match("somstr"),
- # note we assume current tested DBAPIs use emulated setinputsizes
- # here, the cast is not strictly necessary
- "matchtable.title @@ to_tsquery(%s::VARCHAR(200))",
- )
+
+ if self._strs_render_bind_casts(connection):
+ self.assert_compile(
+ matchtable.c.title.match("somstr"),
+ "matchtable.title @@ to_tsquery(%s::VARCHAR(200))",
+ )
+ else:
+ self.assert_compile(
+ matchtable.c.title.match("somstr"),
+ "matchtable.title @@ to_tsquery(%s)",
+ )
def test_simple_match(self, connection):
matchtable = self.tables.matchtable
@@ -1551,17 +1583,106 @@ class TableValuedRoundTripTest(fixtures.TestBase):
[(14, 1), (41, 2), (7, 3), (54, 4), (9, 5), (49, 6)],
)
- @testing.only_on(
- "postgresql+psycopg2",
- "I cannot get this to run at all on other drivers, "
- "even selecting from a table",
+ @testing.combinations(
+ (
+ type_coerce,
+ testing.fails("fails on all drivers"),
+ ),
+ (
+ cast,
+ testing.fails("fails on all drivers"),
+ ),
+ (
+ None,
+ testing.fails_on_everything_except(
+ ["postgresql+psycopg2"],
+ "I cannot get this to run at all on other drivers, "
+ "even selecting from a table",
+ ),
+ ),
+ argnames="cast_fn",
)
- def test_render_derived_quoting(self, connection):
+ def test_render_derived_quoting_text(self, connection, cast_fn):
+
+ value = (
+ '[{"CaseSensitive":1,"the % value":"foo"}, '
+ '{"CaseSensitive":"2","the % value":"bar"}]'
+ )
+
+ if cast_fn:
+ value = cast_fn(value, JSON)
+
fn = (
- func.json_to_recordset( # noqa
- '[{"CaseSensitive":1,"the % value":"foo"}, '
- '{"CaseSensitive":"2","the % value":"bar"}]'
+ func.json_to_recordset(value)
+ .table_valued(
+ column("CaseSensitive", Integer), column("the % value", String)
)
+ .render_derived(with_types=True)
+ )
+
+ stmt = select(fn.c.CaseSensitive, fn.c["the % value"])
+
+ eq_(connection.execute(stmt).all(), [(1, "foo"), (2, "bar")])
+
+ @testing.combinations(
+ (
+ type_coerce,
+ testing.fails("fails on all drivers"),
+ ),
+ (
+ cast,
+ testing.fails("fails on all drivers"),
+ ),
+ (
+ None,
+ testing.fails("Fails on all drivers"),
+ ),
+ argnames="cast_fn",
+ )
+ def test_render_derived_quoting_text_to_json(self, connection, cast_fn):
+
+ value = (
+ '[{"CaseSensitive":1,"the % value":"foo"}, '
+ '{"CaseSensitive":"2","the % value":"bar"}]'
+ )
+
+ if cast_fn:
+ value = cast_fn(value, JSON)
+
+ # why wont this work?!?!?
+ # should be exactly json_to_recordset(to_json('string'::text))
+ #
+ fn = (
+ func.json_to_recordset(func.to_json(value))
+ .table_valued(
+ column("CaseSensitive", Integer), column("the % value", String)
+ )
+ .render_derived(with_types=True)
+ )
+
+ stmt = select(fn.c.CaseSensitive, fn.c["the % value"])
+
+ eq_(connection.execute(stmt).all(), [(1, "foo"), (2, "bar")])
+
+ @testing.combinations(
+ (type_coerce,),
+ (cast,),
+ (None, testing.fails("Fails on all PG backends")),
+ argnames="cast_fn",
+ )
+ def test_render_derived_quoting_straight_json(self, connection, cast_fn):
+ # these all work
+
+ value = [
+ {"CaseSensitive": 1, "the % value": "foo"},
+ {"CaseSensitive": "2", "the % value": "bar"},
+ ]
+
+ if cast_fn:
+ value = cast_fn(value, JSON)
+
+ fn = (
+ func.json_to_recordset(value) # noqa
.table_valued(
column("CaseSensitive", Integer), column("the % value", String)
)