diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 24 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/provision.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_types.py | 79 |
4 files changed, 88 insertions, 31 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 2ad88e8b0..32e3372c0 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -996,12 +996,26 @@ class DATETIMEOFFSET(sqltypes.TypeEngine): self.precision = precision -class _StringType(object): +class _UnicodeLiteral(object): + def literal_processor(self, dialect): + def process(value): + + value = value.replace("'", "''") + + if dialect.identifier_preparer._double_percents: + value = value.replace("%", "%%") + + return "N'%s'" % value + + return process + - """Base for MSSQL string types.""" +class _MSUnicode(_UnicodeLiteral, sqltypes.Unicode): + pass - def __init__(self, collation=None): - super(_StringType, self).__init__(collation=collation) + +class _MSUnicodeText(_UnicodeLiteral, sqltypes.UnicodeText): + pass class TIMESTAMP(sqltypes._Binary): @@ -2117,6 +2131,8 @@ class MSDialect(default.DefaultDialect): sqltypes.DateTime: _MSDateTime, sqltypes.Date: _MSDate, sqltypes.Time: TIME, + sqltypes.Unicode: _MSUnicode, + sqltypes.UnicodeText: _MSUnicodeText, } engine_config_types = default.DefaultDialect.engine_config_types.union( diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 88dc28528..70ace0511 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -207,15 +207,12 @@ def _mysql_create_db(cfg, eng, ident): except Exception: pass - # using utf8mb4 we are getting collation errors on UNIONS: - # test/orm/inheritance/test_polymorphic_rel.py" - # 1271, u"Illegal mix of collations for operation 'UNION'" - conn.execute("CREATE DATABASE %s CHARACTER SET utf8mb3" % ident) + conn.execute("CREATE DATABASE %s CHARACTER SET utf8mb4" % ident) conn.execute( - "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident + "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb4" % ident ) conn.execute( - "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident + "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb4" % ident ) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index a17d26edb..3a2161740 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -733,6 +733,13 @@ class SuiteRequirements(Requirements): return exclusions.open() @property + def expressions_against_unbounded_text(self): + """target database supports use of an unbounded textual field in a + WHERE clause.""" + + return exclusions.open() + + @property def selectone(self): """target driver must support the literal statement 'select 1'""" return exclusions.open() diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index ff8db5897..4791671f3 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -38,6 +38,8 @@ from ...util import u class _LiteralRoundTripFixture(object): + supports_whereclause = True + @testing.provide_metadata def _literal_round_trip(self, type_, input_, output, filter_=None): """test literal rendering """ @@ -49,33 +51,47 @@ class _LiteralRoundTripFixture(object): t = Table("t", self.metadata, Column("x", type_)) t.create() - for value in input_: - ins = ( - t.insert() - .values(x=literal(value)) - .compile( - dialect=testing.db.dialect, - compile_kwargs=dict(literal_binds=True), + with testing.db.connect() as conn: + for value in input_: + ins = ( + t.insert() + .values(x=literal(value)) + .compile( + dialect=testing.db.dialect, + compile_kwargs=dict(literal_binds=True), + ) ) - ) - testing.db.execute(ins) + conn.execute(ins) + + if self.supports_whereclause: + stmt = t.select().where(t.c.x == literal(value)) + else: + stmt = t.select() - for row in t.select().execute(): - value = row[0] - if filter_ is not None: - value = filter_(value) - assert value in output + stmt = stmt.compile( + dialect=testing.db.dialect, + compile_kwargs=dict(literal_binds=True), + ) + for row in conn.execute(stmt): + value = row[0] + if filter_ is not None: + value = filter_(value) + assert value in output class _UnicodeFixture(_LiteralRoundTripFixture): __requires__ = ("unicode_data",) data = u( - "Alors vous imaginez ma surprise, au lever du jour, " - "quand une drôle de petite voix m’a réveillé. Elle " - "disait: « S’il vous plaît… dessine-moi un mouton! »" + "Alors vous imaginez ma 🐍 surprise, au lever du jour, " + "quand une drôle de petite 🐍 voix m’a réveillé. Elle " + "disait: « S’il vous plaît… dessine-moi 🐍 un mouton! »" ) + @property + def supports_whereclause(self): + return config.requirements.expressions_against_unbounded_text.enabled + @classmethod def define_tables(cls, metadata): Table( @@ -122,6 +138,11 @@ class _UnicodeFixture(_LiteralRoundTripFixture): def test_literal(self): self._literal_round_trip(self.datatype, [self.data], [self.data]) + def test_literal_non_ascii(self): + self._literal_round_trip( + self.datatype, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] + ) + class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest): __requires__ = ("unicode_data",) @@ -149,6 +170,10 @@ class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): __requires__ = ("text_type",) __backend__ = True + @property + def supports_whereclause(self): + return config.requirements.expressions_against_unbounded_text.enabled + @classmethod def define_tables(cls, metadata): Table( @@ -177,6 +202,11 @@ class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): def test_literal(self): self._literal_round_trip(Text, ["some text"], ["some text"]) + def test_literal_non_ascii(self): + self._literal_round_trip( + Text, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] + ) + def test_literal_quoting(self): data = """some 'text' hey "hi there" that's text""" self._literal_round_trip(Text, [data], [data]) @@ -202,8 +232,15 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): foo.drop(config.db) def test_literal(self): + # note that in Python 3, this invokes the Unicode + # datatype for the literal part because all strings are unicode self._literal_round_trip(String(40), ["some text"], ["some text"]) + def test_literal_non_ascii(self): + self._literal_round_trip( + String(40), [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] + ) + def test_literal_quoting(self): data = """some 'text' hey "hi there" that's text""" self._literal_round_trip(String(40), [data], [data]) @@ -864,8 +901,8 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): { "name": "r1", "data": { - util.u("réveillé"): util.u("réveillé"), - "data": {"k1": util.u("drôle")}, + util.u("réve🐍 illé"): util.u("réve🐍 illé"), + "data": {"k1": util.u("drôl🐍e")}, }, }, ) @@ -873,8 +910,8 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): eq_( conn.scalar(select([self.tables.data_table.c.data])), { - util.u("réveillé"): util.u("réveillé"), - "data": {"k1": util.u("drôle")}, + util.u("réve🐍 illé"): util.u("réve🐍 illé"), + "data": {"k1": util.u("drôl🐍e")}, }, ) |
