summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_13/4285.rst15
-rw-r--r--lib/sqlalchemy/orm/mapper.py19
-rw-r--r--lib/sqlalchemy/orm/persistence.py11
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py22
-rw-r--r--lib/sqlalchemy/sql/type_api.py14
-rw-r--r--test/orm/test_mapper.py2
-rw-r--r--test/orm/test_naturalpks.py72
-rw-r--r--test/orm/test_unitofwork.py43
-rw-r--r--test/sql/test_types.py39
9 files changed, 231 insertions, 6 deletions
diff --git a/doc/build/changelog/unreleased_13/4285.rst b/doc/build/changelog/unreleased_13/4285.rst
new file mode 100644
index 000000000..1049a5882
--- /dev/null
+++ b/doc/build/changelog/unreleased_13/4285.rst
@@ -0,0 +1,15 @@
+.. change::
+ :tags: usecase, orm
+ :tickets: 4285
+
+ Added support for the use of an :class:`.Enum` datatype using Python
+ pep-435 enumeration objects as values for use as a primary key column
+ mapped by the ORM. As these values are not inherently sortable, as
+ required by the ORM for primary keys, a new
+ :attr:`.TypeEngine.sort_key_function` attribute is added to the typing
+ system which allows any SQL type to implement a sorting for Python objects
+ of its type which is consulted by the unit of work. The :class:`.Enum`
+ type then defines this using the database value of a given enumeration.
+ The sorting scheme can be also be redefined by passing a callable to the
+ :paramref:`.Enum.sort_key_function` parameter. Pull request courtesy
+ Nicolas Caniart.
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 5e8d25647..07fd9f3fb 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -2749,6 +2749,25 @@ class Mapper(InspectionAttr):
return identity_key[1]
@_memoized_configured_property
+ def _persistent_sortkey_fn(self):
+ key_fns = [col.type.sort_key_function for col in self.primary_key]
+
+ if set(key_fns).difference([None]):
+
+ def key(state):
+ return tuple(
+ key_fn(val) if key_fn is not None else val
+ for key_fn, val in zip(key_fns, state.key[1])
+ )
+
+ else:
+
+ def key(state):
+ return state.key[1]
+
+ return key
+
+ @_memoized_configured_property
def _identity_key_props(self):
return [self._columntoproperty[col] for col in self.primary_key]
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index fb25d2405..68052dfdd 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -196,7 +196,7 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
# if batch=false, call _save_obj separately for each object
if not single and not base_mapper.batch:
- for state in _sort_states(states):
+ for state in _sort_states(base_mapper, states):
save_obj(base_mapper, [state], uowtransaction, single=True)
return
@@ -1607,7 +1607,7 @@ def _connections_for_states(base_mapper, uowtransaction, states):
connection = uowtransaction.transaction.connection(base_mapper)
connection_callable = None
- for state in _sort_states(states):
+ for state in _sort_states(base_mapper, states):
if connection_callable:
connection = connection_callable(base_mapper, state.obj())
@@ -1625,12 +1625,15 @@ def _cached_connection_dict(base_mapper):
)
-def _sort_states(states):
+def _sort_states(mapper, states):
pending = set(states)
persistent = set(s for s in pending if s.key is not None)
pending.difference_update(persistent)
+
try:
- persistent_sorted = sorted(persistent, key=lambda q: q.key[1])
+ persistent_sorted = sorted(
+ persistent, key=mapper._persistent_sortkey_fn
+ )
except TypeError as err:
raise sa_exc.InvalidRequestError(
"Could not sort objects by primary key; primary key "
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 631352ceb..fd15d7c79 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -20,6 +20,7 @@ from . import operators
from . import roles
from . import type_api
from .base import _bind_or_error
+from .base import NO_ARG
from .base import SchemaEventTarget
from .elements import _defer_name
from .elements import quoted_name
@@ -1356,6 +1357,19 @@ class Enum(Emulated, String, SchemaType):
.. versionadded:: 1.2.3
+ :param sort_key_function: a Python callable which may be used as the
+ "key" argument in the Python ``sorted()`` built-in. The SQLAlchemy
+ ORM requires that primary key columns which are mapped must
+ be sortable in some way. When using an unsortable enumeration
+ object such as a Python 3 ``Enum`` object, this parameter may be
+ used to set a default sort key function for the objects. By
+ default, the database value of the enumeration is used as the
+ sorting function.
+
+ .. versionadded:: 1.3.8
+
+
+
"""
self._enum_init(enums, kw)
@@ -1377,6 +1391,7 @@ class Enum(Emulated, String, SchemaType):
self.native_enum = kw.pop("native_enum", True)
self.create_constraint = kw.pop("create_constraint", True)
self.values_callable = kw.pop("values_callable", None)
+ self._sort_key_function = kw.pop("sort_key_function", NO_ARG)
values, objects = self._parse_into_values(enums, kw)
self._setup_for_values(values, objects, kw)
@@ -1450,6 +1465,13 @@ class Enum(Emulated, String, SchemaType):
)
@property
+ def sort_key_function(self):
+ if self._sort_key_function is NO_ARG:
+ return self._db_value_for_elem
+ else:
+ return self._sort_key_function
+
+ @property
def native(self):
return self.native_enum
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 9838f0d5a..11407ad2e 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -135,6 +135,16 @@ class TypeEngine(Visitable):
"""
+ sort_key_function = None
+ """A sorting function that can be passed as the key to sorted.
+
+ The default value of ``None`` indicates that the values stored by
+ this type are self-sorting.
+
+ .. versionadded:: 1.3.8
+
+ """
+
should_evaluate_none = False
"""If True, the Python constant ``None`` is considered to be handled
explicitly by this type.
@@ -1354,6 +1364,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
"""
return self.impl.compare_values(x, y)
+ @property
+ def sort_key_function(self):
+ return self.impl.sort_key_function
+
def __repr__(self):
return util.generic_repr(self, to_inspect=self.impl)
diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py
index ceec344d9..93346b32f 100644
--- a/test/orm/test_mapper.py
+++ b/test/orm/test_mapper.py
@@ -346,7 +346,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
states[4].insert_order = DontCompareMeToString(1)
states[2].insert_order = DontCompareMeToString(3)
eq_(
- _sort_states(states),
+ _sort_states(m, states),
[states[4], states[3], states[0], states[1], states[2]],
)
diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py
index 6108a28c4..9a25a618d 100644
--- a/test/orm/test_naturalpks.py
+++ b/test/orm/test_naturalpks.py
@@ -3,11 +3,15 @@ Primary key changing capabilities and passive/non-passive cascading updates.
"""
+import itertools
+
import sqlalchemy as sa
+from sqlalchemy import bindparam
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
+from sqlalchemy import TypeDecorator
from sqlalchemy.orm import create_session
from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
@@ -1754,6 +1758,74 @@ class JoinedInheritanceTest(fixtures.MappedTest):
)
+class UnsortablePKTest(fixtures.MappedTest):
+ """Test integration with TypeEngine.sort_key_function"""
+
+ class HashableDict(dict):
+ def __hash__(self):
+ return hash((self["x"], self["y"]))
+
+ @classmethod
+ def define_tables(cls, metadata):
+ class MyUnsortable(TypeDecorator):
+ impl = String(10)
+
+ def process_bind_param(self, value, dialect):
+ return "%s,%s" % (value["x"], value["y"])
+
+ def process_result_value(self, value, dialect):
+ rec = value.split(",")
+ return cls.HashableDict({"x": rec[0], "y": rec[1]})
+
+ def sort_key_function(self, value):
+ return (value["x"], value["y"])
+
+ Table(
+ "data",
+ metadata,
+ Column("info", MyUnsortable(), primary_key=True),
+ Column("int_value", Integer),
+ )
+
+ @classmethod
+ def setup_classes(cls):
+ class Data(cls.Comparable):
+ pass
+
+ @classmethod
+ def setup_mappers(cls):
+ mapper(cls.classes.Data, cls.tables.data)
+
+ def test_updates_sorted(self):
+ Data = self.classes.Data
+ s = Session()
+
+ s.add_all(
+ [
+ Data(info=self.HashableDict(x="a", y="b")),
+ Data(info=self.HashableDict(x="a", y="a")),
+ Data(info=self.HashableDict(x="b", y="b")),
+ Data(info=self.HashableDict(x="b", y="a")),
+ ]
+ )
+ s.commit()
+
+ aa, ab, ba, bb = s.query(Data).order_by(Data.info).all()
+
+ counter = itertools.count()
+ ab.int_value = bindparam(key=None, callable_=lambda: next(counter))
+ ba.int_value = bindparam(key=None, callable_=lambda: next(counter))
+ bb.int_value = bindparam(key=None, callable_=lambda: next(counter))
+ aa.int_value = bindparam(key=None, callable_=lambda: next(counter))
+
+ s.commit()
+
+ eq_(
+ s.query(Data.int_value).order_by(Data.info).all(),
+ [(0,), (1,), (2,), (3,)],
+ )
+
+
class JoinedInheritancePKOnFKTest(fixtures.MappedTest):
"""Test cascades of pk->non-pk/fk on joined table inh."""
diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py
index 13c5907a4..6185c4a51 100644
--- a/test/orm/test_unitofwork.py
+++ b/test/orm/test_unitofwork.py
@@ -15,12 +15,14 @@ from sqlalchemy import literal_column
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import testing
+from sqlalchemy.inspection import inspect
from sqlalchemy.orm import column_property
from sqlalchemy.orm import create_session
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
+from sqlalchemy.orm.persistence import _sort_states
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
@@ -3398,6 +3400,7 @@ class EnsurePKSortableTest(fixtures.MappedTest):
two = MySortableEnum("two", 2)
three = MyNotSortableEnum("three", 3)
four = MyNotSortableEnum("four", 4)
+ five = MyNotSortableEnum("five", 5)
@classmethod
def define_tables(cls, metadata):
@@ -3411,10 +3414,25 @@ class EnsurePKSortableTest(fixtures.MappedTest):
Table(
"t2",
metadata,
- Column("id", Enum(cls.MyNotSortableEnum), primary_key=True),
+ Column(
+ "id",
+ Enum(cls.MyNotSortableEnum, sort_key_function=None),
+ primary_key=True,
+ ),
Column("data", String(10)),
)
+ Table(
+ "t3",
+ metadata,
+ Column("id", Enum(cls.MyNotSortableEnum), primary_key=True),
+ Column("value", Integer),
+ )
+
+ @staticmethod
+ def sort_enum_key_value(value):
+ return value.value
+
@classmethod
def setup_classes(cls):
class T1(cls.Basic):
@@ -3423,10 +3441,15 @@ class EnsurePKSortableTest(fixtures.MappedTest):
class T2(cls.Basic):
pass
+ class T3(cls.Basic):
+ def __str__(self):
+ return "T3(id={})".format(self.id)
+
@classmethod
def setup_mappers(cls):
mapper(cls.classes.T1, cls.tables.t1)
mapper(cls.classes.T2, cls.tables.t2)
+ mapper(cls.classes.T3, cls.tables.t3)
def test_exception_persistent_flush_py3k(self):
s = Session()
@@ -3459,3 +3482,21 @@ class EnsurePKSortableTest(fixtures.MappedTest):
a.data = "bar"
b.data = "foo"
s.commit()
+
+ def test_pep435_custom_sort_key(self):
+ s = Session()
+
+ a = self.classes.T3(id=self.three, value=1)
+ b = self.classes.T3(id=self.four, value=2)
+ s.add_all([a, b])
+ s.commit()
+
+ c = self.classes.T3(id=self.five, value=0)
+ s.add(c)
+
+ states = [o._sa_instance_state for o in [b, a, c]]
+ eq_(
+ _sort_states(inspect(self.classes.T3), states),
+ # pending come first, then "four" < "three"
+ [o._sa_instance_state for o in [c, b, a]],
+ )
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index a5c9313f8..e3d2134b7 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -1658,6 +1658,45 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
[(1, "two"), (2, "two"), (3, "one")],
)
+ def test_pep435_default_sort_key(self):
+ one, two, a_member, b_member = (
+ self.one,
+ self.two,
+ self.a_member,
+ self.b_member,
+ )
+ typ = Enum(self.SomeEnum)
+
+ is_(typ.sort_key_function.__func__, typ._db_value_for_elem.__func__)
+
+ eq_(
+ sorted([two, one, a_member, b_member], key=typ.sort_key_function),
+ [a_member, b_member, one, two],
+ )
+
+ def test_pep435_custom_sort_key(self):
+ one, two, a_member, b_member = (
+ self.one,
+ self.two,
+ self.a_member,
+ self.b_member,
+ )
+
+ def sort_enum_key_value(value):
+ return str(value.value)
+
+ typ = Enum(self.SomeEnum, sort_key_function=sort_enum_key_value)
+ is_(typ.sort_key_function, sort_enum_key_value)
+
+ eq_(
+ sorted([two, one, a_member, b_member], key=typ.sort_key_function),
+ [one, two, a_member, b_member],
+ )
+
+ def test_pep435_no_sort_key(self):
+ typ = Enum(self.SomeEnum, sort_key_function=None)
+ is_(typ.sort_key_function, None)
+
def test_pep435_enum_round_trip(self):
stdlib_enum_table = self.tables["stdlib_enum_table"]