diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-02-01 19:00:07 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-02-01 19:03:08 -0500 |
commit | b2189da65019ed2f44e77933a122619489319c5a (patch) | |
tree | 1e3bcd44586eccfeebd9bd56b4da9606edd07769 | |
parent | 02dbcfa88a8390ee2af4f76e47bca8f205ddeee5 (diff) | |
download | sqlalchemy-b2189da65019ed2f44e77933a122619489319c5a.tar.gz |
- Repaired support for Postgresql UUID types in conjunction with
the ARRAY type when using psycopg2. The psycopg2 dialect now
employs use of the psycopg2.extras.register_uuid() hook
so that UUID values are always passed to/from the DBAPI as
UUID() objects. The :paramref:`.UUID.as_uuid` flag is still
honored, except with psycopg2 we need to convert returned
UUID objects back into strings when this is disabled.
fixes #2940
-rw-r--r-- | doc/build/changelog/changelog_09.rst | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg2.py | 38 | ||||
-rw-r--r-- | test/dialect/postgresql/test_types.py | 34 |
3 files changed, 75 insertions, 9 deletions
diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst index 2af1cd35f..10d003f09 100644 --- a/doc/build/changelog/changelog_09.rst +++ b/doc/build/changelog/changelog_09.rst @@ -16,6 +16,18 @@ .. change:: :tags: bug, postgresql + :tickets: 2940 + + Repaired support for Postgresql UUID types in conjunction with + the ARRAY type when using psycopg2. The psycopg2 dialect now + employs use of the psycopg2.extras.register_uuid() hook + so that UUID values are always passed to/from the DBAPI as + UUID() objects. The :paramref:`.UUID.as_uuid` flag is still + honored, except with psycopg2 we need to convert returned + UUID objects back into strings when this is disabled. + + .. change:: + :tags: bug, postgresql :pullreq: github:145 Added support for the :class:`postgresql.JSONB` datatype when diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 26e45fed2..4f1e04f20 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -312,10 +312,15 @@ from ... import types as sqltypes from .base import PGDialect, PGCompiler, \ PGIdentifierPreparer, PGExecutionContext, \ ENUM, ARRAY, _DECIMAL_TYPES, _FLOAT_TYPES,\ - _INT_TYPES + _INT_TYPES, UUID from .hstore import HSTORE from .json import JSON, JSONB +try: + from uuid import UUID as _python_UUID +except ImportError: + _python_UUID = None + logger = logging.getLogger('sqlalchemy.dialects.postgresql') @@ -388,6 +393,26 @@ class _PGJSONB(JSONB): else: return super(_PGJSONB, self).result_processor(dialect, coltype) + +class _PGUUID(UUID): + def bind_processor(self, dialect): + if not self.as_uuid and dialect.use_native_uuid: + nonetype = type(None) + + def process(value): + if value is not None: + value = _python_UUID(value) + return value + return process + + def result_processor(self, dialect, coltype): + if not self.as_uuid and dialect.use_native_uuid: + def process(value): + if value is not None: + value = str(value) + return value + return process + # When we're handed literal SQL, ensure it's a SELECT query. Since # 8.3, combining cursors and "FOR UPDATE" has been fine. SERVER_SIDE_CURSOR_RE = re.compile( @@ -488,18 +513,20 @@ class PGDialect_psycopg2(PGDialect): sqltypes.Enum: _PGEnum, # needs force_unicode HSTORE: _PGHStore, JSON: _PGJSON, - JSONB: _PGJSONB + JSONB: _PGJSONB, + UUID: _PGUUID } ) def __init__(self, server_side_cursors=False, use_native_unicode=True, client_encoding=None, - use_native_hstore=True, + use_native_hstore=True, use_native_uuid=True, **kwargs): PGDialect.__init__(self, **kwargs) self.server_side_cursors = server_side_cursors self.use_native_unicode = use_native_unicode self.use_native_hstore = use_native_hstore + self.use_native_uuid = use_native_uuid self.supports_unicode_binds = use_native_unicode self.client_encoding = client_encoding if self.dbapi and hasattr(self.dbapi, '__version__'): @@ -575,6 +602,11 @@ class PGDialect_psycopg2(PGDialect): self.set_isolation_level(conn, self.isolation_level) fns.append(on_connect) + if self.dbapi and self.use_native_uuid: + def on_connect(conn): + extras.register_uuid(None, conn) + fns.append(on_connect) + if self.dbapi and self.use_native_unicode: def on_connect(conn): extensions.register_type(extensions.UNICODE, conn) diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 866bc7d54..36f4fdc3f 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1035,7 +1035,7 @@ class UUIDTest(fixtures.TestBase): import uuid self._test_round_trip( Table('utable', MetaData(), - Column('data', postgresql.UUID()) + Column('data', postgresql.UUID(as_uuid=False)) ), str(uuid.uuid4()), str(uuid.uuid4()) @@ -1057,13 +1057,32 @@ class UUIDTest(fixtures.TestBase): ) @testing.fails_on('postgresql+zxjdbc', - 'column "data" is of type uuid[] but expression is of type character varying') + 'column "data" is of type uuid[] but ' + 'expression is of type character varying') @testing.fails_on('postgresql+pg8000', 'No support for UUID type') def test_uuid_array(self): import uuid self._test_round_trip( - Table('utable', MetaData(), - Column('data', postgresql.ARRAY(postgresql.UUID())) + Table( + 'utable', MetaData(), + Column('data', postgresql.ARRAY(postgresql.UUID(as_uuid=True))) + ), + [uuid.uuid4(), uuid.uuid4()], + [uuid.uuid4(), uuid.uuid4()], + ) + + @testing.fails_on('postgresql+zxjdbc', + 'column "data" is of type uuid[] but ' + 'expression is of type character varying') + @testing.fails_on('postgresql+pg8000', 'No support for UUID type') + def test_uuid_string_array(self): + import uuid + self._test_round_trip( + Table( + 'utable', MetaData(), + Column( + 'data', + postgresql.ARRAY(postgresql.UUID(as_uuid=False))) ), [str(uuid.uuid4()), str(uuid.uuid4())], [str(uuid.uuid4()), str(uuid.uuid4())], @@ -1088,7 +1107,7 @@ class UUIDTest(fixtures.TestBase): def teardown(self): self.conn.close() - def _test_round_trip(self, utable, value1, value2): + def _test_round_trip(self, utable, value1, value2, exp_value2=None): utable.create(self.conn) self.conn.execute(utable.insert(), {'data': value1}) self.conn.execute(utable.insert(), {'data': value2}) @@ -1096,7 +1115,10 @@ class UUIDTest(fixtures.TestBase): select([utable.c.data]). where(utable.c.data != value1) ) - eq_(r.fetchone()[0], value2) + if exp_value2: + eq_(r.fetchone()[0], exp_value2) + else: + eq_(r.fetchone()[0], value2) eq_(r.fetchone(), None) |