diff options
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/__init__.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/json.py (renamed from lib/sqlalchemy/dialects/postgresql/pgjson.py) | 74 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg2.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 4 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_types.py | 107 |
5 files changed, 126 insertions, 67 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 728f1629f..cfe1ebce0 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -15,7 +15,7 @@ from .base import \ TSVECTOR from .constraints import ExcludeConstraint from .hstore import HSTORE, hstore -from .pgjson import JSON +from .json import JSON from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \ TSTZRANGE diff --git a/lib/sqlalchemy/dialects/postgresql/pgjson.py b/lib/sqlalchemy/dialects/postgresql/json.py index a29d0bbcc..5b8ad68f5 100644 --- a/lib/sqlalchemy/dialects/postgresql/pgjson.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -3,16 +3,16 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +from __future__ import absolute_import import json -from .base import ARRAY, ischema_names +from .base import ischema_names from ... import types as sqltypes -from ...sql import functions as sqlfunc from ...sql.operators import custom_op from ... import util -__all__ = ('JSON', 'json') +__all__ = ('JSON', ) class JSON(sqltypes.TypeEngine): @@ -39,21 +39,25 @@ class JSON(sqltypes.TypeEngine): * Index operations returning text (required for text comparison or casting):: - data_table.c.data.get_item_as_text('some key') == 'some value' + data_table.c.data.astext['some key'] == 'some value' * Path index operations:: - data_table.c.data.get_path("{key_1, key_2, ..., key_n}") + data_table.c.data[('key_1', 'key_2', ..., 'key_n')] * Path index operations returning text (required for text comparison or casting):: - data_table.c.data.get_path("{key_1, key_2, ..., key_n}") == 'some value' + data_table.c.data.astext[('key_1', 'key_2', ..., 'key_n')] == 'some value' - Please be aware that when used with the SQLAlchemy ORM, you will need to - replace the JSON object present on an attribute with a new object in order - for any changes to be properly persisted. + The :class:`.JSON` type, when used with the SQLAlchemy ORM, does not detect + in-place mutations to the structure. In order to detect these, the + :mod:`sqlalchemy.ext.mutable` extension must be used. This extension will + allow "in-place" changes to the datastructure to produce events which + will be detected by the unit of work. See the example at :class:`.HSTORE` + for a simple example involving a dictionary. .. versionadded:: 0.9 + """ __visit_name__ = 'JSON' @@ -71,31 +75,35 @@ class JSON(sqltypes.TypeEngine): class comparator_factory(sqltypes.Concatenable.Comparator): """Define comparison operations for :class:`.JSON`.""" + class _astext(object): + def __init__(self, parent): + self.parent = parent + + def __getitem__(self, other): + return self.parent.expr._get_item(other, True) + + def _get_item(self, other, astext): + if hasattr(other, '__iter__') and \ + not isinstance(other, util.string_types): + op = "#>" + other = "{%s}" % (", ".join(util.text_type(elem) for elem in other)) + else: + op = "->" + + if astext: + op += ">" + + # ops: ->, ->>, #>, #>> + return self.expr.op(op, precedence=5)(other) + def __getitem__(self, other): - """Text expression. Get the value at a given key.""" - # I'm choosing to return text here so the result can be cast, - # compared with strings, etc. - # - # The only downside to this is that you cannot dereference more - # than one level deep in json structures, though comparator - # support for multi-level dereference is lacking anyhow. - return self.expr.op('->', precedence=5)(other) - - def get_item_as_text(self, other): - """Text expression. Get the value at the given key as text. Use - this when you need to cast the type of the returned value.""" - return self.expr.op('->>', precedence=5)(other) - - def get_path(self, other): - """Text expression. Get the value at a given path. Paths are of - the form {key_1, key_2, ..., key_n}.""" - return self.expr.op('#>', precedence=5)(other) - - def get_path_as_text(self, other): - """Text expression. Get the value at a given path, as text. - Paths are of the form {key_1, key_2, ..., key_n}. Use this when - you need to cast the type of the returned value.""" - return self.expr.op('#>>', precedence=5)(other) + """Get the value at a given key.""" + + return self._get_item(other, False) + + @property + def astext(self): + return self._astext(self) def _adapt_expression(self, op, other_comparator): if isinstance(op, custom_op): diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 4a9248e5f..ceb04b580 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -179,7 +179,7 @@ from .base import PGDialect, PGCompiler, \ ENUM, ARRAY, _DECIMAL_TYPES, _FLOAT_TYPES,\ _INT_TYPES from .hstore import HSTORE -from .pgjson import JSON +from .json import JSON logger = logging.getLogger('sqlalchemy.dialects.postgresql') @@ -236,9 +236,7 @@ class _PGHStore(HSTORE): class _PGJSON(JSON): - # I've omitted the bind processor here because the method of serializing - # involves registering specific types to auto-serialize, and the adapter - # just a thin wrapper over json.dumps. + def result_processor(self, dialect, coltype): if dialect._has_native_json: return None diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 045056b42..69e365bd3 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1753,6 +1753,10 @@ class Cast(ColumnElement): """ self.type = type_api.to_instance(totype) self.clause = _literal_as_binds(clause, None) + if isinstance(self.clause, BindParameter) and self.clause.type._isnull: + self.clause = self.clause._clone() + self.clause.type = self.type + self.typeclause = TypeClause(self.type) def _copy_internals(self, clone=_clone, **kw): diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 19df131fd..5da2520f3 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -10,7 +10,8 @@ from sqlalchemy import Table, Column, select, MetaData, text, Integer, \ PrimaryKeyConstraint, DateTime, tuple_, Float, BigInteger, \ func, literal_column, literal, bindparam, cast, extract, \ SmallInteger, Enum, REAL, update, insert, Index, delete, \ - and_, Date, TypeDecorator, Time, Unicode, Interval, or_, Text + and_, Date, TypeDecorator, Time, Unicode, Interval, or_, Text, \ + type_coerce from sqlalchemy.orm import Session, mapper, aliased from sqlalchemy import exc, schema, types from sqlalchemy.dialects.postgresql import base as postgresql @@ -23,6 +24,8 @@ from sqlalchemy.testing.util import round_decimal from sqlalchemy.sql import table, column, operators import logging import re +from sqlalchemy import inspect +from sqlalchemy import event class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): __only_on__ = 'postgresql' @@ -965,14 +968,7 @@ class UUIDTest(fixtures.TestBase): -class HStoreTest(fixtures.TestBase): - def _assert_sql(self, construct, expected): - dialect = postgresql.dialect() - compiled = str(construct.compile(dialect=dialect)) - compiled = re.sub(r'\s+', ' ', compiled) - expected = re.sub(r'\s+', ' ', expected) - eq_(compiled, expected) - +class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): def setup(self): metadata = MetaData() self.test_table = Table('test_table', metadata, @@ -983,7 +979,7 @@ class HStoreTest(fixtures.TestBase): def _test_where(self, whereclause, expected): stmt = select([self.test_table]).where(whereclause) - self._assert_sql( + self.assert_compile( stmt, "SELECT test_table.id, test_table.hash FROM test_table " "WHERE %s" % expected @@ -991,7 +987,7 @@ class HStoreTest(fixtures.TestBase): def _test_cols(self, colclause, expected, from_=True): stmt = select([colclause]) - self._assert_sql( + self.assert_compile( stmt, ( "SELECT %s" + @@ -1292,7 +1288,6 @@ class HStoreRoundTripTest(fixtures.TablesTest): return engine def test_reflect(self): - from sqlalchemy import inspect insp = inspect(testing.db) cols = insp.get_columns('data_table') assert isinstance(cols[2]['type'], HSTORE) @@ -1666,13 +1661,7 @@ class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest): return self.extras.DateTimeTZRange(*self.tstzs()) -class JSONTest(fixtures.TestBase): - def _assert_sql(self, construct, expected): - dialect = postgresql.dialect() - compiled = str(construct.compile(dialect=dialect)) - compiled = re.sub(r'\s+', ' ', compiled) - expected = re.sub(r'\s+', ' ', expected) - eq_(compiled, expected) +class JSONTest(AssertsCompiledSQL, fixtures.TestBase): def setup(self): metadata = MetaData() @@ -1684,7 +1673,7 @@ class JSONTest(fixtures.TestBase): def _test_where(self, whereclause, expected): stmt = select([self.test_table]).where(whereclause) - self._assert_sql( + self.assert_compile( stmt, "SELECT test_table.id, test_table.test_column FROM test_table " "WHERE %s" % expected @@ -1692,7 +1681,7 @@ class JSONTest(fixtures.TestBase): def _test_cols(self, colclause, expected, from_=True): stmt = select([colclause]) - self._assert_sql( + self.assert_compile( stmt, ( "SELECT %s" + @@ -1730,19 +1719,19 @@ class JSONTest(fixtures.TestBase): def test_where_path(self): self._test_where( - self.jsoncol.get_path('{"foo", 1}') == None, + self.jsoncol[("foo", 1)] == None, "(test_table.test_column #> %(test_column_1)s) IS NULL" ) def test_where_getitem_as_text(self): self._test_where( - self.jsoncol.get_item_as_text('bar') == None, + self.jsoncol.astext['bar'] == None, "(test_table.test_column ->> %(test_column_1)s) IS NULL" ) def test_where_path_as_text(self): self._test_where( - self.jsoncol.get_path_as_text('{"foo", 1}') == None, + self.jsoncol.astext[("foo", 1)] == None, "(test_table.test_column #>> %(test_column_1)s) IS NULL" ) @@ -1755,7 +1744,7 @@ class JSONTest(fixtures.TestBase): class JSONRoundTripTest(fixtures.TablesTest): - __only_on__ = 'postgresql' + __only_on__ = ('postgresql >= 9.3',) @classmethod def define_tables(cls, metadata): @@ -1792,14 +1781,20 @@ class JSONRoundTripTest(fixtures.TablesTest): def _non_native_engine(self): if testing.against("postgresql+psycopg2"): + from psycopg2.extras import register_default_json engine = engines.testing_engine() + @event.listens_for(engine, "connect") + def connect(dbapi_connection, connection_record): + engine.dialect._has_native_json = False + def pass_(value): + return value + register_default_json(dbapi_connection, loads=pass_) else: engine = testing.db engine.connect() return engine def test_reflect(self): - from sqlalchemy import inspect insp = inspect(testing.db) cols = insp.get_columns('data_table') assert isinstance(cols[2]['type'], JSON) @@ -1830,7 +1825,7 @@ class JSONRoundTripTest(fixtures.TablesTest): data_table = self.tables.data_table result = engine.execute( select([data_table.c.data]).where( - data_table.c.data.get_path_as_text('{k1}') == 'r3v1' + data_table.c.data.astext[('k1',)] == 'r3v1' ) ).first() eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},)) @@ -1840,7 +1835,7 @@ class JSONRoundTripTest(fixtures.TablesTest): self._fixture_data(engine) data_table = self.tables.data_table result = engine.execute( - select([data_table.c.data.get_item_as_text('k1')]) + select([data_table.c.data.astext['k1']]) ).first() assert isinstance(result[0], basestring) @@ -1848,7 +1843,61 @@ class JSONRoundTripTest(fixtures.TablesTest): data_table = self.tables.data_table result = engine.execute( select([data_table.c.data]).where( - data_table.c.data.get_item_as_text('k1') == 'r3v1' + data_table.c.data.astext['k1'] == 'r3v1' ) ).first() eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},)) + + def _test_fixed_round_trip(self, engine): + s = select([ + cast( + { + "key": "value", + "key2": {"k1": "v1", "k2": "v2"} + }, + JSON + ) + ]) + eq_( + engine.scalar(s), + { + "key": "value", + "key2": {"k1": "v1", "k2": "v2"} + }, + ) + + def test_fixed_round_trip_python(self): + engine = self._non_native_engine() + self._test_fixed_round_trip(engine) + + @testing.only_on("postgresql+psycopg2") + def test_fixed_round_trip_native(self): + engine = testing.db + self._test_fixed_round_trip(engine) + + def _test_unicode_round_trip(self, engine): + s = select([ + cast( + { + util.u('réveillé'): util.u('réveillé'), + "data": {"k1": util.u('drôle')} + }, + JSON + ) + ]) + eq_( + engine.scalar(s), + { + util.u('réveillé'): util.u('réveillé'), + "data": {"k1": util.u('drôle')} + }, + ) + + def test_unicode_round_trip_python(self): + engine = self._non_native_engine() + self._test_unicode_round_trip(engine) + + @testing.only_on("postgresql+psycopg2") + def test_unicode_round_trip_native(self): + engine = testing.db + self._test_unicode_round_trip(engine) |
