summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2013-12-17 17:46:09 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2013-12-17 17:46:09 -0500
commitfec03c88d659bf9a0b102dd328afac1ba3dc7f23 (patch)
tree86dc3507dcaa89ed0fe76f7c7774f20c91e65ab3
parent653fcb892bd4680c97491cad70b86987db270208 (diff)
downloadsqlalchemy-fec03c88d659bf9a0b102dd328afac1ba3dc7f23.tar.gz
- make the json serializer and deserializer per-dialect, so that we are
compatible with psycopg2's per-connection/cursor approach. add round trip tests for both native and non-native.
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py5
-rw-r--r--lib/sqlalchemy/dialects/postgresql/json.py33
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py5
-rw-r--r--test/dialect/postgresql/test_types.py91
4 files changed, 97 insertions, 37 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index a7f838009..3edc28fed 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1439,9 +1439,12 @@ class PGDialect(default.DefaultDialect):
_backslash_escapes = True
- def __init__(self, isolation_level=None, **kwargs):
+ def __init__(self, isolation_level=None, json_serializer=None,
+ json_deserializer=None, **kwargs):
default.DefaultDialect.__init__(self, **kwargs)
self.isolation_level = isolation_level
+ self._json_deserializer = json_deserializer
+ self._json_serializer = json_serializer
def initialize(self, connection):
super(PGDialect, self).initialize(connection)
diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py
index 5b8ad68f5..7ba8b1abe 100644
--- a/lib/sqlalchemy/dialects/postgresql/json.py
+++ b/lib/sqlalchemy/dialects/postgresql/json.py
@@ -56,22 +56,25 @@ class JSON(sqltypes.TypeEngine):
will be detected by the unit of work. See the example at :class:`.HSTORE`
for a simple example involving a dictionary.
+ Custom serializers and deserializers are specified at the dialect level,
+ that is using :func:`.create_engine`. The reason for this is that when
+ using psycopg2, the DBAPI only allows serializers at the per-cursor
+ or per-connection level. E.g.::
+
+ engine = create_engine("postgresql://scott:tiger@localhost/test",
+ json_serializer=my_serialize_fn,
+ json_deserializer=my_deserialize_fn
+ )
+
+ When using the psycopg2 dialect, the json_deserializer is registered
+ against the database using ``psycopg2.extras.register_default_json``.
+
.. versionadded:: 0.9
"""
__visit_name__ = 'JSON'
- def __init__(self, json_serializer=None, json_deserializer=None):
- if json_serializer:
- self.json_serializer = json_serializer
- else:
- self.json_serializer = json.dumps
- if json_deserializer:
- self.json_deserializer = json_deserializer
- else:
- self.json_deserializer = json.loads
-
class comparator_factory(sqltypes.Concatenable.Comparator):
"""Define comparison operations for :class:`.JSON`."""
@@ -113,23 +116,25 @@ class JSON(sqltypes.TypeEngine):
_adapt_expression(self, op, other_comparator)
def bind_processor(self, dialect):
+ json_serializer = dialect._json_serializer or json.dumps
if util.py2k:
encoding = dialect.encoding
def process(value):
- return self.json_serializer(value).encode(encoding)
+ return json_serializer(value).encode(encoding)
else:
def process(value):
- return self.json_serializer(value)
+ return json_serializer(value)
return process
def result_processor(self, dialect, coltype):
+ json_deserializer = dialect._json_deserializer or json.loads
if util.py2k:
encoding = dialect.encoding
def process(value):
- return self.json_deserializer(value.decode(encoding))
+ return json_deserializer(value.decode(encoding))
else:
def process(value):
- return self.json_deserializer(value)
+ return json_deserializer(value)
return process
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
index ceb04b580..f5da8a711 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -428,6 +428,11 @@ class PGDialect_psycopg2(PGDialect):
array_oid=array_oid)
fns.append(on_connect)
+ if self.dbapi and self._json_deserializer:
+ def on_connect(conn):
+ extras.register_default_json(conn, loads=self._json_deserializer)
+ fns.append(on_connect)
+
if fns:
def on_connect(conn):
for fn in fns:
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 062c708a5..bcb3e1ebb 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -998,9 +998,8 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
)
def test_bind_serialize_default(self):
- from sqlalchemy.engine import default
- dialect = default.DefaultDialect()
+ dialect = postgresql.dialect()
proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
eq_(
proc(util.OrderedDict([("key1", "value1"), ("key2", "value2")])),
@@ -1008,9 +1007,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
)
def test_bind_serialize_with_slashes_and_quotes(self):
- from sqlalchemy.engine import default
-
- dialect = default.DefaultDialect()
+ dialect = postgresql.dialect()
proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
eq_(
proc({'\\"a': '\\"1'}),
@@ -1018,9 +1015,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
)
def test_parse_error(self):
- from sqlalchemy.engine import default
-
- dialect = default.DefaultDialect()
+ dialect = postgresql.dialect()
proc = self.test_table.c.hash.type._cached_result_processor(
dialect, None)
assert_raises_message(
@@ -1033,9 +1028,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
)
def test_result_deserialize_default(self):
- from sqlalchemy.engine import default
-
- dialect = default.DefaultDialect()
+ dialect = postgresql.dialect()
proc = self.test_table.c.hash.type._cached_result_processor(
dialect, None)
eq_(
@@ -1044,9 +1037,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
)
def test_result_deserialize_with_slashes_and_quotes(self):
- from sqlalchemy.engine import default
-
- dialect = default.DefaultDialect()
+ dialect = postgresql.dialect()
proc = self.test_table.c.hash.type._cached_result_processor(
dialect, None)
eq_(
@@ -1693,9 +1684,7 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
)
def test_bind_serialize_default(self):
- from sqlalchemy.engine import default
-
- dialect = default.DefaultDialect()
+ dialect = postgresql.dialect()
proc = self.test_table.c.test_column.type._cached_bind_processor(dialect)
eq_(
proc({"A": [1, 2, 3, True, False]}),
@@ -1703,9 +1692,7 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
)
def test_result_deserialize_default(self):
- from sqlalchemy.engine import default
-
- dialect = default.DefaultDialect()
+ dialect = postgresql.dialect()
proc = self.test_table.c.test_column.type._cached_result_processor(
dialect, None)
eq_(
@@ -1782,16 +1769,26 @@ class JSONRoundTripTest(fixtures.TablesTest):
)
self._assert_data([{"k1": "r1v1", "k2": "r1v2"}])
- def _non_native_engine(self):
+ def _non_native_engine(self, json_serializer=None, json_deserializer=None):
+ if json_serializer is not None or json_deserializer is not None:
+ options = {
+ "json_serializer": json_serializer,
+ "json_deserializer": json_deserializer
+ }
+ else:
+ options = {}
+
if testing.against("postgresql+psycopg2"):
from psycopg2.extras import register_default_json
- engine = engines.testing_engine()
+ engine = engines.testing_engine(options=options)
@event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
engine.dialect._has_native_json = False
def pass_(value):
return value
register_default_json(dbapi_connection, loads=pass_)
+ elif options:
+ engine = engines.testing_engine(options=options)
else:
engine = testing.db
engine.connect()
@@ -1811,6 +1808,56 @@ class JSONRoundTripTest(fixtures.TablesTest):
engine = self._non_native_engine()
self._test_insert(engine)
+
+ def _test_custom_serialize_deserialize(self, native):
+ import json
+ def loads(value):
+ value = json.loads(value)
+ value['x'] = value['x'] + '_loads'
+ return value
+
+ def dumps(value):
+ value = dict(value)
+ value['x'] = 'dumps_y'
+ return json.dumps(value)
+
+ if native:
+ engine = engines.testing_engine(options=dict(
+ json_serializer=dumps,
+ json_deserializer=loads
+ ))
+ else:
+ engine = self._non_native_engine(
+ json_serializer=dumps,
+ json_deserializer=loads
+ )
+
+ s = select([
+ cast(
+ {
+ "key": "value",
+ "x": "q"
+ },
+ JSON
+ )
+ ])
+ eq_(
+ engine.scalar(s),
+ {
+ "key": "value",
+ "x": "dumps_y_loads"
+ },
+ )
+
+ @testing.only_on("postgresql+psycopg2")
+ def test_custom_native(self):
+ self._test_custom_serialize_deserialize(True)
+
+ @testing.only_on("postgresql+psycopg2")
+ def test_custom_python(self):
+ self._test_custom_serialize_deserialize(False)
+
+
@testing.only_on("postgresql+psycopg2")
def test_criterion_native(self):
engine = testing.db