diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-12-17 17:46:09 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-12-17 17:46:09 -0500 |
| commit | fec03c88d659bf9a0b102dd328afac1ba3dc7f23 (patch) | |
| tree | 86dc3507dcaa89ed0fe76f7c7774f20c91e65ab3 | |
| parent | 653fcb892bd4680c97491cad70b86987db270208 (diff) | |
| download | sqlalchemy-fec03c88d659bf9a0b102dd328afac1ba3dc7f23.tar.gz | |
- make the json serializer and deserializer per-dialect, so that we are
compatible with psycopg2's per-connection/cursor approach. add round trip tests for
both native and non-native.
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/json.py | 33 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg2.py | 5 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_types.py | 91 |
4 files changed, 97 insertions, 37 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index a7f838009..3edc28fed 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1439,9 +1439,12 @@ class PGDialect(default.DefaultDialect): _backslash_escapes = True - def __init__(self, isolation_level=None, **kwargs): + def __init__(self, isolation_level=None, json_serializer=None, + json_deserializer=None, **kwargs): default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level + self._json_deserializer = json_deserializer + self._json_serializer = json_serializer def initialize(self, connection): super(PGDialect, self).initialize(connection) diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 5b8ad68f5..7ba8b1abe 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -56,22 +56,25 @@ class JSON(sqltypes.TypeEngine): will be detected by the unit of work. See the example at :class:`.HSTORE` for a simple example involving a dictionary. + Custom serializers and deserializers are specified at the dialect level, + that is using :func:`.create_engine`. The reason for this is that when + using psycopg2, the DBAPI only allows serializers at the per-cursor + or per-connection level. E.g.:: + + engine = create_engine("postgresql://scott:tiger@localhost/test", + json_serializer=my_serialize_fn, + json_deserializer=my_deserialize_fn + ) + + When using the psycopg2 dialect, the json_deserializer is registered + against the database using ``psycopg2.extras.register_default_json``. + .. versionadded:: 0.9 """ __visit_name__ = 'JSON' - def __init__(self, json_serializer=None, json_deserializer=None): - if json_serializer: - self.json_serializer = json_serializer - else: - self.json_serializer = json.dumps - if json_deserializer: - self.json_deserializer = json_deserializer - else: - self.json_deserializer = json.loads - class comparator_factory(sqltypes.Concatenable.Comparator): """Define comparison operations for :class:`.JSON`.""" @@ -113,23 +116,25 @@ class JSON(sqltypes.TypeEngine): _adapt_expression(self, op, other_comparator) def bind_processor(self, dialect): + json_serializer = dialect._json_serializer or json.dumps if util.py2k: encoding = dialect.encoding def process(value): - return self.json_serializer(value).encode(encoding) + return json_serializer(value).encode(encoding) else: def process(value): - return self.json_serializer(value) + return json_serializer(value) return process def result_processor(self, dialect, coltype): + json_deserializer = dialect._json_deserializer or json.loads if util.py2k: encoding = dialect.encoding def process(value): - return self.json_deserializer(value.decode(encoding)) + return json_deserializer(value.decode(encoding)) else: def process(value): - return self.json_deserializer(value) + return json_deserializer(value) return process diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index ceb04b580..f5da8a711 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -428,6 +428,11 @@ class PGDialect_psycopg2(PGDialect): array_oid=array_oid) fns.append(on_connect) + if self.dbapi and self._json_deserializer: + def on_connect(conn): + extras.register_default_json(conn, loads=self._json_deserializer) + fns.append(on_connect) + if fns: def on_connect(conn): for fn in fns: diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 062c708a5..bcb3e1ebb 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -998,9 +998,8 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): ) def test_bind_serialize_default(self): - from sqlalchemy.engine import default - dialect = default.DefaultDialect() + dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_bind_processor(dialect) eq_( proc(util.OrderedDict([("key1", "value1"), ("key2", "value2")])), @@ -1008,9 +1007,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): ) def test_bind_serialize_with_slashes_and_quotes(self): - from sqlalchemy.engine import default - - dialect = default.DefaultDialect() + dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_bind_processor(dialect) eq_( proc({'\\"a': '\\"1'}), @@ -1018,9 +1015,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): ) def test_parse_error(self): - from sqlalchemy.engine import default - - dialect = default.DefaultDialect() + dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_result_processor( dialect, None) assert_raises_message( @@ -1033,9 +1028,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): ) def test_result_deserialize_default(self): - from sqlalchemy.engine import default - - dialect = default.DefaultDialect() + dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_result_processor( dialect, None) eq_( @@ -1044,9 +1037,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): ) def test_result_deserialize_with_slashes_and_quotes(self): - from sqlalchemy.engine import default - - dialect = default.DefaultDialect() + dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_result_processor( dialect, None) eq_( @@ -1693,9 +1684,7 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase): ) def test_bind_serialize_default(self): - from sqlalchemy.engine import default - - dialect = default.DefaultDialect() + dialect = postgresql.dialect() proc = self.test_table.c.test_column.type._cached_bind_processor(dialect) eq_( proc({"A": [1, 2, 3, True, False]}), @@ -1703,9 +1692,7 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase): ) def test_result_deserialize_default(self): - from sqlalchemy.engine import default - - dialect = default.DefaultDialect() + dialect = postgresql.dialect() proc = self.test_table.c.test_column.type._cached_result_processor( dialect, None) eq_( @@ -1782,16 +1769,26 @@ class JSONRoundTripTest(fixtures.TablesTest): ) self._assert_data([{"k1": "r1v1", "k2": "r1v2"}]) - def _non_native_engine(self): + def _non_native_engine(self, json_serializer=None, json_deserializer=None): + if json_serializer is not None or json_deserializer is not None: + options = { + "json_serializer": json_serializer, + "json_deserializer": json_deserializer + } + else: + options = {} + if testing.against("postgresql+psycopg2"): from psycopg2.extras import register_default_json - engine = engines.testing_engine() + engine = engines.testing_engine(options=options) @event.listens_for(engine, "connect") def connect(dbapi_connection, connection_record): engine.dialect._has_native_json = False def pass_(value): return value register_default_json(dbapi_connection, loads=pass_) + elif options: + engine = engines.testing_engine(options=options) else: engine = testing.db engine.connect() @@ -1811,6 +1808,56 @@ class JSONRoundTripTest(fixtures.TablesTest): engine = self._non_native_engine() self._test_insert(engine) + + def _test_custom_serialize_deserialize(self, native): + import json + def loads(value): + value = json.loads(value) + value['x'] = value['x'] + '_loads' + return value + + def dumps(value): + value = dict(value) + value['x'] = 'dumps_y' + return json.dumps(value) + + if native: + engine = engines.testing_engine(options=dict( + json_serializer=dumps, + json_deserializer=loads + )) + else: + engine = self._non_native_engine( + json_serializer=dumps, + json_deserializer=loads + ) + + s = select([ + cast( + { + "key": "value", + "x": "q" + }, + JSON + ) + ]) + eq_( + engine.scalar(s), + { + "key": "value", + "x": "dumps_y_loads" + }, + ) + + @testing.only_on("postgresql+psycopg2") + def test_custom_native(self): + self._test_custom_serialize_deserialize(True) + + @testing.only_on("postgresql+psycopg2") + def test_custom_python(self): + self._test_custom_serialize_deserialize(False) + + @testing.only_on("postgresql+psycopg2") def test_criterion_native(self): engine = testing.db |
