summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-10-17 13:09:24 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2019-10-20 20:49:03 -0400
commited553fffd65a063d6dbdb3770d1fa0124bd55e23 (patch)
tree59ab8a457b3ed82cb7647b7da1b94b4ce2a815e1
parent528782d1c356445f17cea857ef0974e074c51d60 (diff)
downloadsqlalchemy-ed553fffd65a063d6dbdb3770d1fa0124bd55e23.tar.gz
Implement facade for pytest parametrize, fixtures, classlevel
Add factilities to implement pytest.mark.parametrize and pytest.fixtures patterns, which largely resemble things we are already doing. Ensure a facade is used, so that the test suite remains independent of py.test, but also tailors the functions to the more limited scope in which we are using them. Additionally, create a class-based version that works from the same facade. Several old polymorphic tests as well as two of the sql test are refactored to use the new features. Change-Id: I6ef8af1dafff92534313016944d447f9439856cf References: #4896
-rw-r--r--lib/sqlalchemy/testing/__init__.py2
-rw-r--r--lib/sqlalchemy/testing/config.py72
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py50
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py151
-rw-r--r--test/aaa_profiling/test_memusage.py2
-rw-r--r--test/orm/inheritance/test_abc_polymorphic.py154
-rw-r--r--test/orm/inheritance/test_assorted_poly.py241
-rw-r--r--test/orm/inheritance/test_magazine.py228
-rw-r--r--test/orm/inheritance/test_poly_persistence.py228
-rw-r--r--test/orm/test_descriptor.py12
-rw-r--r--test/sql/test_operators.py315
-rw-r--r--test/sql/test_types.py433
12 files changed, 1041 insertions, 847 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index 2b8158fbb..4f28461e3 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -32,7 +32,9 @@ from .assertions import ne_ # noqa
from .assertions import not_in_ # noqa
from .assertions import startswith_ # noqa
from .assertions import uses_deprecated # noqa
+from .config import combinations # noqa
from .config import db # noqa
+from .config import fixture # noqa
from .config import requirements as requires # noqa
from .exclusions import _is_excluded # noqa
from .exclusions import _server_version # noqa
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
index f94c5b308..87bbc6a0f 100644
--- a/lib/sqlalchemy/testing/config.py
+++ b/lib/sqlalchemy/testing/config.py
@@ -6,7 +6,6 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import collections
-from unittest import SkipTest as _skip_test_exception
requirements = None
db = None
@@ -17,6 +16,75 @@ test_schema = None
test_schema_2 = None
_current = None
+_fixture_functions = None # installed by plugin_base
+
+
+def combinations(*comb, **kw):
+ r"""Deliver multiple versions of a test based on positional combinations.
+
+ This is a facade over pytest.mark.parametrize.
+
+
+ :param \*comb: argument combinations. These are tuples that will be passed
+ positionally to the decorated function.
+
+ :param argnames: optional list of argument names. These are the names
+ of the arguments in the test function that correspond to the entries
+ in each argument tuple. pytest.mark.parametrize requires this, however
+ the combinations function will derive it automatically if not present
+ by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the
+ first argument is "self" which is discarded.
+
+ :param id\_: optional id template. This is a string template that
+ describes how the "id" for each parameter set should be defined, if any.
+ The number of characters in the template should match the number of
+ entries in each argument tuple. Each character describes how the
+ corresponding entry in the argument tuple should be handled, as far as
+ whether or not it is included in the arguments passed to the function, as
+ well as if it is included in the tokens used to create the id of the
+ parameter set.
+
+ If omitted, the argment combinations are passed to parametrize as is. If
+ passed, each argument combination is turned into a pytest.param() object,
+ mapping the elements of the argument tuple to produce an id based on a
+ character value in the same position within the string template using the
+ following scheme::
+
+ i - the given argument is a string that is part of the id only, don't
+ pass it as an argument
+
+ n - the given argument should be passed and it should be added to the
+ id by calling the .__name__ attribute
+
+ r - the given argument should be passed and it should be added to the
+ id by calling repr()
+
+ s- the given argument should be passed and it should be added to the
+ id by calling str()
+
+ e.g.::
+
+ @testing.combinations(
+ (operator.eq, "eq"),
+ (operator.ne, "ne"),
+ (operator.gt, "gt"),
+ (operator.lt, "lt"),
+ id_="na"
+ )
+ def test_operator(self, opfunc, name):
+ pass
+
+ The above combination will call ``.__name__`` on the first member of
+ each tuple and use that as the "id" to pytest.param().
+
+
+ """
+ return _fixture_functions.combinations(*comb, **kw)
+
+
+def fixture(*arg, **kw):
+ return _fixture_functions.fixture(*arg, **kw)
+
class Config(object):
def __init__(self, db, db_opts, options, file_config):
@@ -94,4 +162,4 @@ class Config(object):
def skip_test(msg):
- raise _skip_test_exception(msg)
+ raise _fixture_functions.skip_test_exception(msg)
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
index 859d1d779..a2f969a66 100644
--- a/lib/sqlalchemy/testing/plugin/plugin_base.py
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -16,6 +16,7 @@ is py.test.
from __future__ import absolute_import
+import abc
import re
import sys
@@ -24,8 +25,15 @@ py3k = sys.version_info >= (3, 0)
if py3k:
import configparser
+
+ ABC = abc.ABC
else:
import ConfigParser as configparser
+ import collections as collections_abc # noqa
+
+ class ABC(object):
+ __metaclass__ = abc.ABCMeta
+
# late imports
fixtures = None
@@ -238,14 +246,6 @@ def set_coverage_flag(value):
options.has_coverage = value
-_skip_test_exception = None
-
-
-def set_skip_test(exc):
- global _skip_test_exception
- _skip_test_exception = exc
-
-
def post_begin():
"""things to set up later, once we know coverage is running."""
# Lazy setup of other options (post coverage)
@@ -331,10 +331,10 @@ def _monkeypatch_cdecimal(options, file_config):
@post
-def _init_skiptest(options, file_config):
+def _init_symbols(options, file_config):
from sqlalchemy.testing import config
- config._skip_test_exception = _skip_test_exception
+ config._fixture_functions = _fixture_fn_class()
@post
@@ -486,10 +486,10 @@ def _setup_profiling(options, file_config):
)
-def want_class(cls):
+def want_class(name, cls):
if not issubclass(cls, fixtures.TestBase):
return False
- elif cls.__name__.startswith("_"):
+ elif name.startswith("_"):
return False
elif (
config.options.backend_only
@@ -711,3 +711,29 @@ def _do_skips(cls):
def _setup_config(config_obj, ctx):
config._current.push(config_obj, testing)
+
+
+class FixtureFunctions(ABC):
+ @abc.abstractmethod
+ def skip_test_exception(self, *arg, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def combinations(self, *args, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def param_ident(self, *args, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def fixture(self, fn):
+ raise NotImplementedError()
+
+
+_fixture_fn_class = None
+
+
+def set_fixture_functions(fixture_fn_class):
+ global _fixture_fn_class
+ _fixture_fn_class = fixture_fn_class
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index e0335c135..5d91db5d7 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -8,7 +8,11 @@ except ImportError:
import argparse
import collections
import inspect
+import itertools
+import operator
import os
+import re
+import sys
import pytest
@@ -87,7 +91,7 @@ def pytest_configure(config):
bool(getattr(config.option, "cov_source", False))
)
- plugin_base.set_skip_test(pytest.skip.Exception)
+ plugin_base.set_fixture_functions(PytestFixtureFunctions)
def pytest_sessionstart(session):
@@ -132,6 +136,7 @@ def pytest_collection_modifyitems(session, config, items):
rebuilt_items = collections.defaultdict(
lambda: collections.defaultdict(list)
)
+
items[:] = [
item
for item in items
@@ -173,21 +178,63 @@ def pytest_collection_modifyitems(session, config, items):
def pytest_pycollect_makeitem(collector, name, obj):
- if inspect.isclass(obj) and plugin_base.want_class(obj):
- return pytest.Class(name, parent=collector)
+
+ if inspect.isclass(obj) and plugin_base.want_class(name, obj):
+ return [
+ pytest.Class(parametrize_cls.__name__, parent=collector)
+ for parametrize_cls in _parametrize_cls(collector.module, obj)
+ ]
elif (
inspect.isfunction(obj)
and isinstance(collector, pytest.Instance)
and plugin_base.want_method(collector.cls, obj)
):
- return pytest.Function(name, parent=collector)
+ # None means, fall back to default logic, which includes
+ # method-level parametrize
+ return None
else:
+ # empty list means skip this item
return []
_current_class = None
+def _parametrize_cls(module, cls):
+ """implement a class-based version of pytest parametrize."""
+
+ if "_sa_parametrize" not in cls.__dict__:
+ return [cls]
+
+ _sa_parametrize = cls._sa_parametrize
+ classes = []
+ for full_param_set in itertools.product(
+ *[params for argname, params in _sa_parametrize]
+ ):
+ cls_variables = {}
+
+ for argname, param in zip(
+ [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
+ ):
+ if not argname:
+ raise TypeError("need argnames for class-based combinations")
+ argname_split = re.split(r",\s*", argname)
+ for arg, val in zip(argname_split, param.values):
+ cls_variables[arg] = val
+ parametrized_name = "_".join(
+ # token is a string, but in py2k py.test is giving us a unicode,
+ # so call str() on it.
+ str(re.sub(r"\W", "", token))
+ for param in full_param_set
+ for token in param.id.split("-")
+ )
+ name = "%s_%s" % (cls.__name__, parametrized_name)
+ newcls = type.__new__(type, name, (cls,), cls_variables)
+ setattr(module, name, newcls)
+ classes.append(newcls)
+ return classes
+
+
def pytest_runtest_setup(item):
# here we seem to get called only based on what we collected
# in pytest_collection_modifyitems. So to do class-based stuff
@@ -239,3 +286,99 @@ def class_setup(item):
def class_teardown(item):
plugin_base.stop_test_class(item.cls)
+
+
+def getargspec(fn):
+ if sys.version_info.major == 3:
+ return inspect.getfullargspec(fn)
+ else:
+ return inspect.getargspec(fn)
+
+
+class PytestFixtureFunctions(plugin_base.FixtureFunctions):
+ def skip_test_exception(self, *arg, **kw):
+ return pytest.skip.Exception(*arg, **kw)
+
+ _combination_id_fns = {
+ "i": lambda obj: obj,
+ "r": repr,
+ "s": str,
+ "n": operator.attrgetter("__name__"),
+ }
+
+ def combinations(self, *arg_sets, **kw):
+ """facade for pytest.mark.paramtrize.
+
+ Automatically derives argument names from the callable which in our
+ case is always a method on a class with positional arguments.
+
+ ids for parameter sets are derived using an optional template.
+
+ """
+
+ if sys.version_info.major == 3:
+ if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
+ arg_sets = list(arg_sets[0])
+ else:
+ if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"):
+ arg_sets = list(arg_sets[0])
+
+ argnames = kw.pop("argnames", None)
+
+ id_ = kw.pop("id_", None)
+
+ if id_:
+ _combination_id_fns = self._combination_id_fns
+
+ # because itemgetter is not consistent for one argument vs.
+ # multiple, make it multiple in all cases and use a slice
+ # to omit the first argument
+ _arg_getter = operator.itemgetter(
+ 0,
+ *[
+ idx
+ for idx, char in enumerate(id_)
+ if char in ("n", "r", "s", "a")
+ ]
+ )
+ fns = [
+ (operator.itemgetter(idx), _combination_id_fns[char])
+ for idx, char in enumerate(id_)
+ if char in _combination_id_fns
+ ]
+ arg_sets = [
+ pytest.param(
+ *_arg_getter(arg)[1:],
+ id="-".join(
+ comb_fn(getter(arg)) for getter, comb_fn in fns
+ )
+ )
+ for arg in arg_sets
+ ]
+ else:
+ # ensure using pytest.param so that even a 1-arg paramset
+ # still needs to be a tuple. otherwise paramtrize tries to
+ # interpret a single arg differently than tuple arg
+ arg_sets = [pytest.param(*arg) for arg in arg_sets]
+
+ def decorate(fn):
+ if inspect.isclass(fn):
+ if "_sa_parametrize" not in fn.__dict__:
+ fn._sa_parametrize = []
+ fn._sa_parametrize.append((argnames, arg_sets))
+ return fn
+ else:
+ if argnames is None:
+ _argnames = getargspec(fn).args[1:]
+ else:
+ _argnames = argnames
+ return pytest.mark.parametrize(_argnames, arg_sets)(fn)
+
+ return decorate
+
+ def param_ident(self, *parameters):
+ ident = parameters[0]
+ return pytest.param(*parameters[1:], id=ident)
+
+ def fixture(self, fn):
+ return pytest.fixture(fn)
diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py
index cbfbc63ee..431e53b1b 100644
--- a/test/aaa_profiling/test_memusage.py
+++ b/test/aaa_profiling/test_memusage.py
@@ -921,7 +921,7 @@ class MemUsageWBackendTest(EnsureZeroed):
metadata.drop_all()
assert_no_mappers()
- @testing.expect_deprecated
+ @testing.uses_deprecated()
@testing.provide_metadata
def test_key_fallback_result(self):
e = self.engine
diff --git a/test/orm/inheritance/test_abc_polymorphic.py b/test/orm/inheritance/test_abc_polymorphic.py
index f430e761f..cf06c9e26 100644
--- a/test/orm/inheritance/test_abc_polymorphic.py
+++ b/test/orm/inheritance/test_abc_polymorphic.py
@@ -1,13 +1,13 @@
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import String
+from sqlalchemy import testing
from sqlalchemy.orm import create_session
from sqlalchemy.orm import mapper
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
-from sqlalchemy.testing.util import function_named
class ABCTest(fixtures.MappedTest):
@@ -36,91 +36,85 @@ class ABCTest(fixtures.MappedTest):
Column("cdata", String(30)),
)
- def _make_test(fetchtype):
- def test_roundtrip(self):
- class A(fixtures.ComparableEntity):
- pass
+ @testing.combinations(("union",), ("none",))
+ def test_abc_poly_roundtrip(self, fetchtype):
+ class A(fixtures.ComparableEntity):
+ pass
- class B(A):
- pass
+ class B(A):
+ pass
- class C(B):
- pass
+ class C(B):
+ pass
- if fetchtype == "union":
- abc = a.outerjoin(b).outerjoin(c)
- bc = a.join(b).outerjoin(c)
- else:
- abc = bc = None
+ if fetchtype == "union":
+ abc = a.outerjoin(b).outerjoin(c)
+ bc = a.join(b).outerjoin(c)
+ else:
+ abc = bc = None
- mapper(
- A,
- a,
- with_polymorphic=("*", abc),
- polymorphic_on=a.c.type,
- polymorphic_identity="a",
- )
- mapper(
- B,
- b,
- with_polymorphic=("*", bc),
- inherits=A,
- polymorphic_identity="b",
- )
- mapper(C, c, inherits=B, polymorphic_identity="c")
-
- a1 = A(adata="a1")
- b1 = B(bdata="b1", adata="b1")
- b2 = B(bdata="b2", adata="b2")
- b3 = B(bdata="b3", adata="b3")
- c1 = C(cdata="c1", bdata="c1", adata="c1")
- c2 = C(cdata="c2", bdata="c2", adata="c2")
- c3 = C(cdata="c2", bdata="c2", adata="c2")
-
- sess = create_session()
- for x in (a1, b1, b2, b3, c1, c2, c3):
- sess.add(x)
- sess.flush()
- sess.expunge_all()
+ mapper(
+ A,
+ a,
+ with_polymorphic=("*", abc),
+ polymorphic_on=a.c.type,
+ polymorphic_identity="a",
+ )
+ mapper(
+ B,
+ b,
+ with_polymorphic=("*", bc),
+ inherits=A,
+ polymorphic_identity="b",
+ )
+ mapper(C, c, inherits=B, polymorphic_identity="c")
- # for obj in sess.query(A).all():
- # print obj
- eq_(
- [
- A(adata="a1"),
- B(bdata="b1", adata="b1"),
- B(bdata="b2", adata="b2"),
- B(bdata="b3", adata="b3"),
- C(cdata="c1", bdata="c1", adata="c1"),
- C(cdata="c2", bdata="c2", adata="c2"),
- C(cdata="c2", bdata="c2", adata="c2"),
- ],
- sess.query(A).order_by(A.id).all(),
- )
+ a1 = A(adata="a1")
+ b1 = B(bdata="b1", adata="b1")
+ b2 = B(bdata="b2", adata="b2")
+ b3 = B(bdata="b3", adata="b3")
+ c1 = C(cdata="c1", bdata="c1", adata="c1")
+ c2 = C(cdata="c2", bdata="c2", adata="c2")
+ c3 = C(cdata="c2", bdata="c2", adata="c2")
- eq_(
- [
- B(bdata="b1", adata="b1"),
- B(bdata="b2", adata="b2"),
- B(bdata="b3", adata="b3"),
- C(cdata="c1", bdata="c1", adata="c1"),
- C(cdata="c2", bdata="c2", adata="c2"),
- C(cdata="c2", bdata="c2", adata="c2"),
- ],
- sess.query(B).order_by(A.id).all(),
- )
+ sess = create_session()
+ for x in (a1, b1, b2, b3, c1, c2, c3):
+ sess.add(x)
+ sess.flush()
+ sess.expunge_all()
- eq_(
- [
- C(cdata="c1", bdata="c1", adata="c1"),
- C(cdata="c2", bdata="c2", adata="c2"),
- C(cdata="c2", bdata="c2", adata="c2"),
- ],
- sess.query(C).order_by(A.id).all(),
- )
+ # for obj in sess.query(A).all():
+ # print obj
+ eq_(
+ [
+ A(adata="a1"),
+ B(bdata="b1", adata="b1"),
+ B(bdata="b2", adata="b2"),
+ B(bdata="b3", adata="b3"),
+ C(cdata="c1", bdata="c1", adata="c1"),
+ C(cdata="c2", bdata="c2", adata="c2"),
+ C(cdata="c2", bdata="c2", adata="c2"),
+ ],
+ sess.query(A).order_by(A.id).all(),
+ )
- test_roundtrip = function_named(test_roundtrip, "test_%s" % fetchtype)
- return test_roundtrip
+ eq_(
+ [
+ B(bdata="b1", adata="b1"),
+ B(bdata="b2", adata="b2"),
+ B(bdata="b3", adata="b3"),
+ C(cdata="c1", bdata="c1", adata="c1"),
+ C(cdata="c2", bdata="c2", adata="c2"),
+ C(cdata="c2", bdata="c2", adata="c2"),
+ ],
+ sess.query(B).order_by(A.id).all(),
+ )
- test_union = _make_test("union")
- test_none = _make_test("none")
+ eq_(
+ [
+ C(cdata="c1", bdata="c1", adata="c1"),
+ C(cdata="c2", bdata="c2", adata="c2"),
+ C(cdata="c2", bdata="c2", adata="c2"),
+ ],
+ sess.query(C).order_by(A.id).all(),
+ )
diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py
index 2f8677f8b..ecab0a497 100644
--- a/test/orm/inheritance/test_assorted_poly.py
+++ b/test/orm/inheritance/test_assorted_poly.py
@@ -7,7 +7,6 @@ from sqlalchemy import exists
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
-from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import Sequence
from sqlalchemy import String
@@ -15,7 +14,6 @@ from sqlalchemy import testing
from sqlalchemy import Unicode
from sqlalchemy import util
from sqlalchemy.orm import class_mapper
-from sqlalchemy.orm import clear_mappers
from sqlalchemy.orm import column_property
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import create_session
@@ -23,7 +21,6 @@ from sqlalchemy.orm import join
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import mapper
from sqlalchemy.orm import polymorphic_union
-from sqlalchemy.orm import Query
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.orm import with_polymorphic
@@ -33,15 +30,6 @@ from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
-from sqlalchemy.testing.util import function_named
-
-
-class AttrSettable(object):
- def __init__(self, **kwargs):
- [setattr(self, k, v) for k, v in kwargs.items()]
-
- def __repr__(self):
- return self.__class__.__name__ + "(%s)" % (hex(id(self)))
class RelationshipTest1(fixtures.MappedTest):
@@ -84,17 +72,17 @@ class RelationshipTest1(fixtures.MappedTest):
Column("manager_name", String(50)),
)
- def teardown(self):
- people.update(values={people.c.manager_id: None}).execute()
- super(RelationshipTest1, self).teardown()
-
- def test_parent_refs_descendant(self):
- class Person(AttrSettable):
+ @classmethod
+ def setup_classes(cls):
+ class Person(cls.Comparable):
pass
class Manager(Person):
pass
+ def test_parent_refs_descendant(self):
+ Person, Manager = self.classes("Person", "Manager")
+
mapper(
Person,
people,
@@ -132,11 +120,7 @@ class RelationshipTest1(fixtures.MappedTest):
assert p.manager is m
def test_descendant_refs_parent(self):
- class Person(AttrSettable):
- pass
-
- class Manager(Person):
- pass
+ Person, Manager = self.classes("Person", "Manager")
mapper(Person, people)
mapper(
@@ -212,31 +196,22 @@ class RelationshipTest2(fixtures.MappedTest):
Column("data", String(30)),
)
- def test_relationshiponsubclass_j1_nodata(self):
- self._do_test("join1", False)
-
- def test_relationshiponsubclass_j2_nodata(self):
- self._do_test("join2", False)
-
- def test_relationshiponsubclass_j1_data(self):
- self._do_test("join1", True)
-
- def test_relationshiponsubclass_j2_data(self):
- self._do_test("join2", True)
-
- def test_relationshiponsubclass_j3_nodata(self):
- self._do_test("join3", False)
-
- def test_relationshiponsubclass_j3_data(self):
- self._do_test("join3", True)
-
- def _do_test(self, jointype="join1", usedata=False):
- class Person(AttrSettable):
+ @classmethod
+ def setup_classes(cls):
+ class Person(cls.Comparable):
pass
class Manager(Person):
pass
+ @testing.combinations(
+ ("join1",), ("join2",), ("join3",), argnames="jointype"
+ )
+ @testing.combinations(
+ ("usedata", True), ("nodata", False), id_="ia", argnames="usedata"
+ )
+ def test_relationshiponsubclass(self, jointype, usedata):
+ Person, Manager = self.classes("Person", "Manager")
if jointype == "join1":
poly_union = polymorphic_union(
{
@@ -382,21 +357,20 @@ class RelationshipTest3(fixtures.MappedTest):
Column("data", String(30)),
)
-
-def _generate_test(jointype="join1", usedata=False):
- def _do_test(self):
- class Person(AttrSettable):
+ @classmethod
+ def setup_classes(cls):
+ class Person(cls.Comparable):
pass
class Manager(Person):
pass
- if usedata:
-
- class Data(object):
- def __init__(self, data):
- self.data = data
+ class Data(cls.Comparable):
+ def __init__(self, data):
+ self.data = data
+ def _setup_mappings(self, jointype, usedata):
+ Person, Manager, Data = self.classes("Person", "Manager", "Data")
if jointype == "join1":
poly_union = polymorphic_union(
{
@@ -427,6 +401,8 @@ def _generate_test(jointype="join1", usedata=False):
poly_union = people.outerjoin(managers)
elif jointype == "join4":
poly_union = None
+ else:
+ assert False
if usedata:
mapper(Data, data)
@@ -475,6 +451,16 @@ def _generate_test(jointype="join1", usedata=False):
polymorphic_identity="manager",
)
+ @testing.combinations(
+ ("join1",), ("join2",), ("join3",), ("join4",), argnames="jointype"
+ )
+ @testing.combinations(
+ ("usedata", True), ("nodata", False), id_="ia", argnames="usedata"
+ )
+ def test_relationship_on_base_class(self, jointype, usedata):
+ self._setup_mappings(jointype, usedata)
+ Person, Manager, Data = self.classes("Person", "Manager", "Data")
+
sess = create_session()
p = Person(name="person1")
p2 = Person(name="person2")
@@ -502,20 +488,6 @@ def _generate_test(jointype="join1", usedata=False):
assert p.data.data == "ps data"
assert m.data.data == "ms data"
- do_test = function_named(
- _do_test,
- "test_relationship_on_base_class_%s_%s"
- % (jointype, data and "nodata" or "data"),
- )
- return do_test
-
-
-for jointype in ["join1", "join2", "join3", "join4"]:
- for data in (True, False):
- _fn = _generate_test(jointype, data)
- setattr(RelationshipTest3, _fn.__name__, _fn)
-del _fn
-
class RelationshipTest4(fixtures.MappedTest):
@classmethod
@@ -853,13 +825,17 @@ class RelationshipTest6(fixtures.MappedTest):
Column("status", String(30)),
)
- def test_basic(self):
- class Person(AttrSettable):
+ @classmethod
+ def setup_classes(cls):
+ class Person(cls.Comparable):
pass
class Manager(Person):
pass
+ def test_basic(self):
+ Person, Manager = self.classes("Person", "Manager")
+
mapper(Person, people)
mapper(
@@ -1128,9 +1104,9 @@ class RelationshipTest8(fixtures.MappedTest):
)
-class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
+class GenerativeTest(fixtures.MappedTest, AssertsExecutionResults):
@classmethod
- def setup_class(cls):
+ def define_tables(cls, metadata):
# cars---owned by--- people (abstract) --- has a --- status
# | ^ ^ |
# | | | |
@@ -1138,10 +1114,8 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
# | |
# +--------------------------------------- has a ------+
- global metadata, status, people, engineers, managers, cars
- metadata = MetaData(testing.db)
# table definitions
- status = Table(
+ Table(
"status",
metadata,
Column(
@@ -1153,7 +1127,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
Column("name", String(20)),
)
- people = Table(
+ Table(
"people",
metadata,
Column(
@@ -1171,7 +1145,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
Column("name", String(50)),
)
- engineers = Table(
+ Table(
"engineers",
metadata,
Column(
@@ -1183,7 +1157,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
Column("field", String(30)),
)
- managers = Table(
+ Table(
"managers",
metadata,
Column(
@@ -1195,7 +1169,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
Column("category", String(70)),
)
- cars = Table(
+ Table(
"cars",
metadata,
Column(
@@ -1218,52 +1192,31 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
),
)
- metadata.create_all()
-
@classmethod
- def teardown_class(cls):
- metadata.drop_all()
-
- def teardown(self):
- clear_mappers()
- for t in reversed(metadata.sorted_tables):
- t.delete().execute()
-
- def test_join_to(self):
- # class definitions
- class PersistentObject(object):
- def __init__(self, **kwargs):
- for key, value in kwargs.items():
- setattr(self, key, value)
-
- class Status(PersistentObject):
- def __repr__(self):
- return "Status %s" % self.name
+ def setup_classes(cls):
+ class Status(cls.Comparable):
+ pass
- class Person(PersistentObject):
- def __repr__(self):
- return "Ordinary person %s" % self.name
+ class Person(cls.Comparable):
+ pass
class Engineer(Person):
- def __repr__(self):
- return "Engineer %s, field %s, status %s" % (
- self.name,
- self.field,
- self.status,
- )
+ pass
class Manager(Person):
- def __repr__(self):
- return "Manager %s, category %s, status %s" % (
- self.name,
- self.category,
- self.status,
- )
+ pass
- class Car(PersistentObject):
- def __repr__(self):
- return "Car number %d" % self.car_id
+ class Car(cls.Comparable):
+ pass
+ @classmethod
+ def setup_mappers(cls):
+ status, people, engineers, managers, cars = cls.tables(
+ "status", "people", "engineers", "managers", "cars"
+ )
+ Status, Person, Engineer, Manager, Car = cls.classes(
+ "Status", "Person", "Engineer", "Manager", "Car"
+ )
# create a union that represents both types of joins.
employee_join = polymorphic_union(
{
@@ -1283,7 +1236,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
polymorphic_identity="person",
properties={"status": relationship(status_mapper)},
)
- engineer_mapper = mapper(
+ mapper(
Engineer,
engineers,
inherits=person_mapper,
@@ -1304,6 +1257,11 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
},
)
+ @classmethod
+ def insert_data(cls):
+ Status, Person, Engineer, Manager, Car = cls.classes(
+ "Status", "Person", "Engineer", "Manager", "Car"
+ )
session = create_session()
active = Status(name="active")
@@ -1332,7 +1290,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
session.flush()
# get E4
- engineer4 = session.query(engineer_mapper).filter_by(name="E4").one()
+ engineer4 = session.query(Engineer).filter_by(name="E4").one()
# create 2 cars for E4, one active and one dead
car1 = Car(employee=engineer4, status=active)
@@ -1341,9 +1299,11 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
session.add(car2)
session.flush()
- # this particular adapt used to cause a recursion overflow;
- # added here for testing
- Query(Person)._adapt_clause(employee_join, False, False)
+ def test_join_to_q_person(self):
+ Status, Person, Engineer, Manager, Car = self.classes(
+ "Status", "Person", "Engineer", "Manager", "Car"
+ )
+ session = create_session()
r = (
session.query(Person)
@@ -1353,31 +1313,52 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
.order_by(Person.person_id)
)
eq_(
- str(list(r)),
- "[Manager M2, category YYYYYYYYY, status "
- "Status active, Engineer E2, field X, "
- "status Status active]",
+ list(r),
+ [
+ Manager(
+ name="M2",
+ category="YYYYYYYYY",
+ status=Status(name="active"),
+ ),
+ Engineer(name="E2", field="X", status=Status(name="active")),
+ ],
+ )
+
+ def test_join_to_q_engineer(self):
+ Status, Person, Engineer, Manager, Car = self.classes(
+ "Status", "Person", "Engineer", "Manager", "Car"
)
+ session = create_session()
r = (
session.query(Engineer)
.join("status")
.filter(
Person.name.in_(["E2", "E3", "E4", "M4", "M2", "M1"])
- & (status.c.name == "active")
+ & (Status.name == "active")
)
.order_by(Person.name)
)
eq_(
- str(list(r)),
- "[Engineer E2, field X, status Status "
- "active, Engineer E3, field X, status "
- "Status active]",
+ list(r),
+ [
+ Engineer(name="E2", field="X", status=Status(name="active")),
+ Engineer(name="E3", field="X", status=Status(name="active")),
+ ],
)
+ def test_join_to_q_person_car(self):
+ Status, Person, Engineer, Manager, Car = self.classes(
+ "Status", "Person", "Engineer", "Manager", "Car"
+ )
+ session = create_session()
r = session.query(Person).filter(
exists([1], Car.owner == Person.person_id)
)
- eq_(str(list(r)), "[Engineer E4, field X, status Status dead]")
+
+ eq_(
+ list(r),
+ [Engineer(name="E4", field="X", status=Status(name="dead"))],
+ )
class MultiLevelTest(fixtures.MappedTest):
diff --git a/test/orm/inheritance/test_magazine.py b/test/orm/inheritance/test_magazine.py
index 1abfb9032..228cb1273 100644
--- a/test/orm/inheritance/test_magazine.py
+++ b/test/orm/inheritance/test_magazine.py
@@ -1,115 +1,54 @@
+"""A legacy test for a particular somewhat complicated mapping."""
+
from sqlalchemy import CHAR
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import String
+from sqlalchemy import testing
from sqlalchemy import Text
from sqlalchemy.orm import backref
-from sqlalchemy.orm import create_session
from sqlalchemy.orm import mapper
from sqlalchemy.orm import polymorphic_union
from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
+from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
-from sqlalchemy.testing.util import function_named
-
-
-class BaseObject(object):
- def __init__(self, *args, **kwargs):
- for key, value in kwargs.items():
- setattr(self, key, value)
-
-
-class Publication(BaseObject):
- pass
-
-
-class Issue(BaseObject):
- pass
-
-
-class Location(BaseObject):
- def __repr__(self):
- return "%s(%s, %s)" % (
- self.__class__.__name__,
- str(getattr(self, "issue_id", None)),
- repr(str(self._name.name)),
- )
-
- def _get_name(self):
- return self._name
-
- def _set_name(self, name):
- session = create_session()
- s = (
- session.query(LocationName)
- .filter(LocationName.name == name)
- .first()
- )
- session.expunge_all()
- if s is not None:
- self._name = s
-
- return
-
- found = False
- for i in session.new:
- if isinstance(i, LocationName) and i.name == name:
- self._name = i
- found = True
- break
-
- if found is False:
- self._name = LocationName(name=name)
-
- name = property(_get_name, _set_name)
-
-
-class LocationName(BaseObject):
- def __repr__(self):
- return "%s()" % (self.__class__.__name__)
-
-
-class PageSize(BaseObject):
- def __repr__(self):
- return "%s(%sx%s, %s)" % (
- self.__class__.__name__,
- self.width,
- self.height,
- self.name,
- )
+class MagazineTest(fixtures.MappedTest):
+ @classmethod
+ def setup_classes(cls):
+ Base = cls.Comparable
+ class Publication(Base):
+ pass
-class Magazine(BaseObject):
- def __repr__(self):
- return "%s(%s, %s)" % (
- self.__class__.__name__,
- repr(self.location),
- repr(self.size),
- )
+ class Issue(Base):
+ pass
+ class Location(Base):
+ pass
-class Page(BaseObject):
- def __repr__(self):
- return "%s(%s)" % (self.__class__.__name__, str(self.page_no))
+ class LocationName(Base):
+ pass
+ class PageSize(Base):
+ pass
-class MagazinePage(Page):
- def __repr__(self):
- return "%s(%s, %s)" % (
- self.__class__.__name__,
- str(self.page_no),
- repr(self.magazine),
- )
+ class Magazine(Base):
+ pass
+ class Page(Base):
+ pass
-class ClassifiedPage(MagazinePage):
- pass
+ class MagazinePage(Page):
+ pass
+ class ClassifiedPage(MagazinePage):
+ pass
-class MagazineTest(fixtures.MappedTest):
@classmethod
def define_tables(cls, metadata):
Table(
@@ -198,9 +137,65 @@ class MagazineTest(fixtures.MappedTest):
Column("name", String(45), default=""),
)
+ def _generate_data(self):
+ (
+ Publication,
+ Issue,
+ Location,
+ LocationName,
+ PageSize,
+ Magazine,
+ Page,
+ MagazinePage,
+ ClassifiedPage,
+ ) = self.classes(
+ "Publication",
+ "Issue",
+ "Location",
+ "LocationName",
+ "PageSize",
+ "Magazine",
+ "Page",
+ "MagazinePage",
+ "ClassifiedPage",
+ )
+ london = LocationName(name="London")
+ pub = Publication(name="Test")
+ issue = Issue(issue=46, publication=pub)
+ location = Location(ref="ABC", name=london, issue=issue)
+
+ page_size = PageSize(name="A4", width=210, height=297)
-def _generate_round_trip_test(use_unions=False, use_joins=False):
- def test_roundtrip(self):
+ magazine = Magazine(location=location, size=page_size)
+
+ ClassifiedPage(magazine=magazine, page_no=1)
+ MagazinePage(magazine=magazine, page_no=2)
+ ClassifiedPage(magazine=magazine, page_no=3)
+
+ return pub
+
+ def _setup_mapping(self, use_unions, use_joins):
+ (
+ Publication,
+ Issue,
+ Location,
+ LocationName,
+ PageSize,
+ Magazine,
+ Page,
+ MagazinePage,
+ ClassifiedPage,
+ ) = self.classes(
+ "Publication",
+ "Issue",
+ "Location",
+ "LocationName",
+ "PageSize",
+ "Magazine",
+ "Page",
+ "MagazinePage",
+ "ClassifiedPage",
+ )
mapper(Publication, self.tables.publication)
mapper(
@@ -228,7 +223,7 @@ def _generate_round_trip_test(use_unions=False, use_joins=False):
cascade="all, delete-orphan",
),
),
- "_name": relationship(LocationName),
+ "name": relationship(LocationName),
},
)
@@ -354,42 +349,29 @@ def _generate_round_trip_test(use_unions=False, use_joins=False):
primary_key=[self.tables.page.c.id],
)
- session = create_session()
-
- pub = Publication(name="Test")
- issue = Issue(issue=46, publication=pub)
- location = Location(ref="ABC", name="London", issue=issue)
+ @testing.combinations(
+ ("unions", True, False),
+ ("joins", False, True),
+ ("plain", False, False),
+ id_="iaa",
+ )
+ def test_magazine_round_trip(self, use_unions, use_joins):
+ self._setup_mapping(use_unions, use_joins)
- page_size = PageSize(name="A4", width=210, height=297)
+ Publication = self.classes.Publication
- magazine = Magazine(location=location, size=page_size)
+ session = Session()
- page = ClassifiedPage(magazine=magazine, page_no=1)
- page2 = MagazinePage(magazine=magazine, page_no=2)
- page3 = ClassifiedPage(magazine=magazine, page_no=3)
+ pub = self._generate_data()
session.add(pub)
+ session.commit()
+ session.close()
- session.flush()
- print([x for x in session])
- session.expunge_all()
-
- session.flush()
- session.expunge_all()
p = session.query(Publication).filter(Publication.name == "Test").one()
- print(p.issues[0].locations[0].magazine.pages)
- print([page, page2, page3])
- assert repr(p.issues[0].locations[0].magazine.pages) == repr(
- [page, page2, page3]
- ), repr(p.issues[0].locations[0].magazine.pages)
-
- test_roundtrip = function_named(
- test_roundtrip,
- "test_%s"
- % (not use_union and (use_joins and "joins" or "select") or "unions"),
- )
- setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip)
-
-
-for (use_union, use_join) in [(True, False), (False, True), (False, False)]:
- _generate_round_trip_test(use_union, use_join)
+ test_pub = self._generate_data()
+ eq_(p, test_pub)
+ eq_(
+ p.issues[0].locations[0].magazine.pages,
+ test_pub.issues[0].locations[0].magazine.pages,
+ )
diff --git a/test/orm/inheritance/test_poly_persistence.py b/test/orm/inheritance/test_poly_persistence.py
index 1cef654cd..508cb9965 100644
--- a/test/orm/inheritance/test_poly_persistence.py
+++ b/test/orm/inheritance/test_poly_persistence.py
@@ -2,9 +2,7 @@
from sqlalchemy import exc as sa_exc
from sqlalchemy import ForeignKey
-from sqlalchemy import func
from sqlalchemy import Integer
-from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
@@ -12,12 +10,12 @@ from sqlalchemy.orm import create_session
from sqlalchemy.orm import mapper
from sqlalchemy.orm import polymorphic_union
from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing.schema import Column
-from sqlalchemy.testing.util import function_named
class Person(fixtures.ComparableEntity):
@@ -115,8 +113,6 @@ class PolymorphTest(fixtures.MappedTest):
Column("golf_swing", String(30)),
)
- metadata.create_all()
-
class InsertOrderTest(PolymorphTest):
def test_insert_order(self):
@@ -198,28 +194,41 @@ class InsertOrderTest(PolymorphTest):
eq_(session.query(Company).get(c.company_id), c)
+@testing.combinations(
+ ("lazy", True), ("nonlazy", False), argnames="lazy_relationship", id_="ia"
+)
+@testing.combinations(
+ ("redefine", True),
+ ("noredefine", False),
+ argnames="redefine_colprop",
+ id_="ia",
+)
+@testing.combinations(
+ ("unions", True),
+ ("unions", False),
+ ("joins", False),
+ ("auto", False),
+ ("none", False),
+ argnames="with_polymorphic,include_base",
+ id_="rr",
+)
class RoundTripTest(PolymorphTest):
- pass
-
-
-def _generate_round_trip_test(
- include_base, lazy_relationship, redefine_colprop, with_polymorphic
-):
- """generates a round trip test.
-
- include_base - whether or not to include the base 'person' type in
- the union.
+ lazy_relationship = None
+ include_base = None
+ redefine_colprop = None
+ with_polymorphic = None
- lazy_relationship - whether or not the Company relationship to
- People is lazy or eager.
+ run_inserts = "once"
+ run_deletes = None
+ run_setup_mappers = "once"
- redefine_colprop - if we redefine the 'name' column to be
- 'people_name' on the base Person class
-
- use_literal_join - primary join condition is explicitly specified
- """
+ @classmethod
+ def setup_mappers(cls):
+ include_base = cls.include_base
+ lazy_relationship = cls.lazy_relationship
+ redefine_colprop = cls.redefine_colprop
+ with_polymorphic = cls.with_polymorphic
- def test_roundtrip(self):
if with_polymorphic == "unions":
if include_base:
person_join = polymorphic_union(
@@ -308,6 +317,11 @@ def _generate_round_trip_test(
},
)
+ @classmethod
+ def insert_data(cls):
+ redefine_colprop = cls.redefine_colprop
+ include_base = cls.include_base
+
if redefine_colprop:
person_attribute_name = "person_name"
else:
@@ -342,15 +356,48 @@ def _generate_round_trip_test(
),
]
- dilbert = employees[1]
-
- session = create_session()
+ session = Session()
c = Company(name="company1")
c.employees = employees
session.add(c)
- session.flush()
- session.expunge_all()
+ session.commit()
+
+ @testing.fixture
+ def get_dilbert(self):
+ def run(session):
+ if self.redefine_colprop:
+ person_attribute_name = "person_name"
+ else:
+ person_attribute_name = "name"
+
+ dilbert = (
+ session.query(Engineer)
+ .filter_by(**{person_attribute_name: "dilbert"})
+ .one()
+ )
+ return dilbert
+
+ return run
+
+ def test_lazy_load(self):
+ lazy_relationship = self.lazy_relationship
+ with_polymorphic = self.with_polymorphic
+
+ if self.redefine_colprop:
+ person_attribute_name = "person_name"
+ else:
+ person_attribute_name = "name"
+
+ session = create_session()
+
+ dilbert = (
+ session.query(Engineer)
+ .filter_by(**{person_attribute_name: "dilbert"})
+ .one()
+ )
+ employees = session.query(Person).order_by(Person.person_id).all()
+ company = session.query(Company).first()
eq_(session.query(Person).get(dilbert.person_id), dilbert)
session.expunge_all()
@@ -364,20 +411,29 @@ def _generate_round_trip_test(
session.expunge_all()
def go():
- cc = session.query(Company).get(c.company_id)
+ cc = session.query(Company).get(company.company_id)
eq_(cc.employees, employees)
if not lazy_relationship:
if with_polymorphic != "none":
self.assert_sql_count(testing.db, go, 1)
else:
- self.assert_sql_count(testing.db, go, 5)
+ self.assert_sql_count(testing.db, go, 2)
else:
if with_polymorphic != "none":
self.assert_sql_count(testing.db, go, 2)
else:
- self.assert_sql_count(testing.db, go, 6)
+ self.assert_sql_count(testing.db, go, 3)
+
+ def test_baseclass_lookup(self, get_dilbert):
+ session = Session()
+ dilbert = get_dilbert(session)
+
+ if self.redefine_colprop:
+ person_attribute_name = "person_name"
+ else:
+ person_attribute_name = "name"
# test selecting from the query, using the base
# mapped table (people) as the selection criterion.
@@ -390,12 +446,14 @@ def _generate_round_trip_test(
dilbert,
)
- assert (
- session.query(Person)
- .filter(getattr(Person, person_attribute_name) == "dilbert")
- .first()
- .person_id
- )
+ def test_subclass_lookup(self, get_dilbert):
+ session = Session()
+ dilbert = get_dilbert(session)
+
+ if self.redefine_colprop:
+ person_attribute_name = "person_name"
+ else:
+ person_attribute_name = "name"
eq_(
session.query(Engineer)
@@ -404,6 +462,10 @@ def _generate_round_trip_test(
dilbert,
)
+ def test_baseclass_base_alias_filter(self, get_dilbert):
+ session = Session()
+ dilbert = get_dilbert(session)
+
# test selecting from the query, joining against
# an alias of the base "people" table. test that
# the "palias" alias does *not* get sucked up
@@ -419,6 +481,13 @@ def _generate_round_trip_test(
)
.first(),
)
+
+ def test_subclass_base_alias_filter(self, get_dilbert):
+ session = Session()
+ dilbert = get_dilbert(session)
+
+ palias = people.alias("palias")
+
is_(
dilbert,
session.query(Engineer)
@@ -428,6 +497,11 @@ def _generate_round_trip_test(
)
.first(),
)
+
+ def test_baseclass_sub_table_filter(self, get_dilbert):
+ session = Session()
+ dilbert = get_dilbert(session)
+
is_(
dilbert,
session.query(Person)
@@ -437,6 +511,11 @@ def _generate_round_trip_test(
)
.first(),
)
+
+ def test_subclass_getitem(self, get_dilbert):
+ session = Session()
+ dilbert = get_dilbert(session)
+
is_(
dilbert,
session.query(Engineer).filter(
@@ -444,17 +523,16 @@ def _generate_round_trip_test(
)[0],
)
- session.flush()
- session.expunge_all()
+ def test_primary_table_only_for_requery(self):
- def go():
- session.query(Person).filter(
- getattr(Person, person_attribute_name) == "dilbert"
- ).first()
+ session = Session()
- self.assert_sql_count(testing.db, go, 1)
- session.expunge_all()
- dilbert = (
+ if self.redefine_colprop:
+ person_attribute_name = "person_name"
+ else:
+ person_attribute_name = "name"
+
+ dilbert = ( # noqa
session.query(Person)
.filter(getattr(Person, person_attribute_name) == "dilbert")
.first()
@@ -471,7 +549,14 @@ def _generate_round_trip_test(
self.assert_sql_count(testing.db, go, 1)
- # test standalone orphans
+ def test_standalone_orphans(self):
+ if self.redefine_colprop:
+ person_attribute_name = "person_name"
+ else:
+ person_attribute_name = "name"
+
+ session = Session()
+
daboss = Boss(
status="BBB",
manager_name="boss",
@@ -480,52 +565,3 @@ def _generate_round_trip_test(
)
session.add(daboss)
assert_raises(sa_exc.DBAPIError, session.flush)
-
- c = session.query(Company).first()
- daboss.company = c
- manager_list = [e for e in c.employees if isinstance(e, Manager)]
- session.flush()
- session.expunge_all()
-
- eq_(
- session.query(Manager).order_by(Manager.person_id).all(),
- manager_list,
- )
- c = session.query(Company).first()
-
- session.delete(c)
- session.flush()
-
- eq_(select([func.count("*")]).select_from(people).scalar(), 0)
-
- test_roundtrip = function_named(
- test_roundtrip,
- "test_%s%s%s_%s"
- % (
- (lazy_relationship and "lazy" or "eager"),
- (include_base and "_inclbase" or ""),
- (redefine_colprop and "_redefcol" or ""),
- with_polymorphic,
- ),
- )
- setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip)
-
-
-for lazy_relationship in [True, False]:
- for redefine_colprop in [True, False]:
- for with_polymorphic_ in ["unions", "joins", "auto", "none"]:
- if with_polymorphic_ == "unions":
- for include_base in [True, False]:
- _generate_round_trip_test(
- include_base,
- lazy_relationship,
- redefine_colprop,
- with_polymorphic_,
- )
- else:
- _generate_round_trip_test(
- False,
- lazy_relationship,
- redefine_colprop,
- with_polymorphic_,
- )
diff --git a/test/orm/test_descriptor.py b/test/orm/test_descriptor.py
index 1baa82d3d..7b530b928 100644
--- a/test/orm/test_descriptor.py
+++ b/test/orm/test_descriptor.py
@@ -13,7 +13,7 @@ from sqlalchemy.testing import fixtures
from sqlalchemy.util import partial
-class TestDescriptor(descriptor_props.DescriptorProperty):
+class MockDescriptor(descriptor_props.DescriptorProperty):
def __init__(
self, cls, key, descriptor=None, doc=None, comparator_factory=None
):
@@ -40,7 +40,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest):
def test_fixture(self):
Foo = self._fixture()
- d = TestDescriptor(Foo, "foo")
+ d = MockDescriptor(Foo, "foo")
d.instrument_class(Foo.__mapper__)
assert Foo.foo
@@ -50,7 +50,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest):
prop = property(lambda self: None)
Foo.foo = prop
- d = TestDescriptor(Foo, "foo")
+ d = MockDescriptor(Foo, "foo")
d.instrument_class(Foo.__mapper__)
assert Foo().foo is None
@@ -68,7 +68,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest):
prop = myprop(lambda self: None)
Foo.foo = prop
- d = TestDescriptor(Foo, "foo")
+ d = MockDescriptor(Foo, "foo")
d.instrument_class(Foo.__mapper__)
assert Foo().foo is None
@@ -95,7 +95,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest):
return column("foo") == func.upper(other)
Foo = self._fixture()
- d = TestDescriptor(Foo, "foo", comparator_factory=Comparator)
+ d = MockDescriptor(Foo, "foo", comparator_factory=Comparator)
d.instrument_class(Foo.__mapper__)
eq_(Foo.foo.method1(), "method1")
eq_(Foo.foo.method2("x"), "method2")
@@ -119,7 +119,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest):
prop = mapper._props["_name"]
return Comparator(prop, mapper)
- d = TestDescriptor(Foo, "foo", comparator_factory=comparator_factory)
+ d = MockDescriptor(Foo, "foo", comparator_factory=comparator_factory)
d.instrument_class(Foo.__mapper__)
eq_(str(Foo.foo == "ed"), "foobar(foo.name) = foobar(:foobar_1)")
diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py
index 66fe18598..637f1f8a5 100644
--- a/test/sql/test_operators.py
+++ b/test/sql/test_operators.py
@@ -73,12 +73,41 @@ class LoopOperate(operators.ColumnOperators):
class DefaultColumnComparatorTest(fixtures.TestBase):
- def _do_scalar_test(self, operator, compare_to):
+ @testing.combinations((operators.desc_op, desc), (operators.asc_op, asc))
+ def test_scalar(self, operator, compare_to):
left = column("left")
assert left.comparator.operate(operator).compare(compare_to(left))
self._loop_test(operator)
- def _do_operate_test(self, operator, right=column("right")):
+ right_column = column("right")
+
+ @testing.combinations(
+ (operators.add, right_column),
+ (operators.is_, None),
+ (operators.isnot, None),
+ (operators.is_, null()),
+ (operators.is_, true()),
+ (operators.is_, false()),
+ (operators.eq, True),
+ (operators.ne, True),
+ (operators.is_distinct_from, True),
+ (operators.is_distinct_from, False),
+ (operators.is_distinct_from, None),
+ (operators.isnot_distinct_from, True),
+ (operators.is_, True),
+ (operators.isnot, True),
+ (operators.is_, False),
+ (operators.isnot, False),
+ (operators.like_op, right_column),
+ (operators.notlike_op, right_column),
+ (operators.ilike_op, right_column),
+ (operators.notilike_op, right_column),
+ (operators.is_, right_column),
+ (operators.isnot, right_column),
+ (operators.concat_op, right_column),
+ id_="ns",
+ )
+ def test_operate(self, operator, right):
left = column("left")
assert left.comparator.operate(operator, right).compare(
@@ -109,84 +138,13 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
loop = LoopOperate()
is_(operator(loop, *arg), operator)
- def test_desc(self):
- self._do_scalar_test(operators.desc_op, desc)
-
- def test_asc(self):
- self._do_scalar_test(operators.asc_op, asc)
-
- def test_plus(self):
- self._do_operate_test(operators.add)
-
- def test_is_null(self):
- self._do_operate_test(operators.is_, None)
-
- def test_isnot_null(self):
- self._do_operate_test(operators.isnot, None)
-
- def test_is_null_const(self):
- self._do_operate_test(operators.is_, null())
-
- def test_is_true_const(self):
- self._do_operate_test(operators.is_, true())
-
- def test_is_false_const(self):
- self._do_operate_test(operators.is_, false())
-
- def test_equals_true(self):
- self._do_operate_test(operators.eq, True)
-
- def test_notequals_true(self):
- self._do_operate_test(operators.ne, True)
-
- def test_is_distinct_from_true(self):
- self._do_operate_test(operators.is_distinct_from, True)
-
- def test_is_distinct_from_false(self):
- self._do_operate_test(operators.is_distinct_from, False)
-
- def test_is_distinct_from_null(self):
- self._do_operate_test(operators.is_distinct_from, None)
-
- def test_isnot_distinct_from_true(self):
- self._do_operate_test(operators.isnot_distinct_from, True)
-
- def test_is_true(self):
- self._do_operate_test(operators.is_, True)
-
- def test_isnot_true(self):
- self._do_operate_test(operators.isnot, True)
-
- def test_is_false(self):
- self._do_operate_test(operators.is_, False)
-
- def test_isnot_false(self):
- self._do_operate_test(operators.isnot, False)
-
- def test_like(self):
- self._do_operate_test(operators.like_op)
-
- def test_notlike(self):
- self._do_operate_test(operators.notlike_op)
-
- def test_ilike(self):
- self._do_operate_test(operators.ilike_op)
-
- def test_notilike(self):
- self._do_operate_test(operators.notilike_op)
-
- def test_is(self):
- self._do_operate_test(operators.is_)
-
- def test_isnot(self):
- self._do_operate_test(operators.isnot)
-
def test_no_getitem(self):
assert_raises_message(
NotImplementedError,
"Operator 'getitem' is not supported on this expression",
- self._do_operate_test,
+ self.test_operate,
operators.getitem,
+ column("right"),
)
assert_raises_message(
NotImplementedError,
@@ -274,9 +232,6 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
collate(left, right)
)
- def test_concat(self):
- self._do_operate_test(operators.concat_op)
-
def test_default_adapt(self):
class TypeOne(TypeEngine):
pass
@@ -329,7 +284,8 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default"
- def _factorial_fixture(self):
+ @testing.fixture
+ def factorial(self):
class MyInteger(Integer):
class comparator_factory(Integer.Comparator):
def factorial(self):
@@ -355,24 +311,24 @@ class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
return MyInteger
- def test_factorial(self):
- col = column("somecol", self._factorial_fixture())
+ def test_factorial(self, factorial):
+ col = column("somecol", factorial())
self.assert_compile(col.factorial(), "somecol !")
- def test_double_factorial(self):
- col = column("somecol", self._factorial_fixture())
+ def test_double_factorial(self, factorial):
+ col = column("somecol", factorial())
self.assert_compile(col.factorial().factorial(), "somecol ! !")
- def test_factorial_prefix(self):
- col = column("somecol", self._factorial_fixture())
+ def test_factorial_prefix(self, factorial):
+ col = column("somecol", factorial())
self.assert_compile(col.factorial_prefix(), "!! somecol")
- def test_factorial_invert(self):
- col = column("somecol", self._factorial_fixture())
+ def test_factorial_invert(self, factorial):
+ col = column("somecol", factorial())
self.assert_compile(~col, "!!! somecol")
- def test_double_factorial_invert(self):
- col = column("somecol", self._factorial_fixture())
+ def test_double_factorial_invert(self, factorial):
+ col = column("somecol", factorial())
self.assert_compile(~(~col), "!!! (!!! somecol)")
def test_unary_no_ops(self):
@@ -1845,7 +1801,15 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
table1 = table("mytable", column("myid", Integer))
- def _test_math_op(self, py_op, sql_op):
+ @testing.combinations(
+ ("add", operator.add, "+"),
+ ("mul", operator.mul, "*"),
+ ("sub", operator.sub, "-"),
+ ("div", operator.truediv if util.py3k else operator.div, "/"),
+ ("mod", operator.mod, "%"),
+ id_="iaa",
+ )
+ def test_math_op(self, py_op, sql_op):
for (lhs, rhs, res) in (
(5, self.table1.c.myid, ":myid_1 %s mytable.myid"),
(5, literal(5), ":param_1 %s :param_2"),
@@ -1862,24 +1826,6 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
):
self.assert_compile(py_op(lhs, rhs), res % sql_op)
- def test_math_op_add(self):
- self._test_math_op(operator.add, "+")
-
- def test_math_op_mul(self):
- self._test_math_op(operator.mul, "*")
-
- def test_math_op_sub(self):
- self._test_math_op(operator.sub, "-")
-
- def test_math_op_div(self):
- if util.py3k:
- self._test_math_op(operator.truediv, "/")
- else:
- self._test_math_op(operator.div, "/")
-
- def test_math_op_mod(self):
- self._test_math_op(operator.mod, "%")
-
class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default"
@@ -1898,7 +1844,16 @@ class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
clause = tuple_(1, 2, 3)
eq_(str(clause), str(util.pickle.loads(util.pickle.dumps(clause))))
- def _test_comparison_op(self, py_op, fwd_op, rev_op):
+ @testing.combinations(
+ (operator.lt, "<", ">"),
+ (operator.gt, ">", "<"),
+ (operator.eq, "=", "="),
+ (operator.ne, "!=", "!="),
+ (operator.le, "<=", ">="),
+ (operator.ge, ">=", "<="),
+ id_="naa",
+ )
+ def test_comparison_op(self, py_op, fwd_op, rev_op):
dt = datetime.datetime(2012, 5, 10, 15, 27, 18)
for (lhs, rhs, l_sql, r_sql) in (
("a", self.table1.c.myid, ":myid_1", "mytable.myid"),
@@ -1935,24 +1890,6 @@ class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
+ "'",
)
- def test_comparison_operators_lt(self):
- self._test_comparison_op(operator.lt, "<", ">"),
-
- def test_comparison_operators_gt(self):
- self._test_comparison_op(operator.gt, ">", "<")
-
- def test_comparison_operators_eq(self):
- self._test_comparison_op(operator.eq, "=", "=")
-
- def test_comparison_operators_ne(self):
- self._test_comparison_op(operator.ne, "!=", "!=")
-
- def test_comparison_operators_le(self):
- self._test_comparison_op(operator.le, "<=", ">=")
-
- def test_comparison_operators_ge(self):
- self._test_comparison_op(operator.ge, ">=", "<=")
-
class NonZeroTest(fixtures.TestBase):
def _raises(self, expr):
@@ -2690,38 +2627,39 @@ class CustomOpTest(fixtures.TestBase):
assert operators.is_comparison(op1)
assert not operators.is_comparison(op2)
- def test_return_types(self):
+ @testing.combinations(
+ (sqltypes.NULLTYPE,),
+ (Integer(),),
+ (ARRAY(String),),
+ (String(50),),
+ (Boolean(),),
+ (DateTime(),),
+ (sqltypes.JSON(),),
+ (postgresql.ARRAY(Integer),),
+ (sqltypes.Numeric(5, 2),),
+ id_="r",
+ )
+ def test_return_types(self, typ):
some_return_type = sqltypes.DECIMAL()
- for typ in [
- sqltypes.NULLTYPE,
- Integer(),
- ARRAY(String),
- String(50),
- Boolean(),
- DateTime(),
- sqltypes.JSON(),
- postgresql.ARRAY(Integer),
- sqltypes.Numeric(5, 2),
- ]:
- c = column("x", typ)
- expr = c.op("$", is_comparison=True)(None)
- is_(expr.type, sqltypes.BOOLEANTYPE)
+ c = column("x", typ)
+ expr = c.op("$", is_comparison=True)(None)
+ is_(expr.type, sqltypes.BOOLEANTYPE)
- c = column("x", typ)
- expr = c.bool_op("$")(None)
- is_(expr.type, sqltypes.BOOLEANTYPE)
+ c = column("x", typ)
+ expr = c.bool_op("$")(None)
+ is_(expr.type, sqltypes.BOOLEANTYPE)
- expr = c.op("$")(None)
- is_(expr.type, typ)
+ expr = c.op("$")(None)
+ is_(expr.type, typ)
- expr = c.op("$", return_type=some_return_type)(None)
- is_(expr.type, some_return_type)
+ expr = c.op("$", return_type=some_return_type)(None)
+ is_(expr.type, some_return_type)
- expr = c.op("$", is_comparison=True, return_type=some_return_type)(
- None
- )
- is_(expr.type, some_return_type)
+ expr = c.op("$", is_comparison=True, return_type=some_return_type)(
+ None
+ )
+ is_(expr.type, some_return_type)
class TupleTypingTest(fixtures.TestBase):
@@ -2756,7 +2694,8 @@ class TupleTypingTest(fixtures.TestBase):
class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default"
- def _fixture(self):
+ @testing.fixture
+ def t_fixture(self):
m = MetaData()
t = Table(
@@ -2767,8 +2706,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
)
return t
- def test_any_array(self):
- t = self._fixture()
+ def test_any_array(self, t_fixture):
+ t = t_fixture
self.assert_compile(
5 == any_(t.c.arrval),
@@ -2776,8 +2715,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
checkparams={"param_1": 5},
)
- def test_any_array_method(self):
- t = self._fixture()
+ def test_any_array_method(self, t_fixture):
+ t = t_fixture
self.assert_compile(
5 == t.c.arrval.any_(),
@@ -2785,8 +2724,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
checkparams={"param_1": 5},
)
- def test_all_array(self):
- t = self._fixture()
+ def test_all_array(self, t_fixture):
+ t = t_fixture
self.assert_compile(
5 == all_(t.c.arrval),
@@ -2794,8 +2733,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
checkparams={"param_1": 5},
)
- def test_all_array_method(self):
- t = self._fixture()
+ def test_all_array_method(self, t_fixture):
+ t = t_fixture
self.assert_compile(
5 == t.c.arrval.all_(),
@@ -2803,8 +2742,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
checkparams={"param_1": 5},
)
- def test_any_comparator_array(self):
- t = self._fixture()
+ def test_any_comparator_array(self, t_fixture):
+ t = t_fixture
self.assert_compile(
5 > any_(t.c.arrval),
@@ -2812,8 +2751,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
checkparams={"param_1": 5},
)
- def test_all_comparator_array(self):
- t = self._fixture()
+ def test_all_comparator_array(self, t_fixture):
+ t = t_fixture
self.assert_compile(
5 > all_(t.c.arrval),
@@ -2821,8 +2760,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
checkparams={"param_1": 5},
)
- def test_any_comparator_array_wexpr(self):
- t = self._fixture()
+ def test_any_comparator_array_wexpr(self, t_fixture):
+ t = t_fixture
self.assert_compile(
t.c.data > any_(t.c.arrval),
@@ -2830,8 +2769,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
checkparams={},
)
- def test_all_comparator_array_wexpr(self):
- t = self._fixture()
+ def test_all_comparator_array_wexpr(self, t_fixture):
+ t = t_fixture
self.assert_compile(
t.c.data > all_(t.c.arrval),
@@ -2839,8 +2778,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
checkparams={},
)
- def test_illegal_ops(self):
- t = self._fixture()
+ def test_illegal_ops(self, t_fixture):
+ t = t_fixture
assert_raises_message(
exc.ArgumentError,
@@ -2856,8 +2795,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
t.c.data + all_(t.c.arrval), "tab1.data + ALL (tab1.arrval)"
)
- def test_any_array_comparator_accessor(self):
- t = self._fixture()
+ def test_any_array_comparator_accessor(self, t_fixture):
+ t = t_fixture
self.assert_compile(
t.c.arrval.any(5, operator.gt),
@@ -2865,8 +2804,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
checkparams={"param_1": 5},
)
- def test_all_array_comparator_accessor(self):
- t = self._fixture()
+ def test_all_array_comparator_accessor(self, t_fixture):
+ t = t_fixture
self.assert_compile(
t.c.arrval.all(5, operator.gt),
@@ -2874,8 +2813,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
checkparams={"param_1": 5},
)
- def test_any_array_expression(self):
- t = self._fixture()
+ def test_any_array_expression(self, t_fixture):
+ t = t_fixture
self.assert_compile(
5 == any_(t.c.arrval[5:6] + postgresql.array([3, 4])),
@@ -2891,8 +2830,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
dialect="postgresql",
)
- def test_all_array_expression(self):
- t = self._fixture()
+ def test_all_array_expression(self, t_fixture):
+ t = t_fixture
self.assert_compile(
5 == all_(t.c.arrval[5:6] + postgresql.array([3, 4])),
@@ -2908,8 +2847,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
dialect="postgresql",
)
- def test_any_subq(self):
- t = self._fixture()
+ def test_any_subq(self, t_fixture):
+ t = t_fixture
self.assert_compile(
5
@@ -2919,8 +2858,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
checkparams={"data_1": 10, "param_1": 5},
)
- def test_any_subq_method(self):
- t = self._fixture()
+ def test_any_subq_method(self, t_fixture):
+ t = t_fixture
self.assert_compile(
5
@@ -2933,8 +2872,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
checkparams={"data_1": 10, "param_1": 5},
)
- def test_all_subq(self):
- t = self._fixture()
+ def test_all_subq(self, t_fixture):
+ t = t_fixture
self.assert_compile(
5
@@ -2944,8 +2883,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
checkparams={"data_1": 10, "param_1": 5},
)
- def test_all_subq_method(self):
- t = self._fixture()
+ def test_all_subq_method(self, t_fixture):
+ t = t_fixture
self.assert_compile(
5
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 7bf83b461..2ffdd83b7 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -5,6 +5,7 @@ import importlib
import operator
import os
+import sqlalchemy as sa
from sqlalchemy import and_
from sqlalchemy import ARRAY
from sqlalchemy import BigInteger
@@ -87,42 +88,83 @@ from sqlalchemy.testing.util import round_decimal
from sqlalchemy.util import OrderedDict
-class AdaptTest(fixtures.TestBase):
- def _all_dialect_modules(self):
- return [
- importlib.import_module("sqlalchemy.dialects.%s" % d)
- for d in dialects.__all__
- if not d.startswith("_")
- ]
+def _all_dialect_modules():
+ return [
+ importlib.import_module("sqlalchemy.dialects.%s" % d)
+ for d in dialects.__all__
+ if not d.startswith("_")
+ ]
- def _all_dialects(self):
- return [d.base.dialect() for d in self._all_dialect_modules()]
- def _types_for_mod(self, mod):
- for key in dir(mod):
- typ = getattr(mod, key)
- if not isinstance(typ, type) or not issubclass(
- typ, types.TypeEngine
- ):
- continue
- yield typ
+def _all_dialects():
+ return [d.base.dialect() for d in _all_dialect_modules()]
- def _all_types(self):
- for typ in self._types_for_mod(types):
- yield typ
- for dialect in self._all_dialect_modules():
- for typ in self._types_for_mod(dialect):
- yield typ
- def test_uppercase_importable(self):
- import sqlalchemy as sa
+def _types_for_mod(mod):
+ for key in dir(mod):
+ typ = getattr(mod, key)
+ if not isinstance(typ, type) or not issubclass(typ, types.TypeEngine):
+ continue
+ yield typ
- for typ in self._types_for_mod(types):
- if typ.__name__ == typ.__name__.upper():
- assert getattr(sa, typ.__name__) is typ
- assert typ.__name__ in types.__all__
- def test_uppercase_rendering(self):
+def _all_types(omit_special_types=False):
+ seen = set()
+ for typ in _types_for_mod(types):
+ if omit_special_types and typ in (
+ types.TypeDecorator,
+ types.TypeEngine,
+ types.Variant,
+ ):
+ continue
+
+ if typ in seen:
+ continue
+ seen.add(typ)
+ yield typ
+ for dialect in _all_dialect_modules():
+ for typ in _types_for_mod(dialect):
+ if typ in seen:
+ continue
+ seen.add(typ)
+ yield typ
+
+
+class AdaptTest(fixtures.TestBase):
+ @testing.combinations(((t,) for t in _types_for_mod(types)), id_="n")
+ def test_uppercase_importable(self, typ):
+ if typ.__name__ == typ.__name__.upper():
+ assert getattr(sa, typ.__name__) is typ
+ assert typ.__name__ in types.__all__
+
+ @testing.combinations(
+ ((d.name, d) for d in _all_dialects()), argnames="dialect", id_="ia"
+ )
+ @testing.combinations(
+ (REAL(), "REAL"),
+ (FLOAT(), "FLOAT"),
+ (NUMERIC(), "NUMERIC"),
+ (DECIMAL(), "DECIMAL"),
+ (INTEGER(), "INTEGER"),
+ (SMALLINT(), "SMALLINT"),
+ (TIMESTAMP(), ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE")),
+ (DATETIME(), "DATETIME"),
+ (DATE(), "DATE"),
+ (TIME(), ("TIME", "TIME WITHOUT TIME ZONE")),
+ (CLOB(), "CLOB"),
+ (VARCHAR(10), ("VARCHAR(10)", "VARCHAR(10 CHAR)")),
+ (
+ NVARCHAR(10),
+ ("NVARCHAR(10)", "NATIONAL VARCHAR(10)", "NVARCHAR2(10)"),
+ ),
+ (CHAR(), "CHAR"),
+ (NCHAR(), ("NCHAR", "NATIONAL CHAR")),
+ (BLOB(), ("BLOB", "BLOB SUB_TYPE 0")),
+ (BOOLEAN(), ("BOOLEAN", "BOOL", "INTEGER")),
+ argnames="type_, expected",
+ id_="ra",
+ )
+ def test_uppercase_rendering(self, dialect, type_, expected):
"""Test that uppercase types from types.py always render as their
type.
@@ -133,51 +175,48 @@ class AdaptTest(fixtures.TestBase):
"""
- for dialect in self._all_dialects():
- for type_, expected in (
- (REAL, "REAL"),
- (FLOAT, "FLOAT"),
- (NUMERIC, "NUMERIC"),
- (DECIMAL, "DECIMAL"),
- (INTEGER, "INTEGER"),
- (SMALLINT, "SMALLINT"),
- (TIMESTAMP, ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE")),
- (DATETIME, "DATETIME"),
- (DATE, "DATE"),
- (TIME, ("TIME", "TIME WITHOUT TIME ZONE")),
- (CLOB, "CLOB"),
- (VARCHAR(10), ("VARCHAR(10)", "VARCHAR(10 CHAR)")),
- (
- NVARCHAR(10),
- ("NVARCHAR(10)", "NATIONAL VARCHAR(10)", "NVARCHAR2(10)"),
- ),
- (CHAR, "CHAR"),
- (NCHAR, ("NCHAR", "NATIONAL CHAR")),
- (BLOB, ("BLOB", "BLOB SUB_TYPE 0")),
- (BOOLEAN, ("BOOLEAN", "BOOL", "INTEGER")),
- ):
- if isinstance(expected, str):
- expected = (expected,)
+ if isinstance(expected, str):
+ expected = (expected,)
- try:
- compiled = types.to_instance(type_).compile(
- dialect=dialect
- )
- except NotImplementedError:
- continue
+ try:
+ compiled = type_.compile(dialect=dialect)
+ except NotImplementedError:
+ return
- assert compiled in expected, (
- "%r matches none of %r for dialect %s"
- % (compiled, expected, dialect.name)
- )
+ assert compiled in expected, "%r matches none of %r for dialect %s" % (
+ compiled,
+ expected,
+ dialect.name,
+ )
- assert str(types.to_instance(type_)) in expected, (
- "default str() of type %r not expected, %r"
- % (type_, expected)
- )
+ assert (
+ str(types.to_instance(type_)) in expected
+ ), "default str() of type %r not expected, %r" % (type_, expected)
+
+ def _adaptions():
+ for typ in _all_types(omit_special_types=True):
+
+ # up adapt from LowerCase to UPPERCASE,
+ # as well as to all non-sqltypes
+ up_adaptions = [typ] + typ.__subclasses__()
+ yield "%s.%s" % (
+ typ.__module__,
+ typ.__name__,
+ ), False, typ, up_adaptions
+ for subcl in typ.__subclasses__():
+ if (
+ subcl is not typ
+ and typ is not TypeDecorator
+ and "sqlalchemy" in subcl.__module__
+ ):
+ yield "%s.%s" % (
+ subcl.__module__,
+ subcl.__name__,
+ ), True, subcl, [typ]
@testing.uses_deprecated(".*Binary.*")
- def test_adapt_method(self):
+ @testing.combinations(_adaptions(), id_="iaaa")
+ def test_adapt_method(self, is_down_adaption, typ, target_adaptions):
"""ensure all types have a working adapt() method,
which creates a distinct copy.
@@ -190,67 +229,44 @@ class AdaptTest(fixtures.TestBase):
"""
- def adaptions():
- for typ in self._all_types():
- # up adapt from LowerCase to UPPERCASE,
- # as well as to all non-sqltypes
- up_adaptions = [typ] + typ.__subclasses__()
- yield False, typ, up_adaptions
- for subcl in typ.__subclasses__():
- if (
- subcl is not typ
- and typ is not TypeDecorator
- and "sqlalchemy" in subcl.__module__
- ):
- yield True, subcl, [typ]
-
- for is_down_adaption, typ, target_adaptions in adaptions():
- if typ in (types.TypeDecorator, types.TypeEngine, types.Variant):
+ if issubclass(typ, ARRAY):
+ t1 = typ(String)
+ else:
+ t1 = typ()
+ for cls in target_adaptions:
+ if (is_down_adaption and issubclass(typ, sqltypes.Emulated)) or (
+ not is_down_adaption and issubclass(cls, sqltypes.Emulated)
+ ):
continue
- elif issubclass(typ, ARRAY):
- t1 = typ(String)
- else:
- t1 = typ()
- for cls in target_adaptions:
- if (
- is_down_adaption and issubclass(typ, sqltypes.Emulated)
- ) or (
- not is_down_adaption and issubclass(cls, sqltypes.Emulated)
- ):
- continue
- if cls.__module__.startswith("test"):
+ # print("ADAPT %s -> %s" % (t1.__class__, cls))
+ t2 = t1.adapt(cls)
+ assert t1 is not t2
+
+ if is_down_adaption:
+ t2, t1 = t1, t2
+
+ for k in t1.__dict__:
+ if k in (
+ "impl",
+ "_is_oracle_number",
+ "_create_events",
+ "create_constraint",
+ "inherit_schema",
+ "schema",
+ "metadata",
+ "name",
+ ):
continue
+ # assert each value was copied, or that
+ # the adapted type has a more specific
+ # value than the original (i.e. SQL Server
+ # applies precision=24 for REAL)
+ assert (
+ getattr(t2, k) == t1.__dict__[k] or t1.__dict__[k] is None
+ )
- # print("ADAPT %s -> %s" % (t1.__class__, cls))
- t2 = t1.adapt(cls)
- assert t1 is not t2
-
- if is_down_adaption:
- t2, t1 = t1, t2
-
- for k in t1.__dict__:
- if k in (
- "impl",
- "_is_oracle_number",
- "_create_events",
- "create_constraint",
- "inherit_schema",
- "schema",
- "metadata",
- "name",
- ):
- continue
- # assert each value was copied, or that
- # the adapted type has a more specific
- # value than the original (i.e. SQL Server
- # applies precision=24 for REAL)
- assert (
- getattr(t2, k) == t1.__dict__[k]
- or t1.__dict__[k] is None
- )
-
- eq_(t1.evaluates_none().should_evaluate_none, True)
+ eq_(t1.evaluates_none().should_evaluate_none, True)
def test_python_type(self):
eq_(types.Integer().python_type, int)
@@ -270,15 +286,13 @@ class AdaptTest(fixtures.TestBase):
)
@testing.uses_deprecated()
- def test_repr(self):
- for typ in self._all_types():
- if typ in (types.TypeDecorator, types.TypeEngine, types.Variant):
- continue
- elif issubclass(typ, ARRAY):
- t1 = typ(String)
- else:
- t1 = typ()
- repr(t1)
+ @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)])
+ def test_repr(self, typ):
+ if issubclass(typ, ARRAY):
+ t1 = typ(String)
+ else:
+ t1 = typ()
+ repr(t1)
def test_adapt_constructor_copy_override_kw(self):
"""test that adapt() can accept kw args that override
@@ -299,27 +313,30 @@ class AdaptTest(fixtures.TestBase):
class TypeAffinityTest(fixtures.TestBase):
- def test_type_affinity(self):
- for type_, affin in [
- (String(), String),
- (VARCHAR(), String),
- (Date(), Date),
- (LargeBinary(), types._Binary),
- ]:
- eq_(type_._type_affinity, affin)
-
- for t1, t2, comp in [
- (Integer(), SmallInteger(), True),
- (Integer(), String(), False),
- (Integer(), Integer(), True),
- (Text(), String(), True),
- (Text(), Unicode(), True),
- (LargeBinary(), Integer(), False),
- (LargeBinary(), PickleType(), True),
- (PickleType(), LargeBinary(), True),
- (PickleType(), PickleType(), True),
- ]:
- eq_(t1._compare_type_affinity(t2), comp, "%s %s" % (t1, t2))
+ @testing.combinations(
+ (String(), String),
+ (VARCHAR(), String),
+ (Date(), Date),
+ (LargeBinary(), types._Binary),
+ id_="rn",
+ )
+ def test_type_affinity(self, type_, affin):
+ eq_(type_._type_affinity, affin)
+
+ @testing.combinations(
+ (Integer(), SmallInteger(), True),
+ (Integer(), String(), False),
+ (Integer(), Integer(), True),
+ (Text(), String(), True),
+ (Text(), Unicode(), True),
+ (LargeBinary(), Integer(), False),
+ (LargeBinary(), PickleType(), True),
+ (PickleType(), LargeBinary(), True),
+ (PickleType(), PickleType(), True),
+ id_="rra",
+ )
+ def test_compare_type_affinity(self, t1, t2, comp):
+ eq_(t1._compare_type_affinity(t2), comp, "%s %s" % (t1, t2))
def test_decorator_doesnt_cache(self):
from sqlalchemy.dialects import postgresql
@@ -340,30 +357,32 @@ class TypeAffinityTest(fixtures.TestBase):
class PickleTypesTest(fixtures.TestBase):
- def test_pickle_types(self):
+ @testing.combinations(
+ ("Boo", Boolean()),
+ ("Str", String()),
+ ("Tex", Text()),
+ ("Uni", Unicode()),
+ ("Int", Integer()),
+ ("Sma", SmallInteger()),
+ ("Big", BigInteger()),
+ ("Num", Numeric()),
+ ("Flo", Float()),
+ ("Dat", DateTime()),
+ ("Dat", Date()),
+ ("Tim", Time()),
+ ("Lar", LargeBinary()),
+ ("Pic", PickleType()),
+ ("Int", Interval()),
+ id_="ar",
+ )
+ def test_pickle_types(self, name, type_):
+ column_type = Column(name, type_)
+ meta = MetaData()
+ Table("foo", meta, column_type)
+
for loads, dumps in picklers():
- column_types = [
- Column("Boo", Boolean()),
- Column("Str", String()),
- Column("Tex", Text()),
- Column("Uni", Unicode()),
- Column("Int", Integer()),
- Column("Sma", SmallInteger()),
- Column("Big", BigInteger()),
- Column("Num", Numeric()),
- Column("Flo", Float()),
- Column("Dat", DateTime()),
- Column("Dat", Date()),
- Column("Tim", Time()),
- Column("Lar", LargeBinary()),
- Column("Pic", PickleType()),
- Column("Int", Interval()),
- ]
- for column_type in column_types:
- meta = MetaData()
- Table("foo", meta, column_type)
- loads(dumps(column_type))
- loads(dumps(meta))
+ loads(dumps(column_type))
+ loads(dumps(meta))
class _UserDefinedTypeFixture(object):
@@ -2414,19 +2433,19 @@ class ExpressionTest(
expr = column("foo", CHAR) == "asdf"
eq_(expr.right.type.__class__, CHAR)
- def test_actual_literal_adapters(self):
- for data, expected in [
- (5, Integer),
- (2.65, Float),
- (True, Boolean),
- (decimal.Decimal("2.65"), Numeric),
- (datetime.date(2015, 7, 20), Date),
- (datetime.time(10, 15, 20), Time),
- (datetime.datetime(2015, 7, 20, 10, 15, 20), DateTime),
- (datetime.timedelta(seconds=5), Interval),
- (None, types.NullType),
- ]:
- is_(literal(data).type.__class__, expected)
+ @testing.combinations(
+ (5, Integer),
+ (2.65, Float),
+ (True, Boolean),
+ (decimal.Decimal("2.65"), Numeric),
+ (datetime.date(2015, 7, 20), Date),
+ (datetime.time(10, 15, 20), Time),
+ (datetime.datetime(2015, 7, 20, 10, 15, 20), DateTime),
+ (datetime.timedelta(seconds=5), Interval),
+ (None, types.NullType),
+ )
+ def test_actual_literal_adapters(self, data, expected):
+ is_(literal(data).type.__class__, expected)
def test_typedec_operator_adapt(self):
expr = test_table.c.bvalue + "hi"
@@ -2592,18 +2611,22 @@ class ExpressionTest(
expr = column("bar", types.Interval) * column("foo", types.Numeric)
eq_(expr.type._type_affinity, types.Interval)
- def test_numerics_coercion(self):
-
- for op in (operator.add, operator.mul, operator.truediv, operator.sub):
- for other in (Numeric(10, 2), Integer):
- expr = op(
- column("bar", types.Numeric(10, 2)), column("foo", other)
- )
- assert isinstance(expr.type, types.Numeric)
- expr = op(
- column("foo", other), column("bar", types.Numeric(10, 2))
- )
- assert isinstance(expr.type, types.Numeric)
+ @testing.combinations(
+ (operator.add,),
+ (operator.mul,),
+ (operator.truediv,),
+ (operator.sub,),
+ argnames="op",
+ id_="n",
+ )
+ @testing.combinations(
+ (Numeric(10, 2),), (Integer(),), argnames="other", id_="r"
+ )
+ def test_numerics_coercion(self, op, other):
+ expr = op(column("bar", types.Numeric(10, 2)), column("foo", other))
+ assert isinstance(expr.type, types.Numeric)
+ expr = op(column("foo", other), column("bar", types.Numeric(10, 2)))
+ assert isinstance(expr.type, types.Numeric)
def test_asdecimal_int_to_numeric(self):
expr = column("a", Integer) * column("b", Numeric(asdecimal=False))