summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-05-09 16:34:10 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-05-09 16:34:10 +0000
commit4a6afd469fad170868554bf28578849bf3dfd5dd (patch)
treeb396edc33d567ae19dd244e87137296450467725 /test
parent46b7c9dc57a38d5b9e44a4723dad2ad8ec57baca (diff)
downloadsqlalchemy-4a6afd469fad170868554bf28578849bf3dfd5dd.tar.gz
r4695 merged to trunk; trunk now becomes 0.5.
0.4 development continues at /sqlalchemy/branches/rel_0_4
Diffstat (limited to 'test')
-rw-r--r--test/base/dependency.py2
-rw-r--r--test/base/except.py9
-rw-r--r--test/base/utils.py480
-rw-r--r--test/dialect/firebird.py2
-rw-r--r--test/dialect/maxdb.py6
-rwxr-xr-xtest/dialect/mssql.py4
-rw-r--r--test/dialect/mysql.py6
-rw-r--r--test/dialect/oracle.py6
-rw-r--r--test/dialect/postgres.py12
-rw-r--r--test/dialect/sqlite.py14
-rw-r--r--test/engine/bind.py21
-rw-r--r--test/engine/ddlevents.py14
-rw-r--r--test/engine/execute.py90
-rw-r--r--test/engine/metadata.py13
-rw-r--r--test/engine/parseconnect.py16
-rw-r--r--test/engine/pool.py17
-rw-r--r--test/engine/reconnect.py51
-rw-r--r--test/engine/reflection.py283
-rw-r--r--test/engine/transaction.py151
-rw-r--r--test/ext/activemapper.py357
-rw-r--r--test/ext/alltests.py3
-rw-r--r--test/ext/assignmapper.py83
-rw-r--r--test/ext/declarative.py6
-rw-r--r--test/orm/alltests.py10
-rw-r--r--test/orm/association.py8
-rw-r--r--test/orm/assorted_eager.py90
-rw-r--r--test/orm/attributes.py426
-rw-r--r--test/orm/cascade.py11
-rw-r--r--test/orm/collection.py6
-rw-r--r--test/orm/compile.py6
-rw-r--r--test/orm/cycles.py32
-rw-r--r--test/orm/deprecations.py394
-rw-r--r--test/orm/dynamic.py18
-rw-r--r--test/orm/eager_relations.py161
-rw-r--r--test/orm/entity.py127
-rw-r--r--test/orm/expire.py120
-rw-r--r--test/orm/extendedattr.py303
-rw-r--r--test/orm/generative.py54
-rw-r--r--test/orm/inheritance/abc_inheritance.py2
-rw-r--r--test/orm/inheritance/abc_polymorphic.py9
-rw-r--r--test/orm/inheritance/basic.py17
-rw-r--r--test/orm/inheritance/concrete.py4
-rw-r--r--test/orm/inheritance/poly_linked_list.py7
-rw-r--r--test/orm/inheritance/polymorph.py116
-rw-r--r--test/orm/inheritance/polymorph2.py28
-rw-r--r--test/orm/inheritance/query.py263
-rw-r--r--test/orm/inheritance/single.py6
-rw-r--r--test/orm/instrumentation.py745
-rw-r--r--test/orm/lazy_relations.py41
-rw-r--r--test/orm/manytomany.py11
-rw-r--r--test/orm/mapper.py364
-rw-r--r--test/orm/merge.py155
-rw-r--r--test/orm/naturalpks.py21
-rw-r--r--test/orm/onetoone.py26
-rw-r--r--test/orm/pickled.py3
-rw-r--r--test/orm/query.py728
-rw-r--r--test/orm/relationships.py83
-rw-r--r--test/orm/scoping.py171
-rw-r--r--test/orm/selectable.py4
-rw-r--r--test/orm/session.py420
-rw-r--r--test/orm/sessioncontext.py48
-rw-r--r--test/orm/sharding/shard.py6
-rw-r--r--test/orm/transaction.py360
-rw-r--r--test/orm/unitofwork.py104
-rw-r--r--test/orm/utils.py208
-rw-r--r--test/perf/masseagerload.py1
-rw-r--r--test/profiling/compiler.py4
-rw-r--r--test/profiling/zoomark.py4
-rw-r--r--test/sql/case_statement.py4
-rw-r--r--test/sql/columns.py8
-rw-r--r--test/sql/constraints.py10
-rw-r--r--test/sql/defaults.py6
-rw-r--r--test/sql/functions.py2
-rw-r--r--test/sql/generative.py214
-rw-r--r--test/sql/query.py6
-rw-r--r--test/sql/quote.py27
-rw-r--r--test/sql/select.py36
-rwxr-xr-xtest/sql/selectable.py12
-rw-r--r--test/sql/testtypes.py91
-rw-r--r--test/testlib/__init__.py22
-rw-r--r--test/testlib/compat.py18
-rw-r--r--test/testlib/engines.py2
-rw-r--r--test/testlib/filters.py4
-rw-r--r--test/testlib/fixtures.py72
-rw-r--r--test/testlib/profiling.py5
-rw-r--r--test/testlib/requires.py32
-rw-r--r--test/testlib/schema.py2
-rw-r--r--test/testlib/tables.py5
-rw-r--r--test/testlib/testing.py134
89 files changed, 5562 insertions, 2521 deletions
diff --git a/test/base/dependency.py b/test/base/dependency.py
index 25d34ffd3..f891bc92e 100644
--- a/test/base/dependency.py
+++ b/test/base/dependency.py
@@ -1,7 +1,7 @@
import testenv; testenv.configure_for_tests()
import sqlalchemy.topological as topological
from sqlalchemy import util
-from testlib import *
+from testlib import TestBase
class DependencySortTest(TestBase):
diff --git a/test/base/except.py b/test/base/except.py
index 84b84793c..cbbb941c6 100644
--- a/test/base/except.py
+++ b/test/base/except.py
@@ -1,9 +1,8 @@
"""Tests exceptions and DB-API exception wrapping."""
import testenv; testenv.configure_for_tests()
-import sys, unittest
+import unittest
import exceptions as stdlib_exceptions
-from sqlalchemy import exceptions as sa_exceptions
-from testlib import *
+from sqlalchemy import exc as sa_exceptions
class Error(stdlib_exceptions.StandardError):
@@ -48,10 +47,10 @@ class WrapTest(unittest.TestCase):
# subclasses of sqlalchemy.exceptions.DBAPIError
try:
raise sa_exceptions.DBAPIError.instance(
- '', [], sa_exceptions.AssertionError())
+ '', [], sa_exceptions.ArgumentError())
except sa_exceptions.DBAPIError, e:
self.assert_(e.__class__ is sa_exceptions.DBAPIError)
- except sa_exceptions.AssertionError:
+ except sa_exceptions.ArgumentError:
self.assert_(False)
def test_db_error_keyboard_interrupt(self):
diff --git a/test/base/utils.py b/test/base/utils.py
index a00338f5f..070ffb583 100644
--- a/test/base/utils.py
+++ b/test/base/utils.py
@@ -1,8 +1,9 @@
import testenv; testenv.configure_for_tests()
-import unittest
-from sqlalchemy import util, sql, exceptions
-from testlib import *
-from testlib import sorted
+import threading, unittest
+from sqlalchemy import util, sql, exc
+from testlib import TestBase
+from testlib.testing import eq_, is_, ne_
+from testlib.compat import frozenset, set, sorted
class OrderedDictTest(TestBase):
def test_odict(self):
@@ -12,40 +13,37 @@ class OrderedDictTest(TestBase):
o['snack'] = 'attack'
o['c'] = 3
- self.assert_(o.keys() == ['a', 'b', 'snack', 'c'])
- self.assert_(o.values() == [1, 2, 'attack', 3])
+ eq_(o.keys(), ['a', 'b', 'snack', 'c'])
+ eq_(o.values(), [1, 2, 'attack', 3])
o.pop('snack')
- self.assert_(o.keys() == ['a', 'b', 'c'])
- self.assert_(o.values() == [1, 2, 3])
+ eq_(o.keys(), ['a', 'b', 'c'])
+ eq_(o.values(), [1, 2, 3])
o2 = util.OrderedDict(d=4)
o2['e'] = 5
- self.assert_(o2.keys() == ['d', 'e'])
- self.assert_(o2.values() == [4, 5])
+ eq_(o2.keys(), ['d', 'e'])
+ eq_(o2.values(), [4, 5])
o.update(o2)
- self.assert_(o.keys() == ['a', 'b', 'c', 'd', 'e'])
- self.assert_(o.values() == [1, 2, 3, 4, 5])
+ eq_(o.keys(), ['a', 'b', 'c', 'd', 'e'])
+ eq_(o.values(), [1, 2, 3, 4, 5])
o.setdefault('c', 'zzz')
o.setdefault('f', 6)
- self.assert_(o.keys() == ['a', 'b', 'c', 'd', 'e', 'f'])
- self.assert_(o.values() == [1, 2, 3, 4, 5, 6])
+ eq_(o.keys(), ['a', 'b', 'c', 'd', 'e', 'f'])
+ eq_(o.values(), [1, 2, 3, 4, 5, 6])
class OrderedSetTest(TestBase):
def test_mutators_against_iter(self):
# testing a set modified against an iterator
o = util.OrderedSet([3,2, 4, 5])
- self.assertEquals(o.difference(iter([3,4])),
- util.OrderedSet([2,5]))
- self.assertEquals(o.intersection(iter([3,4, 6])),
- util.OrderedSet([3, 4]))
- self.assertEquals(o.union(iter([3,4, 6])),
- util.OrderedSet([2, 3, 4, 5, 6]))
+ eq_(o.difference(iter([3,4])), util.OrderedSet([2,5]))
+ eq_(o.intersection(iter([3,4, 6])), util.OrderedSet([3, 4]))
+ eq_(o.union(iter([3,4, 6])), util.OrderedSet([2, 3, 4, 5, 6]))
class ColumnCollectionTest(TestBase):
def test_in(self):
@@ -59,8 +57,8 @@ class ColumnCollectionTest(TestBase):
try:
cc['col1'] in cc
assert False
- except exceptions.ArgumentError, e:
- assert str(e) == "__contains__ requires a string argument"
+ except exc.ArgumentError, e:
+ eq_(str(e), "__contains__ requires a string argument")
def test_compare(self):
cc1 = sql.ColumnCollection()
@@ -90,11 +88,11 @@ class ArgSingletonTest(unittest.TestCase):
m3 = MyClass(3, 4)
assert m1 is m3
assert m2 is not m3
- assert len(util.ArgSingleton.instances) == 2
+ eq_(len(util.ArgSingleton.instances), 2)
m1 = m2 = m3 = None
MyClass.dispose(MyClass)
- assert len(util.ArgSingleton.instances) == 0
+ eq_(len(util.ArgSingleton.instances), 0)
class ImmutableSubclass(str):
@@ -140,7 +138,7 @@ class IdentitySetTest(unittest.TestCase):
def assert_eq(self, identityset, expected_iterable):
expected = sorted([id(o) for o in expected_iterable])
found = sorted([id(o) for o in identityset])
- self.assertEquals(found, expected)
+ eq_(found, expected)
def test_init(self):
ids = util.IdentitySet([1,2,3,2,1])
@@ -184,32 +182,35 @@ class IdentitySetTest(unittest.TestCase):
ids.remove(o1)
self.assertRaises(KeyError, ids.remove, o1)
- self.assert_(ids.copy() == ids)
- self.assert_(ids != None)
- self.assert_(not(ids == None))
- self.assert_(ids != IdentitySet([o1,o2,o3]))
+ eq_(ids.copy(), ids)
+
+ # explicit __eq__ and __ne__ tests
+ assert ids != None
+ assert not(ids == None)
+
+ ne_(ids, IdentitySet([o1,o2,o3]))
ids.clear()
- self.assert_(o1 not in ids)
+ assert o1 not in ids
ids.add(o2)
- self.assert_(o2 in ids)
- self.assert_(ids.pop() == o2)
+ assert o2 in ids
+ eq_(ids.pop(), o2)
ids.add(o1)
- self.assert_(len(ids) == 1)
+ eq_(len(ids), 1)
isuper = IdentitySet([o1,o2])
- self.assert_(ids < isuper)
- self.assert_(ids.issubset(isuper))
- self.assert_(isuper.issuperset(ids))
- self.assert_(isuper > ids)
-
- self.assert_(ids.union(isuper) == isuper)
- self.assert_(ids | isuper == isuper)
- self.assert_(isuper - ids == IdentitySet([o2]))
- self.assert_(isuper.difference(ids) == IdentitySet([o2]))
- self.assert_(ids.intersection(isuper) == IdentitySet([o1]))
- self.assert_(ids & isuper == IdentitySet([o1]))
- self.assert_(ids.symmetric_difference(isuper) == IdentitySet([o2]))
- self.assert_(ids ^ isuper == IdentitySet([o2]))
+ assert ids < isuper
+ assert ids.issubset(isuper)
+ assert isuper.issuperset(ids)
+ assert isuper > ids
+
+ eq_(ids.union(isuper), isuper)
+ eq_(ids | isuper, isuper)
+ eq_(isuper - ids, IdentitySet([o2]))
+ eq_(isuper.difference(ids), IdentitySet([o2]))
+ eq_(ids.intersection(isuper), IdentitySet([o1]))
+ eq_(ids & isuper, IdentitySet([o1]))
+ eq_(ids.symmetric_difference(isuper), IdentitySet([o2]))
+ eq_(ids ^ isuper, IdentitySet([o2]))
ids.update(isuper)
ids |= isuper
@@ -223,16 +224,16 @@ class IdentitySetTest(unittest.TestCase):
ids.update('foobar')
try:
ids |= 'foobar'
- self.assert_(False)
+ assert False
except TypeError:
- self.assert_(True)
+ assert True
try:
s = set([o1,o2])
s |= ids
- self.assert_(False)
+ assert False
except TypeError:
- self.assert_(True)
+ assert True
self.assertRaises(TypeError, cmp, ids)
self.assertRaises(TypeError, hash, ids)
@@ -243,8 +244,8 @@ class IdentitySetTest(unittest.TestCase):
s1 = set([1,2,3])
s2 = set([3,4,5])
- self.assertEquals(os1 - os2, util.IdentitySet([1, 2]))
- self.assertEquals(os2 - os1, util.IdentitySet([4, 5]))
+ eq_(os1 - os2, util.IdentitySet([1, 2]))
+ eq_(os2 - os1, util.IdentitySet([4, 5]))
self.assertRaises(TypeError, lambda: os1 - s2)
self.assertRaises(TypeError, lambda: os1 - [3, 4, 5])
self.assertRaises(TypeError, lambda: s1 - os2)
@@ -256,7 +257,7 @@ class DictlikeIteritemsTest(unittest.TestCase):
def _ok(self, instance):
iterator = util.dictlike_iteritems(instance)
- self.assertEquals(set(iterator), self.baseline)
+ eq_(set(iterator), self.baseline)
def _notok(self, instance):
self.assertRaises(TypeError,
@@ -322,6 +323,33 @@ class DictlikeIteritemsTest(unittest.TestCase):
self._notok(duck6())
+class DuckTypeCollectionTest(TestBase):
+ def test_sets(self):
+ import sets
+ class SetLike(object):
+ def add(self):
+ pass
+
+ class ForcedSet(list):
+ __emulates__ = set
+
+ for type_ in (set,
+ sets.Set,
+ util.Set,
+ SetLike,
+ ForcedSet):
+ eq_(util.duck_type_collection(type_), util.Set)
+ instance = type_()
+ eq_(util.duck_type_collection(instance), util.Set)
+
+ for type_ in (frozenset,
+ sets.ImmutableSet,
+ util.FrozenSet):
+ is_(util.duck_type_collection(type_), None)
+ instance = type_()
+ is_(util.duck_type_collection(instance), None)
+
+
class ArgInspectionTest(TestBase):
def test_get_cls_kwargs(self):
class A(object):
@@ -359,7 +387,7 @@ class ArgInspectionTest(TestBase):
pass
def test(cls, *expected):
- self.assertEquals(set(util.get_cls_kwargs(cls)), set(expected))
+ eq_(set(util.get_cls_kwargs(cls)), set(expected))
test(A, 'a')
test(A1, 'a1')
@@ -382,7 +410,7 @@ class ArgInspectionTest(TestBase):
def f4(**foo): pass
def test(fn, *expected):
- self.assertEquals(set(util.get_func_kwargs(fn)), set(expected))
+ eq_(set(util.get_func_kwargs(fn)), set(expected))
test(f1)
test(f2, 'foo')
@@ -419,7 +447,336 @@ class SymbolTest(TestBase):
assert rt is sym1
assert rt is sym2
+class WeakIdentityMappingTest(TestBase):
+ class Data(object):
+ pass
+
+ def _some_data(self, some=20):
+ return [self.Data() for _ in xrange(some)]
+
+ def _fixture(self, some=20):
+ data = self._some_data()
+ wim = util.WeakIdentityMapping()
+ for idx, obj in enumerate(data):
+ wim[obj] = idx
+ return data, wim
+
+ def test_delitem(self):
+ data, wim = self._fixture()
+ needle = data[-1]
+
+ assert needle in wim
+ assert id(needle) in wim.by_id
+ eq_(wim[needle], wim.by_id[id(needle)])
+
+ del wim[needle]
+
+ assert needle not in wim
+ assert id(needle) not in wim.by_id
+ eq_(len(wim), (len(data) - 1))
+
+ data.remove(needle)
+
+ assert needle not in wim
+ assert id(needle) not in wim.by_id
+ eq_(len(wim), len(data))
+
+ def test_setitem(self):
+ data, wim = self._fixture()
+
+ o1, oid1 = data[-1], id(data[-1])
+
+ assert o1 in wim
+ assert oid1 in wim.by_id
+ eq_(wim[o1], wim.by_id[oid1])
+ id_keys = set(wim.by_id.keys())
+
+ wim[o1] = 1234
+ assert o1 in wim
+ assert oid1 in wim.by_id
+ eq_(wim[o1], wim.by_id[oid1])
+ eq_(set(wim.by_id.keys()), id_keys)
+
+ o2 = self.Data()
+ oid2 = id(o2)
+
+ wim[o2] = 5678
+ assert o2 in wim
+ assert oid2 in wim.by_id
+ eq_(wim[o2], wim.by_id[oid2])
+
+ def test_pop(self):
+ data, wim = self._fixture()
+ needle = data[-1]
+
+ needle = data.pop()
+ assert needle in wim
+ assert id(needle) in wim.by_id
+ eq_(wim[needle], wim.by_id[id(needle)])
+ eq_(len(wim), (len(data) + 1))
+
+ wim.pop(needle)
+ assert needle not in wim
+ assert id(needle) not in wim.by_id
+ eq_(len(wim), len(data))
+
+ def test_pop_default(self):
+ data, wim = self._fixture()
+ needle = data[-1]
+
+ value = wim[needle]
+ x = wim.pop(needle, 123)
+ ne_(x, 123)
+ eq_(x, value)
+ assert needle not in wim
+ assert id(needle) not in wim.by_id
+ eq_(len(data), (len(wim) + 1))
+
+ n2 = self.Data()
+ y = wim.pop(n2, 456)
+ eq_(y, 456)
+ assert n2 not in wim
+ assert id(n2) not in wim.by_id
+ eq_(len(data), (len(wim) + 1))
+
+ def test_popitem(self):
+ data, wim = self._fixture()
+ (needle, idx) = wim.popitem()
+
+ assert needle in data
+ eq_(len(data), (len(wim) + 1))
+ assert id(needle) not in wim.by_id
+
+ def test_setdefault(self):
+ data, wim = self._fixture()
+
+ o1 = self.Data()
+ oid1 = id(o1)
+
+ assert o1 not in wim
+
+ res1 = wim.setdefault(o1, 123)
+ assert o1 in wim
+ assert oid1 in wim.by_id
+ eq_(res1, 123)
+ id_keys = set(wim.by_id.keys())
+
+ res2 = wim.setdefault(o1, 456)
+ assert o1 in wim
+ assert oid1 in wim.by_id
+ eq_(res2, 123)
+ assert set(wim.by_id.keys()) == id_keys
+
+ del wim[o1]
+ assert o1 not in wim
+ assert oid1 not in wim.by_id
+ ne_(set(wim.by_id.keys()), id_keys)
+
+ res3 = wim.setdefault(o1, 789)
+ assert o1 in wim
+ assert oid1 in wim.by_id
+ eq_(res3, 789)
+ eq_(set(wim.by_id.keys()), id_keys)
+
+ def test_clear(self):
+ data, wim = self._fixture()
+
+ assert len(data) == len(wim) == len(wim.by_id)
+ wim.clear()
+
+ eq_(wim, {})
+ eq_(wim.by_id, {})
+
+ def test_update(self):
+ data, wim = self._fixture()
+ self.assertRaises(NotImplementedError, wim.update)
+
+ def test_weak_clear(self):
+ data, wim = self._fixture()
+
+ assert len(data) == len(wim) == len(wim.by_id)
+
+ del data[:]
+ eq_(wim, {})
+ eq_(wim.by_id, {})
+ eq_(wim._weakrefs, {})
+
+ def test_weak_single(self):
+ data, wim = self._fixture()
+
+ assert len(data) == len(wim) == len(wim.by_id)
+
+ oid = id(data[0])
+ del data[0]
+
+ assert len(data) == len(wim) == len(wim.by_id)
+ assert oid not in wim.by_id
+
+ def test_weak_threadhop(self):
+ data, wim = self._fixture()
+ data = set(data)
+
+ cv = threading.Condition()
+
+ def empty(obj):
+ cv.acquire()
+ obj.clear()
+ cv.notify()
+ cv.release()
+
+ th = threading.Thread(target=empty, args=(data,))
+
+ cv.acquire()
+ th.start()
+ cv.wait()
+ cv.release()
+
+ eq_(wim, {})
+ eq_(wim.by_id, {})
+ eq_(wim._weakrefs, {})
+
+
+class TestFormatArgspec(TestBase):
+ def test_specs(self):
+ def test(fn, wanted, grouped=None):
+ if grouped is None:
+ parsed = util.format_argspec_plus(fn)
+ else:
+ parsed = util.format_argspec_plus(fn, grouped=grouped)
+ eq_(parsed, wanted)
+
+ test(lambda: None,
+ {'args': '()', 'self_arg': None,
+ 'apply_kw': '()', 'apply_pos': '()' })
+
+ test(lambda: None,
+ {'args': '', 'self_arg': None,
+ 'apply_kw': '', 'apply_pos': '' },
+ grouped=False)
+
+ test(lambda self: None,
+ {'args': '(self)', 'self_arg': 'self',
+ 'apply_kw': '(self)', 'apply_pos': '(self)' })
+
+ test(lambda self: None,
+ {'args': 'self', 'self_arg': 'self',
+ 'apply_kw': 'self', 'apply_pos': 'self' },
+ grouped=False)
+
+ test(lambda *a: None,
+ {'args': '(*a)', 'self_arg': None,
+ 'apply_kw': '(*a)', 'apply_pos': '(*a)' })
+
+ test(lambda **kw: None,
+ {'args': '(**kw)', 'self_arg': None,
+ 'apply_kw': '(**kw)', 'apply_pos': '(**kw)' })
+
+ test(lambda *a, **kw: None,
+ {'args': '(*a, **kw)', 'self_arg': None,
+ 'apply_kw': '(*a, **kw)', 'apply_pos': '(*a, **kw)' })
+
+ test(lambda a, *b: None,
+ {'args': '(a, *b)', 'self_arg': 'a',
+ 'apply_kw': '(a, *b)', 'apply_pos': '(a, *b)' })
+
+ test(lambda a, **b: None,
+ {'args': '(a, **b)', 'self_arg': 'a',
+ 'apply_kw': '(a, **b)', 'apply_pos': '(a, **b)' })
+
+ test(lambda a, *b, **c: None,
+ {'args': '(a, *b, **c)', 'self_arg': 'a',
+ 'apply_kw': '(a, *b, **c)', 'apply_pos': '(a, *b, **c)' })
+
+ test(lambda a, b=1, **c: None,
+ {'args': '(a, b=1, **c)', 'self_arg': 'a',
+ 'apply_kw': '(a, b=b, **c)', 'apply_pos': '(a, b, **c)' })
+
+ test(lambda a=1, b=2: None,
+ {'args': '(a=1, b=2)', 'self_arg': 'a',
+ 'apply_kw': '(a=a, b=b)', 'apply_pos': '(a, b)' })
+
+ test(lambda a=1, b=2: None,
+ {'args': 'a=1, b=2', 'self_arg': 'a',
+ 'apply_kw': 'a=a, b=b', 'apply_pos': 'a, b' },
+ grouped=False)
+
+ def test_init_grouped(self):
+ object_spec = {
+ 'args': '(self)', 'self_arg': 'self',
+ 'apply_pos': '(self)', 'apply_kw': '(self)'}
+ wrapper_spec = {
+ 'args': '(self, *args, **kwargs)', 'self_arg': 'self',
+ 'apply_pos': '(self, *args, **kwargs)',
+ 'apply_kw': '(self, *args, **kwargs)'}
+ custom_spec = {
+ 'args': '(slef, a=123)', 'self_arg': 'slef', # yes, slef
+ 'apply_pos': '(slef, a)', 'apply_kw': '(slef, a=a)'}
+
+ self._test_init(None, object_spec, wrapper_spec, custom_spec)
+ self._test_init(True, object_spec, wrapper_spec, custom_spec)
+
+ def test_init_bare(self):
+ object_spec = {
+ 'args': 'self', 'self_arg': 'self',
+ 'apply_pos': 'self', 'apply_kw': 'self'}
+ wrapper_spec = {
+ 'args': 'self, *args, **kwargs', 'self_arg': 'self',
+ 'apply_pos': 'self, *args, **kwargs',
+ 'apply_kw': 'self, *args, **kwargs'}
+ custom_spec = {
+ 'args': 'slef, a=123', 'self_arg': 'slef', # yes, slef
+ 'apply_pos': 'slef, a', 'apply_kw': 'slef, a=a'}
+
+ self._test_init(False, object_spec, wrapper_spec, custom_spec)
+
+ def _test_init(self, grouped, object_spec, wrapper_spec, custom_spec):
+ def test(fn, wanted):
+ if grouped is None:
+ parsed = util.format_argspec_init(fn)
+ else:
+ parsed = util.format_argspec_init(fn, grouped=grouped)
+ eq_(parsed, wanted)
+
+ class O(object): pass
+
+ test(O.__init__, object_spec)
+
+ class O(object):
+ def __init__(self):
+ pass
+
+ test(O.__init__, object_spec)
+
+ class O(object):
+ def __init__(slef, a=123):
+ pass
+
+ test(O.__init__, custom_spec)
+
+ class O(list): pass
+
+ test(O.__init__, wrapper_spec)
+
+ class O(list):
+ def __init__(self, *args, **kwargs):
+ pass
+
+ test(O.__init__, wrapper_spec)
+
+ class O(list):
+ def __init__(self):
+ pass
+
+ test(O.__init__, object_spec)
+
+ class O(list):
+ def __init__(slef, a=123):
+ pass
+
+ test(O.__init__, custom_spec)
+
class AsInterfaceTest(TestBase):
+
class Something(object):
def _ignoreme(self): pass
def foo(self): pass
@@ -442,9 +799,9 @@ class AsInterfaceTest(TestBase):
cls=self.Something, required=('foo'))
obj = self.Something()
- self.assertEqual(obj, util.as_interface(obj, cls=self.Something))
- self.assertEqual(obj, util.as_interface(obj, methods=('foo',)))
- self.assertEqual(
+ eq_(obj, util.as_interface(obj, cls=self.Something))
+ eq_(obj, util.as_interface(obj, methods=('foo',)))
+ eq_(
obj, util.as_interface(obj, cls=self.Something,
required=('outofband',)))
partial = self.Partial()
@@ -453,12 +810,11 @@ class AsInterfaceTest(TestBase):
slotted.bar = lambda self: 123
for obj in partial, slotted:
- self.assertEqual(obj, util.as_interface(obj, cls=self.Something))
+ eq_(obj, util.as_interface(obj, cls=self.Something))
self.assertRaises(TypeError, util.as_interface, obj,
methods=('foo'))
- self.assertEqual(obj, util.as_interface(obj, methods=('bar',)))
- self.assertEqual(
- obj, util.as_interface(obj, cls=self.Something,
+ eq_(obj, util.as_interface(obj, methods=('bar',)))
+ eq_(obj, util.as_interface(obj, cls=self.Something,
required=('bar',)))
self.assertRaises(TypeError, util.as_interface, obj,
cls=self.Something, required=('foo',))
diff --git a/test/dialect/firebird.py b/test/dialect/firebird.py
index f929443fd..da6cc6970 100644
--- a/test/dialect/firebird.py
+++ b/test/dialect/firebird.py
@@ -1,7 +1,7 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.databases import firebird
-from sqlalchemy.exceptions import ProgrammingError
+from sqlalchemy.exc import ProgrammingError
from sqlalchemy.sql import table, column
from testlib import *
diff --git a/test/dialect/maxdb.py b/test/dialect/maxdb.py
index 0a35f5470..f0bcd00e1 100644
--- a/test/dialect/maxdb.py
+++ b/test/dialect/maxdb.py
@@ -3,7 +3,7 @@
import testenv; testenv.configure_for_tests()
import StringIO, sys
from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc, sql
from sqlalchemy.util import Decimal
from sqlalchemy.databases import maxdb
from testlib import *
@@ -53,7 +53,7 @@ class ReflectionTest(TestBase, AssertsExecutionResults):
finally:
try:
testing.db.execute("DROP TABLE dectest")
- except exceptions.DatabaseError:
+ except exc.DatabaseError:
pass
def test_decimal_fixed_serial(self):
@@ -165,7 +165,7 @@ class ReflectionTest(TestBase, AssertsExecutionResults):
finally:
try:
testing.db.execute("DROP TABLE assorted")
- except exceptions.DatabaseError:
+ except exc.DatabaseError:
pass
class DBAPITest(TestBase, AssertsExecutionResults):
diff --git a/test/dialect/mssql.py b/test/dialect/mssql.py
index b5d7f1641..c3ce338df 100755
--- a/test/dialect/mssql.py
+++ b/test/dialect/mssql.py
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
import re
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc
from sqlalchemy.sql import table, column
from sqlalchemy.databases import mssql
from testlib import *
@@ -210,7 +210,7 @@ class QueryTest(TestBase):
r = users.select(limit=3, offset=2,
order_by=[users.c.user_id]).execute().fetchall()
assert False # InvalidRequestError should have been raised
- except exceptions.InvalidRequestError:
+ except exc.InvalidRequestError:
pass
finally:
metadata.drop_all()
diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py
index 00478908e..923658b01 100644
--- a/test/dialect/mysql.py
+++ b/test/dialect/mysql.py
@@ -1,7 +1,7 @@
import testenv; testenv.configure_for_tests()
import sets
from sqlalchemy import *
-from sqlalchemy import sql, exceptions
+from sqlalchemy import sql, exc
from sqlalchemy.databases import mysql
from testlib import *
@@ -537,13 +537,13 @@ class TypesTest(TestBase, AssertsExecutionResults):
try:
enum_table.insert().execute(e1=None, e2=None, e3=None, e4=None)
self.assert_(False)
- except exceptions.SQLError:
+ except exc.SQLError:
self.assert_(True)
try:
enum_table.insert().execute(e1='c', e2='c', e3='c', e4='c')
self.assert_(False)
- except exceptions.InvalidRequestError:
+ except exc.InvalidRequestError:
self.assert_(True)
enum_table.insert().execute()
diff --git a/test/dialect/oracle.py b/test/dialect/oracle.py
index cdd575dd3..24353152d 100644
--- a/test/dialect/oracle.py
+++ b/test/dialect/oracle.py
@@ -120,10 +120,10 @@ AND mytable.myid = myothertable.otherid(+)",
query = table1.outerjoin(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid)
self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid LEFT OUTER JOIN thirdtable ON thirdtable.userid = myothertable.otherid")
- self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid(+) AND thirdtable.userid(+) = myothertable.otherid", dialect=oracle.dialect(use_ansi=False))
+ self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE thirdtable.userid(+) = myothertable.otherid AND mytable.myid = myothertable.otherid(+)", dialect=oracle.dialect(use_ansi=False))
query = table1.join(table2, table1.c.myid==table2.c.otherid).join(table3, table3.c.userid==table2.c.otherid)
- self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid AND thirdtable.userid = myothertable.otherid", dialect=oracle.dialect(use_ansi=False))
+ self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE thirdtable.userid = myothertable.otherid AND mytable.myid = myothertable.otherid", dialect=oracle.dialect(use_ansi=False))
query = table1.join(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid)
self.assert_compile(query.select().order_by(table1.oid_column).limit(10).offset(5), "SELECT myid, name, description, otherid, othername, userid, \
@@ -131,7 +131,7 @@ otherstuff FROM (SELECT mytable.myid AS myid, mytable.name AS name, \
mytable.description AS description, myothertable.otherid AS otherid, \
myothertable.othername AS othername, thirdtable.userid AS userid, \
thirdtable.otherstuff AS otherstuff, ROW_NUMBER() OVER (ORDER BY mytable.rowid) AS ora_rn \
-FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid AND thirdtable.userid(+) = myothertable.otherid) \
+FROM mytable, myothertable, thirdtable WHERE thirdtable.userid(+) = myothertable.otherid AND mytable.myid = myothertable.otherid) \
WHERE ora_rn>5 AND ora_rn<=15", dialect=oracle.dialect(use_ansi=False))
def test_alias_outer_join(self):
diff --git a/test/dialect/postgres.py b/test/dialect/postgres.py
index 90cc0a477..3e5c200e4 100644
--- a/test/dialect/postgres.py
+++ b/test/dialect/postgres.py
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
import datetime
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc
from sqlalchemy.databases import postgres
from sqlalchemy.engine.strategies import MockEngineStrategy
from testlib import *
@@ -332,12 +332,12 @@ class InsertTest(TestBase, AssertsExecutionResults):
try:
table.insert().execute({'data':'d2'})
assert False
- except exceptions.IntegrityError, e:
+ except exc.IntegrityError, e:
assert "violates not-null constraint" in str(e)
try:
table.insert().execute({'data':'d2'}, {'data':'d3'})
assert False
- except exceptions.IntegrityError, e:
+ except exc.IntegrityError, e:
assert "violates not-null constraint" in str(e)
table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'})
@@ -359,12 +359,12 @@ class InsertTest(TestBase, AssertsExecutionResults):
try:
table.insert().execute({'data':'d2'})
assert False
- except exceptions.IntegrityError, e:
+ except exc.IntegrityError, e:
assert "violates not-null constraint" in str(e)
try:
table.insert().execute({'data':'d2'}, {'data':'d3'})
assert False
- except exceptions.IntegrityError, e:
+ except exc.IntegrityError, e:
assert "violates not-null constraint" in str(e)
table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'})
@@ -387,7 +387,7 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
try:
con.execute('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42')
con.execute('CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0')
- except exceptions.SQLError, e:
+ except exc.SQLError, e:
if not "already exists" in str(e):
raise e
con.execute('CREATE TABLE testtable (question integer, answer testdomain)')
diff --git a/test/dialect/sqlite.py b/test/dialect/sqlite.py
index 585a853d2..4cde5fc33 100644
--- a/test/dialect/sqlite.py
+++ b/test/dialect/sqlite.py
@@ -3,7 +3,7 @@
import testenv; testenv.configure_for_tests()
import datetime
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc
from sqlalchemy.databases import sqlite
from testlib import *
@@ -34,11 +34,11 @@ class TestTypes(TestBase, AssertsExecutionResults):
@testing.uses_deprecated('Using String type with no length')
def test_type_reflection(self):
# (ask_for, roundtripped_as_if_different)
- specs = [( String(), sqlite.SLText(), ),
+ specs = [( String(), sqlite.SLString(), ),
( String(1), sqlite.SLString(1), ),
( String(3), sqlite.SLString(3), ),
( Text(), sqlite.SLText(), ),
- ( Unicode(), sqlite.SLText(), ),
+ ( Unicode(), sqlite.SLString(), ),
( Unicode(1), sqlite.SLString(1), ),
( Unicode(3), sqlite.SLString(3), ),
( UnicodeText(), sqlite.SLText(), ),
@@ -94,7 +94,7 @@ class TestTypes(TestBase, AssertsExecutionResults):
for table in rt, rv:
for i, reflected in enumerate(table.c):
print reflected.type, type(expected[i])
- assert isinstance(reflected.type, type(expected[i]))
+ assert isinstance(reflected.type, type(expected[i])), type(expected[i])
finally:
db.execute('DROP VIEW types_v')
finally:
@@ -212,7 +212,7 @@ class DialectTest(TestBase, AssertsExecutionResults):
except:
try:
cx.execute('DROP TABLE tempy')
- except exceptions.DBAPIError:
+ except exc.DBAPIError:
pass
raise
@@ -247,7 +247,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
@testing.exclude('sqlite', '<', (3, 4))
def test_empty_insert_pk2(self):
self.assertRaises(
- exceptions.DBAPIError,
+ exc.DBAPIError,
self._test_empty_insert,
Table('b', MetaData(testing.db),
Column('x', Integer, primary_key=True),
@@ -256,7 +256,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
@testing.exclude('sqlite', '<', (3, 4))
def test_empty_insert_pk3(self):
self.assertRaises(
- exceptions.DBAPIError,
+ exc.DBAPIError,
self._test_empty_insert,
Table('c', MetaData(testing.db),
Column('x', Integer, primary_key=True),
diff --git a/test/engine/bind.py b/test/engine/bind.py
index b59cd284a..300a4eae6 100644
--- a/test/engine/bind.py
+++ b/test/engine/bind.py
@@ -2,9 +2,10 @@
including the deprecated versions of these arguments"""
import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import engine, exceptions
-from testlib import *
+from sqlalchemy import engine, exc
+from sqlalchemy import MetaData, ThreadLocalMetaData
+from testlib.sa import Table, Column, Integer, String, func, Sequence, text
+from testlib import TestBase, testing
class BindTest(TestBase):
@@ -41,7 +42,7 @@ class BindTest(TestBase):
try:
meth()
assert False
- except exceptions.UnboundExecutionError, e:
+ except exc.UnboundExecutionError, e:
self.assertEquals(
str(e),
"The MetaData "
@@ -59,7 +60,7 @@ class BindTest(TestBase):
try:
meth()
assert False
- except exceptions.UnboundExecutionError, e:
+ except exc.UnboundExecutionError, e:
self.assertEquals(
str(e),
"The Table 'test_table' "
@@ -71,6 +72,10 @@ class BindTest(TestBase):
@testing.future
def test_create_drop_err2(self):
+ metadata = MetaData()
+ table = Table('test_table', metadata,
+ Column('foo', Integer))
+
for meth in [
table.exists,
table.create,
@@ -79,7 +84,7 @@ class BindTest(TestBase):
try:
meth()
assert False
- except exceptions.UnboundExecutionError, e:
+ except exc.UnboundExecutionError, e:
self.assertEquals(
str(e),
"The Table 'test_table' "
@@ -201,7 +206,7 @@ class BindTest(TestBase):
assert e.bind is None
e.execute()
assert False
- except exceptions.UnboundExecutionError, e:
+ except exc.UnboundExecutionError, e:
assert str(e).endswith(
'is not bound and does not support direct '
'execution. Supply this statement to a Connection or '
@@ -248,7 +253,7 @@ class BindTest(TestBase):
try:
sess.flush()
assert False
- except exceptions.InvalidRequestError, e:
+ except exc.InvalidRequestError, e:
assert str(e).startswith("Could not locate any Engine or Connection bound to mapper")
finally:
if isinstance(bind, engine.Connection):
diff --git a/test/engine/ddlevents.py b/test/engine/ddlevents.py
index 258c61412..117ee1219 100644
--- a/test/engine/ddlevents.py
+++ b/test/engine/ddlevents.py
@@ -1,9 +1,9 @@
import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import exceptions
from sqlalchemy.schema import DDL
-import sqlalchemy
-from testlib import *
+from sqlalchemy import create_engine
+from testlib.sa import MetaData, Table, Column, Integer, String
+import testlib.sa as tsa
+from testlib import TestBase, testing
class DDLEventTest(TestBase):
@@ -294,7 +294,7 @@ class DDLExecutionTest(TestBase):
try:
r = eval(py)
assert False
- except exceptions.UnboundExecutionError:
+ except tsa.exc.UnboundExecutionError:
pass
for bind in engine, cx:
@@ -310,7 +310,7 @@ class DDLTest(TestBase):
engine = create_engine(testing.db.name + '://',
strategy='mock', executor=executor)
engine.dialect.identifier_preparer = \
- sqlalchemy.sql.compiler.IdentifierPreparer(engine.dialect)
+ tsa.sql.compiler.IdentifierPreparer(engine.dialect)
return engine
def test_tokens(self):
@@ -324,7 +324,7 @@ class DDLTest(TestBase):
ddl = DDL('%(schema)s-%(table)s-%(fullname)s')
self.assertEquals(ddl._expand(sane_alone, bind), '-t-t')
- self.assertEquals(ddl._expand(sane_schema, bind), '"s"-t-s.t')
+ self.assertEquals(ddl._expand(sane_schema, bind), 's-t-s.t')
self.assertEquals(ddl._expand(insane_alone, bind), '-"t t"-"t t"')
self.assertEquals(ddl._expand(insane_schema, bind),
'"s s"-"t t"-"s s"."t t"')
diff --git a/test/engine/execute.py b/test/engine/execute.py
index 260a05e27..36a6bc317 100644
--- a/test/engine/execute.py
+++ b/test/engine/execute.py
@@ -1,8 +1,13 @@
import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import exceptions
-from testlib import *
+import re
+from sqlalchemy.interfaces import ConnectionProxy
+from testlib.sa import MetaData, Table, Column, Integer, String, INT, \
+ VARCHAR, func
+import testlib.sa as tsa
+from testlib import TestBase, testing, engines
+
+users, metadata = None, None
class ExecuteTest(TestBase):
def setUpAll(self):
global users, metadata
@@ -70,8 +75,85 @@ class ExecuteTest(TestBase):
try:
conn.execute("osdjafioajwoejoasfjdoifjowejfoawejqoijwef")
assert False
- except exceptions.DBAPIError:
+ except tsa.exc.DBAPIError:
assert True
+class ProxyConnectionTest(TestBase):
+ def test_proxy(self):
+
+ stmts = []
+ cursor_stmts = []
+
+ class MyProxy(ConnectionProxy):
+ def execute(self, conn, execute, clauseelement, *multiparams, **params):
+ stmts.append(
+ (str(clauseelement), params,multiparams)
+ )
+ return execute(clauseelement, *multiparams, **params)
+
+ def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+ cursor_stmts.append(
+ (statement, parameters, None)
+ )
+ return execute(cursor, statement, parameters, context)
+
+ def assert_stmts(expected, received):
+ for stmt, params, posn in expected:
+ if not received:
+ assert False
+ while received:
+ teststmt, testparams, testmultiparams = received.pop(0)
+ teststmt = re.compile(r'[\n\t ]+', re.M).sub(' ', teststmt).strip()
+ if teststmt.startswith(stmt) and (testparams==params or testparams==posn):
+ break
+
+ for engine in (
+ engines.testing_engine(options=dict(proxy=MyProxy())),
+ engines.testing_engine(options=dict(proxy=MyProxy(), strategy='threadlocal'))
+ ):
+ m = MetaData(engine)
+
+ t1 = Table('t1', m, Column('c1', Integer, primary_key=True), Column('c2', String(50), default=func.lower('Foo'), primary_key=True))
+
+ m.create_all()
+ try:
+ t1.insert().execute(c1=5, c2='some data')
+ t1.insert().execute(c1=6)
+ assert engine.execute("select * from t1").fetchall() == [(5, 'some data'), (6, 'foo')]
+ finally:
+ m.drop_all()
+
+ engine.dispose()
+
+ compiled = [
+ ("CREATE TABLE t1", {}, None),
+ ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, None),
+ ("INSERT INTO t1 (c1, c2)", {'c1': 6}, None),
+ ("select * from t1", {}, None),
+ ("DROP TABLE t1", {}, None)
+ ]
+
+ if engine.dialect.preexecute_pk_sequences:
+ cursor = [
+ ("CREATE TABLE t1", {}, None),
+ ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']),
+ ("SELECT lower", {'lower_2':'Foo'}, ['Foo']),
+ ("INSERT INTO t1 (c1, c2)", {'c2': 'foo', 'c1': 6}, [6, 'foo']),
+ ("select * from t1", {}, None),
+ ("DROP TABLE t1", {}, None)
+ ]
+ else:
+ cursor = [
+ ("CREATE TABLE t1", {}, None),
+ ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']),
+ ("INSERT INTO t1 (c1, c2)", {'c1': 6, "lower_2":"Foo"}, [6, "Foo"]), # bind param name 'lower_2' might be incorrect
+ ("select * from t1", {}, None),
+ ("DROP TABLE t1", {}, None)
+ ]
+
+ assert_stmts(compiled, stmts)
+ assert_stmts(cursor, cursor_stmts)
+
+
if __name__ == "__main__":
testenv.main()
diff --git a/test/engine/metadata.py b/test/engine/metadata.py
index 22cdaafee..90f8a00a8 100644
--- a/test/engine/metadata.py
+++ b/test/engine/metadata.py
@@ -1,8 +1,11 @@
import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import exceptions
-from testlib import *
import pickle
+from sqlalchemy import MetaData
+from testlib.sa import Table, Column, Integer, String, UniqueConstraint, \
+ CheckConstraint, ForeignKey
+import testlib.sa as tsa
+from testlib import TestBase, ComparesTables, testing
+
class MetaDataTest(TestBase, ComparesTables):
def test_metadata_connect(self):
@@ -30,7 +33,7 @@ class MetaDataTest(TestBase, ComparesTables):
t2 = Table('table1', metadata, Column('col1', Integer, primary_key=True),
Column('col2', String(20)))
assert False
- except exceptions.InvalidRequestError, e:
+ except tsa.exc.InvalidRequestError, e:
assert str(e) == "Table 'table1' is already defined for this MetaData instance. Specify 'useexisting=True' to redefine options and columns on an existing Table object."
finally:
metadata.drop_all()
@@ -109,7 +112,7 @@ class MetaDataTest(TestBase, ComparesTables):
meta.drop_all(testing.db)
def test_nonexistent(self):
- self.assertRaises(exceptions.NoSuchTableError, Table,
+ self.assertRaises(tsa.exc.NoSuchTableError, Table,
'fake_table',
MetaData(testing.db), autoload=True)
diff --git a/test/engine/parseconnect.py b/test/engine/parseconnect.py
index 117c3ed4b..1f7d09c9d 100644
--- a/test/engine/parseconnect.py
+++ b/test/engine/parseconnect.py
@@ -1,9 +1,9 @@
import testenv; testenv.configure_for_tests()
import ConfigParser, StringIO
-from sqlalchemy import *
-from sqlalchemy import exceptions, pool, engine
import sqlalchemy.engine.url as url
-from testlib import *
+from sqlalchemy import create_engine, engine_from_config
+import testlib.sa as tsa
+from testlib import TestBase
class ParseConnectTest(TestBase):
@@ -92,10 +92,10 @@ pool_timeout=10
}
prefixed = dict(ini.items('prefixed'))
- self.assert_(engine._coerce_config(prefixed, 'sqlalchemy.') == expected)
+ self.assert_(tsa.engine._coerce_config(prefixed, 'sqlalchemy.') == expected)
plain = dict(ini.items('plain'))
- self.assert_(engine._coerce_config(plain, '') == expected)
+ self.assert_(tsa.engine._coerce_config(plain, '') == expected)
def test_engine_from_config(self):
dbapi = MockDBAPI()
@@ -181,7 +181,7 @@ pool_timeout=10
try:
c = e.connect()
assert False
- except exceptions.DBAPIError:
+ except tsa.exc.DBAPIError:
assert True
def test_urlattr(self):
@@ -200,11 +200,11 @@ pool_timeout=10
assert e.pool._recycle == 50
# these args work for QueuePool
- e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=pool.QueuePool, module=MockDBAPI())
+ e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=MockDBAPI())
try:
# but not SingletonThreadPool
- e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=pool.SingletonThreadPool)
+ e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.SingletonThreadPool)
assert False
except TypeError:
assert True
diff --git a/test/engine/pool.py b/test/engine/pool.py
index 75cb08e3c..f2b74a45a 100644
--- a/test/engine/pool.py
+++ b/test/engine/pool.py
@@ -1,9 +1,8 @@
import testenv; testenv.configure_for_tests()
-import threading, thread, time, gc
-import sqlalchemy.pool as pool
-import sqlalchemy.interfaces as interfaces
-import sqlalchemy.exceptions as exceptions
-from testlib import *
+import threading, time, gc
+from sqlalchemy import pool
+import testlib.sa as tsa
+from testlib import TestBase
mcid = 1
@@ -127,7 +126,7 @@ class PoolTest(TestBase):
try:
c4 = p.connect()
assert False
- except exceptions.TimeoutError, e:
+ except tsa.exc.TimeoutError, e:
assert int(time.time() - now) == 2
def test_timeout_race(self):
@@ -145,7 +144,7 @@ class PoolTest(TestBase):
now = time.time()
try:
c1 = p.connect()
- except exceptions.TimeoutError, e:
+ except tsa.exc.TimeoutError, e:
timeouts.append(int(time.time()) - now)
continue
time.sleep(4)
@@ -181,7 +180,7 @@ class PoolTest(TestBase):
peaks.append(p.overflow())
con.close()
del con
- except exceptions.TimeoutError:
+ except tsa.exc.TimeoutError:
pass
threads = []
for i in xrange(thread_count):
@@ -444,7 +443,7 @@ class PoolTest(TestBase):
# con can be None if invalidated
assert record is not None
self.checked_in.append(con)
- class ListenAll(interfaces.PoolListener, InstrumentingListener):
+ class ListenAll(tsa.interfaces.PoolListener, InstrumentingListener):
pass
class ListenConnect(InstrumentingListener):
def connect(self, con, record):
diff --git a/test/engine/reconnect.py b/test/engine/reconnect.py
index d0d037a34..1539d80e0 100644
--- a/test/engine/reconnect.py
+++ b/test/engine/reconnect.py
@@ -1,7 +1,8 @@
import testenv; testenv.configure_for_tests()
-import sys, weakref
-from sqlalchemy import create_engine, exceptions, select, MetaData, Table, Column, Integer, String
-from testlib import *
+import weakref
+from testlib.sa import select, MetaData, Table, Column, Integer, String
+import testlib.sa as tsa
+from testlib import TestBase, testing, engines
class MockDisconnect(Exception):
@@ -43,13 +44,14 @@ class MockCursor(object):
def close(self):
pass
+db, dbapi = None, None
class MockReconnectTest(TestBase):
def setUp(self):
global db, dbapi
dbapi = MockDBAPI()
# create engine using our current dburi
- db = create_engine('postgres://foo:bar@localhost/test', module=dbapi)
+ db = tsa.create_engine('postgres://foo:bar@localhost/test', module=dbapi)
# monkeypatch disconnect checker
db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect)
@@ -80,7 +82,7 @@ class MockReconnectTest(TestBase):
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError:
+ except tsa.exc.DBAPIError:
pass
# assert was invalidated
@@ -108,7 +110,7 @@ class MockReconnectTest(TestBase):
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError:
+ except tsa.exc.DBAPIError:
pass
# assert was invalidated
@@ -120,7 +122,7 @@ class MockReconnectTest(TestBase):
try:
conn.execute(select([1]))
assert False
- except exceptions.InvalidRequestError, e:
+ except tsa.exc.InvalidRequestError, e:
assert str(e) == "Can't reconnect until invalid transaction is rolled back"
assert trans.is_active
@@ -128,7 +130,7 @@ class MockReconnectTest(TestBase):
try:
trans.commit()
assert False
- except exceptions.InvalidRequestError, e:
+ except tsa.exc.InvalidRequestError, e:
assert str(e) == "Can't reconnect until invalid transaction is rolled back"
assert trans.is_active
@@ -154,7 +156,7 @@ class MockReconnectTest(TestBase):
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError:
+ except tsa.exc.DBAPIError:
pass
assert not conn.closed
@@ -168,7 +170,7 @@ class MockReconnectTest(TestBase):
assert not conn.invalidated
assert len(dbapi.connections) == 1
-
+engine = None
class RealReconnectTest(TestBase):
def setUp(self):
global engine
@@ -188,7 +190,7 @@ class RealReconnectTest(TestBase):
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError, e:
+ except tsa.exc.DBAPIError, e:
if not e.connection_invalidated:
raise
@@ -204,7 +206,7 @@ class RealReconnectTest(TestBase):
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError, e:
+ except tsa.exc.DBAPIError, e:
if not e.connection_invalidated:
raise
assert conn.invalidated
@@ -212,7 +214,7 @@ class RealReconnectTest(TestBase):
assert not conn.invalidated
conn.close()
-
+
def test_close(self):
conn = engine.connect()
self.assertEquals(conn.execute(select([1])).scalar(), 1)
@@ -223,7 +225,7 @@ class RealReconnectTest(TestBase):
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError, e:
+ except tsa.exc.DBAPIError, e:
if not e.connection_invalidated:
raise
@@ -244,7 +246,7 @@ class RealReconnectTest(TestBase):
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError, e:
+ except tsa.exc.DBAPIError, e:
if not e.connection_invalidated:
raise
@@ -255,7 +257,7 @@ class RealReconnectTest(TestBase):
try:
conn.execute(select([1]))
assert False
- except exceptions.InvalidRequestError, e:
+ except tsa.exc.InvalidRequestError, e:
assert str(e) == "Can't reconnect until invalid transaction is rolled back"
assert trans.is_active
@@ -263,7 +265,7 @@ class RealReconnectTest(TestBase):
try:
trans.commit()
assert False
- except exceptions.InvalidRequestError, e:
+ except tsa.exc.InvalidRequestError, e:
assert str(e) == "Can't reconnect until invalid transaction is rolled back"
assert trans.is_active
@@ -275,6 +277,7 @@ class RealReconnectTest(TestBase):
self.assertEquals(conn.execute(select([1])).scalar(), 1)
assert not conn.invalidated
+meta, table, engine = None, None, None
class InvalidateDuringResultTest(TestBase):
def setUp(self):
global meta, table, engine
@@ -287,28 +290,28 @@ class InvalidateDuringResultTest(TestBase):
table.insert().execute(
[{'id':i, 'name':'row %d' % i} for i in range(1, 100)]
)
-
+
def tearDown(self):
meta.drop_all()
engine.dispose()
-
- @testing.fails_on('mysql')
+
+ @testing.fails_on('mysql')
def test_invalidate_on_results(self):
conn = engine.connect()
-
+
result = conn.execute("select * from sometable")
for x in xrange(20):
result.fetchone()
-
+
engine.test_shutdown()
try:
result.fetchone()
assert False
- except exceptions.DBAPIError, e:
+ except tsa.exc.DBAPIError, e:
if not e.connection_invalidated:
raise
assert conn.invalidated
-
+
if __name__ == '__main__':
testenv.main()
diff --git a/test/engine/reflection.py b/test/engine/reflection.py
index 2ace3306a..64c8489ed 100644
--- a/test/engine/reflection.py
+++ b/test/engine/reflection.py
@@ -1,12 +1,13 @@
import testenv; testenv.configure_for_tests()
import StringIO, unicodedata
-from sqlalchemy import *
-from sqlalchemy import exceptions
-from sqlalchemy import types as sqltypes
-from testlib import *
-from testlib import engines
+import sqlalchemy as sa
+from testlib.sa import MetaData, Table, Column
+from testlib import TestBase, ComparesTables, testing, engines, sa as tsa
+from testlib.compat import set
+metadata, users = None, None
+
class ReflectionTest(TestBase, ComparesTables):
@testing.exclude('mysql', '<', (4, 1, 1))
@@ -14,35 +15,38 @@ class ReflectionTest(TestBase, ComparesTables):
meta = MetaData(testing.db)
users = Table('engine_users', meta,
- Column('user_id', INT, primary_key=True),
- Column('user_name', VARCHAR(20), nullable=False),
- Column('test1', CHAR(5), nullable=False),
- Column('test2', Float(5), nullable=False),
- Column('test3', Text),
- Column('test4', Numeric, nullable = False),
- Column('test5', DateTime),
- Column('parent_user_id', Integer, ForeignKey('engine_users.user_id')),
- Column('test6', DateTime, nullable=False),
- Column('test7', Text),
- Column('test8', Binary),
- Column('test_passivedefault2', Integer, PassiveDefault("5")),
- Column('test9', Binary(100)),
- Column('test_numeric', Numeric()),
+ Column('user_id', sa.INT, primary_key=True),
+ Column('user_name', sa.VARCHAR(20), nullable=False),
+ Column('test1', sa.CHAR(5), nullable=False),
+ Column('test2', sa.Float(5), nullable=False),
+ Column('test3', sa.Text),
+ Column('test4', sa.Numeric, nullable = False),
+ Column('test5', sa.DateTime),
+ Column('parent_user_id', sa.Integer,
+ sa.ForeignKey('engine_users.user_id')),
+ Column('test6', sa.DateTime, nullable=False),
+ Column('test7', sa.Text),
+ Column('test8', sa.Binary),
+ Column('test_passivedefault2', sa.Integer, sa.PassiveDefault("5")),
+ Column('test9', sa.Binary(100)),
+ Column('test_numeric', sa.Numeric()),
test_needs_fk=True,
)
addresses = Table('engine_email_addresses', meta,
- Column('address_id', Integer, primary_key = True),
- Column('remote_user_id', Integer, ForeignKey(users.c.user_id)),
- Column('email_address', String(20)),
+ Column('address_id', sa.Integer, primary_key = True),
+ Column('remote_user_id', sa.Integer, sa.ForeignKey(users.c.user_id)),
+ Column('email_address', sa.String(20)),
test_needs_fk=True,
)
meta.create_all()
try:
meta2 = MetaData()
- reflected_users = Table('engine_users', meta2, autoload=True, autoload_with=testing.db)
- reflected_addresses = Table('engine_email_addresses', meta2, autoload=True, autoload_with=testing.db)
+ reflected_users = Table('engine_users', meta2, autoload=True,
+ autoload_with=testing.db)
+ reflected_addresses = Table('engine_email_addresses', meta2,
+ autoload=True, autoload_with=testing.db)
self.assert_tables_equal(users, reflected_users)
self.assert_tables_equal(addresses, reflected_addresses)
finally:
@@ -51,22 +55,25 @@ class ReflectionTest(TestBase, ComparesTables):
def test_include_columns(self):
meta = MetaData(testing.db)
- foo = Table('foo', meta, *[Column(n, String(30)) for n in ['a', 'b', 'c', 'd', 'e', 'f']])
+ foo = Table('foo', meta, *[Column(n, sa.String(30))
+ for n in ['a', 'b', 'c', 'd', 'e', 'f']])
meta.create_all()
try:
meta2 = MetaData(testing.db)
- foo = Table('foo', meta2, autoload=True, include_columns=['b', 'f', 'e'])
+ foo = Table('foo', meta2, autoload=True,
+ include_columns=['b', 'f', 'e'])
# test that cols come back in original order
self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f'])
for c in ('b', 'f', 'e'):
assert c in foo.c
for c in ('a', 'c', 'd'):
assert c not in foo.c
-
+
# test against a table which is already reflected
meta3 = MetaData(testing.db)
foo = Table('foo', meta3, autoload=True)
- foo = Table('foo', meta3, include_columns=['b', 'f', 'e'], useexisting=True)
+ foo = Table('foo', meta3, include_columns=['b', 'f', 'e'],
+ useexisting=True)
self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f'])
for c in ('b', 'f', 'e'):
assert c in foo.c
@@ -79,7 +86,7 @@ class ReflectionTest(TestBase, ComparesTables):
def test_unknown_types(self):
meta = MetaData(testing.db)
t = Table("test", meta,
- Column('foo', DateTime))
+ Column('foo', sa.DateTime))
import sys
dialect_module = sys.modules[testing.db.dialect.__module__]
@@ -100,14 +107,14 @@ class ReflectionTest(TestBase, ComparesTables):
m2 = MetaData(testing.db)
t2 = Table("test", m2, autoload=True)
assert False
- except exceptions.SAWarning:
+ except tsa.exc.SAWarning:
assert True
@testing.emits_warning('Did not recognize type')
def warns():
m3 = MetaData(testing.db)
t3 = Table("test", m3, autoload=True)
- assert t3.c.foo.type.__class__ == sqltypes.NullType
+ assert t3.c.foo.type.__class__ == sa.types.NullType
finally:
dialect_module.ischema_names = ischema_names
@@ -117,9 +124,9 @@ class ReflectionTest(TestBase, ComparesTables):
meta = MetaData(testing.db)
table = Table(
'override_test', meta,
- Column('col1', Integer, primary_key=True),
- Column('col2', String(20)),
- Column('col3', Numeric)
+ Column('col1', sa.Integer, primary_key=True),
+ Column('col2', sa.String(20)),
+ Column('col3', sa.Numeric)
)
table.create()
@@ -127,12 +134,12 @@ class ReflectionTest(TestBase, ComparesTables):
try:
table = Table(
'override_test', meta2,
- Column('col2', Unicode()),
- Column('col4', String(30)), autoload=True)
+ Column('col2', sa.Unicode()),
+ Column('col4', sa.String(30)), autoload=True)
- self.assert_(isinstance(table.c.col1.type, Integer))
- self.assert_(isinstance(table.c.col2.type, Unicode))
- self.assert_(isinstance(table.c.col4.type, String))
+ self.assert_(isinstance(table.c.col1.type, sa.Integer))
+ self.assert_(isinstance(table.c.col2.type, sa.Unicode))
+ self.assert_(isinstance(table.c.col4.type, sa.String))
finally:
table.drop()
@@ -142,18 +149,19 @@ class ReflectionTest(TestBase, ComparesTables):
meta = MetaData(testing.db)
users = Table('users', meta,
- Column('id', Integer, primary_key=True),
- Column('name', String(30)))
+ Column('id', sa.Integer, primary_key=True),
+ Column('name', sa.String(30)))
addresses = Table('addresses', meta,
- Column('id', Integer, primary_key=True),
- Column('street', String(30)))
+ Column('id', sa.Integer, primary_key=True),
+ Column('street', sa.String(30)))
meta.create_all()
try:
meta2 = MetaData(testing.db)
a2 = Table('addresses', meta2,
- Column('id', Integer, ForeignKey('users.id'), primary_key=True),
+ Column('id', sa.Integer,
+ sa.ForeignKey('users.id'), primary_key=True),
autoload=True)
u2 = Table('users', meta2, autoload=True)
@@ -164,7 +172,8 @@ class ReflectionTest(TestBase, ComparesTables):
meta3 = MetaData(testing.db)
u3 = Table('users', meta3, autoload=True)
a3 = Table('addresses', meta3,
- Column('id', Integer, ForeignKey('users.id'), primary_key=True),
+ Column('id', sa.Integer, sa.ForeignKey('users.id'),
+ primary_key=True),
autoload=True)
assert list(a3.primary_key) == [a3.c.id]
@@ -180,18 +189,18 @@ class ReflectionTest(TestBase, ComparesTables):
meta = MetaData(testing.db)
users = Table('users', meta,
- Column('id', Integer, primary_key=True),
- Column('name', String(30)))
+ Column('id', sa.Integer, primary_key=True),
+ Column('name', sa.String(30)))
addresses = Table('addresses', meta,
- Column('id', Integer, primary_key=True),
- Column('street', String(30)),
- Column('user_id', Integer))
+ Column('id', sa.Integer, primary_key=True),
+ Column('street', sa.String(30)),
+ Column('user_id', sa.Integer))
meta.create_all()
try:
meta2 = MetaData(testing.db)
a2 = Table('addresses', meta2,
- Column('user_id', Integer, ForeignKey('users.id')),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
autoload=True)
u2 = Table('users', meta2, autoload=True)
@@ -205,19 +214,19 @@ class ReflectionTest(TestBase, ComparesTables):
meta3 = MetaData(testing.db)
u3 = Table('users', meta3, autoload=True)
a3 = Table('addresses', meta3,
- Column('user_id', Integer, ForeignKey('users.id')),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
autoload=True)
assert u3.join(a3).onclause == u3.c.id==a3.c.user_id
meta4 = MetaData(testing.db)
u4 = Table('users', meta4,
- Column('id', Integer, key='u_id', primary_key=True),
+ Column('id', sa.Integer, key='u_id', primary_key=True),
autoload=True)
a4 = Table('addresses', meta4,
- Column('id', Integer, key='street', primary_key=True),
- Column('street', String(30), key='user_id'),
- Column('user_id', Integer, ForeignKey('users.u_id'),
+ Column('id', sa.Integer, key='street', primary_key=True),
+ Column('street', sa.String(30), key='user_id'),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.u_id'),
key='id'),
autoload=True)
@@ -237,19 +246,19 @@ class ReflectionTest(TestBase, ComparesTables):
meta = MetaData(testing.db)
users = Table('users', meta,
- Column('id', Integer, primary_key=True),
- Column('name', String(30)),
+ Column('id', sa.Integer, primary_key=True),
+ Column('name', sa.String(30)),
test_needs_fk=True)
addresses = Table('addresses', meta,
- Column('id', Integer,primary_key=True),
- Column('user_id', Integer, ForeignKey('users.id')),
+ Column('id', sa.Integer, primary_key=True),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
test_needs_fk=True)
meta.create_all()
try:
meta2 = MetaData(testing.db)
a2 = Table('addresses', meta2,
- Column('user_id',Integer, ForeignKey('users.id')),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
autoload=True)
u2 = Table('users', meta2, autoload=True)
@@ -263,11 +272,11 @@ class ReflectionTest(TestBase, ComparesTables):
meta2 = MetaData(testing.db)
u2 = Table('users', meta2,
- Column('id', Integer, primary_key=True),
+ Column('id', sa.Integer, primary_key=True),
autoload=True)
a2 = Table('addresses', meta2,
- Column('id', Integer, primary_key=True),
- Column('user_id',Integer, ForeignKey('users.id')),
+ Column('id', sa.Integer, primary_key=True),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
autoload=True)
assert len(a2.foreign_keys) == 1
@@ -279,31 +288,31 @@ class ReflectionTest(TestBase, ComparesTables):
assert u2.join(a2).onclause == u2.c.id==a2.c.user_id
finally:
meta.drop_all()
-
+
def test_use_existing(self):
meta = MetaData(testing.db)
users = Table('users', meta,
- Column('id', Integer, primary_key=True),
- Column('name', String(30)),
+ Column('id', sa.Integer, primary_key=True),
+ Column('name', sa.String(30)),
test_needs_fk=True)
addresses = Table('addresses', meta,
- Column('id', Integer,primary_key=True),
- Column('user_id', Integer, ForeignKey('users.id')),
- Column('data', String(100)),
+ Column('id', sa.Integer,primary_key=True),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
+ Column('data', sa.String(100)),
test_needs_fk=True)
meta.create_all()
try:
meta2 = MetaData(testing.db)
- addresses = Table('addresses', meta2, Column('data', Unicode), autoload=True)
+ addresses = Table('addresses', meta2, Column('data', sa.Unicode), autoload=True)
try:
- users = Table('users', meta2, Column('name', Unicode), autoload=True)
+ users = Table('users', meta2, Column('name', sa.Unicode), autoload=True)
assert False
- except exceptions.InvalidRequestError, err:
+ except tsa.exc.InvalidRequestError, err:
assert str(err) == "Table 'users' is already defined for this MetaData instance. Specify 'useexisting=True' to redefine options and columns on an existing Table object."
- users = Table('users', meta2, Column('name', Unicode), autoload=True, useexisting=True)
- assert isinstance(users.c.name.type, Unicode)
+ users = Table('users', meta2, Column('name', sa.Unicode), autoload=True, useexisting=True)
+ assert isinstance(users.c.name.type, sa.Unicode)
assert not users.quote
@@ -328,8 +337,8 @@ class ReflectionTest(TestBase, ComparesTables):
try:
metadata = MetaData(bind=testing.db)
book = Table('book', metadata, autoload=True)
- assert book.c.id in book.primary_key
- assert book.c.series not in book.primary_key
+ assert book.primary_key.contains_column(book.c.id)
+ assert not book.primary_key.contains_column(book.c.series)
assert len(book.primary_key) == 1
finally:
testing.db.execute("drop table book")
@@ -337,14 +346,14 @@ class ReflectionTest(TestBase, ComparesTables):
def test_fk_error(self):
metadata = MetaData(testing.db)
slots_table = Table('slots', metadata,
- Column('slot_id', Integer, primary_key=True),
- Column('pkg_id', Integer, ForeignKey('pkgs.pkg_id')),
- Column('slot', String(128)),
+ Column('slot_id', sa.Integer, primary_key=True),
+ Column('pkg_id', sa.Integer, sa.ForeignKey('pkgs.pkg_id')),
+ Column('slot', sa.String(128)),
)
try:
metadata.create_all()
assert False
- except exceptions.InvalidRequestError, err:
+ except tsa.exc.InvalidRequestError, err:
assert str(err) == "Could not find table 'pkgs' with which to generate a foreign key"
def test_composite_pks(self):
@@ -363,9 +372,9 @@ class ReflectionTest(TestBase, ComparesTables):
try:
metadata = MetaData(bind=testing.db)
book = Table('book', metadata, autoload=True)
- assert book.c.id in book.primary_key
- assert book.c.isbn in book.primary_key
- assert book.c.series not in book.primary_key
+ assert book.primary_key.contains_column(book.c.id)
+ assert book.primary_key.contains_column(book.c.isbn)
+ assert not book.primary_key.contains_column(book.c.series)
assert len(book.primary_key) == 2
finally:
testing.db.execute("drop table book")
@@ -377,20 +386,20 @@ class ReflectionTest(TestBase, ComparesTables):
meta = MetaData(testing.db)
multi = Table(
'multi', meta,
- Column('multi_id', Integer, primary_key=True),
- Column('multi_rev', Integer, primary_key=True),
- Column('multi_hoho', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('val', String(100)),
+ Column('multi_id', sa.Integer, primary_key=True),
+ Column('multi_rev', sa.Integer, primary_key=True),
+ Column('multi_hoho', sa.Integer, primary_key=True),
+ Column('name', sa.String(50), nullable=False),
+ Column('val', sa.String(100)),
test_needs_fk=True,
)
multi2 = Table('multi2', meta,
- Column('id', Integer, primary_key=True),
- Column('foo', Integer),
- Column('bar', Integer),
- Column('lala', Integer),
- Column('data', String(50)),
- ForeignKeyConstraint(['foo', 'bar', 'lala'], ['multi.multi_id', 'multi.multi_rev', 'multi.multi_hoho']),
+ Column('id', sa.Integer, primary_key=True),
+ Column('foo', sa.Integer),
+ Column('bar', sa.Integer),
+ Column('lala', sa.Integer),
+ Column('data', sa.String(50)),
+ sa.ForeignKeyConstraint(['foo', 'bar', 'lala'], ['multi.multi_id', 'multi.multi_rev', 'multi.multi_hoho']),
test_needs_fk=True,
)
meta.create_all()
@@ -401,8 +410,8 @@ class ReflectionTest(TestBase, ComparesTables):
table2 = Table('multi2', meta2, autoload=True, autoload_with=testing.db)
self.assert_tables_equal(multi, table)
self.assert_tables_equal(multi2, table2)
- j = join(table, table2)
- self.assert_(and_(table.c.multi_id==table2.c.foo, table.c.multi_rev==table2.c.bar, table.c.multi_hoho==table2.c.lala).compare(j.onclause))
+ j = sa.join(table, table2)
+ self.assert_(sa.and_(table.c.multi_id==table2.c.foo, table.c.multi_rev==table2.c.bar, table.c.multi_hoho==table2.c.lala).compare(j.onclause))
finally:
meta.drop_all()
@@ -412,10 +421,10 @@ class ReflectionTest(TestBase, ComparesTables):
# check a table that uses an SQL reserved name doesn't cause an error
meta = MetaData(testing.db)
table_a = Table('select', meta,
- Column('not', Integer, primary_key=True),
- Column('from', String(12), nullable=False),
- UniqueConstraint('from', name='when'))
- Index('where', table_a.c['from'])
+ Column('not', sa.Integer, primary_key=True),
+ Column('from', sa.String(12), nullable=False),
+ sa.UniqueConstraint('from', name='when'))
+ sa.Index('where', table_a.c['from'])
# There's currently no way to calculate identifier case normalization
# in isolation, so...
@@ -426,17 +435,17 @@ class ReflectionTest(TestBase, ComparesTables):
quoter = meta.bind.dialect.identifier_preparer.quote_identifier
table_b = Table('false', meta,
- Column('create', Integer, primary_key=True),
- Column('true', Integer, ForeignKey('select.not')),
- CheckConstraint('%s <> 1' % quoter(check_col),
+ Column('create', sa.Integer, primary_key=True),
+ Column('true', sa.Integer, sa.ForeignKey('select.not')),
+ sa.CheckConstraint('%s <> 1' % quoter(check_col),
name='limit'))
table_c = Table('is', meta,
- Column('or', Integer, nullable=False, primary_key=True),
- Column('join', Integer, nullable=False, primary_key=True),
- PrimaryKeyConstraint('or', 'join', name='to'))
+ Column('or', sa.Integer, nullable=False, primary_key=True),
+ Column('join', sa.Integer, nullable=False, primary_key=True),
+ sa.PrimaryKeyConstraint('or', 'join', name='to'))
- index_c = Index('else', table_c.c.join)
+ index_c = sa.Index('else', table_c.c.join)
meta.create_all()
@@ -462,7 +471,7 @@ class ReflectionTest(TestBase, ComparesTables):
baseline = MetaData(testing.db)
for name in names:
- Table(name, baseline, Column('id', Integer, primary_key=True))
+ Table(name, baseline, Column('id', sa.Integer, primary_key=True))
baseline.create_all()
try:
@@ -484,7 +493,7 @@ class ReflectionTest(TestBase, ComparesTables):
try:
m4.reflect(only=['rt_a', 'rt_f'])
self.assert_(False)
- except exceptions.InvalidRequestError, e:
+ except tsa.exc.InvalidRequestError, e:
self.assert_(e.args[0].endswith('(rt_f)'))
m5 = MetaData(testing.db)
@@ -501,7 +510,7 @@ class ReflectionTest(TestBase, ComparesTables):
try:
m8 = MetaData(reflect=True)
self.assert_(False)
- except exceptions.ArgumentError, e:
+ except tsa.exc.ArgumentError, e:
self.assert_(
e.args[0] ==
"A bind must be supplied in conjunction with reflect=True")
@@ -521,27 +530,27 @@ class CreateDropTest(TestBase):
global metadata, users
metadata = MetaData()
users = Table('users', metadata,
- Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key=True),
- Column('user_name', String(40)),
+ Column('user_id', sa.Integer, sa.Sequence('user_id_seq', optional=True), primary_key=True),
+ Column('user_name', sa.String(40)),
)
addresses = Table('email_addresses', metadata,
- Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
- Column('user_id', Integer, ForeignKey(users.c.user_id)),
- Column('email_address', String(40)),
+ Column('address_id', sa.Integer, sa.Sequence('address_id_seq', optional=True), primary_key = True),
+ Column('user_id', sa.Integer, sa.ForeignKey(users.c.user_id)),
+ Column('email_address', sa.String(40)),
)
orders = Table('orders', metadata,
- Column('order_id', Integer, Sequence('order_id_seq', optional=True), primary_key = True),
- Column('user_id', Integer, ForeignKey(users.c.user_id)),
- Column('description', String(50)),
- Column('isopen', Integer),
+ Column('order_id', sa.Integer, sa.Sequence('order_id_seq', optional=True), primary_key = True),
+ Column('user_id', sa.Integer, sa.ForeignKey(users.c.user_id)),
+ Column('description', sa.String(50)),
+ Column('isopen', sa.Integer),
)
orderitems = Table('items', metadata,
- Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True),
- Column('order_id', INT, ForeignKey("orders")),
- Column('item_name', VARCHAR(50)),
+ Column('item_id', sa.INT, sa.Sequence('items_id_seq', optional=True), primary_key = True),
+ Column('order_id', sa.INT, sa.ForeignKey("orders")),
+ Column('item_name', sa.VARCHAR(50)),
)
def test_sorter( self ):
@@ -590,10 +599,10 @@ class SchemaManipulationTest(TestBase):
def test_append_constraint_unique(self):
meta = MetaData()
- users = Table('users', meta, Column('id', Integer))
- addresses = Table('addresses', meta, Column('id', Integer), Column('user_id', Integer))
+ users = Table('users', meta, Column('id', sa.Integer))
+ addresses = Table('addresses', meta, Column('id', sa.Integer), Column('user_id', sa.Integer))
- fk = ForeignKeyConstraint(['user_id'],[users.c.id])
+ fk = sa.ForeignKeyConstraint(['user_id'],[users.c.id])
addresses.append_constraint(fk)
addresses.append_constraint(fk)
@@ -616,7 +625,7 @@ class UnicodeReflectionTest(TestBase):
names = set([u'plain', u'Unit\u00e9ble', u'\u6e2c\u8a66'])
for name in names:
- Table(name, metadata, Column('id', Integer, Sequence(name + "_id_seq"), primary_key=True))
+ Table(name, metadata, Column('id', sa.Integer, sa.Sequence(name + "_id_seq"), primary_key=True))
metadata.create_all()
reflected = set(bind.table_names())
@@ -642,18 +651,18 @@ class SchemaTest(TestBase):
def test_iteration(self):
metadata = MetaData()
table1 = Table('table1', metadata,
- Column('col1', Integer, primary_key=True),
+ Column('col1', sa.Integer, primary_key=True),
schema='someschema')
table2 = Table('table2', metadata,
- Column('col1', Integer, primary_key=True),
- Column('col2', Integer, ForeignKey('someschema.table1.col1')),
+ Column('col1', sa.Integer, primary_key=True),
+ Column('col2', sa.Integer, sa.ForeignKey('someschema.table1.col1')),
schema='someschema')
# ensure this doesnt crash
print [t for t in metadata.table_iterator()]
buf = StringIO.StringIO()
def foo(s, p=None):
buf.write(s)
- gen = create_engine(testing.db.name + "://", strategy="mock", executor=foo)
+ gen = sa.create_engine(testing.db.name + "://", strategy="mock", executor=foo)
gen = gen.dialect.schemagenerator(gen.dialect, gen)
gen.traverse(table1)
gen.traverse(table2)
@@ -681,12 +690,12 @@ class SchemaTest(TestBase):
metadata = MetaData(engine)
table1 = Table('table1', metadata,
- Column('col1', Integer, primary_key=True),
+ Column('col1', sa.Integer, primary_key=True),
schema=schema)
table2 = Table('table2', metadata,
- Column('col1', Integer, primary_key=True),
- Column('col2', Integer,
- ForeignKey('%s.table1.col1' % schema)),
+ Column('col1', sa.Integer, primary_key=True),
+ Column('col2', sa.Integer,
+ sa.ForeignKey('%s.table1.col1' % schema)),
schema=schema)
try:
metadata.create_all()
@@ -704,8 +713,8 @@ class HasSequenceTest(TestBase):
global metadata, users
metadata = MetaData()
users = Table('users', metadata,
- Column('user_id', Integer, Sequence('user_id_seq'), primary_key=True),
- Column('user_name', String(40)),
+ Column('user_id', sa.Integer, sa.Sequence('user_id_seq'), primary_key=True),
+ Column('user_name', sa.String(40)),
)
@testing.unsupported('sqlite', 'mysql', 'mssql', 'access', 'sybase')
diff --git a/test/engine/transaction.py b/test/engine/transaction.py
index edae14da2..1cb6ba7a1 100644
--- a/test/engine/transaction.py
+++ b/test/engine/transaction.py
@@ -1,11 +1,11 @@
import testenv; testenv.configure_for_tests()
import sys, time, threading
-
-from sqlalchemy import *
-from sqlalchemy.orm import *
-from testlib import *
+from testlib.sa import create_engine, MetaData, Table, Column, INT, VARCHAR, \
+ Sequence, select, Integer, String, func, text
+from testlib import TestBase, testing
+users, metadata = None, None
class TransactionTest(TestBase):
def setUpAll(self):
global users, metadata
@@ -22,7 +22,7 @@ class TransactionTest(TestBase):
def tearDownAll(self):
users.drop(testing.db)
- def testcommits(self):
+ def test_commits(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -38,7 +38,7 @@ class TransactionTest(TestBase):
assert len(result.fetchall()) == 3
transaction.commit()
- def testrollback(self):
+ def test_rollback(self):
"""test a basic rollback"""
connection = testing.db.connect()
transaction = connection.begin()
@@ -51,7 +51,7 @@ class TransactionTest(TestBase):
assert len(result.fetchall()) == 0
connection.close()
- def testraise(self):
+ def test_raise(self):
connection = testing.db.connect()
transaction = connection.begin()
@@ -70,7 +70,7 @@ class TransactionTest(TestBase):
connection.close()
@testing.exclude('mysql', '<', (5, 0, 3))
- def testnestedrollback(self):
+ def test_nested_rollback(self):
connection = testing.db.connect()
try:
@@ -100,7 +100,7 @@ class TransactionTest(TestBase):
@testing.exclude('mysql', '<', (5, 0, 3))
- def testnesting(self):
+ def test_nesting(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -118,7 +118,7 @@ class TransactionTest(TestBase):
connection.close()
@testing.exclude('mysql', '<', (5, 0, 3))
- def testclose(self):
+ def test_close(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -139,7 +139,7 @@ class TransactionTest(TestBase):
connection.close()
@testing.exclude('mysql', '<', (5, 0, 3))
- def testclose2(self):
+ def test_close2(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -159,10 +159,8 @@ class TransactionTest(TestBase):
assert len(result.fetchall()) == 0
connection.close()
-
- @testing.unsupported('sqlite', 'mssql', 'sybase', 'access')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testnestedsubtransactionrollback(self):
+ @testing.requires.savepoints
+ def test_nested_subtransaction_rollback(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -178,9 +176,8 @@ class TransactionTest(TestBase):
)
connection.close()
- @testing.unsupported('sqlite', 'mssql', 'sybase', 'access')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testnestedsubtransactioncommit(self):
+ @testing.requires.savepoints
+ def test_nested_subtransaction_commit(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -196,9 +193,8 @@ class TransactionTest(TestBase):
)
connection.close()
- @testing.unsupported('sqlite', 'mssql', 'sybase', 'access')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testrollbacktosubtransaction(self):
+ @testing.requires.savepoints
+ def test_rollback_to_subtransaction(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -216,10 +212,8 @@ class TransactionTest(TestBase):
)
connection.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testtwophasetransaction(self):
+ @testing.requires.two_phase_transactions
+ def test_two_phase_transaction(self):
connection = testing.db.connect()
transaction = connection.begin_twophase()
@@ -246,10 +240,9 @@ class TransactionTest(TestBase):
)
connection.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testmixedtwophasetransaction(self):
+ @testing.requires.two_phase_transactions
+ @testing.requires.savepoints
+ def test_mixed_two_phase_transaction(self):
connection = testing.db.connect()
transaction = connection.begin_twophase()
@@ -281,11 +274,9 @@ class TransactionTest(TestBase):
)
connection.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- # fixme: see if this is still true and/or can be convert to fails_on()
- @testing.unsupported('mysql')
- def testtwophaserecover(self):
+ @testing.requires.two_phase_transactions
+ @testing.fails_on('mysql')
+ def test_two_phase_recover(self):
# MySQL recovery doesn't currently seem to work correctly
# Prepared transactions disappear when connections are closed and even
# when they aren't it doesn't seem possible to use the recovery id.
@@ -316,10 +307,8 @@ class TransactionTest(TestBase):
)
connection2.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testmultipletwophase(self):
+ @testing.requires.two_phase_transactions
+ def test_multiple_two_phase(self):
conn = testing.db.connect()
xa = conn.begin_twophase()
@@ -355,7 +344,7 @@ class AutoRollbackTest(TestBase):
metadata.drop_all(testing.db)
@testing.unsupported('sqlite')
- def testrollback_deadlock(self):
+ def test_rollback_deadlock(self):
"""test that returning connections to the pool clears any object locks."""
conn1 = testing.db.connect()
conn2 = testing.db.connect()
@@ -375,12 +364,13 @@ class AutoRollbackTest(TestBase):
users.drop(conn2)
conn2.close()
+foo = None
class ExplicitAutoCommitTest(TestBase):
- """test the 'autocommit' flag on select() and text() objects.
-
+ """test the 'autocommit' flag on select() and text() objects.
+
Requires Postgres so that we may define a custom function which modifies the database.
"""
-
+
__only_on__ = 'postgres'
def setUpAll(self):
@@ -392,13 +382,13 @@ class ExplicitAutoCommitTest(TestBase):
def tearDown(self):
foo.delete().execute()
-
+
def tearDownAll(self):
testing.db.execute("drop function insert_foo(varchar)")
metadata.drop_all()
-
+
def test_control(self):
- # test that not using autocommit does not commit
+ # test that not using autocommit does not commit
conn1 = testing.db.connect()
conn2 = testing.db.connect()
@@ -412,44 +402,45 @@ class ExplicitAutoCommitTest(TestBase):
trans.commit()
assert conn2.execute(select([foo.c.data])).fetchall() == [('data1',), ('moredata',)]
-
+
conn1.close()
conn2.close()
-
+
def test_explicit_compiled(self):
conn1 = testing.db.connect()
conn2 = testing.db.connect()
-
+
conn1.execute(select([func.insert_foo('data1')], autocommit=True))
assert conn2.execute(select([foo.c.data])).fetchall() == [('data1',)]
conn1.execute(select([func.insert_foo('data2')]).autocommit())
assert conn2.execute(select([foo.c.data])).fetchall() == [('data1',), ('data2',)]
-
+
conn1.close()
conn2.close()
-
+
def test_explicit_text(self):
conn1 = testing.db.connect()
conn2 = testing.db.connect()
-
+
conn1.execute(text("select insert_foo('moredata')", autocommit=True))
assert conn2.execute(select([foo.c.data])).fetchall() == [('moredata',)]
-
+
conn1.close()
conn2.close()
def test_implicit_text(self):
conn1 = testing.db.connect()
conn2 = testing.db.connect()
-
+
conn1.execute(text("insert into foo (data) values ('implicitdata')"))
assert conn2.execute(select([foo.c.data])).fetchall() == [('implicitdata',)]
-
+
conn1.close()
conn2.close()
-
-
+
+
+tlengine = None
class TLTransactionTest(TestBase):
def setUpAll(self):
global users, metadata, tlengine
@@ -502,7 +493,7 @@ class TLTransactionTest(TestBase):
finally:
external_connection.close()
- def testrollback(self):
+ def test_rollback(self):
"""test a basic rollback"""
tlengine.begin()
tlengine.execute(users.insert(), user_id=1, user_name='user1')
@@ -517,7 +508,7 @@ class TLTransactionTest(TestBase):
finally:
external_connection.close()
- def testcommit(self):
+ def test_commit(self):
"""test a basic commit"""
tlengine.begin()
tlengine.execute(users.insert(), user_id=1, user_name='user1')
@@ -532,7 +523,7 @@ class TLTransactionTest(TestBase):
finally:
external_connection.close()
- def testcommits(self):
+ def test_commits(self):
assert tlengine.connect().execute("select count(1) from query_users").scalar() == 0
connection = tlengine.contextual_connect()
@@ -551,7 +542,7 @@ class TLTransactionTest(TestBase):
assert len(l) == 3, "expected 3 got %d" % len(l)
transaction.commit()
- def testrollback_off_conn(self):
+ def test_rollback_off_conn(self):
# test that a TLTransaction opened off a TLConnection allows that
# TLConnection to be aware of the transactional context
conn = tlengine.contextual_connect()
@@ -568,7 +559,7 @@ class TLTransactionTest(TestBase):
finally:
external_connection.close()
- def testmorerollback_off_conn(self):
+ def test_morerollback_off_conn(self):
# test that an existing TLConnection automatically takes place in a TLTransaction
# opened on a second TLConnection
conn = tlengine.contextual_connect()
@@ -586,7 +577,7 @@ class TLTransactionTest(TestBase):
finally:
external_connection.close()
- def testcommit_off_conn(self):
+ def test_commit_off_connection(self):
conn = tlengine.contextual_connect()
trans = conn.begin()
conn.execute(users.insert(), user_id=1, user_name='user1')
@@ -603,7 +594,7 @@ class TLTransactionTest(TestBase):
@testing.unsupported('sqlite')
@testing.exclude('mysql', '<', (5, 0, 3))
- def testnesting(self):
+ def test_nesting(self):
"""tests nesting of transactions"""
external_connection = tlengine.connect()
self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
@@ -622,7 +613,7 @@ class TLTransactionTest(TestBase):
external_connection.close()
@testing.exclude('mysql', '<', (5, 0, 3))
- def testmixednesting(self):
+ def test_mixed_nesting(self):
"""tests nesting of transactions off the TLEngine directly inside of
tranasctions off the connection from the TLEngine"""
external_connection = tlengine.connect()
@@ -651,7 +642,7 @@ class TLTransactionTest(TestBase):
external_connection.close()
@testing.exclude('mysql', '<', (5, 0, 3))
- def testmoremixednesting(self):
+ def test_more_mixed_nesting(self):
"""tests nesting of transactions off the connection from the TLEngine
inside of tranasctions off thbe TLEngine directly."""
external_connection = tlengine.connect()
@@ -674,24 +665,9 @@ class TLTransactionTest(TestBase):
finally:
external_connection.close()
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testsessionnesting(self):
- class User(object):
- pass
- try:
- mapper(User, users)
-
- sess = create_session(bind=tlengine)
- tlengine.begin()
- u = User()
- sess.save(u)
- sess.flush()
- tlengine.commit()
- finally:
- clear_mappers()
- def testconnections(self):
+ def test_connections(self):
"""tests that contextual_connect is threadlocal"""
c1 = tlengine.contextual_connect()
c2 = tlengine.contextual_connect()
@@ -699,10 +675,8 @@ class TLTransactionTest(TestBase):
c2.close()
assert c1.connection.connection is not None
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testtwophasetransaction(self):
+ @testing.requires.two_phase_transactions
+ def test_two_phase_transaction(self):
tlengine.begin_twophase()
tlengine.execute(users.insert(), user_id=1, user_name='user1')
tlengine.prepare()
@@ -726,6 +700,7 @@ class TLTransactionTest(TestBase):
[(1,),(2,)]
)
+counters = None
class ForUpdateTest(TestBase):
def setUpAll(self):
global counters, metadata
@@ -770,7 +745,7 @@ class ForUpdateTest(TestBase):
@testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access')
- def testqueued_update(self):
+ def test_queued_update(self):
"""Test SELECT FOR UPDATE with concurrent modifications.
Runs concurrent modifications on a single row in the users table,
@@ -832,7 +807,7 @@ class ForUpdateTest(TestBase):
return errors
@testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access')
- def testqueued_select(self):
+ def test_queued_select(self):
"""Simple SELECT FOR UPDATE conflict test"""
errors = self._threaded_overlap(2, [(1,2,3),(3,4,5)])
@@ -842,7 +817,7 @@ class ForUpdateTest(TestBase):
@testing.unsupported('sqlite', 'mysql', 'mssql', 'firebird',
'sybase', 'access')
- def testnowait_select(self):
+ def test_nowait_select(self):
"""Simple SELECT FOR UPDATE NOWAIT conflict test"""
errors = self._threaded_overlap(2, [(1,2,3),(3,4,5)],
diff --git a/test/ext/activemapper.py b/test/ext/activemapper.py
deleted file mode 100644
index fa112c3b3..000000000
--- a/test/ext/activemapper.py
+++ /dev/null
@@ -1,357 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from datetime import datetime
-
-from sqlalchemy.ext.activemapper import ActiveMapper, column, one_to_many, one_to_one, many_to_many, objectstore
-from sqlalchemy import and_, or_, exceptions
-from sqlalchemy import ForeignKey, String, Integer, DateTime, Table, Column
-from sqlalchemy.orm import clear_mappers, backref, create_session, class_mapper
-import sqlalchemy.ext.activemapper as activemapper
-import sqlalchemy
-from testlib import *
-
-
-class testcase(TestBase):
- def setUpAll(self):
- clear_mappers()
- objectstore.clear()
- global Person, Preferences, Address
-
- class Person(ActiveMapper):
- class mapping:
- __version_id_col__ = 'row_version'
- full_name = column(String(128))
- first_name = column(String(128))
- middle_name = column(String(128))
- last_name = column(String(128))
- birth_date = column(DateTime)
- ssn = column(String(128))
- gender = column(String(128))
- home_phone = column(String(128))
- cell_phone = column(String(128))
- work_phone = column(String(128))
- row_version = column(Integer, default=0)
- prefs_id = column(Integer, foreign_key=ForeignKey('preferences.id'))
- addresses = one_to_many('Address', colname='person_id', backref='person', order_by=['state', 'city', 'postal_code'])
- preferences = one_to_one('Preferences', colname='pref_id', backref='person')
-
- def __str__(self):
- s = '%s\n' % self.full_name
- s += ' * birthdate: %s\n' % (self.birth_date or 'not provided')
- s += ' * fave color: %s\n' % (self.preferences.favorite_color or 'Unknown')
- s += ' * personality: %s\n' % (self.preferences.personality_type or 'Unknown')
-
- for address in self.addresses:
- s += ' * address: %s\n' % address.address_1
- s += ' %s, %s %s\n' % (address.city, address.state, address.postal_code)
-
- return s
-
- class Preferences(ActiveMapper):
- class mapping:
- __table__ = 'preferences'
- favorite_color = column(String(128))
- personality_type = column(String(128))
-
- class Address(ActiveMapper):
- class mapping:
- # note that in other objects, the 'id' primary key is
- # automatically added -- if you specify a primary key,
- # then ActiveMapper will not add an integer primary key
- # for you.
- id = column(Integer, primary_key=True)
- type = column(String(128))
- address_1 = column(String(128))
- city = column(String(128))
- state = column(String(128))
- postal_code = column(String(128))
- person_id = column(Integer, foreign_key=ForeignKey('person.id'))
-
- activemapper.metadata.bind = testing.db
- activemapper.create_tables()
-
- def tearDownAll(self):
- clear_mappers()
- activemapper.drop_tables()
-
- def tearDown(self):
- for t in activemapper.metadata.table_iterator(reverse=True):
- t.delete().execute()
-
- def create_person_one(self):
- # create a person
- p1 = Person(
- full_name='Jonathan LaCour',
- birth_date=datetime(1979, 10, 12),
- preferences=Preferences(
- favorite_color='Green',
- personality_type='ENTP'
- ),
- addresses=[
- Address(
- address_1='123 Some Great Road.',
- city='Atlanta',
- state='GA',
- postal_code='30338'
- ),
- Address(
- address_1='435 Franklin Road.',
- city='Atlanta',
- state='GA',
- postal_code='30342'
- )
- ]
- )
- return p1
-
-
- def create_person_two(self):
- p2 = Person(
- full_name='Lacey LaCour',
- addresses=[
- Address(
- address_1='123 Some Great Road.',
- city='Atlanta',
- state='GA',
- postal_code='30338'
- ),
- Address(
- address_1='200 Main Street',
- city='Roswell',
- state='GA',
- postal_code='30075'
- )
- ]
- )
- # I don't like that I have to do this... and putting
- # a "self.preferences = Preferences()" into the __init__
- # of Person also doens't seem to fix this
- p2.preferences = Preferences()
-
- return p2
-
-
- def test_create(self):
- p1 = self.create_person_one()
- objectstore.flush()
- objectstore.clear()
-
- results = Person.query.all()
-
- self.assertEquals(len(results), 1)
-
- person = results[0]
- self.assertEquals(person.id, p1.id)
- self.assertEquals(len(person.addresses), 2)
- self.assertEquals(person.addresses[0].postal_code, '30338')
-
- @testing.unsupported('mysql')
- def test_update(self):
- p1 = self.create_person_one()
- objectstore.flush()
- objectstore.clear()
-
- person = Person.query.first()
- person.gender = 'F'
- objectstore.flush()
- objectstore.clear()
- self.assertEquals(person.row_version, 2)
-
- person = Person.query.first()
- person.gender = 'M'
- objectstore.flush()
- objectstore.clear()
- self.assertEquals(person.row_version, 3)
-
- #TODO: check that a concurrent modification raises exception
- p1 = Person.query.first()
- s1 = objectstore()
- s2 = create_session()
- objectstore.registry.set(s2)
- p2 = Person.query.first()
- p1.first_name = "jack"
- p2.first_name = "ed"
- objectstore.flush()
- try:
- objectstore.registry.set(s1)
- objectstore.flush()
- # Only dialects with a sane rowcount can detect the ConcurrentModificationError
- if testing.db.dialect.supports_sane_rowcount:
- assert False
- except exceptions.ConcurrentModificationError:
- pass
-
-
- def test_delete(self):
- p1 = self.create_person_one()
-
- objectstore.flush()
- objectstore.clear()
-
- results = Person.query.all()
- self.assertEquals(len(results), 1)
-
- objectstore.delete(results[0])
- objectstore.flush()
- objectstore.clear()
-
- results = Person.query.all()
- self.assertEquals(len(results), 0)
-
-
- def test_multiple(self):
- p1 = self.create_person_one()
- p2 = self.create_person_two()
-
- objectstore.flush()
- objectstore.clear()
-
- # select and make sure we get back two results
- people = Person.query.all()
- self.assertEquals(len(people), 2)
-
- # make sure that our backwards relationships work
- self.assertEquals(people[0].addresses[0].person.id, p1.id)
- self.assertEquals(people[1].addresses[0].person.id, p2.id)
-
- # try a more complex select
- results = Person.query.filter(
- or_(
- and_(
- Address.c.person_id == Person.c.id,
- Address.c.postal_code.like('30075')
- ),
- and_(
- Person.c.prefs_id == Preferences.c.id,
- Preferences.c.favorite_color == 'Green'
- )
- )
- ).all()
- self.assertEquals(len(results), 2)
-
-
- def test_oneway_backref(self):
- # FIXME: I don't know why, but it seems that my backwards relationship
- # on preferences still ends up being a list even though I pass
- # in uselist=False...
- # FIXED: the backref is a new PropertyLoader which needs its own "uselist".
- # uses a function which I dont think existed when you first wrote ActiveMapper.
- p1 = self.create_person_one()
- self.assertEquals(p1.preferences.person, p1)
- objectstore.flush()
- objectstore.delete(p1)
-
- objectstore.flush()
- objectstore.clear()
-
-
- def test_select_by(self):
- # FIXME: either I don't understand select_by, or it doesn't work.
- # FIXED (as good as we can for now): yup....everyone thinks it works that way....it only
- # generates joins for keyword arguments, not ColumnClause args. would need a new layer of
- # "MapperClause" objects to use properties in expressions. (MB)
-
- p1 = self.create_person_one()
- p2 = self.create_person_two()
-
- objectstore.flush()
- objectstore.clear()
-
- results = Person.query.join('addresses').filter(
- Address.c.postal_code.like('30075')
- ).all()
- self.assertEquals(len(results), 1)
-
- self.assertEquals(Person.query.count(), 2)
-
-class testmanytomany(TestBase):
- def setUpAll(self):
- clear_mappers()
- objectstore.clear()
- global secondarytable, foo, baz
- secondarytable = Table("secondarytable",
- activemapper.metadata,
- Column("foo_id", Integer, ForeignKey("foo.id"),primary_key=True),
- Column("baz_id", Integer, ForeignKey("baz.id"),primary_key=True))
-
- class foo(activemapper.ActiveMapper):
- class mapping:
- name = column(String(30))
-# bazrel = many_to_many('baz', secondarytable, backref='foorel')
-
- class baz(activemapper.ActiveMapper):
- class mapping:
- name = column(String(30))
- foorel = many_to_many("foo", secondarytable, backref='bazrel')
-
- activemapper.metadata.bind = testing.db
- activemapper.create_tables()
-
- # Create a couple of activemapper objects
- def create_objects(self):
- return foo(name='foo1'), baz(name='baz1')
-
- def tearDownAll(self):
- clear_mappers()
- activemapper.drop_tables()
- objectstore.clear()
- def testbasic(self):
- # Set up activemapper objects
- foo1, baz1 = self.create_objects()
-
- objectstore.flush()
- objectstore.clear()
-
- foo1 = foo.query.filter_by(name='foo1').one()
- baz1 = baz.query.filter_by(name='baz1').one()
-
- # Just checking ...
- assert (foo1.name == 'foo1')
- assert (baz1.name == 'baz1')
-
- # Diagnostics ...
- # import sys
- # sys.stderr.write("\nbazrel missing from dir(foo1):\n%s\n" % dir(foo1))
- # sys.stderr.write("\nbazrel in foo1 relations:\n%s\n" % foo1.relations)
-
- # Optimistically based on activemapper one_to_many test, try to append
- # baz1 to foo1.bazrel - (AttributeError: 'foo' object has no attribute 'bazrel')
- foo1.bazrel.append(baz1)
- assert (foo1.bazrel == [baz1])
-
-class testselfreferential(TestBase):
- def setUpAll(self):
- clear_mappers()
- objectstore.clear()
- global TreeNode
- class TreeNode(activemapper.ActiveMapper):
- class mapping:
- id = column(Integer, primary_key=True)
- name = column(String(30))
- parent_id = column(Integer, foreign_key=ForeignKey('treenode.id'))
- children = one_to_many('TreeNode', colname='id', backref='parent')
-
- activemapper.metadata.bind = testing.db
- activemapper.create_tables()
- def tearDownAll(self):
- clear_mappers()
- activemapper.drop_tables()
-
- def testbasic(self):
- t = TreeNode(name='node1')
- t.children.append(TreeNode(name='node2'))
- t.children.append(TreeNode(name='node3'))
- objectstore.flush()
- objectstore.clear()
-
- t = TreeNode.query.filter_by(name='node1').one()
- assert (t.name == 'node1')
- assert (t.children[0].name == 'node2')
- assert (t.children[1].name == 'node3')
- assert (t.children[1].parent is t)
-
- objectstore.clear()
- t = TreeNode.query.filter_by(name='node3').one()
- assert (t.parent is TreeNode.query.filter_by(name='node1').one())
-
-if __name__ == '__main__':
- testenv.main()
diff --git a/test/ext/alltests.py b/test/ext/alltests.py
index d5db4d01e..1b6dc53d2 100644
--- a/test/ext/alltests.py
+++ b/test/ext/alltests.py
@@ -2,8 +2,7 @@ import testenv; testenv.configure_for_tests()
import doctest, sys, unittest
def suite():
- unittest_modules = ['ext.activemapper',
- 'ext.assignmapper',
+ unittest_modules = [
'ext.declarative',
'ext.orderinglist',
'ext.associationproxy']
diff --git a/test/ext/assignmapper.py b/test/ext/assignmapper.py
deleted file mode 100644
index 1cb2ca375..000000000
--- a/test/ext/assignmapper.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import exceptions
-from sqlalchemy.orm import create_session, clear_mappers, relation, class_mapper
-from sqlalchemy.ext.assignmapper import assign_mapper
-from sqlalchemy.ext.sessioncontext import SessionContext
-from testlib import *
-
-
-class AssignMapperTest(TestBase):
- def setUpAll(self):
- global metadata, table, table2
- metadata = MetaData(testing.db)
- table = Table('sometable', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(30)))
- table2 = Table('someothertable', metadata,
- Column('id', Integer, primary_key=True),
- Column('someid', None, ForeignKey('sometable.id'))
- )
- metadata.create_all()
-
- @testing.uses_deprecated('SessionContext', 'assign_mapper')
- def setUp(self):
- global SomeObject, SomeOtherObject, ctx
- class SomeObject(object):pass
- class SomeOtherObject(object):pass
-
- ctx = SessionContext(create_session)
- assign_mapper(ctx, SomeObject, table, properties={
- 'options':relation(SomeOtherObject)
- })
- assign_mapper(ctx, SomeOtherObject, table2)
-
- s = SomeObject()
- s.id = 1
- s.data = 'hello'
- sso = SomeOtherObject()
- s.options.append(sso)
- ctx.current.flush()
- ctx.current.clear()
-
- def tearDownAll(self):
- metadata.drop_all()
-
- def tearDown(self):
- for table in metadata.table_iterator(reverse=True):
- table.delete().execute()
- clear_mappers()
-
- @testing.uses_deprecated('assign_mapper')
- def test_override_attributes(self):
-
- sso = SomeOtherObject.query().first()
-
- assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
-
- s2 = SomeObject(someid=12)
- s3 = SomeOtherObject(someid=123, bogus=345)
-
- class ValidatedOtherObject(object):pass
- assign_mapper(ctx, ValidatedOtherObject, table2, validate=True)
-
- v1 = ValidatedOtherObject(someid=12)
- try:
- v2 = ValidatedOtherObject(someid=12, bogus=345)
- assert False
- except exceptions.ArgumentError:
- pass
-
- @testing.uses_deprecated('assign_mapper')
- def test_dont_clobber_methods(self):
- class MyClass(object):
- def expunge(self):
- return "an expunge !"
-
- assign_mapper(ctx, MyClass, table2)
-
- assert MyClass().expunge() == "an expunge !"
-
-
-if __name__ == '__main__':
- testenv.main()
diff --git a/test/ext/declarative.py b/test/ext/declarative.py
index ab07627dd..4c4f9b012 100644
--- a/test/ext/declarative.py
+++ b/test/ext/declarative.py
@@ -5,7 +5,7 @@ from sqlalchemy.orm import *
from sqlalchemy.orm.interfaces import MapperExtension
from sqlalchemy.ext.declarative import declarative_base, declared_synonym, \
synonym_for, comparable_using
-from sqlalchemy import exceptions
+from sqlalchemy import exc
from testlib.fixtures import Base as Fixture
from testlib import *
@@ -94,7 +94,7 @@ class DeclarativeTest(TestBase, AssertsExecutionResults):
id = Column(Integer, primary_key=True)
foo = column_property(User.id==5)
- self.assertRaises(exceptions.InvalidRequestError, go)
+ self.assertRaises(exc.InvalidRequestError, go)
def test_add_prop(self):
class User(Base, Fixture):
@@ -183,7 +183,7 @@ class DeclarativeTest(TestBase, AssertsExecutionResults):
name = Column('name', String(50))
assert False
self.assertRaisesMessage(
- exceptions.ArgumentError,
+ exc.ArgumentError,
"Mapper Mapper|User|users could not assemble any primary key",
define)
diff --git a/test/orm/alltests.py b/test/orm/alltests.py
index 73406c00d..77745aea1 100644
--- a/test/orm/alltests.py
+++ b/test/orm/alltests.py
@@ -6,7 +6,9 @@ import sharding.alltests as sharding
def suite():
modules_to_test = (
- 'orm.attributes',
+ 'orm.attributes',
+ 'orm.extendedattr',
+ 'orm.instrumentation',
'orm.query',
'orm.lazy_relations',
'orm.eager_relations',
@@ -19,15 +21,17 @@ def suite():
'orm.assorted_eager',
'orm.naturalpks',
- 'orm.sessioncontext',
'orm.unitofwork',
'orm.session',
+ 'orm.transaction',
+ 'orm.scoping',
'orm.cascade',
'orm.relationships',
'orm.association',
'orm.merge',
'orm.pickled',
'orm.memusage',
+ 'orm.utils',
'orm.cycles',
@@ -36,6 +40,8 @@ def suite():
'orm.manytomany',
'orm.onetoone',
'orm.dynamic',
+
+ 'orm.deprecations',
)
alltests = unittest.TestSuite()
for name in modules_to_test:
diff --git a/test/orm/association.py b/test/orm/association.py
index 65d702538..1115849d2 100644
--- a/test/orm/association.py
+++ b/test/orm/association.py
@@ -5,7 +5,6 @@ from sqlalchemy.orm import *
from testlib import *
class AssociationTest(TestBase):
- @testing.uses_deprecated('association option')
def setUpAll(self):
global items, item_keywords, keywords, metadata, Item, Keyword, KeywordAssociation
metadata = MetaData(testing.db)
@@ -46,7 +45,7 @@ class AssociationTest(TestBase):
'keyword':relation(Keyword, lazy=False)
}, primary_key=[item_keywords.c.item_id, item_keywords.c.keyword_id], order_by=[item_keywords.c.data])
mapper(Item, items, properties={
- 'keywords' : relation(KeywordAssociation, association=Keyword)
+ 'keywords' : relation(KeywordAssociation, cascade="all, delete-orphan")
})
def tearDown(self):
@@ -123,7 +122,6 @@ class AssociationTest(TestBase):
print loaded
self.assert_(saved == loaded)
- @testing.uses_deprecated('association option')
def testdelete(self):
sess = create_session()
item1 = Item('item1')
@@ -185,7 +183,7 @@ in self.c ]
mapper(Originals, table_originals, order_by=Originals.order,
properties={
- 'people': relation(IsAuthor, association=People),
+ 'people': relation(IsAuthor, cascade="all, delete-orphan"),
'authors': relation(People, secondary=table_isauthor, backref='written',
primaryjoin=and_(table_originals.c.ID==table_isauthor.c.OriginalsID,
table_isauthor.c.Kind=='A')),
@@ -193,7 +191,7 @@ in self.c ]
'date': table_originals.c.Date,
})
mapper(People, table_people, order_by=People.order, properties= {
- 'originals': relation(IsAuthor, association=Originals),
+ 'originals': relation(IsAuthor, cascade="all, delete-orphan"),
'name': table_people.c.Name,
'country': table_people.c.Country,
})
diff --git a/test/orm/assorted_eager.py b/test/orm/assorted_eager.py
index af3fcbc7b..731a9f916 100644
--- a/test/orm/assorted_eager.py
+++ b/test/orm/assorted_eager.py
@@ -4,7 +4,6 @@ import testenv; testenv.configure_for_tests()
import random, datetime
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
from testlib import *
from testlib import fixtures
@@ -125,15 +124,6 @@ class EagerTest(TestBase, AssertsExecutionResults):
print result
assert result == [u'1 Some Category', u'3 Some Category']
- @testing.uses_deprecated('//select')
- def test_withouteagerload_deprecated(self):
- s = create_session()
- l=s.query(Test).select ( and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
- from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
- result = ["%d %s" % ( t.id,t.category.name ) for t in l]
- print result
- assert result == [u'1 Some Category', u'3 Some Category']
-
def test_witheagerload(self):
"""test that an eagerload locates the correct "from" clause with
which to attach to, when presented with a query that already has a complicated from clause."""
@@ -152,17 +142,6 @@ class EagerTest(TestBase, AssertsExecutionResults):
print result
assert result == [u'1 Some Category', u'3 Some Category']
- @testing.uses_deprecated('//select')
- def test_witheagerload_deprecated(self):
- """As test_witheagerload, but via select()."""
- s = create_session()
- q=s.query(Test).options(eagerload('category'))
- l=q.select ( and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
- from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
- result = ["%d %s" % ( t.id,t.category.name ) for t in l]
- print result
- assert result == [u'1 Some Category', u'3 Some Category']
-
def test_dslish(self):
"""test the same as witheagerload except using generative"""
s = create_session()
@@ -188,16 +167,6 @@ class EagerTest(TestBase, AssertsExecutionResults):
print result
assert result == [u'3 Some Category']
- @testing.unsupported('sybase')
- @testing.uses_deprecated('//select', '//join_to')
- def test_withoutouterjoin_literal_deprecated(self):
- s = create_session()
- q=s.query(Test).options(eagerload('category'))
- l=q.select( (tests.c.owner_id==1) & ('options.someoption is null or options.someoption=%s' % false) & q.join_to('owner_option') )
- result = ["%d %s" % ( t.id,t.category.name ) for t in l]
- print result
- assert result == [u'3 Some Category']
-
def test_withoutouterjoin(self):
s = create_session()
q=s.query(Test).options(eagerload('category'))
@@ -206,15 +175,6 @@ class EagerTest(TestBase, AssertsExecutionResults):
print result
assert result == [u'3 Some Category']
- @testing.uses_deprecated('//select', '//join_to', '//join_via')
- def test_withoutouterjoin_deprecated(self):
- s = create_session()
- q=s.query(Test).options(eagerload('category'))
- l=q.select( (tests.c.owner_id==1) & ((options.c.someoption==None) | (options.c.someoption==False)) & q.join_to('owner_option') )
- result = ["%d %s" % ( t.id,t.category.name ) for t in l]
- print result
- assert result == [u'3 Some Category']
-
class EagerTest2(TestBase, AssertsExecutionResults):
def setUpAll(self):
global metadata, middle, left, right
@@ -389,7 +349,7 @@ class EagerTest4(ORMTest):
sess.flush()
q = sess.query(Department)
- q = q.join('employees').filter(Employee.c.name.startswith('J')).distinct().order_by([desc(Department.c.name)])
+ q = q.join('employees').filter(Employee.name.startswith('J')).distinct().order_by([desc(Department.name)])
assert q.count() == 2
assert q[0] is d2
@@ -543,12 +503,11 @@ class EagerTest6(ORMTest):
x.inheritedParts
class EagerTest7(ORMTest):
- @testing.uses_deprecated('SessionContext')
def define_tables(self, metadata):
global companies_table, addresses_table, invoice_table, phones_table, items_table, ctx
global Company, Address, Phone, Item,Invoice
- ctx = SessionContext(create_session)
+ ctx = scoped_session(create_session)
companies_table = Table('companies', metadata,
Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key = True),
@@ -606,20 +565,19 @@ class EagerTest7(ORMTest):
def __repr__(self):
return "Item: " + repr(getattr(self, 'item_id', None)) + " " + repr(getattr(self, 'invoice_id', None)) + " " + repr(self.code) + " " + repr(self.qty)
- @testing.uses_deprecated('SessionContext')
def testone(self):
"""tests eager load of a many-to-one attached to a one-to-many. this testcase illustrated
the bug, which is that when the single Company is loaded, no further processing of the rows
occurred in order to load the Company's second Address object."""
mapper(Address, addresses_table, properties={
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
mapper(Company, companies_table, properties={
'addresses' : relation(Address, lazy=False),
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
mapper(Invoice, invoice_table, properties={
'company': relation(Company, lazy=False, )
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
c1 = Company()
c1.company_name = 'company 1'
@@ -633,18 +591,18 @@ class EagerTest7(ORMTest):
i1.date = datetime.datetime.now()
i1.company = c1
- ctx.current.flush()
+ ctx.flush()
company_id = c1.company_id
invoice_id = i1.invoice_id
- ctx.current.clear()
+ ctx.clear()
- c = ctx.current.query(Company).get(company_id)
+ c = ctx.query(Company).get(company_id)
- ctx.current.clear()
+ ctx.clear()
- i = ctx.current.query(Invoice).get(invoice_id)
+ i = ctx.query(Invoice).get(invoice_id)
print repr(c)
print repr(i.company)
@@ -653,24 +611,24 @@ class EagerTest7(ORMTest):
def testtwo(self):
"""this is the original testcase that includes various complicating factors"""
- mapper(Phone, phones_table, extension=ctx.mapper_extension)
+ mapper(Phone, phones_table, extension=ctx.extension)
mapper(Address, addresses_table, properties={
'phones': relation(Phone, lazy=False, backref='address')
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
mapper(Company, companies_table, properties={
'addresses' : relation(Address, lazy=False, backref='company'),
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
- mapper(Item, items_table, extension=ctx.mapper_extension)
+ mapper(Item, items_table, extension=ctx.extension)
mapper(Invoice, invoice_table, properties={
'items': relation(Item, lazy=False, backref='invoice'),
'company': relation(Company, lazy=False, backref='invoices')
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
- ctx.current.clear()
+ ctx.clear()
c1 = Company()
c1.company_name = 'company 1'
@@ -705,13 +663,13 @@ class EagerTest7(ORMTest):
c1.addresses.append(a2)
- ctx.current.flush()
+ ctx.flush()
company_id = c1.company_id
- ctx.current.clear()
+ ctx.clear()
- a = ctx.current.query(Company).get(company_id)
+ a = ctx.query(Company).get(company_id)
print repr(a)
# set up an invoice
@@ -734,18 +692,18 @@ class EagerTest7(ORMTest):
item3.qty = 3
item3.invoice = i1
- ctx.current.flush()
+ ctx.flush()
invoice_id = i1.invoice_id
- ctx.current.clear()
+ ctx.clear()
- c = ctx.current.query(Company).get(company_id)
+ c = ctx.query(Company).get(company_id)
print repr(c)
- ctx.current.clear()
+ ctx.clear()
- i = ctx.current.query(Invoice).get(invoice_id)
+ i = ctx.query(Invoice).get(invoice_id)
assert repr(i.company) == repr(c), repr(i.company) + " does not match " + repr(c)
diff --git a/test/orm/attributes.py b/test/orm/attributes.py
index caa129e5e..3883cdcd1 100644
--- a/test/orm/attributes.py
+++ b/test/orm/attributes.py
@@ -2,18 +2,24 @@ import testenv; testenv.configure_for_tests()
import pickle
import sqlalchemy.orm.attributes as attributes
from sqlalchemy.orm.collections import collection
-from sqlalchemy import exceptions
+from sqlalchemy.orm.interfaces import AttributeExtension
+from sqlalchemy import exc as sa_exc
from testlib import *
from testlib import fixtures
-ROLLBACK_SUPPORTED=False
-
-# these test classes defined at the module
-# level to support pickling
-class MyTest(object):pass
-class MyTest2(object):pass
+# global for pickling tests
+MyTest = None
+MyTest2 = None
class AttributesTest(TestBase):
+ def setUp(self):
+ global MyTest, MyTest2
+ class MyTest(object): pass
+ class MyTest2(object): pass
+
+ def tearDown(self):
+ global MyTest, MyTest2
+ MyTest, MyTest2 = None, None
def test_basic(self):
class User(object):pass
@@ -29,7 +35,7 @@ class AttributesTest(TestBase):
u.email_address = 'lala@123.com'
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
- u._state.commit_all()
+ attributes.instance_state(u).commit_all()
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
u.user_name = 'heythere'
@@ -99,31 +105,33 @@ class AttributesTest(TestBase):
class Foo(object):pass
data = {'a':'this is a', 'b':12}
- def loader(instance, keys):
+ def loader(state, keys):
for k in keys:
- instance.__dict__[k] = data[k]
+ state.dict[k] = data[k]
return attributes.ATTR_WAS_SET
- attributes.register_class(Foo, deferred_scalar_loader=loader)
+ attributes.register_class(Foo)
+ manager = attributes.manager_of_class(Foo)
+ manager.deferred_scalar_loader = loader
attributes.register_attribute(Foo, 'a', uselist=False, useobject=False)
attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
f = Foo()
- f._state.expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(None)
self.assertEquals(f.a, "this is a")
self.assertEquals(f.b, 12)
f.a = "this is some new a"
- f._state.expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(None)
self.assertEquals(f.a, "this is a")
self.assertEquals(f.b, 12)
- f._state.expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(None)
f.a = "this is another new a"
self.assertEquals(f.a, "this is another new a")
self.assertEquals(f.b, 12)
- f._state.expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(None)
self.assertEquals(f.a, "this is a")
self.assertEquals(f.b, 12)
@@ -131,23 +139,25 @@ class AttributesTest(TestBase):
self.assertEquals(f.a, None)
self.assertEquals(f.b, 12)
- f._state.commit_all()
+ attributes.instance_state(f).commit_all()
self.assertEquals(f.a, None)
self.assertEquals(f.b, 12)
def test_deferred_pickleable(self):
data = {'a':'this is a', 'b':12}
- def loader(instance, keys):
+ def loader(state, keys):
for k in keys:
- instance.__dict__[k] = data[k]
+ state.dict[k] = data[k]
return attributes.ATTR_WAS_SET
- attributes.register_class(MyTest, deferred_scalar_loader=loader)
+ attributes.register_class(MyTest)
+ manager = attributes.manager_of_class(MyTest)
+ manager.deferred_scalar_loader=loader
attributes.register_attribute(MyTest, 'a', uselist=False, useobject=False)
attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False)
m = MyTest()
- m._state.expire_attributes(None)
+ attributes.instance_state(m).expire_attributes(None)
assert 'a' not in m.__dict__
m2 = pickle.loads(pickle.dumps(m))
assert 'a' not in m2.__dict__
@@ -176,7 +186,7 @@ class AttributesTest(TestBase):
u.addresses.append(a)
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
- u, a._state.commit_all()
+ u, attributes.instance_state(a).commit_all()
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
u.user_name = 'heythere'
@@ -186,6 +196,45 @@ class AttributesTest(TestBase):
u.addresses.append(a)
self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com')
+ def test_scalar_listener(self):
+ # listeners on ScalarAttributeImpl and MutableScalarAttributeImpl aren't used normally.
+ # test that they work for the benefit of user extensions
+ class Foo(object):
+ pass
+
+ results = []
+ class ReceiveEvents(AttributeExtension):
+ def append(self, state, child, initiator):
+ assert False
+
+ def remove(self, state, child, initiator):
+ results.append(("remove", state.obj(), child))
+
+ def set(self, state, child, oldchild, initiator):
+ results.append(("set", state.obj(), child, oldchild))
+
+ attributes.register_class(Foo)
+ attributes.register_attribute(Foo, 'x', uselist=False, mutable_scalars=False, useobject=False, extension=ReceiveEvents())
+ attributes.register_attribute(Foo, 'y', uselist=False, mutable_scalars=True, useobject=False, copy_function=lambda x:x, extension=ReceiveEvents())
+
+ f = Foo()
+ f.x = 5
+ f.x = 17
+ del f.x
+ f.y = [1,2,3]
+ f.y = [4,5,6]
+ del f.y
+
+ self.assertEquals(results, [
+ ('set', f, 5, None),
+ ('set', f, 17, 5),
+ ('remove', f, 17),
+ ('set', f, [1,2,3], None),
+ ('set', f, [4,5,6], [1,2,3]),
+ ('remove', f, [4,5,6])
+ ])
+
+
def test_lazytrackparent(self):
"""test that the "hasparent" flag works properly when lazy loaders and backrefs are used"""
@@ -201,9 +250,9 @@ class AttributesTest(TestBase):
# create objects as if they'd been freshly loaded from the database (without history)
b = Blog()
p1 = Post()
- b._state.set_callable('posts', lambda:[p1])
- p1._state.set_callable('blog', lambda:b)
- p1, b._state.commit_all()
+ attributes.instance_state(b).set_callable('posts', lambda:[p1])
+ attributes.instance_state(p1).set_callable('blog', lambda:b)
+ p1, attributes.instance_state(b).commit_all()
# no orphans (called before the lazy loaders fire off)
assert attributes.has_parent(Blog, p1, 'posts', optimistic=True)
@@ -253,10 +302,10 @@ class AttributesTest(TestBase):
states = set()
class Foo(object):
def __init__(self):
- states.add(self._state)
+ states.add(attributes.instance_state(self))
class Bar(Foo):
def __init__(self):
- states.add(self._state)
+ states.add(attributes.instance_state(self))
Foo.__init__(self)
@@ -283,10 +332,10 @@ class AttributesTest(TestBase):
el = Element()
x = Bar()
x.element = el
- self.assertEquals(attributes.get_history(x._state, 'element'), ([el],[], []))
- x._state.commit_all()
+ self.assertEquals(attributes.get_history(attributes.instance_state(x), 'element'), ([el],[], []))
+ attributes.instance_state(x).commit_all()
- (added, unchanged, deleted) = attributes.get_history(x._state, 'element')
+ (added, unchanged, deleted) = attributes.get_history(attributes.instance_state(x), 'element')
assert added == []
assert unchanged == [el]
@@ -312,9 +361,9 @@ class AttributesTest(TestBase):
attributes.register_attribute(Bar, 'id', uselist=False, useobject=True)
x = Foo()
- x._state.commit_all()
+ attributes.instance_state(x).commit_all()
x.col2.append(bar4)
- self.assertEquals(attributes.get_history(x._state, 'col2'), ([bar4], [bar1, bar2, bar3], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(x), 'col2'), ([bar4], [bar1, bar2, bar3], []))
def test_parenttrack(self):
class Foo(object):pass
@@ -358,9 +407,9 @@ class AttributesTest(TestBase):
attributes.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True, useobject=False)
x = Foo()
x.element = ['one', 'two', 'three']
- x._state.commit_all()
+ attributes.instance_state(x).commit_all()
x.element[1] = 'five'
- assert x._state.is_modified()
+ assert attributes.instance_state(x).check_modified()
attributes.unregister_class(Foo)
@@ -368,9 +417,9 @@ class AttributesTest(TestBase):
attributes.register_attribute(Foo, 'element', uselist=False, useobject=False)
x = Foo()
x.element = ['one', 'two', 'three']
- x._state.commit_all()
+ attributes.instance_state(x).commit_all()
x.element[1] = 'five'
- assert not x._state.is_modified()
+ assert not attributes.instance_state(x).check_modified()
def test_descriptorattributes(self):
"""changeset: 1633 broke ability to use ORM to map classes with unusual
@@ -379,27 +428,31 @@ class AttributesTest(TestBase):
This is a simple regression test to prevent that defect.
"""
class des(object):
- def __get__(self, instance, owner): raise AttributeError('fake attribute')
+ def __get__(self, instance, owner):
+ raise AttributeError('fake attribute')
class Foo(object):
A = des()
-
+ attributes.register_class(Foo)
attributes.unregister_class(Foo)
def test_collectionclasses(self):
class Foo(object):pass
attributes.register_class(Foo)
+
attributes.register_attribute(Foo, "collection", uselist=True, typecallable=set, useobject=True)
+ assert attributes.manager_of_class(Foo).is_instrumented("collection")
assert isinstance(Foo().collection, set)
attributes.unregister_attribute(Foo, "collection")
-
+ assert not attributes.manager_of_class(Foo).is_instrumented("collection")
+
try:
attributes.register_attribute(Foo, "collection", uselist=True, typecallable=dict, useobject=True)
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e) == "Type InstrumentedDict must elect an appender method to be a collection class"
class MyDict(dict):
@@ -418,7 +471,7 @@ class AttributesTest(TestBase):
try:
attributes.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl, useobject=True)
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e) == "Type MyColl must elect an appender method to be a collection class"
class MyColl(object):
@@ -435,7 +488,7 @@ class AttributesTest(TestBase):
try:
Foo().collection
assert True
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert False
@@ -512,7 +565,7 @@ class BackrefTest(TestBase):
j.port = None
self.assert_(p.jack is None)
-class DeferredBackrefTest(TestBase):
+class PendingBackrefTest(TestBase):
def setUp(self):
global Post, Blog, called, lazy_load
@@ -550,6 +603,7 @@ class DeferredBackrefTest(TestBase):
b = Blog("blog 1")
p = Post("post 4")
+
p.blog = b
p = Post("post 5")
p.blog = b
@@ -559,6 +613,22 @@ class DeferredBackrefTest(TestBase):
# calling backref calls the callable, populates extra posts
assert b.posts == [p1, p2, p3, Post("post 4"), Post("post 5")]
assert called[0] == 1
+
+ def test_lazy_history(self):
+ global lazy_load
+
+ p1, p2, p3 = Post("post 1"), Post("post 2"), Post("post 3")
+ lazy_load = [p1, p2, p3]
+
+ b = Blog("blog 1")
+ p = Post("post 4")
+ p.blog = b
+
+ p4 = Post("post 5")
+ p4.blog = b
+ assert called[0] == 0
+ self.assertEquals(attributes.instance_state(b).get_history('posts'), ([p, p4], [p1, p2, p3], []))
+ assert called[0] == 1
def test_lazy_remove(self):
global lazy_load
@@ -609,17 +679,17 @@ class HistoryTest(TestBase):
attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False)
f = Foo()
- self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), None)
+ self.assertEquals(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
f.someattr = 3
- self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), None)
+ self.assertEquals(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
f = Foo()
f.someattr = 3
- self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), None)
+ self.assertEquals(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
- f._state.commit(['someattr'])
- self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), 3)
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), 3)
def test_scalar(self):
class Foo(fixtures.Base):
@@ -630,48 +700,59 @@ class HistoryTest(TestBase):
# case 1. new object
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
f.someattr = "hi"
- self.assertEquals(attributes.get_history(f._state, 'someattr'), (['hi'], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['hi'], [], []))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['hi'], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['hi'], []))
f.someattr = 'there'
- self.assertEquals(attributes.get_history(f._state, 'someattr'), (['there'], [], ['hi']))
- f._state.commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['there'], [], ['hi']))
+ attributes.instance_state(f).commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['there'], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['there'], []))
del f.someattr
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], ['there']))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], ['there']))
# case 2. object with direct dictionary settings (similar to a load operation)
f = Foo()
f.__dict__['someattr'] = 'new'
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['new'], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['new'], []))
f.someattr = 'old'
- self.assertEquals(attributes.get_history(f._state, 'someattr'), (['old'], [], ['new']))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['old'], [], ['new']))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['old'], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['old'], []))
# setting None on uninitialized is currently a change for a scalar attribute
# no lazyload occurs so this allows overwrite operation to proceed
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
+ print f._foostate.committed_state
f.someattr = None
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], []))
+ print f._foostate.committed_state, f._foostate.dict
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], [], []))
f = Foo()
f.__dict__['someattr'] = 'new'
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['new'], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['new'], []))
f.someattr = None
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], ['new']))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], [], ['new']))
+ # set same value twice
+ f = Foo()
+ attributes.instance_state(f).commit(['someattr'])
+ f.someattr = 'one'
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], [], []))
+ f.someattr = 'two'
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['two'], [], []))
+
+
def test_mutable_scalar(self):
class Foo(fixtures.Base):
pass
@@ -681,33 +762,33 @@ class HistoryTest(TestBase):
# case 1. new object
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
f.someattr = {'foo':'hi'}
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'hi'}], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'hi'}], [], []))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'hi'}], []))
- self.assertEquals(f._state.committed_state['someattr'], {'foo':'hi'})
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [{'foo':'hi'}], []))
+ self.assertEquals(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'})
f.someattr['foo'] = 'there'
- self.assertEquals(f._state.committed_state['someattr'], {'foo':'hi'})
+ self.assertEquals(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'})
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'there'}], [], [{'foo':'hi'}]))
- f._state.commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'there'}], [], [{'foo':'hi'}]))
+ attributes.instance_state(f).commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'there'}], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [{'foo':'there'}], []))
# case 2. object with direct dictionary settings (similar to a load operation)
f = Foo()
f.__dict__['someattr'] = {'foo':'new'}
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'new'}], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [{'foo':'new'}], []))
f.someattr = {'foo':'old'}
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'old'}], [], [{'foo':'new'}]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'old'}], [], [{'foo':'new'}]))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'old'}], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [{'foo':'old'}], []))
def test_use_object(self):
@@ -729,48 +810,56 @@ class HistoryTest(TestBase):
# case 1. new object
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [None], []))
f.someattr = hi
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], []))
f.someattr = there
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [], [hi]))
- f._state.commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [hi]))
+ attributes.instance_state(f).commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [there], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [there], []))
del f.someattr
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], [there]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], [], [there]))
# case 2. object with direct dictionary settings (similar to a load operation)
f = Foo()
- f.__dict__['someattr'] = new
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+ f.__dict__['someattr'] = 'new'
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['new'], []))
f.someattr = old
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [], [new]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [], ['new']))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [old], []))
# setting None on uninitialized is currently not a change for an object attribute
# (this is different than scalar attribute). a lazyload has occured so if its
# None, its really None
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [None], []))
f.someattr = None
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [None], []))
f = Foo()
- f.__dict__['someattr'] = new
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+ f.__dict__['someattr'] = 'new'
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['new'], []))
f.someattr = None
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], [new]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], [], ['new']))
+
+ # set same value twice
+ f = Foo()
+ attributes.instance_state(f).commit(['someattr'])
+ f.someattr = 'one'
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], [], []))
+ f.someattr = 'two'
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['two'], [], []))
def test_object_collections_set(self):
class Foo(fixtures.Base):
@@ -789,39 +878,39 @@ class HistoryTest(TestBase):
# case 1. new object
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
f.someattr = [hi]
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], []))
f.someattr = [there]
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [], [hi]))
- f._state.commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [hi]))
+ attributes.instance_state(f).commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [there], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [there], []))
f.someattr = [hi]
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], [there]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [there]))
f.someattr = [old, new]
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old, new], [], [there]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [], [there]))
# case 2. object with direct settings (similar to a load operation)
f = Foo()
- collection = attributes.init_collection(f, 'someattr')
+ collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
collection.append_without_event(new)
- f._state.commit_all()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+ attributes.instance_state(f).commit_all()
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [new], []))
f.someattr = [old]
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [], [new]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [], [new]))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [old], []))
def test_dict_collections(self):
class Foo(fixtures.Base):
@@ -840,16 +929,16 @@ class HistoryTest(TestBase):
new = Bar(name='new')
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
f.someattr['hi'] = hi
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
f.someattr['there'] = there
- self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([hi, there]), set([]), set([])))
+ self.assertEquals(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set([hi, there]), set([]), set([])))
- f._state.commit(['someattr'])
- self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([]), set([hi, there]), set([])))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set([]), set([hi, there]), set([])))
def test_object_collections_mutate(self):
class Foo(fixtures.Base):
@@ -868,65 +957,65 @@ class HistoryTest(TestBase):
# case 1. new object
f = Foo(id=1)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
f.someattr.append(hi)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], []))
f.someattr.append(there)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [hi], []))
- f._state.commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [hi], []))
+ attributes.instance_state(f).commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi, there], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi, there], []))
f.someattr.remove(there)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], [there]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], [there]))
f.someattr.append(old)
f.someattr.append(new)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old, new], [hi], [there]))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi, old, new], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [hi], [there]))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi, old, new], []))
f.someattr.pop(0)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old, new], [hi]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [old, new], [hi]))
# case 2. object with direct settings (similar to a load operation)
f = Foo()
f.__dict__['id'] = 1
- collection = attributes.init_collection(f, 'someattr')
+ collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
collection.append_without_event(new)
- f._state.commit_all()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+ attributes.instance_state(f).commit_all()
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [new], []))
f.someattr.append(old)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [new], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [new], []))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new, old], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [new, old], []))
f = Foo()
- collection = attributes.init_collection(f, 'someattr')
+ collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
collection.append_without_event(new)
- f._state.commit_all()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+ attributes.instance_state(f).commit_all()
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [new], []))
f.id = 1
f.someattr.remove(new)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [new]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], [new]))
# case 3. mixing appends with sets
f = Foo()
f.someattr.append(hi)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
f.someattr.append(there)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi, there], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi, there], [], []))
f.someattr = [there]
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], []))
def test_collections_via_backref(self):
class Foo(fixtures.Base):
@@ -941,19 +1030,19 @@ class HistoryTest(TestBase):
f1 = Foo()
b1 = Bar()
- self.assertEquals(attributes.get_history(f1._state, 'bars'), ([], [], []))
- self.assertEquals(attributes.get_history(b1._state, 'foo'), ([], [None], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(b1), 'foo'), ([], [None], []))
#b1.foo = f1
f1.bars.append(b1)
- self.assertEquals(attributes.get_history(f1._state, 'bars'), ([b1], [], []))
- self.assertEquals(attributes.get_history(b1._state, 'foo'), ([f1], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(b1), 'foo'), ([f1], [], []))
b2 = Bar()
f1.bars.append(b2)
- self.assertEquals(attributes.get_history(f1._state, 'bars'), ([b1, b2], [], []))
- self.assertEquals(attributes.get_history(b1._state, 'foo'), ([f1], [], []))
- self.assertEquals(attributes.get_history(b2._state, 'foo'), ([f1], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1, b2], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(b1), 'foo'), ([f1], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(b2), 'foo'), ([f1], [], []))
def test_lazy_backref_collections(self):
class Foo(fixtures.Base):
@@ -978,17 +1067,17 @@ class HistoryTest(TestBase):
f = Foo()
bar4 = Bar()
bar4.foo = f
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar2, bar3], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar2, bar3], []))
lazy_load = None
f = Foo()
bar4 = Bar()
bar4.foo = f
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [], []))
lazy_load = [bar1, bar2, bar3]
- f._state.expire_attributes(['bars'])
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar2, bar3], []))
+ attributes.instance_state(f).expire_attributes(['bars'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [bar1, bar2, bar3], []))
def test_collections_via_lazyload(self):
class Foo(fixtures.Base):
@@ -1011,26 +1100,26 @@ class HistoryTest(TestBase):
f = Foo()
f.bars = []
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [], [bar1, bar2, bar3]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [], [bar1, bar2, bar3]))
f = Foo()
f.bars.append(bar4)
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar2, bar3], []) )
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar2, bar3], []) )
f = Foo()
f.bars.remove(bar2)
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar3], [bar2]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [bar1, bar3], [bar2]))
f.bars.append(bar4)
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar3], [bar2]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar3], [bar2]))
f = Foo()
del f.bars[1]
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar3], [bar2]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [bar1, bar3], [bar2]))
lazy_load = None
f = Foo()
f.bars.append(bar2)
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar2], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar2], [], []))
def test_scalar_via_lazyload(self):
class Foo(fixtures.Base):
@@ -1051,24 +1140,24 @@ class HistoryTest(TestBase):
f = Foo()
self.assertEquals(f.bar, "hi")
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([], ["hi"], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([], ["hi"], []))
f = Foo()
f.bar = None
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], []))
f = Foo()
f.bar = "there"
- self.assertEquals(attributes.get_history(f._state, 'bar'), (["there"], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), (["there"], [], []))
f.bar = "hi"
- self.assertEquals(attributes.get_history(f._state, 'bar'), (["hi"], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), (["hi"], [], []))
f = Foo()
self.assertEquals(f.bar, "hi")
del f.bar
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [], ["hi"]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([], [], ["hi"]))
assert f.bar is None
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], ["hi"]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], ["hi"]))
def test_scalar_object_via_lazyload(self):
class Foo(fixtures.Base):
@@ -1092,24 +1181,25 @@ class HistoryTest(TestBase):
# operations
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [bar1], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([], [bar1], []))
f = Foo()
f.bar = None
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], [bar1]))
f = Foo()
f.bar = bar2
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([bar2], [], [bar1]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([bar2], [], [bar1]))
f.bar = bar1
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [bar1], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([], [bar1], []))
f = Foo()
self.assertEquals(f.bar, bar1)
del f.bar
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], [bar1]))
assert f.bar is None
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], [bar1]))
+
if __name__ == "__main__":
testenv.main()
diff --git a/test/orm/cascade.py b/test/orm/cascade.py
index 7a68a4d58..4a2dc4419 100644
--- a/test/orm/cascade.py
+++ b/test/orm/cascade.py
@@ -1,8 +1,9 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
+from sqlalchemy.orm import attributes, exc as orm_exc
from testlib import *
from testlib import fixtures
@@ -45,7 +46,7 @@ class O2MCascadeTest(fixtures.FixtureTest):
try:
sess.flush()
assert False
- except exceptions.FlushError, e:
+ except orm_exc.FlushError, e:
assert "is an orphan" in str(e)
def test_delete(self):
@@ -571,7 +572,7 @@ class UnsavedOrphansTest(ORMTest):
s.save(a)
try:
s.flush()
- except exceptions.FlushError, e:
+ except orm_exc.FlushError, e:
pass
assert a.address_id is None, "Error: address should not be persistent"
@@ -794,7 +795,7 @@ class DoubleParentOrphanTest(ORMTest):
try:
session.flush()
assert False
- except exceptions.FlushError, e:
+ except orm_exc.FlushError, e:
assert True
class CollectionAssignmentOrphanTest(ORMTest):
@@ -831,7 +832,7 @@ class CollectionAssignmentOrphanTest(ORMTest):
self.assertEquals(sess.query(A).get(a1.id), A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')]))
a1 = sess.query(A).get(a1.id)
- assert not class_mapper(B)._is_orphan(a1.bs[0])
+ assert not class_mapper(B)._is_orphan(attributes.instance_state(a1.bs[0]))
a1.bs[0].foo='b2modified'
a1.bs[1].foo='b3modified'
sess.flush()
diff --git a/test/orm/collection.py b/test/orm/collection.py
index 711dc730b..94e36f366 100644
--- a/test/orm/collection.py
+++ b/test/orm/collection.py
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
import sys
from operator import and_
from sqlalchemy import *
-import sqlalchemy.exceptions as exceptions
+import sqlalchemy.exc as sa_exc
from sqlalchemy.orm import create_session, mapper, relation, \
interfaces, attributes
import sqlalchemy.orm.collections as collections
@@ -933,13 +933,13 @@ class CollectionsTest(TestBase):
self._test_adapter(dict, dictable_entity,
to_set=lambda c: set(c.values()))
self.assert_(False)
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
try:
self._test_dict(dict)
self.assert_(False)
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
def test_dict_subclass(self):
diff --git a/test/orm/compile.py b/test/orm/compile.py
index 31b686062..59d636bae 100644
--- a/test/orm/compile.py
+++ b/test/orm/compile.py
@@ -1,6 +1,6 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
from testlib import *
@@ -118,7 +118,7 @@ class CompileTest(TestBase, AssertsExecutionResults):
try:
class_mapper(Product).compile()
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e).index("Error creating backref ") > -1
def testthree(self):
@@ -177,7 +177,7 @@ class CompileTest(TestBase, AssertsExecutionResults):
try:
compile_mappers()
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e).index("Error creating backref") > -1
if __name__ == '__main__':
diff --git a/test/orm/cycles.py b/test/orm/cycles.py
index f956a4529..8b5173d3c 100644
--- a/test/orm/cycles.py
+++ b/test/orm/cycles.py
@@ -173,22 +173,25 @@ class InheritTestOne(TestBase, AssertsExecutionResults):
Column("child2_data", String(50))
)
meta.create_all()
+
def tearDownAll(self):
meta.drop_all()
+
def testmanytooneonly(self):
"""test similar to SelfReferentialTest.testmanytooneonly"""
+
class Parent(object):
- pass
+ pass
mapper(Parent, parent)
class Child1(Parent):
- pass
+ pass
mapper(Child1, child1, inherits=Parent)
class Child2(Parent):
- pass
+ pass
mapper(Child2, child2, properties={
"child1": relation(Child1,
@@ -216,7 +219,9 @@ class InheritTestOne(TestBase, AssertsExecutionResults):
class InheritTestTwo(ORMTest):
"""the fix in BiDirectionalManyToOneTest raised this issue, regarding
the 'circular sort' containing UOWTasks that were still polymorphic, which could
- create duplicate entries in the final sort"""
+ create duplicate entries in the final sort
+
+ """
def define_tables(self, metadata):
global a, b, c
a = Table('a', metadata,
@@ -235,6 +240,7 @@ class InheritTestTwo(ORMTest):
Column('data', String(30)),
Column('aid', Integer, ForeignKey('a.id', use_alter=True, name="foo")),
)
+
def test_flush(self):
class A(object):pass
class B(A):pass
@@ -484,17 +490,19 @@ class OneToManyManyToOneTest(TestBase, AssertsExecutionResults):
def testcycle(self):
"""this test has a peculiar aspect in that it doesnt create as many dependent
- relationships as the other tests, and revealed a small glitch in the circular dependency sorting."""
+ relationships as the other tests, and revealed a small glitch in the circular dependency sorting.
+
+ """
class Person(object):
- pass
+ pass
class Ball(object):
- pass
+ pass
Ball.mapper = mapper(Ball, ball)
Person.mapper = mapper(Person, person, properties= dict(
- balls = relation(Ball.mapper, primaryjoin=ball.c.person_id==person.c.id, remote_side=ball.c.person_id),
- favorateBall = relation(Ball.mapper, primaryjoin=person.c.favorite_ball_id==ball.c.id, remote_side=person.c.favorite_ball_id),
+ balls = relation(Ball.mapper, primaryjoin=ball.c.person_id==person.c.id, remote_side=ball.c.person_id),
+ favorateBall = relation(Ball.mapper, primaryjoin=person.c.favorite_ball_id==ball.c.id, remote_side=ball.c.id),
)
)
@@ -502,10 +510,9 @@ class OneToManyManyToOneTest(TestBase, AssertsExecutionResults):
p = Person()
p.balls.append(b)
sess = create_session()
- sess.save(b)
- sess.save(b)
+ sess.save(p)
sess.flush()
-
+
def testpostupdate_m2o(self):
"""tests a cycle between two rows, with a post_update on the many-to-one"""
class Person(object):
@@ -860,6 +867,7 @@ class SelfReferentialPostUpdateTest2(TestBase, AssertsExecutionResults):
a_table.create()
def tearDownAll(self):
a_table.drop()
+
def testbasic(self):
"""test that post_update remembers to be involved in update operations as well,
since it replaces the normal dependency processing completely [ticket:413]"""
diff --git a/test/orm/deprecations.py b/test/orm/deprecations.py
new file mode 100644
index 000000000..d6caaa196
--- /dev/null
+++ b/test/orm/deprecations.py
@@ -0,0 +1,394 @@
+"""The collection of modern alternatives to deprecated & removed functionality.
+
+Collects specimens of old ORM code and explicitly covers the recommended
+modern (i.e. not deprecated) alternative to them. The tests snippets here can
+be migrated directly to the wiki, docs, etc.
+
+"""
+import testenv; testenv.configure_for_tests()
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+
+
+users, addresses = None, None
+session = None
+
+class Base(object):
+ def __init__(self, **kw):
+ for k, v in kw.iteritems():
+ setattr(self, k, v)
+
+class User(Base): pass
+class Address(Base): pass
+
+
+class QueryAlternativesTest(ORMTest):
+ '''Collects modern idioms for Queries
+
+ The docstring for each test case serves as miniature documentation about
+ the deprecated use case, and the test body illustrates (and covers) the
+ intended replacement code to accomplish the same task.
+
+ Documenting the "old way" including the argument signature helps these
+ cases remain useful to readers even after the deprecated method has been
+ removed from the modern codebase.
+
+ Format:
+
+ def test_deprecated_thing(self):
+ """Query.methodname(old, arg, **signature)
+
+ output = session.query(User).deprecatedmethod(inputs)
+
+ """
+ # 0.4+
+ output = session.query(User).newway(inputs)
+ assert output is correct
+
+ # 0.5+
+ output = session.query(User).evennewerway(inputs)
+ assert output is correct
+
+ '''
+ keep_mappers = True
+ keep_data = True
+
+ def define_tables(self, metadata):
+ global users_table, addresses_table
+ users_table = Table(
+ 'users', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(64)))
+
+ addresses_table = Table(
+ 'addresses', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('user_id', Integer, ForeignKey('users.id')),
+ Column('email_address', String(128)),
+ Column('purpose', String(16)),
+ Column('bounces', Integer, default=0))
+
+ def setup_mappers(self):
+ mapper(User, users_table, properties=dict(
+ addresses=relation(Address, backref='user'),
+ ))
+ mapper(Address, addresses_table)
+
+ def insert_data(self):
+ user_cols = ('id', 'name')
+ user_rows = ((1, 'jack'), (2, 'ed'), (3, 'fred'), (4, 'chuck'))
+ users_table.insert().execute(
+ [dict(zip(user_cols, row)) for row in user_rows])
+
+ add_cols = ('id', 'user_id', 'email_address', 'purpose', 'bounces')
+ add_rows = (
+ (1, 1, 'jack@jack.home', 'Personal', 0),
+ (2, 1, 'jack@jack.bizz', 'Work', 1),
+ (3, 2, 'ed@foo.bar', 'Personal', 0),
+ (4, 3, 'fred@the.fred', 'Personal', 10))
+
+ addresses_table.insert().execute(
+ [dict(zip(add_cols, row)) for row in add_rows])
+
+ def setUp(self):
+ super(QueryAlternativesTest, self).setUp()
+ global session
+ if session is None:
+ session = create_session()
+
+ def tearDown(self):
+ super(QueryAlternativesTest, self).tearDown()
+ session.clear()
+
+ ######################################################################
+
+ def test_apply_max(self):
+ """Query.apply_max(col)
+
+ max = session.query(Address).apply_max(Address.bounces)
+
+ """
+ # 0.5.0
+ maxes = list(session.query(Address).values(func.max(Address.bounces)))
+ max = maxes[0][0]
+ assert max == 10
+
+ max = session.query(func.max(Address.bounces)).one()[0]
+ assert max == 10
+
+ def test_apply_min(self):
+ """Query.apply_min(col)
+
+ min = session.query(Address).apply_min(Address.bounces)
+
+ """
+ # 0.5.0
+ mins = list(session.query(Address).values(func.min(Address.bounces)))
+ min = mins[0][0]
+ assert min == 0
+
+ min = session.query(func.min(Address.bounces)).one()[0]
+ assert min == 0
+
+ def test_apply_avg(self):
+ """Query.apply_avg(col)
+
+ avg = session.query(Address).apply_avg(Address.bounces)
+
+ """
+ avgs = list(session.query(Address).values(func.avg(Address.bounces)))
+ avg = avgs[0][0]
+ assert avg > 0 and avg < 10
+
+ avg = session.query(func.avg(Address.bounces)).one()[0]
+ assert avg > 0 and avg < 10
+
+ def test_apply_sum(self):
+ """Query.apply_sum(col)
+
+ avg = session.query(Address).apply_avg(Address.bounces)
+
+ """
+ avgs = list(session.query(Address).values(func.avg(Address.bounces)))
+ avg = avgs[0][0]
+ assert avg > 0 and avg < 10
+
+ avg = session.query(func.avg(Address.bounces)).one()[0]
+ assert avg > 0 and avg < 10
+
+ def test_count_by(self):
+ """Query.count_by(*args, **params)
+
+ num = session.query(Address).count_by(purpose='Personal')
+
+ # old-style implicit *_by join
+ num = session.query(User).count_by(purpose='Personal')
+
+ """
+ num = session.query(Address).filter_by(purpose='Personal').count()
+ assert num == 3, num
+
+ num = (session.query(User).join('addresses').
+ filter(Address.purpose=='Personal')).count()
+ assert num == 3, num
+
+ def test_count_whereclause(self):
+ """Query.count(whereclause=None, params=None, **kwargs)
+
+ num = session.query(Address).count(address_table.c.bounces > 1)
+
+ """
+ num = session.query(Address).filter(Address.bounces > 1).count()
+ assert num == 1, num
+
+ def test_execute(self):
+ """Query.execute(clauseelement, params=None, *args, **kwargs)
+
+ users = session.query(User).execute(users_table.select())
+
+ """
+ users = session.query(User).from_statement(users_table.select()).all()
+ assert len(users) == 4
+
+ def test_get_by(self):
+ """Query.get_by(*args, **params)
+
+ user = session.query(User).get_by(name='ed')
+
+ # 0.3-style implicit *_by join
+ user = session.query(User).get_by(email_addresss='fred@the.fred')
+
+ """
+ user = session.query(User).filter_by(name='ed').first()
+ assert user.name == 'ed'
+
+ user = (session.query(User).join('addresses').
+ filter(Address.email_address=='fred@the.fred')).first()
+ assert user.name == 'fred'
+
+ user = session.query(User).filter(
+ User.addresses.any(Address.email_address=='fred@the.fred')).first()
+ assert user.name == 'fred'
+
+ def test_instances_entities(self):
+ """Query.instances(cursor, *mappers_or_columns, **kwargs)
+
+ sel = users_table.join(addresses_table).select(use_labels=True)
+ res = session.query(User).instances(sel.execute(), Address)
+
+ """
+ sel = users_table.join(addresses_table).select(use_labels=True)
+ res = session.query(User, Address).instances(sel.execute())
+
+ assert len(res) == 4
+ cola, colb = res[0]
+ assert isinstance(cola, User) and isinstance(colb, Address)
+
+
+ def test_join_by(self):
+ """Query.join_by(*args, **params)
+
+ TODO
+ """
+
+ def test_join_to(self):
+ """Query.join_to(key)
+
+ TODO
+ """
+
+ def test_join_via(self):
+ """Query.join_via(keys)
+
+ TODO
+ """
+
+ def test_list(self):
+ """Query.list()
+
+ users = session.query(User).list()
+
+ """
+ users = session.query(User).all()
+ assert len(users) == 4
+
+ def test_scalar(self):
+ """Query.scalar()
+
+ user = session.query(User).filter(User.id==1).scalar()
+
+ """
+ user = session.query(User).filter(User.id==1).first()
+ assert user.id==1
+
+ def test_select(self):
+ """Query.select(arg=None, **kwargs)
+
+ users = session.query(User).select(users_table.c.name != None)
+
+ """
+ users = session.query(User).filter(User.name != None).all()
+ assert len(users) == 4
+
+ def test_select_by(self):
+ """Query.select_by(*args, **params)
+
+ users = session.query(User).select_by(name='fred')
+
+ # 0.3 magic join on *_by methods
+ users = session.query(User).select_by(email_address='fred@the.fred')
+
+ """
+ users = session.query(User).filter_by(name='fred').all()
+ assert len(users) == 1
+
+ users = session.query(User).filter(User.name=='fred').all()
+ assert len(users) == 1
+
+ users = (session.query(User).join('addresses').
+ filter_by(email_address='fred@the.fred')).all()
+ assert len(users) == 1
+
+ users = session.query(User).filter(User.addresses.any(
+ Address.email_address == 'fred@the.fred')).all()
+ assert len(users) == 1
+
+ def test_selectfirst(self):
+ """Query.selectfirst(arg=None, **kwargs)
+
+ bounced = session.query(Address).selectfirst(
+ addresses_table.c.bounces > 0)
+
+ """
+ bounced = session.query(Address).filter(Address.bounces > 0).first()
+ assert bounced.bounces > 0
+
+ def test_selectfirst_by(self):
+ """Query.selectfirst_by(*args, **params)
+
+ onebounce = session.query(Address).selectfirst_by(bounces=1)
+
+ # 0.3 magic join on *_by methods
+ onebounce_user = session.query(User).selectfirst_by(bounces=1)
+
+ """
+ onebounce = session.query(Address).filter_by(bounces=1).first()
+ assert onebounce.bounces == 1
+
+ onebounce_user = (session.query(User).join('addresses').
+ filter_by(bounces=1)).first()
+ assert onebounce_user.name == 'jack'
+
+ onebounce_user = (session.query(User).join('addresses').
+ filter(Address.bounces == 1)).first()
+ assert onebounce_user.name == 'jack'
+
+ onebounce_user = session.query(User).filter(User.addresses.any(
+ Address.bounces == 1)).first()
+ assert onebounce_user.name == 'jack'
+
+
+ def test_selectone(self):
+ """Query.selectone(arg=None, **kwargs)
+
+ ed = session.query(User).selectone(users_table.c.name == 'ed')
+
+ """
+ ed = session.query(User).filter(User.name == 'jack').one()
+
+ def test_selectone_by(self):
+ """Query.selectone_by
+
+ ed = session.query(User).selectone_by(name='ed')
+
+ # 0.3 magic join on *_by methods
+ ed = session.query(User).selectone_by(email_address='ed@foo.bar')
+
+ """
+ ed = session.query(User).filter_by(name='jack').one()
+
+ ed = session.query(User).filter(User.name == 'jack').one()
+
+ ed = session.query(User).join('addresses').filter(
+ Address.email_address == 'ed@foo.bar').one()
+
+ ed = session.query(User).filter(User.addresses.any(
+ Address.email_address == 'ed@foo.bar')).one()
+
+ def test_select_statement(self):
+ """Query.select_statement(statement, **params)
+
+ users = session.query(User).select_statement(users_table.select())
+
+ """
+ users = session.query(User).from_statement(users_table.select()).all()
+ assert len(users) == 4
+
+ def test_select_text(self):
+ """Query.select_text(text, **params)
+
+ users = session.query(User).select_text('SELECT * FROM users')
+
+ """
+ users = session.query(User).from_statement('SELECT * FROM users').all()
+ assert len(users) == 4
+
+ def test_select_whereclause(self):
+ """Query.select_whereclause(whereclause=None, params=None, **kwargs)
+
+
+ users = session,query(User).select_whereclause(users.c.name=='ed')
+ users = session.query(User).select_whereclause("name='ed'")
+
+ """
+ users = session.query(User).filter(User.name=='ed').all()
+ assert len(users) == 1 and users[0].name == 'ed'
+
+ users = session.query(User).filter("name='ed'").all()
+ assert len(users) == 1 and users[0].name == 'ed'
+
+
+
+if __name__ == '__main__':
+ testenv.main()
diff --git a/test/orm/dynamic.py b/test/orm/dynamic.py
index c38b27823..0c3f1a95d 100644
--- a/test/orm/dynamic.py
+++ b/test/orm/dynamic.py
@@ -129,7 +129,25 @@ class FlushTest(FixtureTest):
User(name='jack', addresses=[Address(email_address='lala@hoho.com')]),
User(name='ed', addresses=[Address(email_address='foo@bar.com')])
] == sess.query(User).all()
+
+ def test_rollback(self):
+ class Fixture(Base):
+ pass
+ mapper(User, users, properties={
+ 'addresses':dynamic_loader(mapper(Address, addresses))
+ })
+ sess = create_session(autoexpire=False, autocommit=False, autoflush=True)
+ u1 = User(name='jack')
+ u1.addresses.append(Address(email_address='lala@hoho.com'))
+ sess.save(u1)
+ sess.flush()
+ sess.commit()
+ u1.addresses.append(Address(email_address='foo@bar.com'))
+ self.assertEquals(u1.addresses.all(), [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')])
+ sess.rollback()
+ self.assertEquals(u1.addresses.all(), [Address(email_address='lala@hoho.com')])
+
@testing.fails_on('maxdb')
def test_delete_nocascade(self):
mapper(User, users, properties={
diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py
index 418df83dd..94723a20b 100644
--- a/test/orm/eager_relations.py
+++ b/test/orm/eager_relations.py
@@ -6,6 +6,7 @@ from sqlalchemy.orm import *
from testlib import *
from testlib.fixtures import *
from query import QueryTest
+from sqlalchemy.orm import attributes
class EagerTest(FixtureTest):
keep_mappers = False
@@ -31,8 +32,8 @@ class EagerTest(FixtureTest):
sess = create_session()
user = sess.query(User).get(7)
- assert getattr(User, 'addresses').hasparent(user.addresses[0], optimistic=True)
- assert not class_mapper(Address)._is_orphan(user.addresses[0])
+ assert getattr(User, 'addresses').hasparent(attributes.instance_state(user.addresses[0]), optimistic=True)
+ assert not class_mapper(Address)._is_orphan(attributes.instance_state(user.addresses[0]))
def test_orderby(self):
mapper(User, users, properties = {
@@ -129,12 +130,18 @@ class EagerTest(FixtureTest):
})
mapper(User, users)
- assert [Address(id=1, user=User(id=7)), Address(id=4, user=User(id=8)), Address(id=5, user=User(id=9))] == create_session().query(Address).filter(Address.id.in_([1, 4, 5])).all()
-
- assert [Address(id=1, user=User(id=7)), Address(id=4, user=User(id=8)), Address(id=5, user=User(id=9))] == create_session().query(Address).filter(Address.id.in_([1, 4, 5])).limit(3).all()
-
sess = create_session()
- a = sess.query(Address).get(1)
+
+ for q in [
+ sess.query(Address).filter(Address.id.in_([1, 4, 5])),
+ sess.query(Address).filter(Address.id.in_([1, 4, 5])).limit(3)
+ ]:
+ sess.clear()
+ self.assertEquals(q.all(),
+ [Address(id=1, user=User(id=7)), Address(id=4, user=User(id=8)), Address(id=5, user=User(id=9))]
+ )
+
+ a = sess.query(Address).filter(Address.id==1).first()
def go():
assert a.user_id==7
# assert that the eager loader added 'user_id' to the row
@@ -150,12 +157,17 @@ class EagerTest(FixtureTest):
'user_id':deferred(addresses.c.user_id),
})
mapper(User, users, properties={'addresses':relation(Address, lazy=False)})
+
+ for q in [
+ sess.query(User).filter(User.id==7),
+ sess.query(User).filter(User.id==7).limit(1)
+ ]:
+ sess.clear()
+ self.assertEquals(q.all(),
+ [User(id=7, addresses=[Address(id=1)])]
+ )
- assert [User(id=7, addresses=[Address(id=1)])] == create_session().query(User).filter(User.id==7).all()
-
- assert [User(id=7, addresses=[Address(id=1)])] == create_session().query(User).limit(1).filter(User.id==7).all()
-
- sess = create_session()
+ sess.clear()
u = sess.query(User).get(7)
def go():
assert u.addresses[0].user_id==7
@@ -173,9 +185,9 @@ class EagerTest(FixtureTest):
mapper(Dingaling, dingalings, properties={
'address_id':deferred(dingalings.c.address_id)
})
- sess = create_session()
+ sess.clear()
def go():
- u = sess.query(User).limit(1).get(8)
+ u = sess.query(User).get(8)
assert User(id=8, addresses=[Address(id=2, dingalings=[Dingaling(id=1)]), Address(id=3), Address(id=4)]) == u
self.assert_sql_count(testing.db, go, 1)
@@ -192,11 +204,11 @@ class EagerTest(FixtureTest):
self.assert_sql_count(testing.db, go, 1)
def go():
- assert fixtures.item_keyword_result[0:2] == q.join('keywords').filter(keywords.c.name == 'red').all()
+ assert fixtures.item_keyword_result[0:2] == q.join('keywords').filter(Keyword.name == 'red').all()
self.assert_sql_count(testing.db, go, 1)
def go():
- assert fixtures.item_keyword_result[0:2] == q.join('keywords', aliased=True).filter(keywords.c.name == 'red').all()
+ assert fixtures.item_keyword_result[0:2] == q.join('keywords', aliased=True).filter(Keyword.name == 'red').all()
self.assert_sql_count(testing.db, go, 1)
@@ -364,7 +376,7 @@ class EagerTest(FixtureTest):
q = sess.query(User)
def go():
- l = q.filter(s.c.u2_id==User.c.id).distinct().all()
+ l = q.filter(s.c.u2_id==User.id).distinct().all()
assert fixtures.user_address_result == l
self.assert_sql_count(testing.db, go, 1)
@@ -377,7 +389,7 @@ class EagerTest(FixtureTest):
sess = create_session()
q = sess.query(Item)
- l = q.filter((Item.c.description=='item 2') | (Item.c.description=='item 5') | (Item.c.description=='item 3')).\
+ l = q.filter((Item.description=='item 2') | (Item.description=='item 5') | (Item.description=='item 3')).\
order_by(Item.id).limit(2).all()
assert fixtures.item_keyword_result[1:3] == l
@@ -607,7 +619,7 @@ class AddEntityTest(FixtureTest):
)
]
- def test_basic(self):
+ def test_mapper_configured(self):
mapper(User, users, properties={
'addresses':relation(Address, lazy=False),
'orders':relation(Order)
@@ -620,8 +632,9 @@ class AddEntityTest(FixtureTest):
sess = create_session()
+ oalias = aliased(Order)
def go():
- ret = sess.query(User).add_entity(Order).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all()
+ ret = sess.query(User, oalias).join(('orders', oalias)).order_by(User.id, oalias.id).all()
self.assertEquals(ret, self._assert_result())
self.assert_sql_count(testing.db, go, 1)
@@ -638,14 +651,15 @@ class AddEntityTest(FixtureTest):
sess = create_session()
+ oalias = aliased(Order)
def go():
- ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all()
+ ret = sess.query(User, oalias).options(eagerload('addresses')).join(('orders', oalias)).order_by(User.id, oalias.id).all()
self.assertEquals(ret, self._assert_result())
self.assert_sql_count(testing.db, go, 6)
sess.clear()
def go():
- ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).options(eagerload('items', Order)).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all()
+ ret = sess.query(User, oalias).options(eagerload('addresses'), eagerload(oalias.items)).join(('orders', oalias)).order_by(User.id, oalias.id).all()
self.assertEquals(ret, self._assert_result())
self.assert_sql_count(testing.db, go, 1)
@@ -933,11 +947,94 @@ class SelfReferentialM2MEagerTest(ORMTest):
sess.flush()
sess.clear()
-# l = sess.query(Widget).filter(Widget.name=='w1').all()
-# print l
assert [Widget(name='w1', children=[Widget(name='w2')])] == sess.query(Widget).filter(Widget.name==u'w1').all()
+class MixedEntitiesTest(FixtureTest, AssertsCompiledSQL):
+ keep_mappers = True
+ keep_data = True
+
+ def setup_mappers(self):
+ mapper(User, users, properties={
+ 'addresses':relation(Address, backref='user'),
+ 'orders':relation(Order, backref='user'), # o2m, m2o
+ })
+ mapper(Address, addresses)
+ mapper(Order, orders, properties={
+ 'items':relation(Item, secondary=order_items, order_by=items.c.id), #m2m
+ })
+ mapper(Item, items, properties={
+ 'keywords':relation(Keyword, secondary=item_keywords) #m2m
+ })
+ mapper(Keyword, keywords)
+
+ def test_two_entities(self):
+ sess = create_session()
+
+ # two FROM clauses
+ def go():
+ self.assertEquals(
+ [
+ (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])),
+ (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])),
+ ],
+ sess.query(User, Order).filter(User.id==Order.user_id).options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).all(),
+ )
+ self.assert_sql_count(testing.db, go, 1)
+
+ # one FROM clause
+ def go():
+ self.assertEquals(
+ [
+ (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])),
+ (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])),
+ ],
+ sess.query(User, Order).join(User.orders).options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).all(),
+ )
+ self.assert_sql_count(testing.db, go, 1)
+
+ def test_aliased_entity(self):
+ sess = create_session()
+
+ oalias = aliased(Order)
+
+ # two FROM clauses
+ def go():
+ self.assertEquals(
+ [
+ (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])),
+ (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])),
+ ],
+ sess.query(User, oalias).filter(User.id==oalias.user_id).options(eagerload(User.addresses), eagerload(oalias.items)).filter(User.id==9).all(),
+ )
+ self.assert_sql_count(testing.db, go, 1)
+
+ # one FROM clause
+ def go():
+ self.assertEquals(
+ [
+ (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])),
+ (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])),
+ ],
+ sess.query(User, oalias).join((User.orders, oalias)).options(eagerload(User.addresses), eagerload(oalias.items)).filter(User.id==9).all(),
+ )
+ self.assert_sql_count(testing.db, go, 1)
+
+ from sqlalchemy.engine.default import DefaultDialect
+
+ # improper setup: oalias in the columns clause but join to usual orders alias.
+ # this should create two FROM clauses even though the query has a from_clause set up via the join
+ self.assert_compile(sess.query(User, oalias).join(User.orders).options(eagerload(oalias.items)).with_labels().statement,
+ "SELECT users.id AS users_id, users.name AS users_name, orders_1.id AS orders_1_id, "\
+ "orders_1.user_id AS orders_1_user_id, orders_1.address_id AS orders_1_address_id, "\
+ "orders_1.description AS orders_1_description, orders_1.isopen AS orders_1_isopen, items_1.id AS items_1_id, "\
+ "items_1.description AS items_1_description FROM users JOIN orders ON users.id = orders.user_id, "\
+ "orders AS orders_1 LEFT OUTER JOIN order_items AS order_items_1 ON orders_1.id = order_items_1.order_id "\
+ "LEFT OUTER JOIN items AS items_1 ON items_1.id = order_items_1.item_id ORDER BY users.id, items_1.id",
+ dialect=DefaultDialect()
+ )
+
class CyclicalInheritingEagerTest(ORMTest):
+
def define_tables(self, metadata):
global t1, t2
t1 = Table('t1', metadata,
@@ -1041,22 +1138,14 @@ class SubqueryTest(ORMTest):
session.save(User(name='bar', tags=[Tag(score1=5.0, score2=4.0), Tag(score1=50.0, score2=1.0), Tag(score1=15.0, score2=2.0)]))
session.flush()
session.clear()
+
+ for user in session.query(User).all():
+ self.assertEquals(user.query_score, user.prop_score)
def go():
- for user in session.query(User).all():
- self.assertEquals(user.query_score, user.prop_score)
- self.assert_sql_count(testing.db, go, 1)
-
-
- # fails for non labeled (fixed in 0.5):
- if labeled:
- def go():
- u = session.query(User).filter_by(name='joe').one()
- self.assertEquals(u.query_score, u.prop_score)
- self.assert_sql_count(testing.db, go, 1)
- else:
u = session.query(User).filter_by(name='joe').one()
self.assertEquals(u.query_score, u.prop_score)
+ self.assert_sql_count(testing.db, go, 1)
for t in (tags_table, users_table):
t.delete().execute()
diff --git a/test/orm/entity.py b/test/orm/entity.py
index 760f8fce9..d9c9e4002 100644
--- a/test/orm/entity.py
+++ b/test/orm/entity.py
@@ -1,19 +1,18 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
from testlib import *
from testlib.tables import *
+from testlib import fixtures
class EntityTest(TestBase, AssertsExecutionResults):
"""tests mappers that are constructed based on "entity names", which allows the same class
to have multiple primary mappers """
- @testing.uses_deprecated('SessionContext')
def setUpAll(self):
global user1, user2, address1, address2, metadata, ctx
metadata = MetaData(testing.db)
- ctx = SessionContext(create_session)
+ ctx = scoped_session(create_session)
user1 = Table('user1', metadata,
Column('user_id', Integer, Sequence('user1_id_seq', optional=True),
@@ -45,28 +44,31 @@ class EntityTest(TestBase, AssertsExecutionResults):
def tearDownAll(self):
metadata.drop_all()
def tearDown(self):
- ctx.current.clear()
+ ctx.clear()
clear_mappers()
for t in metadata.table_iterator(reverse=True):
t.delete().execute()
- @testing.uses_deprecated('SessionContextExt')
def testbasic(self):
"""tests a pair of one-to-many mapper structures, establishing that both
parent and child objects honor the "entity_name" attribute attached to the object
instances."""
- class User(object):pass
- class Address(object):pass
+ class User(object):
+ def __init__(self, **kw):
+ pass
+ class Address(object):
+ def __init__(self, **kw):
+ pass
- a1mapper = mapper(Address, address1, entity_name='address1', extension=ctx.mapper_extension)
- a2mapper = mapper(Address, address2, entity_name='address2', extension=ctx.mapper_extension)
+ a1mapper = mapper(Address, address1, entity_name='address1', extension=ctx.extension)
+ a2mapper = mapper(Address, address2, entity_name='address2', extension=ctx.extension)
u1mapper = mapper(User, user1, entity_name='user1', properties ={
'addresses':relation(a1mapper)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u2mapper =mapper(User, user2, entity_name='user2', properties={
'addresses':relation(a2mapper)
- }, extension=ctx.mapper_extension)
-
+ }, extension=ctx.extension)
+
u1 = User(_sa_entity_name='user1')
u1.name = 'this is user 1'
a1 = Address(_sa_entity_name='address1')
@@ -79,22 +81,22 @@ class EntityTest(TestBase, AssertsExecutionResults):
a2.email='a2@foo.com'
u2.addresses.append(a2)
- ctx.current.flush()
+ ctx.flush()
assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
assert address1.select().execute().fetchall() == [(a1.address_id, u1.user_id, 'a1@foo.com')]
assert address2.select().execute().fetchall() == [(a1.address_id, u2.user_id, 'a2@foo.com')]
- ctx.current.clear()
- u1list = ctx.current.query(User, entity_name='user1').all()
- u2list = ctx.current.query(User, entity_name='user2').all()
+ ctx.clear()
+ u1list = ctx.query(User, entity_name='user1').all()
+ u2list = ctx.query(User, entity_name='user2').all()
assert len(u1list) == len(u2list) == 1
assert u1list[0] is not u2list[0]
assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
- u1 = ctx.current.query(User, entity_name='user1').first()
- ctx.current.refresh(u1)
- ctx.current.expire(u1)
+ u1 = ctx.query(User, entity_name='user1').first()
+ ctx.refresh(u1)
+ ctx.expire(u1)
def testcascade(self):
@@ -142,18 +144,24 @@ class EntityTest(TestBase, AssertsExecutionResults):
def testpolymorphic(self):
"""tests that entity_name can be used to have two kinds of relations on the same class."""
- class User(object):pass
- class Address1(object):pass
- class Address2(object):pass
+ class User(object):
+ def __init__(self, **kw):
+ pass
+ class Address1(object):
+ def __init__(self, **kw):
+ pass
+ class Address2(object):
+ def __init__(self, **kw):
+ pass
- a1mapper = mapper(Address1, address1, extension=ctx.mapper_extension)
- a2mapper = mapper(Address2, address2, extension=ctx.mapper_extension)
+ a1mapper = mapper(Address1, address1, extension=ctx.extension)
+ a2mapper = mapper(Address2, address2, extension=ctx.extension)
u1mapper = mapper(User, user1, entity_name='user1', properties ={
'addresses':relation(a1mapper)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u2mapper =mapper(User, user2, entity_name='user2', properties={
'addresses':relation(a2mapper)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u1 = User(_sa_entity_name='user1')
u1.name = 'this is user 1'
@@ -167,15 +175,15 @@ class EntityTest(TestBase, AssertsExecutionResults):
a2.email='a2@foo.com'
u2.addresses.append(a2)
- ctx.current.flush()
+ ctx.flush()
assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
assert address1.select().execute().fetchall() == [(a1.address_id, u1.user_id, 'a1@foo.com')]
assert address2.select().execute().fetchall() == [(a1.address_id, u2.user_id, 'a2@foo.com')]
- ctx.current.clear()
- u1list = ctx.current.query(User, entity_name='user1').all()
- u2list = ctx.current.query(User, entity_name='user2').all()
+ ctx.clear()
+ u1list = ctx.query(User, entity_name='user1').all()
+ u2list = ctx.query(User, entity_name='user2').all()
assert len(u1list) == len(u2list) == 1
assert u1list[0] is not u2list[0]
assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
@@ -186,13 +194,15 @@ class EntityTest(TestBase, AssertsExecutionResults):
def testpolymorphic_deferred(self):
"""test that deferred columns load properly using entity names"""
- class User(object):pass
+ class User(object):
+ def __init__(self, **kwargs):
+ pass
u1mapper = mapper(User, user1, entity_name='user1', properties ={
'name':deferred(user1.c.name)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u2mapper =mapper(User, user2, entity_name='user2', properties={
'name':deferred(user2.c.name)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u1 = User(_sa_entity_name='user1')
u1.name = 'this is user 1'
@@ -200,13 +210,13 @@ class EntityTest(TestBase, AssertsExecutionResults):
u2 = User(_sa_entity_name='user2')
u2.name='this is user 2'
- ctx.current.flush()
+ ctx.flush()
assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
- ctx.current.clear()
- u1list = ctx.current.query(User, entity_name='user1').all()
- u2list = ctx.current.query(User, entity_name='user2').all()
+ ctx.clear()
+ u1list = ctx.query(User, entity_name='user1').all()
+ u2list = ctx.query(User, entity_name='user2').all()
assert len(u1list) == len(u2list) == 1
assert u1list[0] is not u2list[0]
# the deferred column load requires that setup_loader() check that the correct DeferredColumnLoader
@@ -214,6 +224,49 @@ class EntityTest(TestBase, AssertsExecutionResults):
assert u1list[0].name == 'this is user 1'
assert u2list[0].name == 'this is user 2'
+class SelfReferentialTest(ORMTest):
+ def define_tables(self, metadata):
+ global nodes
+
+ nodes = Table('nodes', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('parent_id', Integer, ForeignKey('nodes.id')),
+ Column('data', String(50)),
+ Column('type', String(50)),
+ )
+
+ # fails inconsistently. entity name needs deterministic
+ # instrumentation.
+ def dont_test_relation(self):
+ class Node(fixtures.Base):
+ pass
+
+ foonodes = nodes.select().where(nodes.c.type=='foo').alias()
+ barnodes = nodes.select().where(nodes.c.type=='bar').alias()
+
+ # TODO: the order of instrumentation here is not deterministic;
+ # therefore the test fails sporadically since "Node.data" references
+ # different mappers at different times
+ m1 = mapper(Node, nodes)
+ m2 = mapper(Node, foonodes, entity_name='foo')
+ m3 = mapper(Node, barnodes, entity_name='bar')
+
+ m1.add_property('foonodes', relation(m2, primaryjoin=nodes.c.id==foonodes.c.parent_id,
+ backref=backref('foo_parent', remote_side=nodes.c.id, primaryjoin=nodes.c.id==foonodes.c.parent_id)))
+ m1.add_property('barnodes', relation(m3, primaryjoin=nodes.c.id==barnodes.c.parent_id,
+ backref=backref('bar_parent', remote_side=nodes.c.id, primaryjoin=nodes.c.id==barnodes.c.parent_id)))
+
+ sess = create_session()
+
+ n1 = Node(data='n1', type='bat')
+ n1.foonodes.append(Node(data='n2', type='foo'))
+ Node(data='n3', type='bar', bar_parent=n1)
+ sess.save(n1)
+ sess.flush()
+ sess.clear()
+
+ self.assertEquals(sess.query(Node, entity_name="bar").one(), Node(data='n3'))
+ self.assertEquals(sess.query(Node).filter(Node.data=='n1').one(), Node(data='n1', foonodes=[Node(data='n2')], barnodes=[Node(data='n3')]))
if __name__ == "__main__":
testenv.main()
diff --git a/test/orm/expire.py b/test/orm/expire.py
index 58c05a382..e99607866 100644
--- a/test/orm/expire.py
+++ b/test/orm/expire.py
@@ -2,8 +2,9 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
+from sqlalchemy.orm import attributes, exc as orm_exc
from testlib import *
from testlib.fixtures import *
import gc
@@ -39,12 +40,12 @@ class ExpireTest(FixtureTest):
sess.expire(u)
# object isnt refreshed yet, using dict to bypass trigger
assert u.__dict__.get('name') != 'jack'
- assert 'name' in u._state.expired_attributes
+ assert 'name' in attributes.instance_state(u).expired_attributes
sess.query(User).all()
# test that it refreshed
assert u.__dict__['name'] == 'jack'
- assert 'name' not in u._state.expired_attributes
+ assert 'name' not in attributes.instance_state(u).expired_attributes
def go():
assert u.name == 'jack'
@@ -56,8 +57,49 @@ class ExpireTest(FixtureTest):
u = s.get(User, 7)
s.clear()
- self.assertRaisesMessage(exceptions.InvalidRequestError, r"is not persistent within this Session", lambda: s.expire(u))
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, r"is not persistent within this Session", s.expire, u)
+
+ def test_get_refreshes(self):
+ mapper(User, users)
+ s = create_session()
+ u = s.get(User, 10)
+ s.expire_all()
+ def go():
+ u = s.get(User, 10) # get() refreshes
+ self.assert_sql_count(testing.db, go, 1)
+ def go():
+ self.assertEquals(u.name, 'chuck') # attributes unexpired
+ self.assert_sql_count(testing.db, go, 0)
+ def go():
+ u = s.get(User, 10) # expire flag reset, so not expired
+ self.assert_sql_count(testing.db, go, 0)
+
+ s.expire_all()
+ users.delete().where(User.id==10).execute()
+
+ # object is gone, get() returns None
+ assert u in s
+ assert s.get(User, 10) is None
+ assert u not in s # and expunges
+
+ # add it back
+ s.add(u)
+ # nope, raises ObjectDeletedError
+ self.assertRaises(orm_exc.ObjectDeletedError, getattr, u, 'name')
+
+ def test_refresh_cancels_expire(self):
+ mapper(User, users)
+ s = create_session()
+ u = s.get(User, 7)
+ s.expire(u)
+ s.refresh(u)
+
+ def go():
+ u = s.get(User, 7)
+ self.assertEquals(u.name, 'jack')
+ self.assert_sql_count(testing.db, go, 0)
+
def test_expire_doesntload_on_set(self):
mapper(User, users)
@@ -79,18 +121,16 @@ class ExpireTest(FixtureTest):
sess.expire(u, attribute_names=['name'])
sess.expunge(u)
- try:
- u.name
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Instance <class 'testlib.fixtures.User'> is not bound to a Session, and no contextual session is established; attribute refresh operation cannot proceed"
+ self.assertRaises(sa_exc.UnboundExecutionError, getattr, u, 'name')
- def test_pending_doesnt_raise(self):
+ def test_pending_raises(self):
+ # this was the opposite in 0.4, but the reasoning there seemed off.
+ # expiring a pending instance makes no sense, so should raise
mapper(User, users)
sess = create_session()
u = User(id=15)
sess.save(u)
- sess.expire(u, ['name'])
- assert u.name is None
+ self.assertRaises(sa_exc.InvalidRequestError, sess.expire, u, ['name'])
def test_no_instance_key(self):
# this tests an artificial condition such that
@@ -103,11 +143,11 @@ class ExpireTest(FixtureTest):
sess.expire(u, attribute_names=['name'])
sess.expunge(u)
- del u._instance_key
+ attributes.instance_state(u).key = None
assert 'name' not in u.__dict__
sess.save(u)
assert u.name == 'jack'
-
+
def test_expire_preserves_changes(self):
"""test that the expire load operation doesn't revert post-expire changes"""
@@ -163,7 +203,7 @@ class ExpireTest(FixtureTest):
orders.update(id=3).execute(description='order 3 modified')
assert o.isopen == 1
- assert o._state.dict['description'] == 'order 3 modified'
+ assert attributes.instance_state(o).dict['description'] == 'order 3 modified'
def go():
sess.flush()
self.assert_sql_count(testing.db, go, 0)
@@ -180,7 +220,7 @@ class ExpireTest(FixtureTest):
u.addresses[0].email_address = 'someotheraddress'
s.expire(u)
u.name
- print u._state.dict
+ print attributes.instance_state(u).dict
assert u.addresses[0].email_address == 'ed@wood.com'
def test_expired_lazy(self):
@@ -307,28 +347,28 @@ class ExpireTest(FixtureTest):
sess.expire(o, attribute_names=['description'])
assert 'id' in o.__dict__
assert 'description' not in o.__dict__
- assert o._state.dict['isopen'] == 1
+ assert attributes.instance_state(o).dict['isopen'] == 1
orders.update(orders.c.id==3).execute(description='order 3 modified')
def go():
assert o.description == 'order 3 modified'
self.assert_sql_count(testing.db, go, 1)
- assert o._state.dict['description'] == 'order 3 modified'
+ assert attributes.instance_state(o).dict['description'] == 'order 3 modified'
o.isopen = 5
sess.expire(o, attribute_names=['description'])
assert 'id' in o.__dict__
assert 'description' not in o.__dict__
assert o.__dict__['isopen'] == 5
- assert o._state.committed_state['isopen'] == 1
+ assert attributes.instance_state(o).committed_state['isopen'] == 1
def go():
assert o.description == 'order 3 modified'
self.assert_sql_count(testing.db, go, 1)
assert o.__dict__['isopen'] == 5
- assert o._state.dict['description'] == 'order 3 modified'
- assert o._state.committed_state['isopen'] == 1
+ assert attributes.instance_state(o).dict['description'] == 'order 3 modified'
+ assert attributes.instance_state(o).committed_state['isopen'] == 1
sess.flush()
@@ -578,44 +618,8 @@ class PolymorphicExpireTest(ORMTest):
{'person_id':3, 'status':'old engineer'},
)
- def test_poly_select(self):
- mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
- mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer')
-
- sess = create_session()
- [p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all()
-
- sess.expire(p1)
- sess.expire(e1, ['status'])
- sess.expire(e2)
-
- for p in [p1, e2]:
- assert 'name' not in p.__dict__
-
- assert 'name' in e1.__dict__
- assert 'status' not in e2.__dict__
- assert 'status' not in e1.__dict__
-
- e1.name = 'new engineer name'
-
- def go():
- sess.query(Person).all()
- self.assert_sql_count(testing.db, go, 3)
-
- for p in [p1, e1, e2]:
- assert 'name' in p.__dict__
-
- assert 'status' in e2.__dict__
- assert 'status' in e1.__dict__
- def go():
- assert e1.name == 'new engineer name'
- assert e2.name == 'engineer2'
- assert e1.status == 'new engineer'
- self.assert_sql_count(testing.db, go, 0)
- self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'], [], ['engineer1']))
-
def test_poly_deferred(self):
- mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person', polymorphic_fetch='deferred')
+ mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer')
sess = create_session()
@@ -700,7 +704,7 @@ class RefreshTest(FixtureTest):
s = create_session()
u = s.get(User, 7)
s.clear()
- self.assertRaisesMessage(exceptions.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u))
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u))
def test_refresh_expired(self):
mapper(User, users)
diff --git a/test/orm/extendedattr.py b/test/orm/extendedattr.py
new file mode 100644
index 000000000..a5c2c4ace
--- /dev/null
+++ b/test/orm/extendedattr.py
@@ -0,0 +1,303 @@
+import testenv; testenv.configure_for_tests()
+import pickle
+from sqlalchemy import util
+import sqlalchemy.orm.attributes as attributes
+from sqlalchemy.orm.collections import collection
+from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute, is_instrumented
+from sqlalchemy.orm import clear_mappers
+from sqlalchemy.orm import InstrumentationManager
+
+from testlib import *
+
+class MyTypesManager(InstrumentationManager):
+
+ def instrument_attribute(self, class_, key, attr):
+ pass
+
+ def install_descriptor(self, class_, key, attr):
+ pass
+
+ def uninstall_descriptor(self, class_, key):
+ pass
+
+ def instrument_collection_class(self, class_, key, collection_class):
+ return MyListLike
+
+ def get_instance_dict(self, class_, instance):
+ return instance._goofy_dict
+
+ def initialize_instance_dict(self, class_, instance):
+ instance.__dict__['_goofy_dict'] = {}
+
+ def install_state(self, class_, instance, state):
+ instance.__dict__['_my_state'] = state
+
+ def state_getter(self, class_):
+ return lambda instance: instance.__dict__['_my_state']
+
+class MyListLike(list):
+ # add @appender, @remover decorators as needed
+ _sa_iterator = list.__iter__
+ def _sa_appender(self, item, _sa_initiator=None):
+ if _sa_initiator is not False:
+ self._sa_adapter.fire_append_event(item, _sa_initiator)
+ list.append(self, item)
+ append = _sa_appender
+ def _sa_remover(self, item, _sa_initiator=None):
+ self._sa_adapter.fire_pre_remove_event(_sa_initiator)
+ if _sa_initiator is not False:
+ self._sa_adapter.fire_remove_event(item, _sa_initiator)
+ list.remove(self, item)
+ remove = _sa_remover
+
+class MyBaseClass(object):
+ __sa_instrumentation_manager__ = InstrumentationManager
+
+class MyClass(object):
+
+ # This proves that a staticmethod will work here; don't
+ # flatten this back to a class assignment!
+ def __sa_instrumentation_manager__(cls):
+ return MyTypesManager(cls)
+
+ __sa_instrumentation_manager__ = staticmethod(__sa_instrumentation_manager__)
+
+ # This proves SA can handle a class with non-string dict keys
+ locals()[42] = 99 # Don't remove this line!
+
+ def __init__(self, **kwargs):
+ for k in kwargs:
+ setattr(self, k, kwargs[k])
+
+ def __getattr__(self, key):
+ if is_instrumented(self, key):
+ return get_attribute(self, key)
+ else:
+ try:
+ return self._goofy_dict[key]
+ except KeyError:
+ raise AttributeError(key)
+
+ def __setattr__(self, key, value):
+ if is_instrumented(self, key):
+ set_attribute(self, key, value)
+ else:
+ self._goofy_dict[key] = value
+
+ def __hasattr__(self, key):
+ if is_instrumented(self, key):
+ return True
+ else:
+ return key in self._goofy_dict
+
+ def __delattr__(self, key):
+ if is_instrumented(self, key):
+ del_attribute(self, key)
+ else:
+ del self._goofy_dict[key]
+
+class UserDefinedExtensionTest(TestBase):
+ def tearDownAll(self):
+ clear_mappers()
+ attributes._install_lookup_strategy(util.symbol('native'))
+
+ def test_basic(self):
+ for base in (object, MyBaseClass, MyClass):
+ class User(base):
+ pass
+
+ attributes.register_class(User)
+ attributes.register_attribute(User, 'user_id', uselist = False, useobject=False)
+ attributes.register_attribute(User, 'user_name', uselist = False, useobject=False)
+ attributes.register_attribute(User, 'email_address', uselist = False, useobject=False)
+
+ u = User()
+ u.user_id = 7
+ u.user_name = 'john'
+ u.email_address = 'lala@123.com'
+
+ self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+ attributes.instance_state(u).commit_all()
+ self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+
+ u.user_name = 'heythere'
+ u.email_address = 'foo@bar.com'
+ self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com')
+
+ def test_deferred(self):
+ for base in (object, MyBaseClass, MyClass):
+ class Foo(base):pass
+
+ data = {'a':'this is a', 'b':12}
+ def loader(state, keys):
+ for k in keys:
+ state.dict[k] = data[k]
+ return attributes.ATTR_WAS_SET
+
+ attributes.register_class(Foo)
+ manager = attributes.manager_of_class(Foo)
+ manager.deferred_scalar_loader = loader
+ attributes.register_attribute(Foo, 'a', uselist=False, useobject=False)
+ attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
+
+ assert Foo in attributes.instrumentation_registry.state_finders
+ f = Foo()
+ attributes.instance_state(f).expire_attributes(None)
+ self.assertEquals(f.a, "this is a")
+ self.assertEquals(f.b, 12)
+
+ f.a = "this is some new a"
+ attributes.instance_state(f).expire_attributes(None)
+ self.assertEquals(f.a, "this is a")
+ self.assertEquals(f.b, 12)
+
+ attributes.instance_state(f).expire_attributes(None)
+ f.a = "this is another new a"
+ self.assertEquals(f.a, "this is another new a")
+ self.assertEquals(f.b, 12)
+
+ attributes.instance_state(f).expire_attributes(None)
+ self.assertEquals(f.a, "this is a")
+ self.assertEquals(f.b, 12)
+
+ del f.a
+ self.assertEquals(f.a, None)
+ self.assertEquals(f.b, 12)
+
+ attributes.instance_state(f).commit_all()
+ self.assertEquals(f.a, None)
+ self.assertEquals(f.b, 12)
+
+ def test_inheritance(self):
+ """tests that attributes are polymorphic"""
+
+ for base in (object, MyBaseClass, MyClass):
+ class Foo(base):pass
+ class Bar(Foo):pass
+
+ attributes.register_class(Foo)
+ attributes.register_class(Bar)
+
+ def func1():
+ print "func1"
+ return "this is the foo attr"
+ def func2():
+ print "func2"
+ return "this is the bar attr"
+ def func3():
+ print "func3"
+ return "this is the shared attr"
+ attributes.register_attribute(Foo, 'element', uselist=False, callable_=lambda o:func1, useobject=True)
+ attributes.register_attribute(Foo, 'element2', uselist=False, callable_=lambda o:func3, useobject=True)
+ attributes.register_attribute(Bar, 'element', uselist=False, callable_=lambda o:func2, useobject=True)
+
+ x = Foo()
+ y = Bar()
+ assert x.element == 'this is the foo attr'
+ assert y.element == 'this is the bar attr', y.element
+ assert x.element2 == 'this is the shared attr'
+ assert y.element2 == 'this is the shared attr'
+
+ def test_collection_with_backref(self):
+ for base in (object, MyBaseClass, MyClass):
+ class Post(base):pass
+ class Blog(base):pass
+
+ attributes.register_class(Post)
+ attributes.register_class(Blog)
+ attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True)
+ attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True)
+ b = Blog()
+ (p1, p2, p3) = (Post(), Post(), Post())
+ b.posts.append(p1)
+ b.posts.append(p2)
+ b.posts.append(p3)
+ self.assert_(b.posts == [p1, p2, p3])
+ self.assert_(p2.blog is b)
+
+ p3.blog = None
+ self.assert_(b.posts == [p1, p2])
+ p4 = Post()
+ p4.blog = b
+ self.assert_(b.posts == [p1, p2, p4])
+
+ p4.blog = b
+ p4.blog = b
+ self.assert_(b.posts == [p1, p2, p4])
+
+ # assert no failure removing None
+ p5 = Post()
+ p5.blog = None
+ del p5.blog
+
+ def test_history(self):
+ for base in (object, MyBaseClass, MyClass):
+ class Foo(base):
+ pass
+ class Bar(base):
+ pass
+
+ attributes.register_class(Foo)
+ attributes.register_class(Bar)
+ attributes.register_attribute(Foo, "name", uselist=False, useobject=False)
+ attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True)
+ attributes.register_attribute(Bar, "name", uselist=False, useobject=False)
+
+
+ f1 = Foo()
+ f1.name = 'f1'
+
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1'], [], []))
+
+ b1 = Bar()
+ b1.name = 'b1'
+ f1.bars.append(b1)
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
+
+ attributes.instance_state(f1).commit_all()
+ attributes.instance_state(b1).commit_all()
+
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), ([], ['f1'], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([], [b1], []))
+
+ f1.name = 'f1mod'
+ b2 = Bar()
+ b2.name = 'b2'
+ f1.bars.append(b2)
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1mod'], [], ['f1']))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], []))
+ f1.bars.remove(b1)
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1]))
+
+ def test_null_instrumentation(self):
+ class Foo(MyBaseClass):
+ pass
+ attributes.register_class(Foo)
+ attributes.register_attribute(Foo, "name", uselist=False, useobject=False)
+ attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True)
+
+ assert Foo.name == attributes.manager_of_class(Foo).get_inst('name')
+ assert Foo.bars == attributes.manager_of_class(Foo).get_inst('bars')
+
+ def test_alternate_finders(self):
+ """Ensure the generic finder front-end deals with edge cases."""
+
+ class Unknown(object): pass
+ class Known(MyBaseClass): pass
+
+ attributes.register_class(Known)
+ k, u = Known(), Unknown()
+
+ assert attributes.manager_of_class(Unknown) is None
+ assert attributes.manager_of_class(Known) is not None
+ assert attributes.manager_of_class(None) is None
+
+ assert attributes.instance_state(k) is not None
+ self.assertRaises((AttributeError, KeyError),
+ attributes.instance_state, u)
+ self.assertRaises((AttributeError, KeyError),
+ attributes.instance_state, None)
+
+
+if __name__ == '__main__':
+ testing.main()
diff --git a/test/orm/generative.py b/test/orm/generative.py
index aced8f626..88793f743 100644
--- a/test/orm/generative.py
+++ b/test/orm/generative.py
@@ -1,7 +1,6 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exceptions
from testlib import *
import testlib.tables as tables
@@ -35,8 +34,8 @@ class GenerativeQueryTest(TestBase):
def test_selectby(self):
res = create_session(bind=testing.db).query(Foo).filter_by(range=5)
- assert res.order_by([Foo.c.bar])[0].bar == 5
- assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
+ assert res.order_by([Foo.bar])[0].bar == 5
+ assert res.order_by([desc(Foo.bar)])[0].bar == 95
@testing.unsupported('mssql')
@testing.fails_on('maxdb')
@@ -60,8 +59,8 @@ class GenerativeQueryTest(TestBase):
assert query.count() == 100
assert query.filter(foo.c.bar<30).min(foo.c.bar) == 0
assert query.filter(foo.c.bar<30).max(foo.c.bar) == 29
- assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).first() == 29
- assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).one() == 29
+ assert query.filter(foo.c.bar<30).values(func.max(foo.c.bar)).next()[0] == 29
+ assert query.filter(foo.c.bar<30).values(func.max(foo.c.bar)).next()[0] == 29
def test_aggregate_1(self):
if (testing.against('mysql') and
@@ -77,22 +76,20 @@ class GenerativeQueryTest(TestBase):
avg = query.filter(foo.c.bar < 30).avg(foo.c.bar)
assert round(avg, 1) == 14.5
- @testing.fails_on('firebird', 'mssql')
- @testing.uses_deprecated('Call to deprecated function apply_avg')
def test_aggregate_3(self):
query = create_session(bind=testing.db).query(Foo)
- avg_f = query.filter(foo.c.bar<30).apply_avg(foo.c.bar).first()
+ avg_f = query.filter(foo.c.bar<30).values(func.avg(foo.c.bar)).next()[0]
assert round(avg_f, 1) == 14.5
- avg_o = query.filter(foo.c.bar<30).apply_avg(foo.c.bar).one()
+ avg_o = query.filter(foo.c.bar<30).values(func.avg(foo.c.bar)).next()[0]
assert round(avg_o, 1) == 14.5
def test_filter(self):
query = create_session(bind=testing.db).query(Foo)
assert query.count() == 100
- assert query.filter(Foo.c.bar < 30).count() == 30
- res2 = query.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10)
+ assert query.filter(Foo.bar < 30).count() == 30
+ res2 = query.filter(Foo.bar < 30).filter(Foo.bar > 10)
assert res2.count() == 19
def test_options(self):
@@ -105,12 +102,12 @@ class GenerativeQueryTest(TestBase):
def test_order_by(self):
query = create_session(bind=testing.db).query(Foo)
- assert query.order_by([Foo.c.bar])[0].bar == 0
- assert query.order_by([desc(Foo.c.bar)])[0].bar == 99
+ assert query.order_by([Foo.bar])[0].bar == 0
+ assert query.order_by([desc(Foo.bar)])[0].bar == 99
def test_offset(self):
query = create_session(bind=testing.db).query(Foo)
- assert list(query.order_by([Foo.c.bar]).offset(10))[0].bar == 10
+ assert list(query.order_by([Foo.bar]).offset(10))[0].bar == 10
def test_offset(self):
query = create_session(bind=testing.db).query(Foo)
@@ -168,7 +165,7 @@ class RelationsTest(TestBase, AssertsExecutionResults):
})
session = create_session(bind=testing.db)
query = session.query(tables.User)
- x = query.join(['orders', 'items']).filter(tables.Item.c.item_id==2)
+ x = query.join(['orders', 'items']).filter(tables.Item.item_id==2)
print x.compile()
self.assert_result(list(x), tables.User, tables.user_result[2])
def test_outerjointo(self):
@@ -180,7 +177,7 @@ class RelationsTest(TestBase, AssertsExecutionResults):
})
session = create_session(bind=testing.db)
query = session.query(tables.User)
- x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+ x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.order_id==None,tables.Item.item_id==2))
print x.compile()
self.assert_result(list(x), tables.User, *tables.user_result[1:3])
def test_outerjointo_count(self):
@@ -192,7 +189,7 @@ class RelationsTest(TestBase, AssertsExecutionResults):
})
session = create_session(bind=testing.db)
query = session.query(tables.User)
- x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
+ x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.order_id==None,tables.Item.item_id==2)).count()
assert x==2
def test_from(self):
mapper(tables.User, tables.users, properties={
@@ -203,7 +200,7 @@ class RelationsTest(TestBase, AssertsExecutionResults):
session = create_session(bind=testing.db)
query = session.query(tables.User)
x = query.select_from(tables.users.outerjoin(tables.orders).outerjoin(tables.orderitems)).\
- filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+ filter(or_(tables.Order.order_id==None,tables.Item.item_id==2))
print x.compile()
self.assert_result(list(x), tables.User, *tables.user_result[1:3])
@@ -238,27 +235,6 @@ class CaseSensitiveTest(TestBase):
res = q.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1)).distinct()
self.assertEqual(res.count(), 1)
-class SelfRefTest(ORMTest):
- def define_tables(self, metadata):
- global t1
- t1 = Table('t1', metadata,
- Column('id', Integer, primary_key=True),
- Column('parent_id', Integer, ForeignKey('t1.id'))
- )
- def test_noautojoin(self):
- class T(object):pass
- mapper(T, t1, properties={'children':relation(T)})
- sess = create_session(bind=testing.db)
- def go():
- sess.query(T).join('children')
- self.assertRaisesMessage(exceptions.InvalidRequestError,
- "Self-referential query on 'T\.children \(T\)' property requires aliased=True argument.", go)
-
- def go():
- sess.query(T).join(['children']).select_by(id=7)
- self.assertRaisesMessage(exceptions.InvalidRequestError,
- "Self-referential query on 'T\.children \(T\)' property requires aliased=True argument.", go)
-
if __name__ == "__main__":
diff --git a/test/orm/inheritance/abc_inheritance.py b/test/orm/inheritance/abc_inheritance.py
index 5f7a10756..e6977506a 100644
--- a/test/orm/inheritance/abc_inheritance.py
+++ b/test/orm/inheritance/abc_inheritance.py
@@ -1,7 +1,7 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy.orm.sync import ONETOMANY, MANYTOONE
+from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE
from testlib import *
diff --git a/test/orm/inheritance/abc_polymorphic.py b/test/orm/inheritance/abc_polymorphic.py
index 076c7b76b..367c2e73c 100644
--- a/test/orm/inheritance/abc_polymorphic.py
+++ b/test/orm/inheritance/abc_polymorphic.py
@@ -1,6 +1,6 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import util
from sqlalchemy.orm import *
from testlib import *
from testlib import fixtures
@@ -32,8 +32,8 @@ class ABCTest(ORMTest):
else:
abc = bc = None
- mapper(A, a, select_table=abc, polymorphic_on=a.c.type, polymorphic_identity='a', polymorphic_fetch=fetchtype)
- mapper(B, b, select_table=bc, inherits=A, polymorphic_identity='b', polymorphic_fetch=fetchtype)
+ mapper(A, a, select_table=abc, polymorphic_on=a.c.type, polymorphic_identity='a')
+ mapper(B, b, select_table=bc, inherits=A, polymorphic_identity='b')
mapper(C, c, inherits=B, polymorphic_identity='c')
a1 = A(adata='a1')
@@ -82,8 +82,7 @@ class ABCTest(ORMTest):
return test_roundtrip
test_union = make_test('union')
- test_select = make_test('select')
- test_deferred = make_test('deferred')
+ test_none = make_test('none')
if __name__ == '__main__':
diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py
index 8a0b6f30a..91e7b3b7f 100644
--- a/test/orm/inheritance/basic.py
+++ b/test/orm/inheritance/basic.py
@@ -1,7 +1,8 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import exc as sa_exc, util
from sqlalchemy.orm import *
+from sqlalchemy.orm import exc as orm_exc
from testlib import *
from testlib import fixtures
@@ -302,7 +303,7 @@ class ConstructionTest(ORMTest):
'content_type':relation(content_types)
}, polymorphic_identity='contents')
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e) == "Mapper 'Mapper|Content|content' specifies a polymorphic_identity of 'contents', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument"
def testbackref(self):
@@ -397,7 +398,7 @@ class FlushTest(ORMTest):
class Admin(User):pass
role_mapper = mapper(Role, roles)
user_mapper = mapper(User, users, properties = {
- 'roles' : relation(Role, secondary=user_roles, lazy=False, private=False)
+ 'roles' : relation(Role, secondary=user_roles, lazy=False)
}
)
admin_mapper = mapper(Admin, admins, inherits=user_mapper)
@@ -432,7 +433,7 @@ class FlushTest(ORMTest):
role_mapper = mapper(Role, roles)
user_mapper = mapper(User, users, properties = {
- 'roles' : relation(Role, secondary=user_roles, lazy=False, private=False)
+ 'roles' : relation(Role, secondary=user_roles, lazy=False)
}
)
@@ -507,13 +508,13 @@ class VersioningTest(ORMTest):
try:
sess2.query(Base).with_lockmode('read').get(s1.id)
assert False
- except exceptions.ConcurrentModificationError, e:
+ except orm_exc.ConcurrentModificationError, e:
assert True
try:
sess2.flush()
assert False
- except exceptions.ConcurrentModificationError, e:
+ except orm_exc.ConcurrentModificationError, e:
assert True
sess2.refresh(s2)
@@ -553,7 +554,7 @@ class VersioningTest(ORMTest):
s1.subdata = 'some new subdata'
sess.flush()
assert False
- except exceptions.ConcurrentModificationError, e:
+ except orm_exc.ConcurrentModificationError, e:
assert True
@@ -608,7 +609,7 @@ class DistinctPKTest(ORMTest):
mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id])
self._do_test(True)
assert False
- except exceptions.SAWarning, e:
+ except sa_exc.SAWarning, e:
assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name.", str(e)
def test_explicit_pk(self):
diff --git a/test/orm/inheritance/concrete.py b/test/orm/inheritance/concrete.py
index 29fa1df60..ffc95ac05 100644
--- a/test/orm/inheritance/concrete.py
+++ b/test/orm/inheritance/concrete.py
@@ -74,6 +74,10 @@ class ConcreteTest(ORMTest):
assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"])
assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Kurt knows how to hack"])
+ manager = session.query(Manager).one()
+ session.expire(manager, ['manager_data'])
+ self.assertEquals(manager.manager_data, "knows how to manage things")
+
def test_multi_level(self):
class Employee(object):
def __init__(self, name):
diff --git a/test/orm/inheritance/poly_linked_list.py b/test/orm/inheritance/poly_linked_list.py
index b2dd6c658..e9e5e1ef6 100644
--- a/test/orm/inheritance/poly_linked_list.py
+++ b/test/orm/inheritance/poly_linked_list.py
@@ -166,7 +166,7 @@ class PolymorphicCircularTest(ORMTest):
# clear and query forwards
sess.clear()
- node = sess.query(Table1).filter(Table1.c.id==t.id).first()
+ node = sess.query(Table1).filter(Table1.id==t.id).first()
assertlist = []
while (node):
assertlist.append(node)
@@ -178,7 +178,7 @@ class PolymorphicCircularTest(ORMTest):
# clear and query backwards
sess.clear()
- node = sess.query(Table1).filter(Table1.c.id==obj.id).first()
+ node = sess.query(Table1).filter(Table1.id==obj.id).first()
assertlist = []
while (node):
assertlist.insert(0, node)
@@ -189,9 +189,6 @@ class PolymorphicCircularTest(ORMTest):
backwards = repr(assertlist)
# everything should match !
- print "ORIGNAL", original
- print "BACKWARDS",backwards
- print "FORWARDS", forwards
assert original == forwards == backwards
if __name__ == '__main__':
diff --git a/test/orm/inheritance/polymorph.py b/test/orm/inheritance/polymorph.py
index 544252024..141aedcac 100644
--- a/test/orm/inheritance/polymorph.py
+++ b/test/orm/inheritance/polymorph.py
@@ -4,7 +4,7 @@ import testenv; testenv.configure_for_tests()
import sets
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exceptions
+from sqlalchemy.orm import exc as orm_exc
from testlib import *
from testlib import fixtures
@@ -122,7 +122,7 @@ class RelationToSubclassTest(PolymorphTest):
class RoundTripTest(PolymorphTest):
pass
-def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False, polymorphic_fetch=None, use_outer_joins=False):
+def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic):
"""generates a round trip test.
include_base - whether or not to include the base 'person' type in the union.
@@ -131,62 +131,52 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
use_literal_join - primary join condition is explicitly specified
"""
def test_roundtrip(self):
- # create a union that represents both types of joins.
- if not polymorphic_fetch == 'union':
- person_join = None
- manager_join = None
- elif include_base:
- if use_outer_joins:
- person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss)
- manager_join = people.join(managers).outerjoin(boss)
- else:
+ if with_polymorphic == 'unions':
+ if include_base:
person_join = polymorphic_union(
{
'engineer':people.join(engineers),
'manager':people.join(managers),
'person':people.select(people.c.type=='person'),
}, None, 'pjoin')
-
- manager_join = people.join(managers).outerjoin(boss)
- else:
- if use_outer_joins:
- person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss)
- manager_join = people.join(managers).outerjoin(boss)
else:
person_join = polymorphic_union(
{
'engineer':people.join(engineers),
'manager':people.join(managers),
}, None, 'pjoin')
- manager_join = people.join(managers).outerjoin(boss)
+
+ manager_join = people.join(managers).outerjoin(boss)
+ person_with_polymorphic = ['*', person_join]
+ manager_with_polymorphic = ['*', manager_join]
+ elif with_polymorphic == 'joins':
+ person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss)
+ manager_join = people.join(managers).outerjoin(boss)
+ person_with_polymorphic = ['*', person_join]
+ manager_with_polymorphic = ['*', manager_join]
+ elif with_polymorphic == 'auto':
+ person_with_polymorphic = '*'
+ manager_with_polymorphic = '*'
+ else:
+ person_with_polymorphic = None
+ manager_with_polymorphic = None
if redefine_colprop:
- person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
+ person_mapper = mapper(Person, people, with_polymorphic=person_with_polymorphic, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
else:
- person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person')
+ person_mapper = mapper(Person, people, with_polymorphic=person_with_polymorphic, polymorphic_on=people.c.type, polymorphic_identity='person')
mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
- mapper(Manager, managers, inherits=person_mapper, select_table=manager_join, polymorphic_identity='manager')
+ mapper(Manager, managers, inherits=person_mapper, with_polymorphic=manager_with_polymorphic, polymorphic_identity='manager')
mapper(Boss, boss, inherits=Manager, polymorphic_identity='boss')
- if use_literal_join:
- mapper(Company, companies, properties={
- 'employees': relation(Person, lazy=lazy_relation,
- primaryjoin=(people.c.company_id ==
- companies.c.company_id),
- cascade="all,delete-orphan",
- backref="company",
- order_by=people.c.person_id
- )
- })
- else:
- mapper(Company, companies, properties={
- 'employees': relation(Person, lazy=lazy_relation,
- cascade="all, delete-orphan",
- backref="company", order_by=people.c.person_id
- )
- })
+ mapper(Company, companies, properties={
+ 'employees': relation(Person, lazy=lazy_relation,
+ cascade="all, delete-orphan",
+ backref="company", order_by=people.c.person_id
+ )
+ })
if redefine_colprop:
person_attribute_name = 'person_name'
@@ -224,18 +214,16 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
def go():
cc = session.query(Company).get(c.company_id)
- for e in cc.employees:
- assert e._instance_key[0] == Person
self.assertEquals(cc.employees, employees)
if not lazy_relation:
- if polymorphic_fetch=='union':
+ if with_polymorphic != 'none':
self.assert_sql_count(testing.db, go, 1)
else:
self.assert_sql_count(testing.db, go, 5)
else:
- if polymorphic_fetch=='union':
+ if with_polymorphic != 'none':
self.assert_sql_count(testing.db, go, 2)
else:
self.assert_sql_count(testing.db, go, 6)
@@ -265,21 +253,20 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
session.flush()
session.clear()
- if polymorphic_fetch == 'select':
- def go():
- session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
- self.assert_sql_count(testing.db, go, 2)
- session.clear()
- dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
- def go():
- # assert that only primary table is queried for already-present-in-session
- d = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
- self.assert_sql_count(testing.db, go, 1)
+ def go():
+ session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+ self.assert_sql_count(testing.db, go, 1)
+ session.clear()
+ dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+ def go():
+ # assert that only primary table is queried for already-present-in-session
+ d = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+ self.assert_sql_count(testing.db, go, 1)
# test standalone orphans
daboss = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'})
session.save(daboss)
- self.assertRaises(exceptions.FlushError, session.flush)
+ self.assertRaises(orm_exc.FlushError, session.flush)
c = session.query(Company).first()
daboss.company = c
manager_list = [e for e in c.employees if isinstance(e, Manager)]
@@ -295,24 +282,21 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
self.assertEquals(people.count().scalar(), 0)
test_roundtrip = _function_named(
- test_roundtrip, "test_%s%s%s%s%s" % (
+ test_roundtrip, "test_%s%s%s_%s" % (
(lazy_relation and "lazy" or "eager"),
(include_base and "_inclbase" or ""),
(redefine_colprop and "_redefcol" or ""),
- (polymorphic_fetch != 'union' and '_' + polymorphic_fetch or (use_literal_join and "_litjoin" or "")),
- (use_outer_joins and '_outerjoins' or '')))
+ with_polymorphic))
setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip)
-for include_base in [True, False]:
- for lazy_relation in [True, False]:
- for redefine_colprop in [True, False]:
- for use_literal_join in [True, False]:
- for polymorphic_fetch in ['union', 'select', 'deferred']:
- if polymorphic_fetch == 'union':
- for use_outer_joins in [True, False]:
- generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch, use_outer_joins)
- else:
- generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch, False)
+for lazy_relation in [True, False]:
+ for redefine_colprop in [True, False]:
+ for with_polymorphic in ['unions', 'joins', 'auto', 'none']:
+ if with_polymorphic == 'unions':
+ for include_base in [True, False]:
+ generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic)
+ else:
+ generate_round_trip_test(False, lazy_relation, redefine_colprop, with_polymorphic)
if __name__ == "__main__":
testenv.main()
diff --git a/test/orm/inheritance/polymorph2.py b/test/orm/inheritance/polymorph2.py
index ed003927b..4b17e9e9d 100644
--- a/test/orm/inheritance/polymorph2.py
+++ b/test/orm/inheritance/polymorph2.py
@@ -4,7 +4,7 @@ inheritance setups for which we maintain compatibility.
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import util
from sqlalchemy.orm import *
from testlib import *
from testlib import fixtures
@@ -560,7 +560,7 @@ class RelationTest7(ORMTest):
class Car(PersistentObject):
def __repr__(self):
- return "Car number %d, name %s" % i(self.car_id, self.name)
+ return "Car number %d, name %s" % (self.car_id, self.name)
class Offraod_Car(Car):
def __repr__(self):
@@ -725,18 +725,18 @@ class GenerativeTest(TestBase, AssertsExecutionResults):
session.save(car2)
session.flush()
- # test these twice because theres caching involved, as well previous issues that modified the polymorphic union
- for x in range(0, 2):
- r = session.query(Person).filter(people.c.name.like('%2')).join('status').filter_by(name="active")
- assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]"
- r = session.query(Engineer).join('status').filter(people.c.name.in_(['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & (status.c.name=="active"))
- assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]"
- # this test embeds the original polymorphic union (employee_join) fully
- # into the WHERE criterion, using a correlated select. ticket #577 tracks
- # that Query's adaptation of the WHERE clause does not dig into the
- # mapped selectable itself, which permanently breaks the mapped selectable.
- r = session.query(Person).filter(exists([Car.c.owner], Car.c.owner==employee_join.c.person_id))
- assert str(list(r)) == "[Engineer E4, field X, status Status dead]"
+ # this particular adapt used to cause a recursion overflow;
+ # added here for testing
+ e = exists([Car.owner], Car.owner==employee_join.c.person_id)
+ Query(Person)._adapt_clause(employee_join, False, False)
+
+ r = session.query(Person).filter(Person.name.like('%2')).join('status').filter_by(name="active")
+ assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]"
+ r = session.query(Engineer).join('status').filter(Person.name.in_(['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & (status.c.name=="active"))
+ assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]"
+
+ r = session.query(Person).filter(exists([1], Car.owner==Person.person_id))
+ assert str(list(r)) == "[Engineer E4, field X, status Status dead]"
class MultiLevelTest(ORMTest):
def define_tables(self, metadata):
diff --git a/test/orm/inheritance/query.py b/test/orm/inheritance/query.py
index 34ead1622..6a40efc4a 100644
--- a/test/orm/inheritance/query.py
+++ b/test/orm/inheritance/query.py
@@ -7,9 +7,11 @@ import testenv; testenv.configure_for_tests()
import sets
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from testlib import *
from testlib import fixtures
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.engine import default
class Company(fixtures.Base):
pass
@@ -30,7 +32,7 @@ class Paperwork(fixtures.Base):
pass
def make_test(select_type):
- class PolymorphicQueryTest(ORMTest):
+ class PolymorphicQueryTest(ORMTest, AssertsCompiledSQL):
keep_data = True
keep_mappers = True
@@ -184,11 +186,42 @@ def make_test(select_type):
def test_primary_eager_aliasing(self):
sess = create_session()
+
+ # assert the SQL itself here to ensure no over-joining is taking place
+ if select_type == '':
+ self.assert_compile(
+ sess.query(Person).options(eagerload(Engineer.machines)).limit(2).offset(1).with_labels().statement,
+ "SELECT people.person_id AS people_person_id, people.company_id AS people_company_id, "\
+ "people.name AS people_name, people.type AS people_type FROM people ORDER BY people.person_id LIMIT 2 OFFSET 1",
+ dialect=default.DefaultDialect())
+
def go():
self.assertEquals(sess.query(Person).options(eagerload(Engineer.machines))[1:3].all(), all_employees[1:3])
self.assert_sql_count(testing.db, go, {'':6, 'Polymorphic':3}.get(select_type, 4))
sess = create_session()
+
+ if select_type == '':
+ self.assert_compile(
+ sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines)).limit(2).offset(1).with_labels().statement,
+ "SELECT anon_1.people_person_id AS anon_1_people_person_id, anon_1.people_company_id AS anon_1_people_company_id, "\
+ "anon_1.people_name AS anon_1_people_name, anon_1.people_type AS anon_1_people_type, anon_1.engineers_person_id AS "\
+ "anon_1_engineers_person_id, anon_1.engineers_status AS anon_1_engineers_status, anon_1.engineers_engineer_name AS "\
+ "anon_1_engineers_engineer_name, anon_1.engineers_primary_language AS anon_1_engineers_primary_language, "\
+ "anon_1.managers_person_id AS anon_1_managers_person_id, anon_1.managers_status AS anon_1_managers_status, "\
+ "anon_1.managers_manager_name AS anon_1_managers_manager_name, anon_1.boss_boss_id AS anon_1_boss_boss_id, "\
+ "anon_1.boss_golf_swing AS anon_1_boss_golf_swing, machines_1.machine_id AS machines_1_machine_id, machines_1.name AS "\
+ "machines_1_name, machines_1.engineer_id AS machines_1_engineer_id FROM (SELECT people.person_id AS people_person_id, "\
+ "people.company_id AS people_company_id, people.name AS people_name, people.type AS people_type, engineers.person_id AS "\
+ "engineers_person_id, engineers.status AS engineers_status, engineers.engineer_name AS engineers_engineer_name, "\
+ "engineers.primary_language AS engineers_primary_language, managers.person_id AS managers_person_id, managers.status "\
+ "AS managers_status, managers.manager_name AS managers_manager_name, boss.boss_id AS boss_boss_id, boss.golf_swing "\
+ "AS boss_golf_swing FROM people LEFT OUTER JOIN engineers ON people.person_id = engineers.person_id LEFT OUTER JOIN "\
+ "managers ON people.person_id = managers.person_id LEFT OUTER JOIN boss ON managers.person_id = boss.boss_id ORDER BY "\
+ "people.person_id LIMIT 2 OFFSET 1) AS anon_1 LEFT OUTER JOIN machines AS machines_1 ON anon_1.engineers_person_id = "\
+ "machines_1.engineer_id ORDER BY anon_1.people_person_id, machines_1.machine_id",
+ dialect=default.DefaultDialect())
+
def go():
self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3].all(), all_employees[1:3])
self.assert_sql_count(testing.db, go, 3)
@@ -199,9 +232,9 @@ def make_test(select_type):
# for all mappers, ensure the primary key has been calculated as just the "person_id"
# column
- self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert"))
- self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert"))
- self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss"))
+ self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
+ self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
+ self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore"))
def test_filter_on_subclass(self):
sess = create_session()
@@ -219,7 +252,7 @@ def make_test(select_type):
def test_join_from_polymorphic(self):
sess = create_session()
-
+
for aliased in (True, False):
self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
@@ -227,7 +260,7 @@ def make_test(select_type):
self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1])
- self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
+ self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
def test_join_from_with_polymorphic(self):
sess = create_session()
@@ -240,14 +273,14 @@ def make_test(select_type):
self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
sess.clear()
- self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
+ self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
def test_join_to_polymorphic(self):
sess = create_session()
self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2)
self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2)
-
+
def test_polymorphic_any(self):
sess = create_session()
@@ -305,6 +338,8 @@ def make_test(select_type):
Manager(name="dogbert", manager_name="dogbert", status="regular manager"),
Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer")
]
+ self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations)
+
def go():
self.assertEquals(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1])
@@ -345,6 +380,7 @@ def make_test(select_type):
]
sess = create_session()
+
def go():
# test load Companies with lazy load to 'employees'
self.assertEquals(sess.query(Company).all(), assert_result)
@@ -359,7 +395,7 @@ def make_test(select_type):
# in the case of select_type='', the eagerload doesn't take in this case;
# it eagerloads company->people, then a load for each of 5 rows, then lazyload of "machines"
self.assert_sql_count(testing.db, go, {'':7, 'Polymorphic':1}.get(select_type, 2))
-
+
def test_eagerload_on_subclass(self):
sess = create_session()
def go():
@@ -371,10 +407,15 @@ def make_test(select_type):
def test_join_to_subclass(self):
sess = create_session()
+ self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
if select_type == '':
self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
+
+ ealias = aliased(Engineer)
+ self.assertEquals(sess.query(Company).join(('employees', ealias)).filter(ealias.primary_language=='java').all(), [c1])
+
self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3])
self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2])
@@ -445,6 +486,150 @@ def make_test(select_type):
self.assertEquals(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2)
+ def test_from_alias(self):
+ sess = create_session()
+
+ palias = aliased(Person)
+ self.assertEquals(
+ sess.query(palias).filter(palias.name.in_(['dilbert', 'wally'])).all(),
+ [e1, e2]
+ )
+
+ def test_self_referential(self):
+ sess = create_session()
+
+ c1_employees = [e1, e2, b1, m1]
+
+ palias = aliased(Person)
+ self.assertEquals(
+ sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
+ filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(),
+ [
+ (m1, e1),
+ (m1, e2),
+ (m1, b1),
+ ]
+ )
+
+ self.assertEquals(
+ sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
+ filter(Person.person_id>palias.person_id).from_self().order_by(Person.person_id, palias.person_id).all(),
+ [
+ (m1, e1),
+ (m1, e2),
+ (m1, b1),
+ ]
+ )
+
+ def test_nesting_queries(self):
+ sess = create_session()
+
+ # query.statement places a flag "no_adapt" on the returned statement. This prevents
+ # the polymorphic adaptation in the second "filter" from hitting it, which would pollute
+ # the subquery and usually results in recursion overflow errors within the adaption.
+ subq = sess.query(engineers.c.person_id).filter(Engineer.primary_language=='java').statement.as_scalar()
+
+ self.assertEquals(sess.query(Person).filter(Person.person_id==subq).one(), e1)
+
+
+ def test_mixed_entities(self):
+ sess = create_session()
+
+ self.assertEquals(
+ sess.query(Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
+ [(u'Elbonia, Inc.',
+ Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'))]
+ )
+
+ self.assertEquals(
+ sess.query(Person, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
+ [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),
+ u'Elbonia, Inc.')]
+ )
+
+ self.assertEquals(
+ sess.query(Person.name, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
+ [(u'vlad',u'Elbonia, Inc.')]
+ )
+
+ self.assertEquals(
+ sess.query(Engineer.primary_language).filter(Person.type=='engineer').all(),
+ [(u'java',), (u'c++',), (u'cobol',)]
+ )
+
+ if select_type != '':
+ self.assertEquals(
+ sess.query(Engineer, Company.name).join(Company.employees).filter(Person.type=='engineer').all(),
+ [
+ (Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), u'MegaCorp, Inc.'),
+ (Engineer(status=u'regular engineer',engineer_name=u'wally',name=u'wally',company_id=1,primary_language=u'c++',person_id=2,type=u'engineer'), u'MegaCorp, Inc.'),
+ (Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',company_id=2,primary_language=u'cobol',person_id=5,type=u'engineer'), u'Elbonia, Inc.')
+ ]
+ )
+
+ self.assertEquals(
+ sess.query(Engineer.primary_language, Company.name).join(Company.employees).filter(Person.type=='engineer').order_by(desc(Engineer.primary_language)).all(),
+ [(u'java', u'MegaCorp, Inc.'), (u'cobol', u'Elbonia, Inc.'), (u'c++', u'MegaCorp, Inc.')]
+ )
+
+ palias = aliased(Person)
+ self.assertEquals(
+ sess.query(Person, Company.name, palias).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
+ [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),
+ u'Elbonia, Inc.',
+ Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'))]
+ )
+
+ self.assertEquals(
+ sess.query(palias, Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
+ [(Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'),
+ u'Elbonia, Inc.',
+ Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),)
+ ]
+ )
+
+ self.assertEquals(
+ sess.query(Person.name, Company.name, palias.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
+ [(u'vlad', u'Elbonia, Inc.', u'dilbert')]
+ )
+
+ palias = aliased(Person)
+ self.assertEquals(
+ sess.query(Person.type, Person.name, palias.type, palias.name).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
+ filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(),
+ [(u'manager', u'dogbert', u'engineer', u'dilbert'),
+ (u'manager', u'dogbert', u'engineer', u'wally'),
+ (u'manager', u'dogbert', u'boss', u'pointy haired boss')]
+ )
+
+ self.assertEquals(
+ sess.query(Person.name, Paperwork.description).filter(Person.person_id==Paperwork.person_id).order_by(Person.name, Paperwork.description).all(),
+ [(u'dilbert', u'tps report #1'), (u'dilbert', u'tps report #2'), (u'dogbert', u'review #2'),
+ (u'dogbert', u'review #3'),
+ (u'pointy haired boss', u'review #1'),
+ (u'vlad', u'elbonian missive #3'),
+ (u'wally', u'tps report #3'),
+ (u'wally', u'tps report #4'),
+ ]
+ )
+
+ if select_type != '':
+ self.assertEquals(
+ sess.query(func.count(Person.person_id)).filter(Engineer.primary_language=='java').all(),
+ [(1, )]
+ )
+
+ self.assertEquals(
+ sess.query(Company.name, func.count(Person.person_id)).filter(Company.company_id==Person.company_id).group_by(Company.name).order_by(Company.name).all(),
+ [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)]
+ )
+
+ self.assertEquals(
+ sess.query(Company.name, func.count(Person.person_id)).join(Company.employees).group_by(Company.name).order_by(Company.name).all(),
+ [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)]
+ )
+
+
PolymorphicQueryTest.__name__ = "Polymorphic%sTest" % select_type
return PolymorphicQueryTest
@@ -500,11 +685,6 @@ class SelfReferentialTest(ORMTest):
self.assertEquals(sess.query(Engineer).join('reports_to', aliased=True).filter(Person.name=='dogbert').first(), Engineer(name='dilbert'))
- def test_noalias_raises(self):
- sess = create_session()
- def go():
- sess.query(Engineer).join('reports_to')
- self.assertRaises(exceptions.InvalidRequestError, go)
class M2MFilterTest(ORMTest):
keep_mappers = True
@@ -570,6 +750,59 @@ class M2MFilterTest(ORMTest):
sess = create_session()
self.assertEquals(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')])
self.assertEquals(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')])
+
+class SelfReferentialM2MTest(ORMTest, AssertsCompiledSQL):
+ def define_tables(self, metadata):
+ Base = declarative_base(metadata=metadata)
+
+ secondary_table = Table('secondary', Base.metadata,
+ Column('left_id', Integer, ForeignKey('parent.id'), nullable=False),
+ Column('right_id', Integer, ForeignKey('parent.id'), nullable=False))
+
+ global Parent, Child1, Child2
+ class Parent(Base):
+ __tablename__ = 'parent'
+ id = Column(Integer, primary_key=True)
+ cls = Column(String(50))
+ __mapper_args__ = dict(polymorphic_on = cls )
+
+ class Child1(Parent):
+ __tablename__ = 'child1'
+ id = Column(Integer, ForeignKey('parent.id'), primary_key=True)
+ __mapper_args__ = dict(polymorphic_identity = 'child1')
+
+ class Child2(Parent):
+ __tablename__ = 'child2'
+ id = Column(Integer, ForeignKey('parent.id'), primary_key=True)
+ __mapper_args__ = dict(polymorphic_identity = 'child2')
+
+ Child1.left_child2 = relation(Child2, secondary = secondary_table,
+ primaryjoin = Parent.id == secondary_table.c.right_id,
+ secondaryjoin = Parent.id == secondary_table.c.left_id,
+ uselist = False,
+ )
+
+ def test_eager_join(self):
+ session = create_session()
+ c1 = Child1()
+ c1.left_child2 = Child2()
+ session.add(c1)
+ session.flush()
+
+ q = session.query(Child1).options(eagerload('left_child2'))
+
+ # test that the splicing of the join works here, doesnt break in the middle of "parent join child1"
+ self.assert_compile(q.limit(1).with_labels().statement,
+ "SELECT anon_1.child1_id AS anon_1_child1_id, anon_1.parent_id AS anon_1_parent_id, "\
+ "anon_1.parent_cls AS anon_1_parent_cls, anon_2.child2_id AS anon_2_child2_id, anon_2.parent_id AS anon_2_parent_id, "\
+ "anon_2.parent_cls AS anon_2_parent_cls FROM (SELECT child1.id AS child1_id, parent.id AS parent_id, "\
+ "parent.cls AS parent_cls, parent.id AS parent_oid FROM parent JOIN child1 ON parent.id = child1.id ORDER BY parent.id "\
+ "LIMIT 1) AS anon_1 LEFT OUTER JOIN secondary AS secondary_1 ON anon_1.parent_id = secondary_1.right_id LEFT OUTER JOIN "\
+ "(SELECT parent.id AS parent_id, parent.cls AS parent_cls, child2.id AS child2_id FROM parent JOIN child2 ON parent.id = child2.id) "\
+ "AS anon_2 ON anon_2.parent_id = secondary_1.left_id ORDER BY anon_1.child1_id"
+ , dialect=default.DefaultDialect())
+ assert q.first() is c1
+
if __name__ == "__main__":
testenv.main()
diff --git a/test/orm/inheritance/single.py b/test/orm/inheritance/single.py
index 81223cc02..dabb701cd 100644
--- a/test/orm/inheritance/single.py
+++ b/test/orm/inheritance/single.py
@@ -61,6 +61,10 @@ class SingleInheritanceTest(TestBase, AssertsExecutionResults):
assert session.query(Engineer).all() == [e1, e2]
assert session.query(Manager).all() == [m1]
assert session.query(JuniorEngineer).all() == [e2]
-
+
+ m1 = session.query(Manager).one()
+ session.expire(m1, ['manager_data'])
+ self.assertEquals(m1.manager_data, "knows how to manage things")
+
if __name__ == '__main__':
testenv.main()
diff --git a/test/orm/instrumentation.py b/test/orm/instrumentation.py
new file mode 100644
index 000000000..5cb3a5c59
--- /dev/null
+++ b/test/orm/instrumentation.py
@@ -0,0 +1,745 @@
+import testenv; testenv.configure_for_tests()
+from sqlalchemy import MetaData, Table, Column, Integer, ForeignKey
+from sqlalchemy import util
+from sqlalchemy.orm import attributes
+from sqlalchemy.orm import create_session
+from sqlalchemy.orm import interfaces
+from sqlalchemy.orm import mapper
+from sqlalchemy.orm import relation
+
+from testlib.testing import eq_, ne_
+from testlib.compat import _function_named
+from testlib import TestBase
+
+
+def modifies_instrumentation_finders(fn):
+ def decorated(*args, **kw):
+ pristine = attributes.instrumentation_finders[:]
+ try:
+ fn(*args, **kw)
+ finally:
+ del attributes.instrumentation_finders[:]
+ attributes.instrumentation_finders.extend(pristine)
+ return _function_named(decorated, fn.func_name)
+
+def with_lookup_strategy(strategy):
+ def decorate(fn):
+ def wrapped(*args, **kw):
+ current = attributes._lookup_strategy
+ try:
+ attributes._install_lookup_strategy(strategy)
+ return fn(*args, **kw)
+ finally:
+ attributes._install_lookup_strategy(current)
+ return _function_named(wrapped, fn.func_name)
+ return decorate
+
+
+class InitTest(TestBase):
+ def fixture(self):
+ return Table('t', MetaData(),
+ Column('id', Integer, primary_key=True),
+ Column('type', Integer),
+ Column('x', Integer),
+ Column('y', Integer))
+
+ def register(self, cls, canary):
+ original_init = cls.__init__
+ attributes.register_class(cls)
+ ne_(cls.__init__, original_init)
+ manager = attributes.manager_of_class(cls)
+ def on_init(state, instance, args, kwargs):
+ canary.append((cls, 'on_init', type(instance)))
+ manager.events.add_listener('on_init', on_init)
+
+ def test_ai(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+
+ obj = A()
+ eq_(inits, [(A, '__init__')])
+
+ def test_A(self):
+ inits = []
+
+ class A(object): pass
+ self.register(A, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A)])
+
+ def test_Ai(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ def test_ai_B(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+
+ class B(A): pass
+ self.register(B, inits)
+
+ obj = A()
+ eq_(inits, [(A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (A, '__init__')])
+
+ def test_ai_Bi(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+
+ class B(A):
+ def __init__(self):
+ inits.append((B, '__init__'))
+ super(B, self).__init__()
+ self.register(B, inits)
+
+ obj = A()
+ eq_(inits, [(A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')])
+
+ def test_Ai_bi(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A):
+ def __init__(self):
+ inits.append((B, '__init__'))
+ super(B, self).__init__()
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, '__init__'), (A, 'on_init', B), (A, '__init__')])
+
+ def test_Ai_Bi(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A):
+ def __init__(self):
+ inits.append((B, '__init__'))
+ super(B, self).__init__()
+ self.register(B, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')])
+
+ def test_Ai_B(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A): pass
+ self.register(B, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (A, '__init__')])
+
+ def test_Ai_Bi_Ci(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A):
+ def __init__(self):
+ inits.append((B, '__init__'))
+ super(B, self).__init__()
+ self.register(B, inits)
+
+ class C(B):
+ def __init__(self):
+ inits.append((C, '__init__'))
+ super(C, self).__init__()
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (C, '__init__'), (B, '__init__'),
+ (A, '__init__')])
+
+ def test_Ai_bi_Ci(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A):
+ def __init__(self):
+ inits.append((B, '__init__'))
+ super(B, self).__init__()
+
+ class C(B):
+ def __init__(self):
+ inits.append((C, '__init__'))
+ super(C, self).__init__()
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, '__init__'), (A, 'on_init', B), (A, '__init__')])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (C, '__init__'), (B, '__init__'),
+ (A, '__init__')])
+
+ def test_Ai_b_Ci(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A): pass
+
+ class C(B):
+ def __init__(self):
+ inits.append((C, '__init__'))
+ super(C, self).__init__()
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(A, 'on_init', B), (A, '__init__')])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (C, '__init__'), (A, '__init__')])
+
+ def test_Ai_B_Ci(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A): pass
+ self.register(B, inits)
+
+ class C(B):
+ def __init__(self):
+ inits.append((C, '__init__'))
+ super(C, self).__init__()
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (A, '__init__')])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (C, '__init__'), (A, '__init__')])
+
+ def test_Ai_B_C(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A): pass
+ self.register(B, inits)
+
+ class C(B): pass
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (A, '__init__')])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (A, '__init__')])
+
+ def test_A_Bi_C(self):
+ inits = []
+
+ class A(object): pass
+ self.register(A, inits)
+
+ class B(A):
+ def __init__(self):
+ inits.append((B, '__init__'))
+ self.register(B, inits)
+
+ class C(B): pass
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A)])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (B, '__init__')])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (B, '__init__')])
+
+ def test_A_B_Ci(self):
+ inits = []
+
+ class A(object): pass
+ self.register(A, inits)
+
+ class B(A): pass
+ self.register(B, inits)
+
+ class C(B):
+ def __init__(self):
+ inits.append((C, '__init__'))
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A)])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B)])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (C, '__init__')])
+
+ def test_A_B_C(self):
+ inits = []
+
+ class A(object): pass
+ self.register(A, inits)
+
+ class B(A): pass
+ self.register(B, inits)
+
+ class C(B): pass
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A)])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B)])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C)])
+
+
+class MapperInitTest(TestBase):
+
+ def fixture(self):
+ return Table('t', MetaData(),
+ Column('id', Integer, primary_key=True),
+ Column('type', Integer),
+ Column('x', Integer),
+ Column('y', Integer))
+
+ def test_partially_mapped_inheritance(self):
+ class A(object):
+ pass
+
+ class B(A):
+ pass
+
+ class C(B):
+ def __init__(self):
+ pass
+
+ mapper(A, self.fixture())
+
+ a = attributes.instance_state(A())
+ assert isinstance(a, attributes.InstanceState)
+ assert type(a) is not attributes.InstanceState
+
+ b = attributes.instance_state(B())
+ assert isinstance(b, attributes.InstanceState)
+ assert type(b) is not attributes.InstanceState
+
+ # C is unmanaged
+ cobj = C()
+ self.assertRaises((AttributeError, TypeError),
+ attributes.instance_state, cobj)
+
+class InstrumentationCollisionTest(TestBase):
+ def test_none(self):
+ class A(object): pass
+ attributes.register_class(A)
+
+ mgr_factory = lambda cls: attributes.ClassManager(cls)
+ class B(object):
+ __sa_instrumentation_manager__ = staticmethod(mgr_factory)
+ attributes.register_class(B)
+
+ class C(object):
+ __sa_instrumentation_manager__ = attributes.ClassManager
+ attributes.register_class(C)
+
+ def test_single_down(self):
+ class A(object): pass
+ attributes.register_class(A)
+
+ mgr_factory = lambda cls: attributes.ClassManager(cls)
+ class B(A):
+ __sa_instrumentation_manager__ = staticmethod(mgr_factory)
+
+ self.assertRaises(TypeError, attributes.register_class, B)
+
+ def test_single_up(self):
+
+ class A(object): pass
+ # delay registration
+
+ mgr_factory = lambda cls: attributes.ClassManager(cls)
+ class B(A):
+ __sa_instrumentation_manager__ = staticmethod(mgr_factory)
+ attributes.register_class(B)
+ self.assertRaises(TypeError, attributes.register_class, A)
+
+ def test_diamond_b1(self):
+ mgr_factory = lambda cls: attributes.ClassManager(cls)
+
+ class A(object): pass
+ class B1(A): pass
+ class B2(A):
+ __sa_instrumentation_manager__ = mgr_factory
+ class C(object): pass
+
+ self.assertRaises(TypeError, attributes.register_class, B1)
+
+ def test_diamond_b2(self):
+ mgr_factory = lambda cls: attributes.ClassManager(cls)
+
+ class A(object): pass
+ class B1(A): pass
+ class B2(A):
+ __sa_instrumentation_manager__ = mgr_factory
+ class C(object): pass
+
+ self.assertRaises(TypeError, attributes.register_class, B2)
+
+ def test_diamond_c_b(self):
+ mgr_factory = lambda cls: attributes.ClassManager(cls)
+
+ class A(object): pass
+ class B1(A): pass
+ class B2(A):
+ __sa_instrumentation_manager__ = mgr_factory
+ class C(object): pass
+
+ attributes.register_class(C)
+ self.assertRaises(TypeError, attributes.register_class, B1)
+
+
+class OnLoadTest(TestBase):
+ """Check that Events.on_load is not hit in regular attributes operations."""
+
+ def test_basic(self):
+ import pickle
+
+ global A
+ class A(object):
+ pass
+
+ def canary(instance): assert False
+
+ try:
+ attributes.register_class(A)
+ manager = attributes.manager_of_class(A)
+ manager.events.add_listener('on_load', canary)
+
+ a = A()
+ p_a = pickle.dumps(a)
+ re_a = pickle.loads(p_a)
+ finally:
+ del A
+
+
+class ExtendedEventsTest(TestBase):
+ """Allow custom Events implementations."""
+
+ @modifies_instrumentation_finders
+ def test_subclassed(self):
+ class MyEvents(attributes.Events):
+ pass
+ class MyClassManager(attributes.ClassManager):
+ event_registry_factory = MyEvents
+
+ attributes.instrumentation_finders.insert(0, lambda cls: MyClassManager)
+
+ class A(object): pass
+
+ attributes.register_class(A)
+ manager = attributes.manager_of_class(A)
+ assert isinstance(manager.events, MyEvents)
+
+
+class NativeInstrumentationTest(TestBase):
+ @with_lookup_strategy(util.symbol('native'))
+ def test_register_reserved_attribute(self):
+ class T(object): pass
+
+ attributes.register_class(T)
+ manager = attributes.manager_of_class(T)
+
+ sa = attributes.ClassManager.STATE_ATTR
+ ma = attributes.ClassManager.MANAGER_ATTR
+
+ fails = lambda method, attr: self.assertRaises(
+ KeyError, getattr(manager, method), attr, property())
+
+ fails('install_member', sa)
+ fails('install_member', ma)
+ fails('install_descriptor', sa)
+ fails('install_descriptor', ma)
+
+ @with_lookup_strategy(util.symbol('native'))
+ def test_mapped_stateattr(self):
+ t = Table('t', MetaData(),
+ Column('id', Integer, primary_key=True),
+ Column(attributes.ClassManager.STATE_ATTR, Integer))
+
+ class T(object): pass
+
+ self.assertRaises(KeyError, mapper, T, t)
+
+ @with_lookup_strategy(util.symbol('native'))
+ def test_mapped_managerattr(self):
+ t = Table('t', MetaData(),
+ Column('id', Integer, primary_key=True),
+ Column(attributes.ClassManager.MANAGER_ATTR, Integer))
+
+ class T(object): pass
+ self.assertRaises(KeyError, mapper, T, t)
+
+
+class MiscTest(TestBase):
+ """Seems basic, but not directly covered elsewhere!"""
+
+ def test_compileonattr(self):
+ t = Table('t', MetaData(),
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ class A(object): pass
+ mapper(A, t)
+
+ a = A()
+ assert a.id is None
+
+ def test_compileonattr_rel(self):
+ m = MetaData()
+ t1 = Table('t1', m,
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ t2 = Table('t2', m,
+ Column('id', Integer, primary_key=True),
+ Column('t1_id', Integer, ForeignKey('t1.id')))
+ class A(object): pass
+ class B(object): pass
+ mapper(A, t1, properties=dict(bs=relation(B)))
+ mapper(B, t2)
+
+ a = A()
+ assert not a.bs
+
+ def test_compileonattr_rel_backref_a(self):
+ m = MetaData()
+ t1 = Table('t1', m,
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ t2 = Table('t2', m,
+ Column('id', Integer, primary_key=True),
+ Column('t1_id', Integer, ForeignKey('t1.id')))
+
+ class Base(object):
+ def __init__(self, *args, **kwargs):
+ pass
+
+ for base in object, Base:
+ class A(base): pass
+ class B(base): pass
+ mapper(A, t1, properties=dict(bs=relation(B, backref='a')))
+ mapper(B, t2)
+
+ b = B()
+ assert b.a is None
+ a = A()
+ b.a = a
+
+ session = create_session()
+ session.save(b)
+ assert a in session, "base is %s" % base
+
+ def test_compileonattr_rel_backref_b(self):
+ m = MetaData()
+ t1 = Table('t1', m,
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ t2 = Table('t2', m,
+ Column('id', Integer, primary_key=True),
+ Column('t1_id', Integer, ForeignKey('t1.id')))
+
+ class Base(object):
+ def __init__(self): pass
+ class Base_AKW(object):
+ def __init__(self, *args, **kwargs): pass
+
+ for base in object, Base, Base_AKW:
+ class A(base): pass
+ class B(base): pass
+ mapper(A, t1)
+ mapper(B, t2, properties=dict(a=relation(A, backref='bs')))
+
+ a = A()
+ b = B()
+ b.a = a
+
+ session = create_session()
+ session.save(a)
+ assert b in session, 'base: %s' % base
+
+ def test_compileonattr_rel_entity_name(self):
+ m = MetaData()
+ t1 = Table('t1', m,
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ t2 = Table('t2', m,
+ Column('id', Integer, primary_key=True),
+ Column('t1_id', Integer, ForeignKey('t1.id')))
+ class A(object): pass
+ class B(object): pass
+ mapper(A, t1, properties=dict(bs=relation(B)), entity_name='x')
+ mapper(B, t2)
+
+ a = A()
+ assert not a.bs
+
+class FinderTest(TestBase):
+ def test_standard(self):
+ class A(object): pass
+
+ attributes.register_class(A)
+
+ eq_(type(attributes.manager_of_class(A)), attributes.ClassManager)
+
+ def test_nativeext_interfaceexact(self):
+ class A(object):
+ __sa_instrumentation_manager__ = interfaces.InstrumentationManager
+
+ attributes.register_class(A)
+ ne_(type(attributes.manager_of_class(A)), attributes.ClassManager)
+
+ def test_nativeext_submanager(self):
+ class Mine(attributes.ClassManager): pass
+ class A(object):
+ __sa_instrumentation_manager__ = Mine
+
+ attributes.register_class(A)
+ eq_(type(attributes.manager_of_class(A)), Mine)
+
+ @modifies_instrumentation_finders
+ def test_customfinder_greedy(self):
+ class Mine(attributes.ClassManager): pass
+ class A(object): pass
+ def find(cls):
+ return Mine
+
+ attributes.instrumentation_finders.insert(0, find)
+ attributes.register_class(A)
+ eq_(type(attributes.manager_of_class(A)), Mine)
+
+ @modifies_instrumentation_finders
+ def test_customfinder_pass(self):
+ class A(object): pass
+ def find(cls):
+ return None
+
+ attributes.instrumentation_finders.insert(0, find)
+ attributes.register_class(A)
+ eq_(type(attributes.manager_of_class(A)), attributes.ClassManager)
+
+
+if __name__ == "__main__":
+ testenv.main()
diff --git a/test/orm/lazy_relations.py b/test/orm/lazy_relations.py
index 55d79fd32..1dd5d5e94 100644
--- a/test/orm/lazy_relations.py
+++ b/test/orm/lazy_relations.py
@@ -2,12 +2,13 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
from testlib import *
from testlib.fixtures import *
from query import QueryTest
import datetime
+from sqlalchemy.orm import attributes
class LazyTest(FixtureTest):
keep_mappers = False
@@ -21,35 +22,17 @@ class LazyTest(FixtureTest):
q = sess.query(User)
assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(users.c.id == 7).all()
- @testing.uses_deprecated('SessionContext')
- def test_bindstosession(self):
- """test that lazy loaders use the mapper's contextual session if the parent instance
- is not in a session, and that an error is raised if no contextual session"""
-
- from sqlalchemy.ext.sessioncontext import SessionContext
- ctx = SessionContext(create_session)
- m = mapper(User, users, properties = dict(
- addresses = relation(mapper(Address, addresses, extension=ctx.mapper_extension), lazy=True)
- ), extension=ctx.mapper_extension)
- q = ctx.current.query(m)
- u = q.filter(users.c.id == 7).first()
- ctx.current.expunge(u)
- assert User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')]) == u
-
- clear_mappers()
+ def test_needs_parent(self):
+ """test the error raised when parent object is not bound."""
mapper(User, users, properties={
'addresses':relation(mapper(Address, addresses), lazy=True)
})
- try:
- sess = create_session()
- q = sess.query(User)
- u = q.filter(users.c.id == 7).first()
- sess.expunge(u)
- assert User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')]) == u
- assert False
- except exceptions.InvalidRequestError, err:
- assert "not bound to a Session, and no contextual session" in str(err)
+ sess = create_session()
+ q = sess.query(User)
+ u = q.filter(users.c.id == 7).first()
+ sess.expunge(u)
+ self.assertRaises(sa_exc.InvalidRequestError, getattr, u, 'addresses')
def test_orderby(self):
mapper(User, users, properties = {
@@ -127,8 +110,8 @@ class LazyTest(FixtureTest):
sess = create_session()
user = sess.query(User).get(7)
- assert getattr(User, 'addresses').hasparent(user.addresses[0], optimistic=True)
- assert not class_mapper(Address)._is_orphan(user.addresses[0])
+ assert getattr(User, 'addresses').hasparent(attributes.instance_state(user.addresses[0]), optimistic=True)
+ assert not class_mapper(Address)._is_orphan(attributes.instance_state(user.addresses[0]))
def test_limit(self):
@@ -170,7 +153,7 @@ class LazyTest(FixtureTest):
u2 = users.alias('u2')
s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u')
print [key for key in s.c.keys()]
- l = q.filter(s.c.u2_id==User.c.id).distinct().all()
+ l = q.filter(s.c.u2_id==User.id).distinct().all()
assert fixtures.user_all_result == l
def test_one_to_many_scalar(self):
diff --git a/test/orm/manytomany.py b/test/orm/manytomany.py
index ca6410533..e8580af4a 100644
--- a/test/orm/manytomany.py
+++ b/test/orm/manytomany.py
@@ -1,8 +1,8 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
from testlib import *
-from sqlalchemy import exceptions
class Place(object):
'''represents a place'''
@@ -75,14 +75,7 @@ class M2MTest(ORMTest):
mapper(Transition, transition, properties={
'places':relation(Place, secondary=place_input, backref='transitions')
})
- try:
- compile_mappers()
- assert False
- except exceptions.ArgumentError, e:
- assert str(e) in [
- "Error creating backref 'transitions' on relation 'Transition.places (Place)': property of that name exists on mapper 'Mapper|Place|place'",
- "Error creating backref 'places' on relation 'Place.transitions (Transition)': property of that name exists on mapper 'Mapper|Transition|transition'"
- ]
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Error creating backref", compile_mappers)
def testcircular(self):
diff --git a/test/orm/mapper.py b/test/orm/mapper.py
index 7dce09614..017b2534c 100644
--- a/test/orm/mapper.py
+++ b/test/orm/mapper.py
@@ -2,9 +2,8 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc as sa_exc, sql
from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext, SessionContextExt
from testlib import *
from testlib import fixtures
from testlib.tables import *
@@ -32,15 +31,44 @@ class MapperTest(MapperSuperTest):
properties={
'addresses':relation(Address, backref='email_address')
})
- self.assertRaises(exceptions.ArgumentError, compile_mappers)
+ self.assertRaises(sa_exc.ArgumentError, compile_mappers)
def test_prop_accessor(self):
mapper(User, users)
self.assertRaises(NotImplementedError, getattr, class_mapper(User), 'properties')
+ @testing.uses_deprecated(
+ 'Call to deprecated function _instance_key',
+ 'Call to deprecated function _sa_session_id',
+ 'Call to deprecated function _entity_name')
+ def test_legacy_accesors(self):
+ u1 = User()
+ assert not hasattr(u1, '_instance_key')
+ assert not hasattr(u1, '_sa_session_id')
+ assert not hasattr(u1, '_entity_name')
+
+ mapper(User, users)
+ u1 = User()
+ assert not hasattr(u1, '_instance_key')
+ assert not hasattr(u1, '_sa_session_id')
+ assert u1._entity_name is None
+
+ sess = create_session()
+ sess.save(u1)
+ assert not hasattr(u1, '_instance_key')
+ assert u1._sa_session_id == sess.hash_key
+ assert u1._entity_name is None
+
+ sess.flush()
+ assert u1._instance_key == class_mapper(u1).identity_key_from_instance(u1)
+ assert u1._sa_session_id == sess.hash_key
+ assert u1._entity_name is None
+ sess.delete(u1)
+ sess.flush()
+
def test_badcascade(self):
mapper(Address, addresses)
- self.assertRaises(exceptions.ArgumentError, relation, Address, cascade="fake, all, delete-orphan")
+ self.assertRaises(sa_exc.ArgumentError, relation, Address, cascade="fake, all, delete-orphan")
def test_columnprefix(self):
mapper(User, users, column_prefix='_', properties={
@@ -56,26 +84,27 @@ class MapperTest(MapperSuperTest):
def test_no_pks(self):
s = select([users.c.user_name]).alias('foo')
- self.assertRaises(exceptions.ArgumentError, mapper, User, s)
-
+ self.assertRaises(sa_exc.ArgumentError, mapper, User, s)
+
def test_recompile_on_othermapper(self):
- """test the global '_new_mappers' flag such that a compile
+ """test the global '_new_mappers' flag such that a compile
trigger on an already-compiled mapper still triggers a check against all mappers."""
from sqlalchemy.orm import mapperlib
-
+
mapper(User, users)
compile_mappers()
assert mapperlib._new_mappers is False
-
- m = mapper(Address, addresses, properties={'user':relation(User, backref="addresses")})
-
- assert m._Mapper__props_init is False
+
+ m = mapper(Address, addresses, properties={
+ 'user': relation(User, backref="addresses")})
+
+ assert m.compiled is False
assert mapperlib._new_mappers is True
u = User()
assert User.addresses
assert mapperlib._new_mappers is False
-
+
def test_compileonsession(self):
m = mapper(User, users)
session = create_session()
@@ -95,7 +124,7 @@ class MapperTest(MapperSuperTest):
def test_badconstructor(self):
"""test that if the construction of a mapped class fails, the instnace does not get placed in the session"""
class Foo(object):
- def __init__(self, one, two):
+ def __init__(self, one, two, _sa_session=None):
pass
mapper(Foo, users)
sess = create_session()
@@ -103,14 +132,13 @@ class MapperTest(MapperSuperTest):
assert len(list(sess)) == 0
self.assertRaises(TypeError, Foo, 'one')
- @testing.uses_deprecated('SessionContext', 'SessionContextExt')
- def test_constructorexceptions(self):
+ def test_constructorexc(self):
"""test that exceptions raised in the mapped class are not masked by sa decorations"""
ex = AssertionError('oops')
sess = create_session()
class Foo(object):
- def __init__(self):
+ def __init__(self, **kw):
raise ex
mapper(Foo, users)
@@ -121,7 +149,7 @@ class MapperTest(MapperSuperTest):
assert e is ex
clear_mappers()
- mapper(Foo, users, extension=SessionContextExt(SessionContext()))
+ mapper(Foo, users, extension=scoped_session(create_session).extension)
def bad_expunge(foo):
raise Exception("this exception should be stated as a warning")
@@ -130,7 +158,7 @@ class MapperTest(MapperSuperTest):
Foo(_sa_session=sess)
assert False
except Exception, e:
- assert isinstance(e, exceptions.SAWarning)
+ assert isinstance(e, sa_exc.SAWarning), e
clear_mappers()
@@ -172,7 +200,7 @@ class MapperTest(MapperSuperTest):
mapper(User, users, properties = {
'addresses' : relation(mapper(Address, addresses))
})
- assert (User.user_id==3).compare(users.c.user_id==3)
+ self.assertEquals((User.user_id==3).__str__(), (users.c.user_id==3).__str__())
clear_mappers()
@@ -232,7 +260,7 @@ class MapperTest(MapperSuperTest):
m.add_property('uc_user_name2', comparable_property(
UCComparator, User.uc_user_name2))
- sess = create_session(transactional=True)
+ sess = create_session(autocommit=False)
assert sess.query(User).get(7)
u = sess.query(User).filter_by(user_name='jack').one()
@@ -337,14 +365,14 @@ class MapperTest(MapperSuperTest):
'addresses':relation(Address)
}).compile()
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert "Attempting to assign a new relation 'addresses' to a non-primary mapper on class 'User'" in str(e)
def test_illegal_non_primary_2(self):
try:
mapper(User, users, non_primary=True)
assert False
- except exceptions.InvalidRequestError, e:
+ except sa_exc.InvalidRequestError, e:
assert "Configure a primary mapper first" in str(e)
def test_propfilters(self):
@@ -386,7 +414,6 @@ class MapperTest(MapperSuperTest):
def assert_props(cls, want):
have = set([n for n in dir(cls) if not n.startswith('_')])
want = set(want)
- want.add('c')
self.assert_(have == want, repr(have) + " " + repr(want))
assert_props(Person, ['id', 'name', 'type'])
@@ -398,16 +425,6 @@ class MapperTest(MapperSuperTest):
assert_props(Hoho, ['id', 'name', 'type'])
assert_props(Lala, ['p_employee_number', 'p_id', 'p_name', 'p_type'])
- @testing.uses_deprecated('//select_by', '//join_via', '//list')
- def test_recursive_select_by_deprecated(self):
- """test that no endless loop occurs when traversing for select_by"""
- m = mapper(User, users, properties={
- 'orders':relation(mapper(Order, orders), backref='user'),
- 'addresses':relation(mapper(Address, addresses), backref='user'),
- })
- q = create_session().query(m)
- q.select_by(email_address='foo')
-
def test_mappingtojoin(self):
"""test mapping to a join"""
usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
@@ -472,21 +489,6 @@ class MapperTest(MapperSuperTest):
self.assert_result(l, User, user_result[0])
- @testing.uses_deprecated('//select')
- def test_customjoin_deprecated(self):
- """test that the from_obj parameter to query.select() can be used
- to totally replace the FROM parameters of the generated query."""
-
- m = mapper(User, users, properties={
- 'orders':relation(mapper(Order, orders, properties={
- 'items':relation(mapper(Item, orderitems))
- }))
- })
-
- q = create_session().query(m)
- l = q.select((orderitems.c.item_name=='item 4'), from_obj=[users.join(orders).join(orderitems)])
- self.assert_result(l, User, user_result[0])
-
def test_orderby(self):
"""test ordering at the mapper and query level"""
@@ -527,21 +529,14 @@ class MapperTest(MapperSuperTest):
mapper(User, users)
q = create_session().query(User)
self.assert_(q.count()==3)
- self.assert_(q.count(users.c.user_id.in_([8,9]))==2)
-
- @testing.unsupported('firebird')
- @testing.uses_deprecated('//count_by', '//join_by', '//join_via')
- def test_count_by_deprecated(self):
- mapper(User, users)
- q = create_session().query(User)
- self.assert_(q.count_by(user_name='fred')==1)
+ self.assert_(q.filter(users.c.user_id.in_([8,9])).count()==2)
def test_manytomany_count(self):
mapper(Item, orderitems, properties = dict(
keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = True),
))
q = create_session().query(Item)
- assert q.join('keywords').distinct().count(Keyword.c.name=="red") == 2
+ assert q.join('keywords').distinct().filter(Keyword.name=="red").count() == 2
def test_override(self):
# assert that overriding a column raises an error
@@ -550,7 +545,7 @@ class MapperTest(MapperSuperTest):
'user_name' : relation(mapper(Address, addresses)),
}).compile()
self.assert_(False, "should have raised ArgumentError")
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
self.assert_(True)
clear_mappers()
@@ -601,8 +596,8 @@ class MapperTest(MapperSuperTest):
self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1]))
addr = sess.query(Address).filter_by(address_id=user_address_result[0]['addresses'][1][0]['address_id']).one()
- u = sess.query(User).filter_by(adname=addr).one()
- u2 = sess.query(User).filter_by(adlist=addr).one()
+ u = sess.query(User).filter(User.adname.contains(addr)).one()
+ u2 = sess.query(User).filter(User.adlist.contains(addr)).one()
assert u is u2
@@ -641,7 +636,7 @@ class MapperTest(MapperSuperTest):
})
User.not_user_name
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e) == "Can't compile synonym '_user_name': no column on table 'users' named 'not_user_name'"
clear_mappers()
@@ -742,33 +737,6 @@ class OptionsTest(MapperSuperTest):
self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1]))
self.assert_sql_count(testing.db, go, 1)
- @testing.uses_deprecated('//select_by')
- def test_extension_options(self):
- sess = create_session()
- class ext1(MapperExtension):
- def populate_instance(self, mapper, selectcontext, row, instance, **flags):
- """test options at the Mapper._instance level"""
- instance.TEST = "hello world"
- return EXT_CONTINUE
- mapper(User, users, extension=ext1(), properties={
- 'addresses':relation(mapper(Address, addresses), lazy=False)
- })
- class testext(MapperExtension):
- def select_by(self, *args, **kwargs):
- """test options at the Query level"""
- return "HI"
- def populate_instance(self, mapper, selectcontext, row, instance, **flags):
- """test options at the Mapper._instance level"""
- instance.TEST_2 = "also hello world"
- return EXT_CONTINUE
- l = sess.query(User).options(extension(testext())).select_by(x=5)
- assert l == "HI"
- l = sess.query(User).options(extension(testext())).get(7)
- assert l.user_id == 7
- assert l.TEST == "hello world"
- assert l.TEST_2 == "also hello world"
- assert not hasattr(l.addresses[0], 'TEST')
- assert not hasattr(l.addresses[0], 'TEST2')
def test_eageroptions(self):
"""tests that a lazy relation can be upgraded to an eager relation via the options method"""
@@ -927,9 +895,9 @@ class OptionsTest(MapperSuperTest):
sess.clear()
- self.assertRaisesMessage(exceptions.ArgumentError,
- r"Can't find entity Mapper\|Order\|orders in Query. Current list: \['Mapper\|User\|users'\]",
- sess.query(User).options, eagerload('items', Order)
+ self.assertRaisesMessage(sa_exc.ArgumentError,
+ r"Can't find entity Mapper\|Order\|orders in Query. Current list: \['Mapper\|User\|users'\]",
+ sess.query(User).options, eagerload(Order.items)
)
# eagerload "keywords" on items. it will lazy load "orders", then lazy load
@@ -1333,11 +1301,29 @@ class MapperExtensionTest(TestBase):
def setUpAll(self):
tables.create()
- global methods, Ext
+ def tearDown(self):
+ clear_mappers()
+ tables.delete()
+ def tearDownAll(self):
+ tables.drop()
+
+ def extension(self):
methods = []
class Ext(MapperExtension):
+ def instrument_class(self, mapper, cls):
+ methods.append('instrument_class')
+ return EXT_CONTINUE
+
+ def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
+ methods.append('init_instance')
+ return EXT_CONTINUE
+
+ def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
+ methods.append('init_failed')
+ return EXT_CONTINUE
+
def load(self, query, *args, **kwargs):
methods.append('load')
return EXT_CONTINUE
@@ -1386,16 +1372,12 @@ class MapperExtensionTest(TestBase):
methods.append('after_delete')
return EXT_CONTINUE
- def tearDown(self):
- clear_mappers()
- methods[:] = []
- tables.delete()
-
- def tearDownAll(self):
- tables.drop()
+ return Ext, methods
def test_basic(self):
"""test that common user-defined methods get called."""
+ Ext, methods = self.extension()
+
mapper(User, users, extension=Ext())
sess = create_session()
u = User()
@@ -1408,13 +1390,17 @@ class MapperExtensionTest(TestBase):
sess.flush()
sess.delete(u)
sess.flush()
- self.assertEquals(methods,
- ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'append_result', 'get', 'translate_row',
- 'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete']
- )
+ self.assertEquals(methods,
+ ['instrument_class', 'init_instance', 'before_insert',
+ 'after_insert', 'load', 'translate_row', 'populate_instance',
+ 'append_result', 'get', 'translate_row', 'create_instance',
+ 'populate_instance', 'append_result', 'before_update',
+ 'after_update', 'before_delete', 'after_delete'])
+
def test_inheritance(self):
- # test using inheritance
+ Ext, methods = self.extension()
+
class AdminUser(User):
pass
@@ -1432,13 +1418,18 @@ class MapperExtensionTest(TestBase):
sess.flush()
sess.delete(am)
sess.flush()
- self.assertEquals(methods,
- ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'append_result', 'get',
- 'translate_row', 'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete'])
+ self.assertEquals(methods,
+ ['instrument_class', 'instrument_class', 'init_instance',
+ 'before_insert', 'after_insert', 'load', 'translate_row',
+ 'populate_instance', 'append_result', 'get', 'translate_row',
+ 'create_instance', 'populate_instance', 'append_result',
+ 'before_update', 'after_update', 'before_delete', 'after_delete'])
def test_after_with_no_changes(self):
# test that after_update is called even if no cols were updated
+ Ext, methods = self.extension()
+
mapper(Item, orderitems, extension=Ext() , properties={
'keywords':relation(Keyword, secondary=itemkeywords)
})
@@ -1450,15 +1441,20 @@ class MapperExtensionTest(TestBase):
sess.save(i1)
sess.save(k1)
sess.flush()
- self.assertEquals(methods, ['before_insert', 'after_insert', 'before_insert', 'after_insert'])
+ self.assertEquals(methods,
+ ['instrument_class', 'instrument_class', 'init_instance',
+ 'init_instance', 'before_insert', 'after_insert',
+ 'before_insert', 'after_insert'])
- methods[:] = []
+ del methods[:]
i1.keywords.append(k1)
sess.flush()
self.assertEquals(methods, ['before_update', 'after_update'])
def test_inheritance_with_dupes(self):
+ Ext, methods = self.extension()
+
# test using inheritance, same extension on both mappers
class AdminUser(User):
pass
@@ -1478,10 +1474,49 @@ class MapperExtensionTest(TestBase):
sess.flush()
sess.delete(am)
sess.flush()
- self.assertEquals(methods,
- ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'append_result', 'get', 'translate_row',
- 'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete']
- )
+ self.assertEquals(methods,
+ ['instrument_class', 'instrument_class', 'init_instance',
+ 'before_insert', 'after_insert', 'load', 'translate_row',
+ 'populate_instance', 'append_result', 'get', 'translate_row',
+ 'create_instance', 'populate_instance', 'append_result',
+ 'before_update', 'after_update', 'before_delete', 'after_delete'])
+
+ def test_single_instrumentor(self):
+ ext_None, methods_None = self.extension()
+ ext_x, methods_x = self.extension()
+
+ def reset():
+ clear_mappers()
+ del methods_None[:]
+ del methods_x[:]
+
+ mapper(User, users, extension=ext_None())
+ mapper(User, users, extension=ext_x(), entity_name='x')
+ User()
+
+ self.assertEquals(methods_None, ['instrument_class', 'init_instance'])
+ self.assertEquals(methods_x, [])
+
+ reset()
+
+ mapper(User, users, extension=ext_x(), entity_name='x')
+ mapper(User, users, extension=ext_None())
+ User()
+
+ self.assertEquals(methods_x, ['instrument_class', 'init_instance'])
+ self.assertEquals(methods_None, [])
+
+ reset()
+
+ ext_y, methods_y = self.extension()
+
+ mapper(User, users, extension=ext_x(), entity_name='x')
+ mapper(User, users, extension=ext_y(), entity_name='y')
+ User()
+
+ self.assertEquals(methods_x, ['instrument_class', 'init_instance'])
+ self.assertEquals(methods_y, [])
+
class RequirementsTest(ORMTest):
"""Tests the contract for user classes."""
@@ -1519,13 +1554,13 @@ class RequirementsTest(ORMTest):
class OldStyle:
pass
- self.assertRaises(exceptions.ArgumentError, mapper, OldStyle, t1)
+ self.assertRaises(sa_exc.ArgumentError, mapper, OldStyle, t1)
class NoWeakrefSupport(str):
pass
# TODO: is weakref support detectable without an instance?
- #self.assertRaises(exceptions.ArgumentError, mapper, NoWeakrefSupport, t2)
+ #self.assertRaises(sa_exc.ArgumentError, mapper, NoWeakrefSupport, t2)
def test_comparison_overrides(self):
"""Simple tests to ensure users can supply comparison __methods__.
@@ -1584,7 +1619,6 @@ class RequirementsTest(ORMTest):
return self.value == other.value
return False
-
mapper(H1, t1, properties={
'h2s': relation(H2, backref='h1'),
'h3s': relation(H3, secondary=t4, backref='h1s'),
@@ -1654,6 +1688,92 @@ class NoEqFoo(object):
def __ne__(self, other):
raise NotImplementedError()
+class MagicNamesTest(ORMTest):
+
+ def define_tables(self, metadata):
+ Table('cartographers', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(50)),
+ Column('alias', String(50)),
+ Column('quip', String(100)))
+ Table('maps', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('cart_id', Integer,
+ ForeignKey('cartographers.id')),
+ Column('state', String(2)),
+ Column('data', Text))
+
+ def tables(self):
+ cat = testing._otest_metadata.tables
+ return cat['cartographers'], cat['maps']
+
+ def classes(self):
+ class Base(object):
+ def __init__(self, **kw):
+ for key, value in kw.iteritems():
+ setattr(self, key, value)
+ class Cartographer(Base): pass
+ class Map(Base): pass
+
+ return Cartographer, Map
+
+ @testing.future
+ def test_mappish(self):
+ t1, t2 = self.tables()
+ Cartographer, Map = self.classes()
+ mapper(Cartographer, t1, properties=dict(
+ query=t1.c.quip))
+ mapper(Map, t2, properties=dict(
+ mapper=relation(Cartographer, backref='maps')))
+
+ c = Cartographer(name='Lenny', alias='The Dude',
+ query='Where be dragons?')
+ m = Map(state='AK', mapper=c)
+
+ sess = create_session()
+ sess.save(c)
+ sess.flush()
+ sess.clear()
+
+ for C, M in ((Cartographer, Map), (aliased(Cartographer), aliased(Map))):
+ print C, M
+ c1 = (sess.query(C).
+ filter(C.alias=='The Dude').
+ filter(C.query=='Where be dragons?')).one()
+ m1 = sess.query(M).filter(M.mapper==c1).one()
+
+ @testing.future
+ def test_stateish(self):
+ from sqlalchemy.orm import attributes
+ if hasattr(attributes, 'ClassManager'):
+ syn1 = attributes.ClassManager.STATE_ATTR
+ syn2 = attributes.ClassManager.MANAGER_ATTR
+ else:
+ syn1 = '_state'
+ syn2 = '_class_state'
+
+
+ t1, t2 = self.tables()
+ Cartographer, Map = self.classes()
+ mapper(Map, t2, properties=dict(
+ syn1=t2.c.state,
+ syn2=t2.c.data))
+
+ m = Map()
+ setattr(m, syn1, 'AK')
+ setattr(m, syn2, '10x10')
+
+ sess = create_session()
+ sess.save(m)
+ sess.flush()
+ sess.clear()
+
+ for M in (Map, aliased(Map)):
+ print M
+ sess.query(M).filter(getattr(M, syn1) == 'AK').one()
+ sess.query(M).filter(getattr(M, syn2) == '10x10').one()
+
+
class ScalarRequirementsTest(ORMTest):
def define_tables(self, metadata):
import pickle
@@ -1661,14 +1781,14 @@ class ScalarRequirementsTest(ORMTest):
t1 = Table('t1', metadata, Column('id', Integer, primary_key=True),
Column('data', PickleType(pickler=pickle)) # dont use cPickle due to import weirdness
)
-
+
def test_correct_comparison(self):
-
+
class H1(fixtures.Base):
pass
-
+
mapper(H1, t1)
-
+
h1 = H1(data=NoEqFoo('12345'))
s = create_session()
s.save(h1)
@@ -1676,7 +1796,7 @@ class ScalarRequirementsTest(ORMTest):
s.clear()
h1 = s.get(H1, h1.id)
assert h1.data.data == '12345'
-
+
if __name__ == "__main__":
testenv.main()
diff --git a/test/orm/merge.py b/test/orm/merge.py
index fd61ccc28..6ca42d53d 100644
--- a/test/orm/merge.py
+++ b/test/orm/merge.py
@@ -1,8 +1,8 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
-from sqlalchemy.orm import mapperlib
+from sqlalchemy.orm import mapperlib, attributes
from sqlalchemy.util import OrderedSet
from testlib import *
from testlib import fixtures
@@ -21,20 +21,34 @@ class MergeTest(TestBase, AssertsExecutionResults):
clear_mappers()
tables.delete()
+ def on_load_tracker(self, cls, canary=None):
+ if canary is None:
+ def canary(instance):
+ canary.called += 1
+ canary.called = 0
+
+ manager = attributes.manager_of_class(cls)
+ manager.events.add_listener('on_load', canary)
+
+ return canary
+
def test_transient_to_pending(self):
class User(fixtures.Base):
pass
mapper(User, users)
sess = create_session()
+ on_load = self.on_load_tracker(User)
u = User(user_id=7, user_name='fred')
+ assert on_load.called == 0
u2 = sess.merge(u)
+ assert on_load.called == 1
assert u2 in sess
self.assertEquals(u2, User(user_id=7, user_name='fred'))
sess.flush()
sess.clear()
self.assertEquals(sess.query(User).first(), User(user_id=7, user_name='fred'))
-
+
def test_transient_to_pending_collection(self):
class User(fixtures.Base):
pass
@@ -42,47 +56,72 @@ class MergeTest(TestBase, AssertsExecutionResults):
pass
mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
mapper(Address, addresses)
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
u = User(user_id=7, user_name='fred', addresses=OrderedSet([
Address(address_id=1, email_address='fred1'),
Address(address_id=2, email_address='fred2'),
- ]))
+ ]))
+ assert on_load.called == 0
+
sess = create_session()
sess.merge(u)
+ assert on_load.called == 3
+
+ merged_users = [e for e in sess if isinstance(e, User)]
+ assert len(merged_users) == 1
+ assert merged_users[0] is not u
+
sess.flush()
sess.clear()
- self.assertEquals(sess.query(User).one(),
+ self.assertEquals(sess.query(User).one(),
User(user_id=7, user_name='fred', addresses=OrderedSet([
Address(address_id=1, email_address='fred1'),
Address(address_id=2, email_address='fred2'),
]))
)
-
+
def test_transient_to_persistent(self):
class User(fixtures.Base):
pass
mapper(User, users)
+ on_load = self.on_load_tracker(User)
+
sess = create_session()
u = User(user_id=7, user_name='fred')
sess.save(u)
sess.flush()
sess.clear()
-
- u2 = User(user_id=7, user_name='fred jones')
+
+ assert on_load.called == 0
+
+ _u2 = u2 = User(user_id=7, user_name='fred jones')
+ assert on_load.called == 0
u2 = sess.merge(u2)
+ assert u2 is not _u2
+ assert on_load.called == 1
sess.flush()
sess.clear()
self.assertEquals(sess.query(User).first(), User(user_id=7, user_name='fred jones'))
-
+ assert on_load.called == 2
+
def test_transient_to_persistent_collection(self):
class User(fixtures.Base):
pass
class Address(fixtures.Base):
pass
- mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
+ mapper(User, users, properties={
+ 'addresses':relation(Address,
+ backref='user',
+ collection_class=OrderedSet, cascade="all, delete-orphan")
+ })
mapper(Address, addresses)
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
+
u = User(user_id=7, user_name='fred', addresses=OrderedSet([
Address(address_id=1, email_address='fred1'),
Address(address_id=2, email_address='fred2'),
@@ -91,14 +130,21 @@ class MergeTest(TestBase, AssertsExecutionResults):
sess.save(u)
sess.flush()
sess.clear()
-
+
+ assert on_load.called == 0
+
u = User(user_id=7, user_name='fred', addresses=OrderedSet([
Address(address_id=3, email_address='fred3'),
Address(address_id=4, email_address='fred4'),
]))
-
+
u = sess.merge(u)
- self.assertEquals(u,
+
+ assert on_load.called == 5, on_load.called # 1. merges User object. updates into session.
+ # 2.,3. merges Address ids 3 & 4, saves into session.
+ # 4.,5. loads pre-existing elements in "addresses" collection,
+ # marks as deleted, Address ids 1 and 2.
+ self.assertEquals(u,
User(user_id=7, user_name='fred', addresses=OrderedSet([
Address(address_id=3, email_address='fred3'),
Address(address_id=4, email_address='fred4'),
@@ -106,13 +152,13 @@ class MergeTest(TestBase, AssertsExecutionResults):
)
sess.flush()
sess.clear()
- self.assertEquals(sess.query(User).one(),
+ self.assertEquals(sess.query(User).one(),
User(user_id=7, user_name='fred', addresses=OrderedSet([
Address(address_id=3, email_address='fred3'),
Address(address_id=4, email_address='fred4'),
]))
)
-
+
def test_detached_to_persistent_collection(self):
class User(fixtures.Base):
pass
@@ -120,7 +166,9 @@ class MergeTest(TestBase, AssertsExecutionResults):
pass
mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
mapper(Address, addresses)
-
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
+
a = Address(address_id=1, email_address='fred1')
u = User(user_id=7, user_name='fred', addresses=OrderedSet([
a,
@@ -130,34 +178,39 @@ class MergeTest(TestBase, AssertsExecutionResults):
sess.save(u)
sess.flush()
sess.clear()
-
+
u.user_name='fred jones'
u.addresses.add(Address(address_id=3, email_address='fred3'))
u.addresses.remove(a)
-
+
+ assert on_load.called == 0
u = sess.merge(u)
+ assert on_load.called == 4
sess.flush()
sess.clear()
-
- self.assertEquals(sess.query(User).first(),
+
+ self.assertEquals(sess.query(User).first(),
User(user_id=7, user_name='fred jones', addresses=OrderedSet([
Address(address_id=2, email_address='fred2'),
Address(address_id=3, email_address='fred3'),
]))
)
-
+
def test_unsaved_cascade(self):
"""test merge of a transient entity with two child transient entities, with a bidirectional relation."""
-
+
class User(fixtures.Base):
pass
class Address(fixtures.Base):
pass
-
+
mapper(User, users, properties={
'addresses':relation(mapper(Address, addresses), cascade="all", backref="user")
})
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
sess = create_session()
+
u = User(user_id=7, user_name='fred')
a1 = Address(email_address='foo@bar.com')
a2 = Address(email_address='hoho@bar.com')
@@ -165,12 +218,16 @@ class MergeTest(TestBase, AssertsExecutionResults):
u.addresses.append(a2)
u2 = sess.merge(u)
+ assert on_load.called == 3
+
self.assertEquals(u, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')]))
self.assertEquals(u2, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')]))
sess.flush()
sess.clear()
u2 = sess.query(User).get(7)
self.assertEquals(u2, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')]))
+ assert on_load.called == 6
+
def test_attribute_cascade(self):
"""test merge of a persistent entity with two child persistent entities."""
@@ -183,6 +240,9 @@ class MergeTest(TestBase, AssertsExecutionResults):
mapper(User, users, properties={
'addresses':relation(mapper(Address, addresses), backref='user')
})
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
+
sess = create_session()
# set up data and save
@@ -202,9 +262,12 @@ class MergeTest(TestBase, AssertsExecutionResults):
u.user_name = 'fred2'
u.addresses[1].email_address = 'hoho@lalala.com'
+ assert on_load.called == 3
+
# new session, merge modified data into session
sess3 = create_session()
u3 = sess3.merge(u)
+ assert on_load.called == 6
# ensure local changes are pending
self.assertEquals(u3, User(user_id=7, user_name='fred2', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@lalala.com')]))
@@ -216,6 +279,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
sess.clear()
u = sess.query(User).get(7)
self.assertEquals(u, User(user_id=7, user_name='fred2', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@lalala.com')]))
+ assert on_load.called == 9
# merge persistent object into another session
sess4 = create_session()
@@ -227,6 +291,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
sess4.flush()
# no changes; therefore flush should do nothing
self.assert_sql_count(testing.db, go, 0)
+ assert on_load.called == 12
# test with "dontload" merge
sess5 = create_session()
@@ -240,6 +305,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
# but also, dont_load wipes out any difference in committed state,
# so no flush at all
self.assert_sql_count(testing.db, go, 0)
+ assert on_load.called == 15
sess4 = create_session()
u = sess4.merge(u, dont_load=True)
@@ -249,11 +315,13 @@ class MergeTest(TestBase, AssertsExecutionResults):
sess4.flush()
# afafds change flushes
self.assert_sql_count(testing.db, go, 1)
+ assert on_load.called == 18
sess5 = create_session()
u2 = sess5.query(User).get(u.user_id)
assert u2.user_name == 'fred2'
assert u2.addresses[1].email_address == 'afafds'
+ assert on_load.called == 21
def test_one_to_many_cascade(self):
@@ -265,6 +333,9 @@ class MergeTest(TestBase, AssertsExecutionResults):
'addresses':relation(mapper(Address, addresses)),
'orders':relation(Order, backref='customer')
})
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
+ self.on_load_tracker(Order, on_load)
sess = create_session()
u = User()
@@ -282,16 +353,24 @@ class MergeTest(TestBase, AssertsExecutionResults):
sess.save(u)
sess.flush()
+ assert on_load.called == 0
+
sess2 = create_session()
u2 = sess2.query(User).get(u.user_id)
+ assert on_load.called == 1
+
u.orders[0].items[1].item_name = 'item 2 modified'
sess2.merge(u)
assert u2.orders[0].items[1].item_name == 'item 2 modified'
+ assert on_load.called == 2
+
+ sess3 = create_session()
+ o2 = sess3.query(Order).get(o.order_id)
+ assert on_load.called == 3
- sess2 = create_session()
- o2 = sess2.query(Order).get(o.order_id)
o.customer.user_name = 'also fred'
- sess2.merge(o)
+ sess3.merge(o)
+ assert on_load.called == 4
assert o2.customer.user_name == 'also fred'
def test_one_to_one_cascade(self):
@@ -299,7 +378,10 @@ class MergeTest(TestBase, AssertsExecutionResults):
mapper(User, users, properties={
'address':relation(mapper(Address, addresses),uselist = False)
})
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
sess = create_session()
+
u = User()
u.user_id = 7
u.user_name = "fred"
@@ -310,19 +392,25 @@ class MergeTest(TestBase, AssertsExecutionResults):
sess.save(u)
sess.flush()
+ assert on_load.called == 0
+
sess2 = create_session()
u2 = sess2.query(User).get(7)
+ assert on_load.called == 1
u2.user_name = 'fred2'
u2.address.email_address = 'hoho@lalala.com'
+ assert on_load.called == 2
u3 = sess.merge(u2)
-
+ assert on_load.called == 2
+ assert u3 is u
+
def test_transient_dontload(self):
mapper(User, users)
sess = create_session()
u = User()
- self.assertRaisesMessage(exceptions.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
def test_dontload_with_backrefs(self):
@@ -407,7 +495,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
try:
sess2.merge(u, dont_load=True)
assert False
- except exceptions.InvalidRequestError, e:
+ except sa_exc.InvalidRequestError, e:
assert "merge() with dont_load=True option does not support objects marked as 'dirty'. flush() all changes on mapped instances before merging with dont_load=True." in str(e)
u2 = sess2.query(User).get(7)
@@ -443,7 +531,8 @@ class MergeTest(TestBase, AssertsExecutionResults):
u2 = sess2.merge(u, dont_load=True)
assert not sess2.dirty
# assert merged instance has a mapper and lazy load proceeds
- assert hasattr(u2, '_entity_name')
+ state = attributes.instance_state(u2)
+ assert state.entity_name is not attributes.NO_ENTITY_NAME
assert mapperlib.has_mapper(u2)
def go():
assert u2.addresses != []
@@ -505,7 +594,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
assert not sess2.dirty
a2 = u2.addresses[0]
a2.email_address='somenewaddress'
- assert not object_mapper(a2)._is_orphan(a2)
+ assert not object_mapper(a2)._is_orphan(attributes.instance_state(a2))
sess2.flush()
sess2.clear()
assert sess2.query(User).get(u2.user_id).addresses[0].email_address == 'somenewaddress'
@@ -526,11 +615,11 @@ class MergeTest(TestBase, AssertsExecutionResults):
# if dont_load is changed to support dirty objects, this code needs to pass
a2 = u2.addresses[0]
a2.email_address='somenewaddress'
- assert not object_mapper(a2)._is_orphan(a2)
+ assert not object_mapper(a2)._is_orphan(attributes.instance_state(a2))
sess2.flush()
sess2.clear()
assert sess2.query(User).get(u2.user_id).addresses[0].email_address == 'somenewaddress'
- except exceptions.InvalidRequestError, e:
+ except sa_exc.InvalidRequestError, e:
assert "dont_load=True option does not support" in str(e)
diff --git a/test/orm/naturalpks.py b/test/orm/naturalpks.py
index ec7d2fca9..67cf5e9ad 100644
--- a/test/orm/naturalpks.py
+++ b/test/orm/naturalpks.py
@@ -1,8 +1,7 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exceptions
-
+from sqlalchemy.orm import attributes, exc as orm_exc
from testlib.fixtures import *
from testlib import *
@@ -62,17 +61,13 @@ class NaturalPKTest(ORMTest):
sess.flush()
assert sess.get(User, 'jack') is u1
- users.update(values={u1.c.username:'jack'}).execute(username='ed')
+ users.update(values={User.username:'jack'}).execute(username='ed')
- try:
- # expire/refresh works off of primary key. the PK is gone
- # in this case so theres no way to look it up. criterion-
- # based session invalidation could solve this [ticket:911]
- sess.expire(u1)
- u1.username
- assert False
- except exceptions.InvalidRequestError, e:
- assert "Could not refresh instance" in str(e)
+ # expire/refresh works off of primary key. the PK is gone
+ # in this case so theres no way to look it up. criterion-
+ # based session invalidation could solve this [ticket:911]
+ sess.expire(u1)
+ self.assertRaises(orm_exc.ObjectDeletedError, getattr, u1, 'username')
sess.clear()
assert sess.get(User, 'jack') is None
@@ -154,7 +149,7 @@ class NaturalPKTest(ORMTest):
u1.username = 'ed'
print id(a1), id(a2), id(u1)
- print u1._state.parents
+ print attributes.instance_state(u1).parents
def go():
sess.flush()
if passive_updates:
diff --git a/test/orm/onetoone.py b/test/orm/onetoone.py
index ae0d6ef86..eb425c577 100644
--- a/test/orm/onetoone.py
+++ b/test/orm/onetoone.py
@@ -1,7 +1,6 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
from testlib import *
class Jack(object):
@@ -29,7 +28,7 @@ class O2OTest(TestBase, AssertsExecutionResults):
def setUpAll(self):
global jack, port, metadata, ctx
metadata = MetaData(testing.db)
- ctx = SessionContext(create_session)
+ ctx = scoped_session(create_session)
jack = Table('jack', metadata,
Column('id', Integer, primary_key=True),
#Column('room_id', Integer, ForeignKey("room.id")),
@@ -54,22 +53,21 @@ class O2OTest(TestBase, AssertsExecutionResults):
def tearDownAll(self):
metadata.drop_all()
- @testing.uses_deprecated('SessionContext')
def test1(self):
- mapper(Port, port, extension=ctx.mapper_extension)
+ mapper(Port, port, extension=ctx.extension)
mapper(Jack, jack, order_by=[jack.c.number],properties = {
'port': relation(Port, backref='jack', uselist=False, lazy=True),
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
j=Jack(number='101')
p=Port(name='fa0/1')
j.port=p
- ctx.current.flush()
+ ctx.flush()
jid = j.id
pid = p.id
- j=ctx.current.query(Jack).get(jid)
- p=ctx.current.query(Port).get(pid)
+ j=ctx.query(Jack).get(jid)
+ p=ctx.query(Port).get(pid)
print p.jack
assert p.jack is not None
assert p.jack is j
@@ -77,17 +75,17 @@ class O2OTest(TestBase, AssertsExecutionResults):
p.jack=None
assert j.port is None #works
- ctx.current.clear()
+ ctx.clear()
- j=ctx.current.query(Jack).get(jid)
- p=ctx.current.query(Port).get(pid)
+ j=ctx.query(Jack).get(jid)
+ p=ctx.query(Port).get(pid)
j.port=None
self.assert_(p.jack is None)
- ctx.current.flush()
+ ctx.flush()
- ctx.current.delete(j)
- ctx.current.flush()
+ ctx.delete(j)
+ ctx.flush()
if __name__ == "__main__":
testenv.main()
diff --git a/test/orm/pickled.py b/test/orm/pickled.py
index 84f5e5daf..6bb455d41 100644
--- a/test/orm/pickled.py
+++ b/test/orm/pickled.py
@@ -1,6 +1,5 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
from sqlalchemy.orm import *
from testlib import *
from testlib.fixtures import *
@@ -113,7 +112,7 @@ class PolymorphicDeferredTest(ORMTest):
)
def test_polymorphic_deferred(self):
- mapper(User, users, polymorphic_identity='user', polymorphic_on=users.c.type, polymorphic_fetch='deferred')
+ mapper(User, users, polymorphic_identity='user', polymorphic_on=users.c.type)
mapper(EmailUser, email_users, inherits=User, polymorphic_identity='emailuser')
eu = EmailUser(name="user1", email_address='foo@bar.com')
diff --git a/test/orm/query.py b/test/orm/query.py
index f1afdb90b..bc67740f2 100644
--- a/test/orm/query.py
+++ b/test/orm/query.py
@@ -1,7 +1,7 @@
import testenv; testenv.configure_for_tests()
import operator
from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import exc as sa_exc, util
from sqlalchemy.sql import compiler
from sqlalchemy.engine import default
from sqlalchemy.orm import *
@@ -10,12 +10,13 @@ from testlib import *
from testlib import engines
from testlib.fixtures import *
-from sqlalchemy.orm.util import _join as join, _outerjoin as outerjoin
+from sqlalchemy.orm.util import join, outerjoin, with_parent
class QueryTest(FixtureTest):
keep_mappers = True
keep_data = True
+
def setup_mappers(self):
mapper(User, users, properties={
'addresses':relation(Address, backref='user'),
@@ -68,11 +69,8 @@ class GetTest(QueryTest):
s = create_session()
- try:
- s.query(User).join('addresses').filter(Address.user_id==8).get(7)
- assert False
- except exceptions.SAWarning, e:
- assert str(e) == "Query.get() being called on a Query with existing criterion; criterion is being ignored."
+ q = s.query(User).join('addresses').filter(Address.user_id==8)
+ self.assertRaises(sa_exc.SAWarning, q.get, 7)
@testing.emits_warning('Query.*')
def warns():
@@ -119,7 +117,7 @@ class GetTest(QueryTest):
try:
assert s.query(User).load(19) is None
assert False
- except exceptions.InvalidRequestError:
+ except sa_exc.InvalidRequestError:
assert True
u = s.query(User).load(7)
@@ -193,6 +191,29 @@ class GetTest(QueryTest):
assert u.addresses[0].email_address == 'jack@bean.com'
assert u.orders[1].items[2].description == 'item 5'
+class InvalidGenerationsTest(QueryTest):
+ def test_no_limit_offset(self):
+ s = create_session()
+
+ q = s.query(User).limit(2)
+ self.assertRaises(sa_exc.SAWarning, q.join, "addresses")
+
+ self.assertRaises(sa_exc.SAWarning, q.filter, User.name=='ed')
+
+ self.assertRaises(sa_exc.SAWarning, q.filter_by, name='ed')
+
+ def test_no_from(self):
+ s = create_session()
+
+ q = s.query(User).select_from(users)
+ self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users)
+
+ q = s.query(User).join('addresses')
+ self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users)
+
+ # this is fine, however
+ q.from_self()
+
class OperatorTest(QueryTest):
"""test sql.Comparator implementation for MapperProperties"""
@@ -268,8 +289,40 @@ class OperatorTest(QueryTest):
c = expr.compile(dialect=default.DefaultDialect())
assert str(c) == compare, "%s != %s" % (str(c), compare)
+class RawSelectTest(QueryTest, AssertsCompiledSQL):
+ """compare a bunch of select() tests with the equivalent Query using straight table/columns.
+
+ Results should be the same as Query should act as a select() pass-thru for ClauseElement entities.
+
+ """
+ def test_select(self):
+ sess = create_session()
+
+ self.assert_compile(sess.query(users).select_from(users.select()).with_labels().statement,
+ "SELECT users.id AS users_id, users.name AS users_name FROM users, (SELECT users.id AS id, users.name AS name FROM users) AS anon_1")
+
+ self.assert_compile(sess.query(users, exists([1], from_obj=addresses)).with_labels().statement,
+ "SELECT users.id AS users_id, users.name AS users_name, EXISTS (SELECT 1 FROM addresses) AS anon_1 FROM users")
+ # a little tedious here, adding labels to work around Query's auto-labelling.
+ # also correlate needed explicitly. hmmm.....
+ # TODO: can we detect only one table in the "froms" and then turn off use_labels ?
+ s = sess.query(addresses.c.id.label('id'), addresses.c.email_address.label('email')).\
+ filter(addresses.c.user_id==users.c.id).correlate(users).statement.alias()
+
+ self.assert_compile(sess.query(users, s.c.email).select_from(users.join(s, s.c.id==users.c.id)).with_labels().statement,
+ "SELECT users.id AS users_id, users.name AS users_name, anon_1.email AS anon_1_email "
+ "FROM users JOIN (SELECT addresses.id AS id, addresses.email_address AS email FROM addresses "
+ "WHERE addresses.user_id = users.id) AS anon_1 ON anon_1.id = users.id",
+ dialect=default.DefaultDialect()
+ )
+
+ x = func.lala(users.c.id).label('foo')
+ self.assert_compile(sess.query(x).filter(x==5).statement,
+ "SELECT lala(users.id) AS foo FROM users WHERE lala(users.id) = :param_1", dialect=default.DefaultDialect())
+
class CompileTest(QueryTest):
+
def test_deferred(self):
session = create_session()
s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), Address.user_id==User.id)).compile()
@@ -324,7 +377,7 @@ class FilterTest(QueryTest):
try:
sess.query(User).filter(User.addresses == address)
assert False
- except exceptions.InvalidRequestError:
+ except sa_exc.InvalidRequestError:
assert True
assert [User(id=10)] == sess.query(User).filter(User.addresses==None).all()
@@ -332,7 +385,7 @@ class FilterTest(QueryTest):
try:
assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
assert False
- except exceptions.InvalidRequestError:
+ except sa_exc.InvalidRequestError:
assert True
#assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
@@ -348,33 +401,15 @@ class FilterTest(QueryTest):
filter(User.addresses.any(id=4)).all()
assert [User(id=9)] == sess.query(User).filter(User.addresses.any(email_address='fred@fred.com')).all()
-
- @testing.fails_on_everything_except()
- def test_broken_any_1(self):
- sess = create_session()
- # overcorrelates
+ # test that any() doesn't overcorrelate
assert [User(id=7), User(id=8)] == sess.query(User).join("addresses").filter(~User.addresses.any(Address.email_address=='fred@fred.com')).all()
-
- def test_broken_any_2(self):
- sess = create_session()
- # works, filter is before the join
- assert [User(id=7), User(id=8)] == sess.query(User).filter(~User.addresses.any(Address.email_address=='fred@fred.com')).join("addresses", aliased=True).all()
-
- def test_broken_any_3(self):
- sess = create_session()
-
- # works, filter is after the join, but reset_joinpoint is called, removing aliasing
- assert [User(id=7), User(id=8)] == sess.query(User).join("addresses", aliased=True).filter(Address.email_address != None).reset_joinpoint().filter(~User.addresses.any(email_address='fred@fred.com')).all()
+ # test that the contents are not adapted by the aliased join
+ assert [User(id=7), User(id=8)] == sess.query(User).join("addresses", aliased=True).filter(~User.addresses.any(Address.email_address=='fred@fred.com')).all()
- @testing.fails_on_everything_except()
- def test_broken_any_4(self):
- sess = create_session()
-
- # filter is after the join, gets aliased. in 0.5 any(), has() and not contains() are shielded from aliasing
assert [User(id=10)] == sess.query(User).outerjoin("addresses", aliased=True).filter(~User.addresses.any()).all()
-
+
@testing.unsupported('maxdb') # can core
def test_has(self):
sess = create_session()
@@ -384,6 +419,12 @@ class FilterTest(QueryTest):
assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+ # test has() doesn't overcorrelate
+ assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user").filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+
+ # test has() doesnt' get subquery contents adapted by aliased join
+ assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user", aliased=True).filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+
dingaling = sess.query(Dingaling).get(2)
assert [User(id=9)] == sess.query(User).filter(User.addresses.any(Address.dingaling==dingaling)).all()
@@ -457,23 +498,39 @@ class FromSelfTest(QueryTest):
(User(id=8), Address(id=4)),
(User(id=9), Address(id=5))
] == create_session().query(User).filter(User.id.in_([8,9]))._from_self().join('addresses').add_entity(Address).order_by(User.id, Address.id).all()
+
+ def test_multiple_entities(self):
+ sess = create_session()
+
+ self.assertEquals(
+ sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().all(),
+ [
+ (User(id=8), Address(id=2)),
+ (User(id=9), Address(id=5))
+ ]
+ )
+
+ self.assertEquals(
+ sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().options(eagerload('addresses')).first(),
+ (User(id=8, addresses=[Address(), Address(), Address()]), Address(id=2)),
+ )
class AggregateTest(QueryTest):
+
def test_sum(self):
sess = create_session()
orders = sess.query(Order).filter(Order.id.in_([2, 3, 4]))
assert orders.sum(Order.user_id * Order.address_id) == 79
- @testing.uses_deprecated('Call to deprecated function apply_sum')
def test_apply(self):
sess = create_session()
- assert sess.query(Order).apply_sum(Order.user_id * Order.address_id).filter(Order.id.in_([2, 3, 4])).one() == 79
+ assert sess.query(func.sum(Order.user_id * Order.address_id)).filter(Order.id.in_([2, 3, 4])).one() == (79,)
def test_having(self):
sess = create_session()
- assert [User(name=u'ed',id=8)] == sess.query(User).group_by([c for c in User.c]).join('addresses').having(func.count(Address.c.id)> 2).all()
+ assert [User(name=u'ed',id=8)] == sess.query(User).group_by(User).join('addresses').having(func.count(Address.id)> 2).all()
- assert [User(name=u'jack',id=7), User(name=u'fred',id=9)] == sess.query(User).group_by([c for c in User.c]).join('addresses').having(func.count(Address.c.id)< 2).all()
+ assert [User(name=u'jack',id=7), User(name=u'fred',id=9)] == sess.query(User).group_by(User).join('addresses').having(func.count(Address.id)< 2).all()
class CountTest(QueryTest):
def test_basic(self):
@@ -561,10 +618,16 @@ class ParentTest(QueryTest):
o = sess.query(Order).with_parent(u1, property='orders').all()
assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o
- # test static method
- o = Query.query_from_parent(u1, property='orders', session=sess).all()
+ o = sess.query(Order).filter(with_parent(u1, User.orders)).all()
assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o
-
+
+ # test static method
+ @testing.uses_deprecated(".*query_from_parent")
+ def go():
+ o = Query.query_from_parent(u1, property='orders', session=sess).all()
+ assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o
+ go()
+
# test generative criterion
o = sess.query(Order).with_parent(u1).filter(orders.c.id>2).all()
assert [Order(description="order 3"), Order(description="order 5")] == o
@@ -582,7 +645,7 @@ class ParentTest(QueryTest):
try:
q = sess.query(Item).with_parent(u1)
assert False
- except exceptions.InvalidRequestError, e:
+ except sa_exc.InvalidRequestError, e:
assert str(e) == "Could not locate a property which relates instances of class 'Item' to instances of class 'User'"
def test_m2m(self):
@@ -594,28 +657,6 @@ class ParentTest(QueryTest):
class JoinTest(QueryTest):
- def test_getjoinable_tables(self):
- sess = create_session()
-
- sel1 = select([users]).alias()
- sel2 = select([users], from_obj=users.join(addresses)).alias()
-
- j1 = sel1.join(users, sel1.c.id==users.c.id)
- j2 = j1.join(addresses)
-
- for from_obj, assert_cond in (
- (users, [users]),
- (users.join(addresses), [users, addresses]),
- (sel1, [sel1]),
- (sel2, [sel2]),
- (sel1.join(users, sel1.c.id==users.c.id), [sel1, users]),
- (sel2.join(users, sel2.c.id==users.c.id), [sel2, users]),
- (j2, [j1, j2, sel1, users, addresses])
-
- ):
- ret = set(sess.query(User).select_from(from_obj)._get_joinable_tables())
- self.assertEquals(ret, set(assert_cond).union([from_obj]), [x.description for x in ret])
-
def test_overlapping_paths(self):
for aliased in (True,False):
# load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack)
@@ -654,7 +695,34 @@ class JoinTest(QueryTest):
def test_orderby_arg_bug(self):
sess = create_session()
+ # no arg error
+ result = sess.query(User).join('orders', aliased=True).order_by([Order.id]).reset_joinpoint().order_by(users.c.id).all()
+
+ def test_no_onclause(self):
+ sess = create_session()
+
+ self.assertEquals(
+ sess.query(User).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(),
+ [User(name='jack')]
+ )
+
+ self.assertEquals(
+ sess.query(User).join(Order, (Item, Order.items)).filter(Item.description == 'item 4').all(),
+ [User(name='jack')]
+ )
+ def test_clause_onclause(self):
+ sess = create_session()
+
+ self.assertEquals(
+ sess.query(User).join(
+ (Order, User.id==Order.user_id),
+ (order_items, Order.id==order_items.c.order_id),
+ (Item, order_items.c.item_id==Item.id)
+ ).filter(Item.description == 'item 4').all(),
+ [User(name='jack')]
+ )
+
# no arg error
result = sess.query(User).join('orders', aliased=True).order_by([Order.id]).reset_joinpoint().order_by(users.c.id).all()
@@ -682,13 +750,43 @@ class JoinTest(QueryTest):
l = q.select_from(outerjoin(User, AdAlias)).filter(AdAlias.email_address=='ed@bettyboop.com').all()
self.assertEquals(l, [(user8, address3)])
-
l = q.select_from(outerjoin(User, AdAlias, 'addresses')).filter(AdAlias.email_address=='ed@bettyboop.com').all()
self.assertEquals(l, [(user8, address3)])
l = q.select_from(outerjoin(User, AdAlias, User.id==AdAlias.user_id)).filter(AdAlias.email_address=='ed@bettyboop.com').all()
self.assertEquals(l, [(user8, address3)])
+ # this is the first test where we are joining "backwards" - from AdAlias to User even though
+ # the query is against User
+ q = sess.query(User, AdAlias)
+ l = q.join(AdAlias.user).filter(User.name=='ed')
+ self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
+
+ q = sess.query(User, AdAlias).select_from(join(AdAlias, User, AdAlias.user)).filter(User.name=='ed')
+ self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
+
+ def test_implicit_joins_from_aliases(self):
+ sess = create_session()
+ OrderAlias = aliased(Order)
+
+ self.assertEquals(
+ sess.query(OrderAlias).join('items').filter_by(description='item 3').all(),
+ [
+ Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1),
+ Order(address_id=4,description=u'order 2',isopen=0,user_id=9,id=2),
+ Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3)
+ ]
+ )
+
+ self.assertEquals(
+ sess.query(User, OrderAlias, Item.description).join(('orders', OrderAlias), 'items').filter_by(description='item 3').all(),
+ [
+ (User(name=u'jack',id=7), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1), u'item 3'),
+ (User(name=u'jack',id=7), Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), u'item 3'),
+ (User(name=u'fred',id=9), Order(address_id=4,description=u'order 2',isopen=0,user_id=9,id=2), u'item 3')
+ ]
+ )
+
def test_aliased_classes_m2m(self):
sess = create_session()
@@ -725,20 +823,6 @@ class JoinTest(QueryTest):
]
)
- def test_generative_join(self):
- # test that alised_ids is copied
- sess = create_session()
- q = sess.query(User).add_entity(Address)
- q1 = q.join('addresses', aliased=True)
- q2 = q.join('addresses', aliased=True)
- q3 = q2.join('addresses', aliased=True)
- q4 = q2.join('addresses', aliased=True, id='someid')
- q5 = q2.join('addresses', aliased=True, id='someid')
- q6 = q5.join('addresses', aliased=True, id='someid')
- assert q1._alias_ids[class_mapper(Address)] != q2._alias_ids[class_mapper(Address)]
- assert q2._alias_ids[class_mapper(Address)] != q3._alias_ids[class_mapper(Address)]
- assert q4._alias_ids['someid'] != q5._alias_ids['someid']
-
def test_reset_joinpoint(self):
for aliased in (True, False):
# load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack)
@@ -779,43 +863,19 @@ class JoinTest(QueryTest):
assert q.count() == 1
assert [User(id=7)] == q.all()
+
# test the control version - same joins but not aliased. rows are not returned because order 3 does not have item 1
- # addtionally by placing this test after the previous one, test that the "aliasing" step does not corrupt the
- # join clauses that are cached by the relationship.
- q = sess.query(User).join('orders').filter(Order.description=="order 3").join(['orders', 'items']).filter(Order.description=="item 1")
+ q = sess.query(User).join('orders').filter(Order.description=="order 3").join(['orders', 'items']).filter(Item.description=="item 1")
assert [] == q.all()
assert q.count() == 0
q = sess.query(User).join('orders', aliased=True).filter(Order.items.any(Item.description=='item 4'))
assert [User(id=7)] == q.all()
-
- def test_aliased_add_entity(self):
- """test the usage of aliased joins with add_entity()"""
- sess = create_session()
- q = sess.query(User).join('orders', aliased=True, id='order1').filter(Order.description=="order 3").join(['orders', 'items'], aliased=True, id='item1').filter(Item.description=="item 1")
-
- try:
- q.add_entity(Order, id='fakeid').compile()
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Query has no alias identified by 'fakeid'"
-
- try:
- q.add_entity(Order, id='fakeid').instances(None)
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Query has no alias identified by 'fakeid'"
-
- q = q.add_entity(Order, id='order1').add_entity(Item, id='item1')
+
+ # test that aliasing gets reset when join() is called
+ q = sess.query(User).join('orders', aliased=True).filter(Order.description=="order 3").join('orders', aliased=True).filter(Order.description=="order 5")
assert q.count() == 1
- assert [(User(id=7), Order(description='order 3'), Item(description='item 1'))] == q.all()
-
- q = sess.query(User).add_entity(Order).join('orders', aliased=True).filter(Order.description=="order 3").join('orders', aliased=True).filter(Order.description=='order 4')
- try:
- q.compile()
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Ambiguous join for entity 'Mapper|Order|orders'; specify id=<someid> to query.join()/query.add_entity()"
+ assert [User(id=7)] == q.all()
class MultiplePathTest(ORMTest):
def define_tables(self, metadata):
@@ -849,11 +909,10 @@ class MultiplePathTest(ORMTest):
})
mapper(T2, t2)
- try:
- create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint().join('t2s_2')
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Can't join to property 't2s_2'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`."
+ q = create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint()
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "a path to this table along a different secondary table already exists.",
+ q.join, 't2s_2'
+ )
create_session().query(T1).join('t2s_1', aliased=True).filter(t2.c.id==5).reset_joinpoint().join('t2s_2').all()
create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint().join('t2s_2', aliased=True).all()
@@ -926,26 +985,34 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
assert fixtures.user_address_result == l
self.assert_sql_count(testing.db, go, 1)
+ # better way. use select_from()
+ def go():
+ l = sess.query(User).select_from(query).options(contains_eager('addresses')).all()
+ assert fixtures.user_address_result == l
+ self.assert_sql_count(testing.db, go, 1)
+
def test_contains_eager(self):
sess = create_session()
+ # test that contains_eager suppresses the normal outer join rendering
q = sess.query(User).outerjoin(User.addresses).options(contains_eager(User.addresses))
- self.assert_compile(q.statement, "SELECT addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, "
- "addresses.email_address AS addresses_email_address, users.id AS users_id, users.name AS users_name "\
- "FROM users LEFT OUTER JOIN addresses ON users.id = addresses.user_id ORDER BY users.id", dialect=default.DefaultDialect())
-
+ self.assert_compile(q.with_labels().statement, "SELECT users.id AS users_id, users.name AS users_name, "\
+ "addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, "\
+ "addresses.email_address AS addresses_email_address FROM users LEFT OUTER JOIN addresses "\
+ "ON users.id = addresses.user_id ORDER BY users.id", dialect=default.DefaultDialect())
+
def go():
assert fixtures.user_address_result == q.all()
self.assert_sql_count(testing.db, go, 1)
sess.clear()
-
+
adalias = addresses.alias()
q = sess.query(User).select_from(users.outerjoin(adalias)).options(contains_eager(User.addresses, alias=adalias))
def go():
assert fixtures.user_address_result == q.all()
self.assert_sql_count(testing.db, go, 1)
sess.clear()
-
+
selectquery = users.outerjoin(addresses).select(users.c.id<10, use_labels=True, order_by=[users.c.id, addresses.c.id])
q = sess.query(User)
@@ -956,6 +1023,13 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
sess.clear()
+
+ def go():
+ l = q.options(contains_eager(User.addresses)).instances(selectquery.execute())
+ assert fixtures.user_address_result[0:3] == l
+ self.assert_sql_count(testing.db, go, 1)
+ sess.clear()
+
def go():
l = q.options(contains_eager('addresses')).from_statement(selectquery).all()
assert fixtures.user_address_result[0:3] == l
@@ -966,38 +1040,34 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.id, adalias.c.id])
sess = create_session()
q = sess.query(User)
-
+
+ # string alias name
def go():
- # test using a string alias name
l = q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute())
assert fixtures.user_address_result == l
self.assert_sql_count(testing.db, go, 1)
sess.clear()
+ # expression.Alias object
def go():
- # test using the Alias object itself
l = q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute())
assert fixtures.user_address_result == l
self.assert_sql_count(testing.db, go, 1)
sess.clear()
- def decorate(row):
- d = {}
- for c in addresses.c:
- d[c] = row[adalias.corresponding_column(c)]
- return d
-
+ # Aliased object
+ adalias = aliased(Address)
def go():
- # test using a custom 'decorate' function
- l = q.options(contains_eager('addresses', decorator=decorate)).instances(selectquery.execute())
- assert fixtures.user_address_result == l
+ l = q.options(contains_eager('addresses', alias=adalias)).outerjoin((adalias, User.addresses)).order_by(User.id, adalias.id)
+ assert fixtures.user_address_result == l.all()
self.assert_sql_count(testing.db, go, 1)
sess.clear()
+
oalias = orders.alias('o1')
ialias = items.alias('i1')
- query = users.outerjoin(oalias).outerjoin(order_items).outerjoin(ialias).select(use_labels=True).order_by(users.c.id).order_by(oalias.c.id).order_by(ialias.c.id)
+ query = users.outerjoin(oalias).outerjoin(order_items).outerjoin(ialias).select(use_labels=True).order_by(users.c.id, oalias.c.id, ialias.c.id)
q = create_session().query(User)
# test using string alias with more than one level deep
def go():
@@ -1014,9 +1084,24 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
self.assert_sql_count(testing.db, go, 1)
sess.clear()
+ # test using Aliased with more than one level deep
+ oalias = aliased(Order)
+ ialias = aliased(Item)
+ def go():
+ l = q.options(contains_eager(User.orders, alias=oalias), contains_eager(User.orders, Order.items, alias=ialias)).\
+ outerjoin((oalias, User.orders), (ialias, Order.items)).order_by(User.id, oalias.id, ialias.id)
+ assert fixtures.user_order_result == l.all()
+ self.assert_sql_count(testing.db, go, 1)
+ sess.clear()
+
+
+class MixedEntitiesTest(QueryTest):
+
def test_values(self):
sess = create_session()
+ assert list(sess.query(User).values()) == list()
+
sel = users.select(User.id.in_([7, 8])).alias()
q = sess.query(User)
q2 = q.select_from(sel).values(User.name)
@@ -1035,19 +1120,166 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address))[1:3].values(User.name, Address.email_address)
self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')])
- q2 = q.join('addresses', aliased=True).filter(User.name.like('%e%')).values(User.name, Address.email_address)
+ adalias = aliased(Address)
+ q2 = q.join(('addresses', adalias)).filter(User.name.like('%e%')).values(User.name, adalias.email_address)
self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')])
q2 = q.values(func.count(User.name))
assert q2.next() == (4,)
- u2 = users.alias()
- q2 = q.select_from(sel).filter(u2.c.id>1).order_by([users.c.id, sel.c.id, u2.c.id]).values(users.c.name, sel.c.name, u2.c.name)
+ u2 = aliased(User)
+ q2 = q.select_from(sel).filter(u2.id>1).order_by([User.id, sel.c.id, u2.id]).values(User.name, sel.c.name, u2.name)
self.assertEquals(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')])
- q2 = q.select_from(sel).filter(users.c.id>1).values(users.c.name, sel.c.name, User.name)
- self.assertEquals(list(q2), [(u'jack', u'jack', u'jack'), (u'ed', u'ed', u'ed')])
+ q2 = q.select_from(sel).filter(User.id==8).values(User.name, sel.c.name, User.name)
+ self.assertEquals(list(q2), [(u'ed', u'ed', u'ed')])
+
+ # using User.xxx is alised against "sel", so this query returns nothing
+ q2 = q.select_from(sel).filter(User.id==8).filter(User.id>sel.c.id).values(User.name, sel.c.name, User.name)
+ self.assertEquals(list(q2), [])
+
+ # whereas this uses users.c.xxx, is not aliased and creates a new join
+ q2 = q.select_from(sel).filter(users.c.id==8).filter(users.c.id>sel.c.id).values(users.c.name, sel.c.name, User.name)
+ self.assertEquals(list(q2), [(u'ed', u'jack', u'jack')])
+ def test_tuple_labeling(self):
+ sess = create_session()
+ for row in sess.query(User, Address).join(User.addresses).all():
+ self.assertEquals(set(row.keys()), set(['User', 'Address']))
+ self.assertEquals(row.User, row[0])
+ self.assertEquals(row.Address, row[1])
+
+ for row in sess.query(User.name, User.id.label('foobar')):
+ self.assertEquals(set(row.keys()), set(['name', 'foobar']))
+ self.assertEquals(row.name, row[0])
+ self.assertEquals(row.foobar, row[1])
+
+ for row in sess.query(User).values(User.name, User.id.label('foobar')):
+ self.assertEquals(set(row.keys()), set(['name', 'foobar']))
+ self.assertEquals(row.name, row[0])
+ self.assertEquals(row.foobar, row[1])
+
+ oalias = aliased(Order)
+ for row in sess.query(User, oalias).join(User.orders).all():
+ self.assertEquals(set(row.keys()), set(['User']))
+ self.assertEquals(row.User, row[0])
+
+ oalias = aliased(Order, name='orders')
+ for row in sess.query(User, oalias).join(User.orders).all():
+ self.assertEquals(set(row.keys()), set(['User', 'orders']))
+ self.assertEquals(row.User, row[0])
+ self.assertEquals(row.orders, row[1])
+
+
+ def test_column_queries(self):
+ sess = create_session()
+
+ self.assertEquals(sess.query(User.name).all(), [(u'jack',), (u'ed',), (u'fred',), (u'chuck',)])
+
+ sel = users.select(User.id.in_([7, 8])).alias()
+ q = sess.query(User.name)
+ q2 = q.select_from(sel).all()
+ self.assertEquals(list(q2), [(u'jack',), (u'ed',)])
+
+ self.assertEquals(sess.query(User.name, Address.email_address).filter(User.id==Address.user_id).all(), [
+ (u'jack', u'jack@bean.com'), (u'ed', u'ed@wood.com'),
+ (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'),
+ (u'fred', u'fred@fred.com')
+ ])
+
+ self.assertEquals(sess.query(User.name, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User.id, User.name).order_by(User.id).all(),
+ [(u'jack', 1), (u'ed', 3), (u'fred', 1), (u'chuck', 0)]
+ )
+
+ self.assertEquals(sess.query(User, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User).order_by(User.id).all(),
+ [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)]
+ )
+
+ self.assertEquals(sess.query(func.count(Address.email_address), User).outerjoin(User.addresses).group_by(User).order_by(User.id).all(),
+ [(1, User(name='jack',id=7)), (3, User(name='ed',id=8)), (1, User(name='fred',id=9)), (0, User(name='chuck',id=10))]
+ )
+
+ adalias = aliased(Address)
+ self.assertEquals(sess.query(User, func.count(adalias.email_address)).outerjoin(('addresses', adalias)).group_by(User).order_by(User.id).all(),
+ [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)]
+ )
+
+ self.assertEquals(sess.query(func.count(adalias.email_address), User).outerjoin((User.addresses, adalias)).group_by(User).order_by(User.id).all(),
+ [(1, User(name=u'jack',id=7)), (3, User(name=u'ed',id=8)), (1, User(name=u'fred',id=9)), (0, User(name=u'chuck',id=10))]
+ )
+
+ # select from aliasing + explicit aliasing
+ self.assertEquals(
+ sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).order_by(User.id, adalias.id).all(),
+ [
+ (User(name=u'jack',id=7), u'jack@bean.com'),
+ (User(name=u'ed',id=8), u'ed@wood.com'),
+ (User(name=u'ed',id=8), u'ed@bettyboop.com'),
+ (User(name=u'ed',id=8), u'ed@lala.com'),
+ (User(name=u'fred',id=9), u'fred@fred.com'),
+ (User(name=u'chuck',id=10), None)
+ ]
+ )
+
+ # anon + select from aliasing
+ self.assertEquals(
+ sess.query(User).join(User.addresses, aliased=True).filter(Address.email_address.like('%ed%')).from_self().all(),
+ [
+ User(name=u'ed',id=8),
+ User(name=u'fred',id=9),
+ ]
+ )
+
+ # test eager aliasing, with/without select_from aliasing
+ for q in [
+ sess.query(User, adalias.email_address).outerjoin((User.addresses, adalias)).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10),
+ sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10),
+ ]:
+ self.assertEquals(
+ q.all(),
+ [(User(addresses=[Address(user_id=7,email_address=u'jack@bean.com',id=1)],name=u'jack',id=7), u'jack@bean.com'),
+ (User(addresses=[
+ Address(user_id=8,email_address=u'ed@wood.com',id=2),
+ Address(user_id=8,email_address=u'ed@bettyboop.com',id=3),
+ Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@wood.com'),
+ (User(addresses=[
+ Address(user_id=8,email_address=u'ed@wood.com',id=2),
+ Address(user_id=8,email_address=u'ed@bettyboop.com',id=3),
+ Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@bettyboop.com'),
+ (User(addresses=[
+ Address(user_id=8,email_address=u'ed@wood.com',id=2),
+ Address(user_id=8,email_address=u'ed@bettyboop.com',id=3),
+ Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@lala.com'),
+ (User(addresses=[Address(user_id=9,email_address=u'fred@fred.com',id=5)],name=u'fred',id=9), u'fred@fred.com'),
+
+ (User(addresses=[],name=u'chuck',id=10), None)]
+ )
+
+ def test_self_referential(self):
+
+ sess = create_session()
+ oalias = aliased(Order)
+
+ for q in [
+ sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id),
+ sess.query(Order, oalias)._from_self().filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id),
+ # here we go....two layers of aliasing
+ sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id)._from_self().order_by(Order.id, oalias.id).limit(10).options(eagerload(Order.items)),
+
+ # gratuitous four layers
+ sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id)._from_self()._from_self()._from_self().order_by(Order.id, oalias.id).limit(10).options(eagerload(Order.items)),
+
+ ]:
+
+ self.assertEquals(
+ q.all(),
+ [
+ (Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)),
+ (Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)),
+ (Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5), Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3))
+ ]
+ )
+
def test_multi_mappers(self):
test_session = create_session()
@@ -1055,7 +1287,6 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
(user7, user8, user9, user10) = test_session.query(User).all()
(address1, address2, address3, address4, address5) = test_session.query(Address).all()
- # note the result is a cartesian product
expected = [(user7, address1),
(user8, address2),
(user8, address3),
@@ -1066,30 +1297,24 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
sess = create_session()
selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id])
- q = sess.query(User)
- l = q.instances(selectquery.execute(), Address)
- assert l == expected
-
+ self.assertEquals(sess.query(User, Address).instances(selectquery.execute()), expected)
sess.clear()
- for aliased in (False, True):
- q = sess.query(User)
-
- q = q.add_entity(Address).outerjoin('addresses', aliased=aliased)
- l = q.all()
- assert l == expected
+ for address_entity in (Address, aliased(Address)):
+ q = sess.query(User).add_entity(address_entity).outerjoin(('addresses', address_entity)).order_by(User.id, address_entity.id)
+ self.assertEquals(q.all(), expected)
sess.clear()
- q = sess.query(User).add_entity(Address)
- l = q.join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com').all()
- assert l == [(user8, address3)]
+ q = sess.query(User).add_entity(address_entity)
+ q = q.join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com')
+ self.assertEquals(q.all(), [(user8, address3)])
sess.clear()
- q = sess.query(User, Address).join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com')
- assert q.all() == [(user8, address3)]
+ q = sess.query(User, address_entity).join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com')
+ self.assertEquals(q.all(), [(user8, address3)])
sess.clear()
- q = sess.query(User, Address).join('addresses', aliased=aliased).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com')
+ q = sess.query(User, address_entity).join(('addresses', address_entity)).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com')
self.assertEquals(list(util.OrderedSet(q.all())), [(user8, address3)])
sess.clear()
@@ -1123,18 +1348,12 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
expected = [(u, u.name) for u in sess.query(User).all()]
- for add_col in (User.name, users.c.name, User.c.name):
+ for add_col in (User.name, users.c.name):
assert sess.query(User).add_column(add_col).all() == expected
sess.clear()
- self.assertRaises(exceptions.InvalidRequestError, sess.query(User).add_column, object())
+ self.assertRaises(sa_exc.InvalidRequestError, sess.query(User).add_column, object())
- def test_ambiguous_column(self):
- sess = create_session()
-
- q = sess.query(User).join('addresses', aliased=True).join('addresses', aliased=True).add_column(Address.id)
- self.assertRaises(exceptions.InvalidRequestError, iter, q)
-
def test_multi_columns_2(self):
"""test aliased/nonalised joins with the usage of add_column()"""
sess = create_session()
@@ -1146,12 +1365,16 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
(user10, 0)
]
- for aliased in (False, True):
- q = sess.query(User)
- q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses', aliased=aliased).add_column(func.count(Address.id).label('count'))
- l = q.all()
- assert l == expected
- sess.clear()
+ q = sess.query(User)
+ q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses').add_column(func.count(Address.id).label('count'))
+ self.assertEquals(q.all(), expected)
+ sess.clear()
+
+ adalias = aliased(Address)
+ q = sess.query(User)
+ q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin(('addresses', adalias)).add_column(func.count(adalias.id).label('count'))
+ self.assertEquals(q.all(), expected)
+ sess.clear()
s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id)
q = sess.query(User)
@@ -1159,7 +1382,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
assert l == expected
- def test_two_columns(self):
+ def test_raw_columns(self):
sess = create_session()
(user7, user8, user9, user10) = sess.query(User).all()
expected = [
@@ -1168,8 +1391,9 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
(user9, 1, "Name:fred"),
(user10, 0, "Name:chuck")]
- q = create_session().query(User).add_column(func.count(addresses.c.id))\
- .add_column(("Name:" + users.c.name)).outerjoin('addresses', aliased=True)\
+ adalias = addresses.alias()
+ q = create_session().query(User).add_column(func.count(adalias.c.id))\
+ .add_column(("Name:" + users.c.name)).outerjoin(('addresses', adalias))\
.group_by([c for c in users.c]).order_by(users.c.id)
assert q.all() == expected
@@ -1190,14 +1414,19 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
assert q.all() == expected
sess.clear()
- # test with outerjoin() both aliased and non
- for aliased in (False, True):
- q = create_session().query(User).add_column(func.count(addresses.c.id))\
- .add_column(("Name:" + users.c.name)).outerjoin('addresses', aliased=aliased)\
- .group_by([c for c in users.c]).order_by(users.c.id)
+ q = create_session().query(User).add_column(func.count(addresses.c.id))\
+ .add_column(("Name:" + users.c.name)).outerjoin('addresses')\
+ .group_by([c for c in users.c]).order_by(users.c.id)
- assert q.all() == expected
- sess.clear()
+ assert q.all() == expected
+ sess.clear()
+
+ q = create_session().query(User).add_column(func.count(adalias.c.id))\
+ .add_column(("Name:" + users.c.name)).outerjoin(('addresses', adalias))\
+ .group_by([c for c in users.c]).order_by(users.c.id)
+
+ assert q.all() == expected
+ sess.clear()
class SelectFromTest(QueryTest):
@@ -1217,7 +1446,7 @@ class SelectFromTest(QueryTest):
self.assertEquals(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)])
- self.assertEquals(sess.query(User).select_from(sel).filter(User.c.id==8).all(), [User(id=8)])
+ self.assertEquals(sess.query(User).select_from(sel).filter(User.id==8).all(), [User(id=8)])
self.assertEquals(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [
User(name='jack',id=7), User(name='ed',id=8)
@@ -1273,7 +1502,8 @@ class SelectFromTest(QueryTest):
]
)
- self.assertEquals(sess.query(User).select_from(sel).join('addresses', aliased=True).add_entity(Address).order_by(User.id).order_by(Address.id).all(),
+ adalias = aliased(Address)
+ self.assertEquals(sess.query(User).select_from(sel).join(('addresses', adalias)).add_entity(adalias).order_by(User.id).order_by(adalias.id).all(),
[
(User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)),
(User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)),
@@ -1297,12 +1527,15 @@ class SelectFromTest(QueryTest):
sel = users.select(users.c.id.in_([7, 8]))
sess = create_session()
+
+ # TODO: remove
+ sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all()
- self.assertEquals(sess.query(User).select_from(sel).join(['orders', 'items', 'keywords']).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
+ self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
User(name=u'jack',id=7)
])
- self.assertEquals(sess.query(User).select_from(sel).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
+ self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
User(name=u'jack',id=7)
])
@@ -1355,7 +1588,7 @@ class SelectFromTest(QueryTest):
sess.clear()
def go():
- self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.c.id==8).all(),
+ self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.id==8).all(),
[User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])]
)
self.assert_sql_count(testing.db, go, 1)
@@ -1364,7 +1597,7 @@ class SelectFromTest(QueryTest):
def go():
self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)]))
self.assert_sql_count(testing.db, go, 1)
-
+
class CustomJoinTest(QueryTest):
keep_mappers = False
@@ -1428,6 +1661,10 @@ class SelfReferentialTest(ORMTest):
node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first()
assert node.data=='n12'
+ ret = sess.query(Node.data).join(Node.children, aliased=True).filter_by(data='n122').all()
+ assert ret == [('n12',)]
+
+
node = sess.query(Node).join(['children', 'children'], aliased=True).filter_by(data='n122').first()
assert node.data=='n1'
@@ -1461,10 +1698,66 @@ class SelfReferentialTest(ORMTest):
list(sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\
filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).values(Node.data, n1.data, n2.data)),
[('n122', 'n12', 'n1')])
+
+ def test_join_to_nonaliased(self):
+ sess = create_session()
- def test_any(self):
+ n1 = aliased(Node)
+
+ # using 'n1.parent' implicitly joins to unaliased Node
+ self.assertEquals(
+ sess.query(n1).join(n1.parent).filter(Node.data=='n1').all(),
+ [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)]
+ )
+
+ # explicit (new syntax)
+ self.assertEquals(
+ sess.query(n1).join((Node, n1.parent)).filter(Node.data=='n1').all(),
+ [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)]
+ )
+
+ def test_multiple_explicit_entities(self):
sess = create_session()
+ parent = aliased(Node)
+ grandparent = aliased(Node)
+ self.assertEquals(
+ sess.query(Node, parent, grandparent).\
+ join((Node.parent, parent), (parent.parent, grandparent)).\
+ filter(Node.data=='n122').filter(parent.data=='n12').\
+ filter(grandparent.data=='n1').first(),
+ (Node(data='n122'), Node(data='n12'), Node(data='n1'))
+ )
+
+ self.assertEquals(
+ sess.query(Node, parent, grandparent).\
+ join((Node.parent, parent), (parent.parent, grandparent)).\
+ filter(Node.data=='n122').filter(parent.data=='n12').\
+ filter(grandparent.data=='n1')._from_self().first(),
+ (Node(data='n122'), Node(data='n12'), Node(data='n1'))
+ )
+
+ self.assertEquals(
+ sess.query(Node, parent, grandparent).\
+ join((Node.parent, parent), (parent.parent, grandparent)).\
+ filter(Node.data=='n122').filter(parent.data=='n12').\
+ filter(grandparent.data=='n1').\
+ options(eagerload(Node.children)).first(),
+ (Node(data='n122'), Node(data='n12'), Node(data='n1'))
+ )
+
+ self.assertEquals(
+ sess.query(Node, parent, grandparent).\
+ join((Node.parent, parent), (parent.parent, grandparent)).\
+ filter(Node.data=='n122').filter(parent.data=='n12').\
+ filter(grandparent.data=='n1')._from_self().\
+ options(eagerload(Node.children)).first(),
+ (Node(data='n122'), Node(data='n12'), Node(data='n1'))
+ )
+
+
+ def test_any(self):
+ sess = create_session()
self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')])
self.assertEquals(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
@@ -1561,6 +1854,8 @@ class SelfReferentialM2MTest(ORMTest):
)
class ExternalColumnsTest(QueryTest):
+ """test mappers with SQL-expressions added as column properties."""
+
keep_mappers = False
def setup_mappers(self):
@@ -1568,15 +1863,11 @@ class ExternalColumnsTest(QueryTest):
def test_external_columns_bad(self):
- self.assertRaisesMessage(exceptions.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={
+ self.assertRaisesMessage(sa_exc.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={
'concat': (users.c.id * 2),
})
clear_mappers()
- self.assertRaisesMessage(exceptions.ArgumentError, "must be given a ColumnElement as its argument.", column_property,
- select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users)
- )
-
def test_external_columns_good(self):
"""test querying mappings that reference external columns or selectables."""
@@ -1586,19 +1877,21 @@ class ExternalColumnsTest(QueryTest):
})
mapper(Address, addresses, properties={
- 'user':relation(User, lazy=True)
+ 'user':relation(User)
})
sess = create_session()
-
- l = sess.query(User).all()
- assert [
- User(id=7, concat=14, count=1),
- User(id=8, concat=16, count=3),
- User(id=9, concat=18, count=1),
- User(id=10, concat=20, count=0),
- ] == l
+ sess.query(Address).options(eagerload('user')).all()
+
+ self.assertEquals(sess.query(User).all(),
+ [
+ User(id=7, concat=14, count=1),
+ User(id=8, concat=16, count=3),
+ User(id=9, concat=18, count=1),
+ User(id=10, concat=20, count=0),
+ ]
+ )
address_result = [
Address(id=1, user=User(id=7, concat=14, count=1)),
@@ -1617,15 +1910,24 @@ class ExternalColumnsTest(QueryTest):
self.assertEquals(sess.query(Address).options(eagerload('user')).all(), address_result)
self.assert_sql_count(testing.db, go, 1)
- tuple_address_result = [(address, address.user) for address in address_result]
-
- q =sess.query(Address).join('user', aliased=True, id='ualias').join('user', aliased=True).add_column(User.concat)
- self.assertRaisesMessage(exceptions.InvalidRequestError, "Ambiguous", q.all)
-
- self.assertEquals(sess.query(Address).join('user', aliased=True, id='ualias').add_entity(User, id='ualias').all(), tuple_address_result)
+ ualias = aliased(User)
+ self.assertEquals(
+ sess.query(Address, ualias).join(('user', ualias)).all(),
+ [(address, address.user) for address in address_result]
+ )
- self.assertEquals(sess.query(Address).join('user', aliased=True, id='ualias').join('user', aliased=True).\
- add_column(User.concat, id='ualias').add_column(User.count, id='ualias').all(),
+ self.assertEquals(
+ sess.query(Address, ualias.count).join(('user', ualias)).join('user', aliased=True).all(),
+ [
+ (Address(id=1), 1),
+ (Address(id=2), 3),
+ (Address(id=3), 3),
+ (Address(id=4), 3),
+ (Address(id=5), 1)
+ ]
+ )
+
+ self.assertEquals(sess.query(Address, ualias.concat, ualias.count).join(('user', ualias)).join('user', aliased=True).all(),
[
(Address(id=1), 14, 1),
(Address(id=2), 16, 3),
@@ -1635,15 +1937,21 @@ class ExternalColumnsTest(QueryTest):
]
)
- self.assertEquals(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)),
- [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)]
+ ua = aliased(User)
+ self.assertEquals(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(),
+ [
+ (Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1),
+ (Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3),
+ (Address(id=3, user=User(id=8, concat=16, count=3)), 16, 3),
+ (Address(id=4, user=User(id=8, concat=16, count=3)), 16, 3),
+ (Address(id=5, user=User(id=9, concat=18, count=1)), 18, 1)
+ ]
)
- self.assertEquals(list(sess.query(Address).join('user', aliased=True).values(Address.id, User.id, User.concat, User.count)),
+ self.assertEquals(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)),
[(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)]
)
- ua = aliased(User)
self.assertEquals(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)),
[(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)]
)
diff --git a/test/orm/relationships.py b/test/orm/relationships.py
index 40773f835..b33684e2f 100644
--- a/test/orm/relationships.py
+++ b/test/orm/relationships.py
@@ -1,9 +1,9 @@
import testenv; testenv.configure_for_tests()
import datetime
from sqlalchemy import *
-from sqlalchemy import exceptions, types
+from sqlalchemy import exc as sa_exc, types
from sqlalchemy.orm import *
-from sqlalchemy.orm import collections
+from sqlalchemy.orm import collections, attributes, exc as orm_exc
from sqlalchemy.orm.collections import collection
from testlib import *
from testlib import fixtures
@@ -278,7 +278,13 @@ class RelationTest3(TestBase):
self.pagename = pagename
self.currentversion = PageVersion(self, 1)
def __repr__(self):
- return "Page jobno:%s pagename:%s %s" % (self.jobno, self.pagename, getattr(self, '_instance_key', None))
+ try:
+ state = attributes.instance_state(self)
+ key = state.key
+ except (KeyError, AttributeError):
+ key = None
+ return ("Page jobno:%s pagename:%s %s" %
+ (self.jobno, self.pagename, key))
def add_version(self):
self.currentversion = PageVersion(self, self.currentversion.version+1)
comment = self.add_comment()
@@ -393,7 +399,7 @@ class RelationTest4(ORMTest):
try:
sess.flush()
assert False
- except exceptions.AssertionError, e:
+ except AssertionError, e:
assert str(e).startswith("Dependency rule tried to blank-out primary key column 'B.id' on instance ")
def test_no_delete_PK_BtoA(self):
@@ -413,7 +419,7 @@ class RelationTest4(ORMTest):
try:
sess.flush()
assert False
- except exceptions.AssertionError, e:
+ except AssertionError, e:
assert str(e).startswith("Dependency rule tried to blank-out primary key column 'B.id' on instance ")
@testing.fails_on_everything_except('sqlite', 'mysql')
@@ -627,7 +633,7 @@ class TypeMatchTest(ORMTest):
try:
sess.save(a1)
assert False
- except exceptions.AssertionError, err:
+ except AssertionError, err:
assert str(err) == "Attribute 'bs' on class '%s' doesn't handle objects of type '%s'" % (A, C)
def test_o2m_onflush(self):
class A(object):pass
@@ -646,11 +652,8 @@ class TypeMatchTest(ORMTest):
sess.save(a1)
sess.save(b1)
sess.save(c1)
- try:
- sess.flush()
- assert False
- except exceptions.FlushError, err:
- assert str(err).startswith("Attempting to flush an item of type %s on collection 'A.bs (B)', which is handled by mapper 'Mapper|B|b' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ?" % C)
+ self.assertRaisesMessage(orm_exc.FlushError, "Attempting to flush an item", sess.flush)
+
def test_o2m_nopoly_onflush(self):
class A(object):pass
class B(object):pass
@@ -668,11 +671,7 @@ class TypeMatchTest(ORMTest):
sess.save(a1)
sess.save(b1)
sess.save(c1)
- try:
- sess.flush()
- assert False
- except exceptions.FlushError, err:
- assert str(err).startswith("Attempting to flush an item of type %s on collection 'A.bs (B)', which is handled by mapper 'Mapper|B|b' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ?" % C)
+ self.assertRaisesMessage(orm_exc.FlushError, "Attempting to flush an item", sess.flush)
def test_m2o_nopoly_onflush(self):
class A(object):pass
@@ -687,11 +686,8 @@ class TypeMatchTest(ORMTest):
sess = create_session()
sess.save(b1)
sess.save(d1)
- try:
- sess.flush()
- assert False
- except exceptions.FlushError, err:
- assert str(err).startswith("Attempting to flush an item of type %s on collection 'D.a (A)', which is handled by mapper 'Mapper|A|a' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ?" % B)
+ self.assertRaisesMessage(orm_exc.FlushError, "Attempting to flush an item", sess.flush)
+
def test_m2o_oncascade(self):
class A(object):pass
class B(object):pass
@@ -703,11 +699,7 @@ class TypeMatchTest(ORMTest):
d1 = D()
d1.a = b1
sess = create_session()
- try:
- sess.save(d1)
- assert False
- except exceptions.AssertionError, err:
- assert str(err) == "Attribute 'a' on class '%s' doesn't handle objects of type '%s'" % (D, B)
+ self.assertRaisesMessage(AssertionError, "doesn't handle objects of type", sess.save, d1)
class TypedAssociationTable(ORMTest):
def define_tables(self, metadata):
@@ -1030,6 +1022,7 @@ class ViewOnlyTest6(ORMTest):
a = sess.query(T1).first()
self.assertEquals(a.t3s, [T3(data='t3')])
+
def test_remote_side_escalation(self):
class T1(fixtures.Base):
@@ -1051,7 +1044,7 @@ class ViewOnlyTest6(ORMTest):
't3s':relation(T3, secondary=t2tot3)
})
mapper(T3, t3)
- self.assertRaisesMessage(exceptions.ArgumentError, "Specify remote_side argument", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Specify remote_side argument", compile_mappers)
class ExplicitLocalRemoteTest(ORMTest):
def define_tables(self, metadata):
@@ -1210,7 +1203,7 @@ class ExplicitLocalRemoteTest(ORMTest):
)
})
mapper(T2, t2)
- self.assertRaises(exceptions.ArgumentError, compile_mappers)
+ self.assertRaises(sa_exc.ArgumentError, compile_mappers)
clear_mappers()
mapper(T1, t1, properties={
@@ -1219,7 +1212,7 @@ class ExplicitLocalRemoteTest(ORMTest):
)
})
mapper(T2, t2)
- self.assertRaises(exceptions.ArgumentError, compile_mappers)
+ self.assertRaises(sa_exc.ArgumentError, compile_mappers)
class InvalidRelationEscalationTest(ORMTest):
def define_tables(self, metadata):
@@ -1237,7 +1230,7 @@ class InvalidRelationEscalationTest(ORMTest):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
def test_no_join_self_ref(self):
mapper(Foo, foos, properties={
@@ -1245,7 +1238,7 @@ class InvalidRelationEscalationTest(ORMTest):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
def test_no_equated(self):
mapper(Foo, foos, properties={
@@ -1253,7 +1246,7 @@ class InvalidRelationEscalationTest(ORMTest):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
def test_no_equated_fks(self):
mapper(Foo, foos, properties={
@@ -1261,7 +1254,7 @@ class InvalidRelationEscalationTest(ORMTest):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers)
def test_no_equated_self_ref(self):
mapper(Foo, foos, properties={
@@ -1269,7 +1262,7 @@ class InvalidRelationEscalationTest(ORMTest):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
def test_no_equated_self_ref(self):
mapper(Foo, foos, properties={
@@ -1277,7 +1270,7 @@ class InvalidRelationEscalationTest(ORMTest):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers)
def test_no_equated_viewonly(self):
mapper(Foo, foos, properties={
@@ -1285,7 +1278,7 @@ class InvalidRelationEscalationTest(ORMTest):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
def test_no_equated_self_ref_viewonly(self):
mapper(Foo, foos, properties={
@@ -1294,7 +1287,7 @@ class InvalidRelationEscalationTest(ORMTest):
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Specify the foreign_keys argument to indicate which columns on the relation are foreign.", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Specify the foreign_keys argument to indicate which columns on the relation are foreign.", compile_mappers)
def test_no_equated_self_ref_viewonly_fks(self):
mapper(Foo, foos, properties={
@@ -1308,21 +1301,21 @@ class InvalidRelationEscalationTest(ORMTest):
'bars':relation(Bar, primaryjoin=foos.c.id==bars.c.fid)
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
def test_equated_self_ref(self):
mapper(Foo, foos, properties={
'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid)
})
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
def test_equated_self_ref_wrong_fks(self):
mapper(Foo, foos, properties={
'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid, foreign_keys=[bars.c.id])
})
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
class InvalidRelationEscalationTestM2M(ORMTest):
def define_tables(self, metadata):
@@ -1341,7 +1334,7 @@ class InvalidRelationEscalationTestM2M(ORMTest):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
def test_no_secondaryjoin(self):
mapper(Foo, foos, properties={
@@ -1349,7 +1342,7 @@ class InvalidRelationEscalationTestM2M(ORMTest):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
def test_bad_primaryjoin(self):
mapper(Foo, foos, properties={
@@ -1357,7 +1350,7 @@ class InvalidRelationEscalationTestM2M(ORMTest):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
def test_bad_secondaryjoin(self):
mapper(Foo, foos, properties={
@@ -1365,7 +1358,7 @@ class InvalidRelationEscalationTestM2M(ORMTest):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for secondaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for secondaryjoin condition", compile_mappers)
def test_no_equated_secondaryjoin(self):
mapper(Foo, foos, properties={
@@ -1373,7 +1366,7 @@ class InvalidRelationEscalationTestM2M(ORMTest):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated, locally mapped column pairs for secondaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not locate any equated, locally mapped column pairs for secondaryjoin condition", compile_mappers)
if __name__ == "__main__":
diff --git a/test/orm/scoping.py b/test/orm/scoping.py
new file mode 100644
index 000000000..523f37671
--- /dev/null
+++ b/test/orm/scoping.py
@@ -0,0 +1,171 @@
+import testenv; testenv.configure_for_tests()
+from sqlalchemy import *
+from sqlalchemy import exc as sa_exc
+from sqlalchemy.orm import *
+from testlib import *
+from testlib import fixtures
+
+
+class ScopedSessionTest(ORMTest):
+
+ def define_tables(self, metadata):
+ global table, table2
+ table = Table('sometable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(30)))
+ table2 = Table('someothertable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('someid', None, ForeignKey('sometable.id'))
+ )
+
+ def test_basic(self):
+ Session = scoped_session(sessionmaker())
+
+ class SomeObject(fixtures.Base):
+ query = Session.query_property()
+ class SomeOtherObject(fixtures.Base):
+ query = Session.query_property()
+
+ mapper(SomeObject, table, properties={
+ 'options':relation(SomeOtherObject)
+ })
+ mapper(SomeOtherObject, table2)
+
+ s = SomeObject(id=1, data="hello")
+ sso = SomeOtherObject()
+ s.options.append(sso)
+ Session.save(s)
+ Session.commit()
+ Session.refresh(sso)
+ Session.remove()
+
+ self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), Session.query(SomeObject).one())
+ self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), SomeObject.query.one())
+ self.assertEquals(SomeOtherObject(someid=1), SomeOtherObject.query.filter(SomeOtherObject.someid==sso.someid).one())
+
+
+class ScopedMapperTest(TestBase):
+ def setUpAll(self):
+ global metadata, table, table2
+ metadata = MetaData(testing.db)
+ table = Table('sometable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(30)))
+ table2 = Table('someothertable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('someid', None, ForeignKey('sometable.id'))
+ )
+ metadata.create_all()
+
+ def setUp(self):
+ global SomeObject, SomeOtherObject
+ class SomeObject(fixtures.Base):pass
+ class SomeOtherObject(fixtures.Base):pass
+
+ global Session
+
+ Session = scoped_session(create_session)
+ Session.mapper(SomeObject, table, properties={
+ 'options':relation(SomeOtherObject)
+ })
+ Session.mapper(SomeOtherObject, table2)
+
+ s = SomeObject()
+ s.id = 1
+ s.data = 'hello'
+ sso = SomeOtherObject()
+ s.options.append(sso)
+ Session.flush()
+ Session.clear()
+
+ def tearDownAll(self):
+ metadata.drop_all()
+
+ def tearDown(self):
+ for table in metadata.table_iterator(reverse=True):
+ table.delete().execute()
+ clear_mappers()
+
+ def test_query(self):
+ sso = SomeOtherObject.query().first()
+ assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
+
+ def test_query_compiles(self):
+ class Foo(object):
+ pass
+ Session.mapper(Foo, table2)
+ assert hasattr(Foo, 'query')
+
+ ext = MapperExtension()
+
+ class Bar(object):
+ pass
+ Session.mapper(Bar, table2, extension=[ext])
+ assert hasattr(Bar, 'query')
+
+ class Baz(object):
+ pass
+ Session.mapper(Baz, table2, extension=ext)
+ assert hasattr(Baz, 'query')
+
+ def test_validating_constructor(self):
+ s2 = SomeObject(someid=12)
+ s3 = SomeOtherObject(someid=123, bogus=345)
+
+ class ValidatedOtherObject(object): pass
+ Session.mapper(ValidatedOtherObject, table2, validate=True)
+
+ v1 = ValidatedOtherObject(someid=12)
+ self.assertRaises(sa_exc.ArgumentError, ValidatedOtherObject, someid=12, bogus=345)
+
+ def test_dont_clobber_methods(self):
+ class MyClass(object):
+ def expunge(self):
+ return "an expunge !"
+
+ Session.mapper(MyClass, table2)
+
+ assert MyClass().expunge() == "an expunge !"
+
+class ScopedMapperTest2(ORMTest):
+ def define_tables(self, metadata):
+ global table, table2
+ table = Table('sometable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(30)),
+ Column('type', String(30))
+
+ )
+ table2 = Table('someothertable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('someid', None, ForeignKey('sometable.id')),
+ Column('somedata', String(30)),
+ )
+
+ def test_inheritance(self):
+ def expunge_list(l):
+ for x in l:
+ Session.expunge(x)
+ return l
+
+ class BaseClass(fixtures.Base):
+ pass
+ class SubClass(BaseClass):
+ pass
+
+ Session = scoped_session(sessionmaker())
+ Session.mapper(BaseClass, table, polymorphic_identity='base', polymorphic_on=table.c.type)
+ Session.mapper(SubClass, table2, polymorphic_identity='sub', inherits=BaseClass)
+
+ b = BaseClass(data='b1')
+ s = SubClass(data='s1', somedata='somedata')
+ Session.commit()
+ Session.clear()
+
+ assert expunge_list([BaseClass(data='b1'), SubClass(data='s1', somedata='somedata')]) == BaseClass.query.all()
+ assert expunge_list([SubClass(data='s1', somedata='somedata')]) == SubClass.query.all()
+
+
+
+if __name__ == "__main__":
+ testenv.main()
diff --git a/test/orm/selectable.py b/test/orm/selectable.py
index fc5be6f50..a16c24fc1 100644
--- a/test/orm/selectable.py
+++ b/test/orm/selectable.py
@@ -2,7 +2,7 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
from testlib import *
from testlib.fixtures import *
@@ -21,7 +21,7 @@ class SelectableNoFromsTest(ORMTest):
class Subset(object):
pass
selectable = select(["x", "y", "z"])
- self.assertRaisesMessage(exceptions.InvalidRequestError, "Could not find any Table objects", mapper, Subset, selectable)
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "Could not find any Table objects", mapper, Subset, selectable)
@testing.emits_warning('.*creating an Alias.*')
def test_basic(self):
diff --git a/test/orm/session.py b/test/orm/session.py
index 49932f8d9..719ecccf9 100644
--- a/test/orm/session.py
+++ b/test/orm/session.py
@@ -1,14 +1,15 @@
import testenv; testenv.configure_for_tests()
+import gc
+import pickle
from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import exc as sa_exc, util
from sqlalchemy.orm import *
+from sqlalchemy.orm import attributes
from sqlalchemy.orm.session import SessionExtension
from sqlalchemy.orm.session import Session as SessionCls
from testlib import *
from testlib.tables import *
from testlib import fixtures, tables
-import pickle
-import gc
class SessionTest(TestBase, AssertsExecutionResults):
@@ -27,7 +28,8 @@ class SessionTest(TestBase, AssertsExecutionResults):
pass
def test_close(self):
- """test that flush() doenst close a connection the session didnt open"""
+ """test that flush() doesn't close a connection the session didn't open"""
+
c = testing.db.connect()
class User(object):pass
mapper(User, users)
@@ -83,9 +85,14 @@ class SessionTest(TestBase, AssertsExecutionResults):
# then see if expunge fails
session.expunge(u)
+ assert object_session(u) is attributes.instance_state(u).session_id is None
+ for a in u.addresses:
+ assert object_session(a) is attributes.instance_state(a).session_id is None
+
@engines.close_open_connections
def test_binds_from_expression(self):
"""test that Session can extract Table objects from ClauseElements and match them to tables."""
+
Session = sessionmaker(binds={users:testing.db, addresses:testing.db})
sess = Session()
sess.execute(users.insert(), params=dict(user_id=1, user_name='ed'))
@@ -123,7 +130,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
conn1 = testing.db.connect()
conn2 = testing.db.connect()
- sess = create_session(transactional=True, bind=conn1)
+ sess = create_session(autocommit=False, bind=conn1)
u = User()
sess.save(u)
sess.flush()
@@ -134,20 +141,6 @@ class SessionTest(TestBase, AssertsExecutionResults):
assert testing.db.connect().execute("select count(1) from users").scalar() == 1
sess.close()
- def test_flush_noop(self):
- session = create_session()
- session.uow = object()
-
- self.assertRaises(AttributeError, session.flush)
-
- session = create_session()
- session.uow = object()
-
- session.flush(objects=[])
- session.flush(objects=set())
- session.flush(objects=())
- session.flush(objects=iter([]))
-
@testing.unsupported('sqlite', 'mssql') # TEMP: test causes mssql to hang
@engines.close_open_connections
def test_autoflush(self):
@@ -156,7 +149,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
conn1 = testing.db.connect()
conn2 = testing.db.connect()
- sess = create_session(bind=conn1, transactional=True, autoflush=True)
+ sess = create_session(bind=conn1, autocommit=False, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
@@ -179,7 +172,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
})
mapper(Address, addresses)
- sess = create_session(autoflush=True, transactional=True)
+ sess = create_session(autoflush=True, autocommit=False)
u = User(user_name='ed', addresses=[Address(email_address='foo')])
sess.save(u)
self.assertEquals(sess.query(Address).filter(Address.user==u).one(), Address(email_address='foo'))
@@ -191,7 +184,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
mapper(User, users)
try:
- sess = create_session(transactional=True, autoflush=True)
+ sess = create_session(autocommit=False, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
@@ -214,7 +207,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
conn1 = testing.db.connect()
conn2 = testing.db.connect()
- sess = create_session(bind=conn1, transactional=True, autoflush=True)
+ sess = create_session(bind=conn1, autocommit=False, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
@@ -223,18 +216,17 @@ class SessionTest(TestBase, AssertsExecutionResults):
assert testing.db.connect().execute("select count(1) from users").scalar() == 1
sess.commit()
- # TODO: not doing rollback of attributes right now.
- def dont_test_autoflush_rollback(self):
+ def test_autoflush_rollback(self):
tables.data()
mapper(Address, addresses)
mapper(User, users, properties={
'addresses':relation(Address)
})
- sess = create_session(transactional=True, autoflush=True)
+ sess = create_session(autocommit=False, autoflush=True)
u = sess.query(User).get(8)
newad = Address()
- newad.email_address == 'something new'
+ newad.email_address = 'something new'
u.addresses.append(newad)
u.user_name = 'some new name'
assert u.user_name == 'some new name'
@@ -244,16 +236,26 @@ class SessionTest(TestBase, AssertsExecutionResults):
assert u.user_name == 'ed'
assert len(u.addresses) == 3
assert newad not in u.addresses
-
+
+ # pending objects dont get expired
+ assert newad.email_address == 'something new'
+
+ def test_textual_execute(self):
+ """test that Session.execute() converts to text()"""
+
+ tables.data()
+ sess = create_session(bind=testing.db)
+ # use :bindparam style
+ self.assertEquals(sess.execute("select * from users where user_id=:id", {'id':7}).fetchall(), [(7, u'jack')])
@engines.close_open_connections
- def test_external_joined_transaction(self):
+ def test_subtransaction_on_external(self):
class User(object):pass
mapper(User, users)
conn = testing.db.connect()
trans = conn.begin()
- sess = create_session(bind=conn, transactional=True, autoflush=True)
- sess.begin()
+ sess = create_session(bind=conn, autocommit=False, autoflush=True)
+ sess.begin(subtransactions=True)
u = User()
sess.save(u)
sess.flush()
@@ -271,7 +273,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
try:
conn = testing.db.connect()
trans = conn.begin()
- sess = create_session(bind=conn, transactional=True, autoflush=True)
+ sess = create_session(bind=conn, autocommit=False, autoflush=True)
u1 = User()
sess.save(u1)
sess.flush()
@@ -288,16 +290,14 @@ class SessionTest(TestBase, AssertsExecutionResults):
conn.close()
raise
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @engines.close_open_connections
+ @testing.requires.savepoints
def test_heavy_nesting(self):
session = create_session(bind=testing.db)
session.begin()
session.connection().execute("insert into users (user_name) values ('user1')")
- session.begin()
+ session.begin(subtransactions=True)
session.begin_nested()
@@ -312,9 +312,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
assert session.connection().execute("select count(1) from users").scalar() == 2
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
+ @testing.requires.two_phase_transactions
def test_twophase(self):
# TODO: mock up a failure condition here
# to ensure a rollback succeeds
@@ -324,7 +322,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
mapper(Address, addresses)
engine2 = create_engine(testing.db.url)
- sess = create_session(transactional=False, autoflush=False, twophase=True)
+ sess = create_session(autocommit=True, autoflush=False, twophase=True)
sess.bind_mapper(User, testing.db)
sess.bind_mapper(Address, engine2)
sess.begin()
@@ -338,11 +336,11 @@ class SessionTest(TestBase, AssertsExecutionResults):
assert users.count().scalar() == 1
assert addresses.count().scalar() == 1
- def test_joined_transaction(self):
+ def test_subtransaction_on_noautocommit(self):
class User(object):pass
mapper(User, users)
- sess = create_session(transactional=True, autoflush=True)
- sess.begin()
+ sess = create_session(autocommit=False, autoflush=True)
+ sess.begin(subtransactions=True)
u = User()
sess.save(u)
sess.flush()
@@ -351,9 +349,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
assert len(sess.query(User).all()) == 0
sess.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
+ @testing.requires.savepoints
def test_nested_transaction(self):
class User(object):pass
mapper(User, users)
@@ -376,13 +372,11 @@ class SessionTest(TestBase, AssertsExecutionResults):
assert len(sess.query(User).all()) == 1
sess.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
+ @testing.requires.savepoints
def test_nested_autotrans(self):
class User(object):pass
mapper(User, users)
- sess = create_session(transactional=True)
+ sess = create_session(autocommit=False)
u = User()
sess.save(u)
sess.flush()
@@ -399,14 +393,12 @@ class SessionTest(TestBase, AssertsExecutionResults):
assert len(sess.query(User).all()) == 1
sess.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
+ @testing.requires.savepoints
def test_nested_transaction_connection_add(self):
class User(object): pass
mapper(User, users)
- sess = create_session(transactional=False)
+ sess = create_session(autocommit=True)
sess.begin()
sess.begin_nested()
@@ -436,18 +428,16 @@ class SessionTest(TestBase, AssertsExecutionResults):
sess.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
+ @testing.requires.savepoints
def test_mixed_transaction_control(self):
class User(object): pass
mapper(User, users)
- sess = create_session(transactional=False)
+ sess = create_session(autocommit=True)
sess.begin()
sess.begin_nested()
- transaction = sess.begin()
+ transaction = sess.begin(subtransactions=True)
sess.save(User())
@@ -469,14 +459,12 @@ class SessionTest(TestBase, AssertsExecutionResults):
sess.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
+ @testing.requires.savepoints
def test_mixed_transaction_close(self):
class User(object): pass
mapper(User, users)
- sess = create_session(transactional=True)
+ sess = create_session(autocommit=False)
sess.begin_nested()
@@ -492,27 +480,20 @@ class SessionTest(TestBase, AssertsExecutionResults):
self.assertEquals(len(sess.query(User).all()), 1)
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
def test_error_on_using_inactive_session(self):
class User(object): pass
mapper(User, users)
- sess = create_session(transactional=False)
+ sess = create_session(autocommit=True)
- try:
- sess.begin()
- sess.begin()
+ sess.begin()
+ sess.begin(subtransactions=True)
- sess.save(User())
- sess.flush()
+ sess.save(User())
+ sess.flush()
- sess.rollback()
- sess.begin()
- assert False
- except exceptions.InvalidRequestError, e:
- self.assertEquals(str(e), "The transaction is inactive due to a rollback in a subtransaction and should be closed")
+ sess.rollback()
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "inactive due to a rollback in a subtransaction", sess.begin, subtransactions=True)
sess.close()
@engines.close_open_connections
@@ -521,30 +502,14 @@ class SessionTest(TestBase, AssertsExecutionResults):
mapper(User, users)
c = testing.db.connect()
sess = create_session(bind=c)
- sess.create_transaction()
+ sess.begin()
transaction = sess.transaction
u = User()
sess.save(u)
sess.flush()
- assert transaction.get_or_add(testing.db) is transaction.get_or_add(c) is c
+ assert transaction._connection_for_bind(testing.db) is transaction._connection_for_bind(c) is c
- try:
- transaction.add(testing.db.connect())
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Session already has a Connection associated for the given Connection's Engine"
-
- try:
- transaction.get_or_add(testing.db.connect())
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Session already has a Connection associated for the given Connection's Engine"
-
- try:
- transaction.add(testing.db)
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Session already has a Connection associated for the given Engine"
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect())
transaction.rollback()
assert len(sess.query(User).all()) == 0
@@ -555,7 +520,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
mapper(User, users)
c = testing.db.connect()
- sess = create_session(bind=c, transactional=True)
+ sess = create_session(bind=c, autocommit=False)
u = User()
sess.save(u)
sess.flush()
@@ -563,7 +528,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
assert not c.in_transaction()
assert c.scalar("select count(1) from users") == 0
- sess = create_session(bind=c, transactional=True)
+ sess = create_session(bind=c, autocommit=False)
u = User()
sess.save(u)
sess.flush()
@@ -576,7 +541,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
c = testing.db.connect()
trans = c.begin()
- sess = create_session(bind=c, transactional=False)
+ sess = create_session(bind=c, autocommit=True)
u = User()
sess.save(u)
sess.flush()
@@ -596,17 +561,8 @@ class SessionTest(TestBase, AssertsExecutionResults):
user = User()
- try:
- s.update(user)
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Instance 'User@%s' is not persisted" % hex(id(user))
-
- try:
- s.delete(user)
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Instance 'User@%s' is not persisted" % hex(id(user))
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "is not persisted", s.update, user)
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "is not persisted", s.delete, user)
s.save(user)
s.flush()
@@ -632,25 +588,13 @@ class SessionTest(TestBase, AssertsExecutionResults):
assert user in s
assert user not in s.dirty
- try:
- s.save(user)
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Instance 'User@%s' is already persistent" % hex(id(user))
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "is already persistent", s.save, user)
s2 = create_session()
- try:
- s2.delete(user)
- assert False
- except exceptions.InvalidRequestError, e:
- assert "is already attached to session" in str(e)
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "is already attached to session", s2.delete, user)
u2 = s2.query(User).get(user.user_id)
- try:
- s.delete(u2)
- assert False
- except exceptions.InvalidRequestError, e:
- assert "already persisted with a different identity" in str(e)
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "already persisted with a different identity", s.delete, u2)
s.delete(user)
s.flush()
@@ -707,21 +651,18 @@ class SessionTest(TestBase, AssertsExecutionResults):
del user
gc.collect()
assert len(s.identity_map) == 0
- assert len(s.identity_map.data) == 0
user = s.query(User).one()
user.user_name = 'fred'
del user
gc.collect()
assert len(s.identity_map) == 1
- assert len(s.identity_map.data) == 1
assert len(s.dirty) == 1
s.flush()
gc.collect()
assert not s.dirty
assert not s.identity_map
- assert not s.identity_map.data
user = s.query(User).one()
assert user.user_name == 'fred'
@@ -890,7 +831,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
assert log == ['before_flush', 'after_begin', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec']
log = []
- sess = create_session(transactional=True, extension=MyExt())
+ sess = create_session(autocommit=False, extension=MyExt())
u = User()
sess.save(u)
sess.flush()
@@ -906,7 +847,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
assert log == ['before_commit', 'after_commit']
log = []
- sess = create_session(transactional=True, extension=MyExt(), bind=testing.db)
+ sess = create_session(autocommit=False, extension=MyExt(), bind=testing.db)
conn = sess.connection()
assert log == ['after_begin']
@@ -918,11 +859,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
u1 = User()
sess1.save(u1)
- try:
- sess2.save(u1)
- assert False
- except exceptions.InvalidRequestError, e:
- assert "already attached to session" in str(e)
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "already attached to session", sess2.save, u1)
u2 = pickle.loads(pickle.dumps(u1))
@@ -941,6 +878,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
sess.expunge(u1)
assert u1 not in sess
+ assert Session.object_session(u1) is None
u2 = sess.query(User).get(u1.user_id)
assert u2 is not None and u2 is not u1
@@ -950,12 +888,14 @@ class SessionTest(TestBase, AssertsExecutionResults):
sess.expunge(u2)
assert u2 not in sess
+ assert Session.object_session(u2) is None
u1.user_name = "John"
u2.user_name = "Doe"
sess.update(u1)
assert u1 in sess
+ assert Session.object_session(u1) is sess
sess.flush()
@@ -981,197 +921,39 @@ class SessionTest(TestBase, AssertsExecutionResults):
assert len(list(sess)) == 1
-class ScopedSessionTest(ORMTest):
-
- def define_tables(self, metadata):
- global table, table2
- table = Table('sometable', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(30)))
- table2 = Table('someothertable', metadata,
- Column('id', Integer, primary_key=True),
- Column('someid', None, ForeignKey('sometable.id'))
- )
-
- def test_basic(self):
- Session = scoped_session(sessionmaker())
-
- class SomeObject(fixtures.Base):
- query = Session.query_property()
- class SomeOtherObject(fixtures.Base):
- query = Session.query_property()
-
- mapper(SomeObject, table, properties={
- 'options':relation(SomeOtherObject)
- })
- mapper(SomeOtherObject, table2)
-
- s = SomeObject(id=1, data="hello")
- sso = SomeOtherObject()
- s.options.append(sso)
- Session.save(s)
- Session.commit()
- Session.remove()
-
- self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), Session.query(SomeObject).one())
- self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), SomeObject.query.one())
- self.assertEquals(SomeOtherObject(someid=1), SomeOtherObject.query.filter(SomeOtherObject.someid==sso.someid).one())
-
-class ScopedMapperTest(TestBase):
+class TLTransactionTest(TestBase):
def setUpAll(self):
- global metadata, table, table2
- metadata = MetaData(testing.db)
- table = Table('sometable', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(30), nullable=False))
- table2 = Table('someothertable', metadata,
- Column('id', Integer, primary_key=True),
- Column('someid', None, ForeignKey('sometable.id'))
- )
- metadata.create_all()
-
- def setUp(self):
- global SomeObject, SomeOtherObject
- class SomeObject(object):pass
- class SomeOtherObject(object):pass
-
- global Session
-
- Session = scoped_session(create_session)
- Session.mapper(SomeObject, table, properties={
- 'options':relation(SomeOtherObject)
- })
- Session.mapper(SomeOtherObject, table2)
-
- s = SomeObject()
- s.id = 1
- s.data = 'hello'
- sso = SomeOtherObject()
- s.options.append(sso)
- Session.flush()
- Session.clear()
-
- def tearDownAll(self):
- metadata.drop_all()
-
+ global users, metadata, tlengine
+ tlengine = create_engine(testing.db.url, strategy='threadlocal')
+ metadata = MetaData()
+ users = Table('query_users', metadata,
+ Column('user_id', INT, Sequence('query_users_id_seq', optional=True), primary_key=True),
+ Column('user_name', VARCHAR(20)),
+ test_needs_acid=True,
+ )
+ users.create(tlengine)
def tearDown(self):
- for table in metadata.table_iterator(reverse=True):
- table.delete().execute()
- clear_mappers()
-
- def test_query(self):
- sso = SomeOtherObject.query().first()
- assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
+ tlengine.execute(users.delete())
- def test_query_compiles(self):
- class Foo(object):
- pass
- Session.mapper(Foo, table2)
- assert hasattr(Foo, 'query')
-
- ext = MapperExtension()
-
- class Bar(object):
- pass
- Session.mapper(Bar, table2, extension=[ext])
- assert hasattr(Bar, 'query')
+ def tearDownAll(self):
+ users.drop(tlengine)
+ tlengine.dispose()
- class Baz(object):
+ @testing.exclude('mysql', '<', (5, 0, 3))
+ def testsessionnesting(self):
+ class User(object):
pass
- Session.mapper(Baz, table2, extension=ext)
- assert hasattr(Baz, 'query')
-
- def test_validating_constructor(self):
- s2 = SomeObject(someid=12)
- s3 = SomeOtherObject(someid=123, bogus=345)
-
- class ValidatedOtherObject(object):pass
- Session.mapper(ValidatedOtherObject, table2, validate=True)
-
- v1 = ValidatedOtherObject(someid=12)
try:
- v2 = ValidatedOtherObject(someid=12, bogus=345)
- assert False
- except exceptions.ArgumentError:
- pass
-
- def test_dont_clobber_methods(self):
- class MyClass(object):
- def expunge(self):
- return "an expunge !"
-
- Session.mapper(MyClass, table2)
-
- assert MyClass().expunge() == "an expunge !"
-
- def _test_autoflush_saveoninit(self, on_init, autoflush=None):
- Session = scoped_session(
- sessionmaker(transactional=True, autoflush=True))
-
- class Foo(object):
- def __init__(self, data=None):
- if autoflush is not None:
- friends = Session.query(Foo).autoflush(autoflush).all()
- else:
- friends = Session.query(Foo).all()
- self.data = data
-
- Session.mapper(Foo, table, save_on_init=on_init)
-
- a1 = Foo('an address')
- Session.flush()
-
- def test_autoflush_saveoninit(self):
- """Test save_on_init + query.autoflush()"""
- self._test_autoflush_saveoninit(False)
- self._test_autoflush_saveoninit(False, True)
- self._test_autoflush_saveoninit(False, False)
-
- self.assertRaises(exceptions.DBAPIError,
- self._test_autoflush_saveoninit, True)
- self.assertRaises(exceptions.DBAPIError,
- self._test_autoflush_saveoninit, True, True)
- self._test_autoflush_saveoninit(True, False)
-
-
-class ScopedMapperTest2(ORMTest):
- def define_tables(self, metadata):
- global table, table2
- table = Table('sometable', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(30)),
- Column('type', String(30))
-
- )
- table2 = Table('someothertable', metadata,
- Column('id', Integer, primary_key=True),
- Column('someid', None, ForeignKey('sometable.id')),
- Column('somedata', String(30)),
- )
-
- def test_inheritance(self):
- def expunge_list(l):
- for x in l:
- Session.expunge(x)
- return l
-
- class BaseClass(fixtures.Base):
- pass
- class SubClass(BaseClass):
- pass
-
- Session = scoped_session(sessionmaker())
- Session.mapper(BaseClass, table, polymorphic_identity='base', polymorphic_on=table.c.type)
- Session.mapper(SubClass, table2, polymorphic_identity='sub', inherits=BaseClass)
-
- b = BaseClass(data='b1')
- s = SubClass(data='s1', somedata='somedata')
- Session.commit()
- Session.clear()
-
- assert expunge_list([BaseClass(data='b1'), SubClass(data='s1', somedata='somedata')]) == BaseClass.query.all()
- assert expunge_list([SubClass(data='s1', somedata='somedata')]) == SubClass.query.all()
+ mapper(User, users)
+ sess = create_session(bind=tlengine)
+ tlengine.begin()
+ u = User()
+ sess.save(u)
+ sess.flush()
+ tlengine.commit()
+ finally:
+ clear_mappers()
if __name__ == "__main__":
diff --git a/test/orm/sessioncontext.py b/test/orm/sessioncontext.py
deleted file mode 100644
index c743dabf9..000000000
--- a/test/orm/sessioncontext.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
-from sqlalchemy.orm.session import object_session, Session
-from testlib import *
-
-
-metadata = MetaData()
-users = Table('users', metadata,
- Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
- Column('user_name', String(40)),
-)
-
-class SessionContextTest(TestBase, AssertsExecutionResults):
- def setUp(self):
- clear_mappers()
-
- def do_test(self, class_, context):
- """test session assignment on object creation"""
- obj = class_()
- assert context.current == object_session(obj)
-
- # keep a reference so the old session doesn't get gc'd
- old_session = context.current
-
- context.current = Session()
- assert context.current != object_session(obj)
- assert old_session == object_session(obj)
-
- new_session = context.current
- del context.current
- assert context.current != new_session
- assert old_session == object_session(obj)
-
- obj2 = class_()
- assert context.current == object_session(obj2)
-
- @testing.uses_deprecated('SessionContext')
- def test_mapper_extension(self):
- context = SessionContext(Session)
- class User(object): pass
- User.mapper = mapper(User, users, extension=context.mapper_extension)
- self.do_test(User, context)
-
-
-if __name__ == "__main__":
- testenv.main()
diff --git a/test/orm/sharding/shard.py b/test/orm/sharding/shard.py
index d231b14a2..f25d097fd 100644
--- a/test/orm/sharding/shard.py
+++ b/test/orm/sharding/shard.py
@@ -1,7 +1,7 @@
import testenv; testenv.configure_for_tests()
import datetime, os
from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import sql
from sqlalchemy.orm import *
from sqlalchemy.orm.shard import ShardedSession
from sqlalchemy.sql import operators
@@ -93,7 +93,7 @@ class ShardTest(TestBase):
else:
return ids
- create_session = sessionmaker(class_=ShardedSession, autoflush=True, transactional=True)
+ create_session = sessionmaker(class_=ShardedSession, autoflush=True, autocommit=False)
create_session.configure(shards={
'north_america':db1,
@@ -139,7 +139,7 @@ class ShardTest(TestBase):
for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
sess.save(c)
sess.commit()
-
+ tokyo.city # reload 'city' attribute on tokyo
sess.clear()
assert db2.execute(weather_locations.select()).fetchall() == [(1, 'Asia', 'Tokyo')]
diff --git a/test/orm/transaction.py b/test/orm/transaction.py
new file mode 100644
index 000000000..ca3680057
--- /dev/null
+++ b/test/orm/transaction.py
@@ -0,0 +1,360 @@
+import testenv; testenv.configure_for_tests()
+import operator
+from sqlalchemy import *
+from sqlalchemy import exc as sa_exc
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.fixtures import *
+
+
+class TransactionTest(FixtureTest):
+ keep_mappers = True
+ session = sessionmaker()
+
+ def setup_mappers(self):
+ mapper(User, users, properties={
+ 'addresses':relation(Address, backref='user',
+ cascade="all, delete-orphan"),
+ })
+ mapper(Address, addresses)
+
+
+class FixtureDataTest(TransactionTest):
+ refresh_data = True
+
+ def test_attrs_on_rollback(self):
+ sess = self.session()
+ u1 = sess.get(User, 7)
+ u1.name = 'ed'
+ sess.rollback()
+ self.assertEquals(u1.name, 'jack')
+
+ def test_commit_persistent(self):
+ sess = self.session()
+ u1 = sess.get(User, 7)
+ u1.name = 'ed'
+ sess.flush()
+ sess.commit()
+ self.assertEquals(u1.name, 'ed')
+
+ def test_concurrent_commit_persistent(self):
+ s1 = self.session()
+ u1 = s1.get(User, 7)
+ u1.name = 'ed'
+ s1.commit()
+
+ s2 = self.session()
+ u2 = s2.get(User, 7)
+ assert u2.name == 'ed'
+ u2.name = 'will'
+ s2.commit()
+
+ assert u1.name == 'will'
+
+class AutoExpireTest(TransactionTest):
+ tables_only = True
+
+ def test_expunge_pending_on_rollback(self):
+ sess = self.session()
+ u2= User(name='newuser')
+ sess.add(u2)
+ assert u2 in sess
+ sess.rollback()
+ assert u2 not in sess
+
+ def test_trans_pending_cleared_on_commit(self):
+ sess = self.session()
+ u2= User(name='newuser')
+ sess.add(u2)
+ assert u2 in sess
+ sess.commit()
+ assert u2 in sess
+ u3 = User(name='anotheruser')
+ sess.add(u3)
+ sess.rollback()
+ assert u3 not in sess
+ assert u2 in sess
+
+ def test_update_deleted_on_rollback(self):
+ s = self.session()
+ u1 = User(name='ed')
+ s.add(u1)
+ s.commit()
+
+ s.delete(u1)
+ assert u1 in s.deleted
+ s.rollback()
+ assert u1 in s
+ assert u1 not in s.deleted
+
+ def test_trans_deleted_cleared_on_rollback(self):
+ s = self.session()
+ u1 = User(name='ed')
+ s.add(u1)
+ s.commit()
+
+ s.delete(u1)
+ s.commit()
+ assert u1 not in s
+ s.rollback()
+ assert u1 not in s
+
+ def test_update_deleted_on_rollback_cascade(self):
+ s = self.session()
+ u1 = User(name='ed', addresses=[Address(email_address='foo')])
+ s.add(u1)
+ s.commit()
+
+ s.delete(u1)
+ assert u1 in s.deleted
+ assert u1.addresses[0] in s.deleted
+ s.rollback()
+ assert u1 in s
+ assert u1 not in s.deleted
+ assert u1.addresses[0] not in s.deleted
+
+ def test_update_deleted_on_rollback_orphan(self):
+ s = self.session()
+ u1 = User(name='ed', addresses=[Address(email_address='foo')])
+ s.add(u1)
+ s.commit()
+
+ a1 = u1.addresses[0]
+ u1.addresses.remove(a1)
+
+ s.flush()
+ self.assertEquals(s.query(Address).filter(Address.email_address=='foo').all(), [])
+ s.rollback()
+ assert a1 not in s.deleted
+ assert u1.addresses == [a1]
+
+ def test_commit_pending(self):
+ sess = self.session()
+ u1 = User(name='newuser')
+ sess.add(u1)
+ sess.flush()
+ sess.commit()
+ self.assertEquals(u1.name, 'newuser')
+
+
+ def test_concurrent_commit_pending(self):
+ s1 = self.session()
+ u1 = User(name='edward')
+ s1.add(u1)
+ s1.commit()
+
+ s2 = self.session()
+ u2 = s2.query(User).filter(User.name=='edward').one()
+ u2.name = 'will'
+ s2.commit()
+
+ assert u1.name == 'will'
+
+class RollbackRecoverTest(TransactionTest):
+ only_tables = True
+
+ def test_pk_violation(self):
+ s = self.session()
+ a1 = Address(email_address='foo')
+ u1 = User(id=1, name='ed', addresses=[a1])
+ s.add(u1)
+ s.commit()
+
+ a2 = Address(email_address='bar')
+ u2 = User(id=1, name='jack', addresses=[a2])
+
+ u1.name = 'edward'
+ a1.email_address = 'foober'
+ s.add(u2)
+ self.assertRaises(sa_exc.FlushError, s.commit)
+ self.assertRaises(sa_exc.InvalidRequestError, s.commit)
+ s.rollback()
+ assert u2 not in s
+ assert a2 not in s
+ assert u1 in s
+ assert a1 in s
+ assert u1.name == 'ed'
+ assert a1.email_address == 'foo'
+ u1.name = 'edward'
+ a1.email_address = 'foober'
+ s.commit()
+ assert s.query(User).all() == [User(id=1, name='edward', addresses=[Address(email_address='foober')])]
+
+ @testing.requires.savepoints
+ def test_pk_violation_with_savepoint(self):
+ s = self.session()
+ a1 = Address(email_address='foo')
+ u1 = User(id=1, name='ed', addresses=[a1])
+ s.add(u1)
+ s.commit()
+
+ a2 = Address(email_address='bar')
+ u2 = User(id=1, name='jack', addresses=[a2])
+
+ u1.name = 'edward'
+ a1.email_address = 'foober'
+ s.begin_nested()
+ s.add(u2)
+ self.assertRaises(sa_exc.FlushError, s.commit)
+ self.assertRaises(sa_exc.InvalidRequestError, s.commit)
+ s.rollback()
+ assert u2 not in s
+ assert a2 not in s
+ assert u1 in s
+ assert a1 in s
+
+ s.commit()
+ assert s.query(User).all() == [User(id=1, name='edward', addresses=[Address(email_address='foober')])]
+
+
+class SavepointTest(TransactionTest):
+
+ only_tables = True
+
+ @testing.requires.savepoints
+ def test_savepoint_rollback(self):
+ s = self.session()
+ u1 = User(name='ed')
+ u2 = User(name='jack')
+ s.add_all([u1, u2])
+
+ s.begin_nested()
+ u3 = User(name='wendy')
+ u4 = User(name='foo')
+ u1.name = 'edward'
+ u2.name = 'jackward'
+ s.add_all([u3, u4])
+ self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+ s.rollback()
+ assert u1.name == 'ed'
+ assert u2.name == 'jack'
+ self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
+ s.commit()
+ assert u1.name == 'ed'
+ assert u2.name == 'jack'
+ self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
+
+ @testing.requires.savepoints
+ def test_savepoint_commit(self):
+ s = self.session()
+ u1 = User(name='ed')
+ u2 = User(name='jack')
+ s.add_all([u1, u2])
+
+ s.begin_nested()
+ u3 = User(name='wendy')
+ u4 = User(name='foo')
+ u1.name = 'edward'
+ u2.name = 'jackward'
+ s.add_all([u3, u4])
+ self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+ s.commit()
+ def go():
+ assert u1.name == 'edward'
+ assert u2.name == 'jackward'
+ self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+ self.assert_sql_count(testing.db, go, 1)
+
+ s.commit()
+ self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+
+ @testing.requires.savepoints
+ def test_savepoint_rollback_collections(self):
+ s = self.session()
+ u1 = User(name='ed', addresses=[Address(email_address='foo')])
+ s.add(u1)
+ s.commit()
+
+ u1.name='edward'
+ u1.addresses.append(Address(email_address='bar'))
+ s.begin_nested()
+ u2 = User(name='jack', addresses=[Address(email_address='bat')])
+ s.add(u2)
+ self.assertEquals(s.query(User).order_by(User.id).all(),
+ [
+ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+ User(name='jack', addresses=[Address(email_address='bat')])
+ ]
+ )
+ s.rollback()
+ self.assertEquals(s.query(User).order_by(User.id).all(),
+ [
+ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+ ]
+ )
+ s.commit()
+ self.assertEquals(s.query(User).order_by(User.id).all(),
+ [
+ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+ ]
+ )
+
+ @testing.requires.savepoints
+ def test_savepoint_commit_collections(self):
+ s = self.session()
+ u1 = User(name='ed', addresses=[Address(email_address='foo')])
+ s.add(u1)
+ s.commit()
+
+ u1.name='edward'
+ u1.addresses.append(Address(email_address='bar'))
+ s.begin_nested()
+ u2 = User(name='jack', addresses=[Address(email_address='bat')])
+ s.add(u2)
+ self.assertEquals(s.query(User).order_by(User.id).all(),
+ [
+ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+ User(name='jack', addresses=[Address(email_address='bat')])
+ ]
+ )
+ s.commit()
+ self.assertEquals(s.query(User).order_by(User.id).all(),
+ [
+ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+ User(name='jack', addresses=[Address(email_address='bat')])
+ ]
+ )
+ s.commit()
+ self.assertEquals(s.query(User).order_by(User.id).all(),
+ [
+ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+ User(name='jack', addresses=[Address(email_address='bat')])
+ ]
+ )
+
+ @testing.requires.savepoints
+ def test_expunge_pending_on_rollback(self):
+ sess = self.session()
+
+ sess.begin_nested()
+ u2= User(name='newuser')
+ sess.add(u2)
+ assert u2 in sess
+ sess.rollback()
+ assert u2 not in sess
+
+ @testing.requires.savepoints
+ def test_update_deleted_on_rollback(self):
+ s = self.session()
+ u1 = User(name='ed')
+ s.add(u1)
+ s.commit()
+
+ s.begin_nested()
+ s.delete(u1)
+ assert u1 in s.deleted
+ s.rollback()
+ assert u1 in s
+ assert u1 not in s.deleted
+
+
+
+class AutocommitTest(TransactionTest):
+ def test_begin_nested_requires_trans(self):
+ sess = create_session(autocommit=True)
+ self.assertRaises(sa_exc.InvalidRequestError, sess.begin_nested)
+
+
+
+if __name__ == '__main__':
+ testenv.main()
diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py
index cd2a3005e..4c6f6f4cf 100644
--- a/test/orm/unitofwork.py
+++ b/test/orm/unitofwork.py
@@ -5,8 +5,9 @@
import testenv; testenv.configure_for_tests()
import pickleable
from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc as sa_exc, sql
from sqlalchemy.orm import *
+from sqlalchemy.orm import attributes, exc as orm_exc
from testlib import *
from testlib.tables import *
from testlib import engines, tables, fixtures
@@ -14,7 +15,7 @@ from testlib import engines, tables, fixtures
# TODO: convert suite to not use Session.mapper, use fixtures.Base
# with explicit session.save()
-Session = scoped_session(sessionmaker(autoflush=True, transactional=True))
+Session = scoped_session(sessionmaker(autoflush=True, autocommit=False, autoexpire=False))
orm_mapper = mapper
mapper = Session.mapper
@@ -28,8 +29,10 @@ class HistoryTest(ORMTest):
def test_backref(self):
s = Session()
- class User(object):pass
- class Address(object):pass
+ class User(object):
+ def __init__(self, **kw): pass
+ class Address(object):
+ def __init__(self, _sa_session=None): pass
am = mapper(Address, addresses)
m = mapper(User, users, properties = dict(
addresses = relation(am, backref='user', lazy=False))
@@ -59,7 +62,9 @@ class VersioningTest(ORMTest):
@engines.close_open_connections
def test_basic(self):
s = Session(scope=None)
- class Foo(object):pass
+ class Foo(object):
+ def __init__(self, value, _sa_session=None):
+ self.value = value
mapper(Foo, version_table, version_id_col=version_table.c.version_id)
f1 = Foo(value='f1', _sa_session=s)
f2 = Foo(value='f2', _sa_session=s)
@@ -67,26 +72,22 @@ class VersioningTest(ORMTest):
f1.value='f1rev2'
s.commit()
+
s2 = Session()
f1_s = s2.query(Foo).get(f1.id)
f1_s.value='f1rev3'
s2.commit()
f1.value='f1rev3mine'
- success = False
- try:
- # a concurrent session has modified this, should throw
- # an exception
- s.commit()
- except exceptions.ConcurrentModificationError, e:
- #print e
- success = True
# Only dialects with a sane rowcount can detect the ConcurrentModificationError
if testing.db.dialect.supports_sane_rowcount:
- assert success
-
- s.close()
+ self.assertRaises(orm_exc.ConcurrentModificationError, s.commit)
+ s.rollback()
+ else:
+ s.commit()
+
+ # new in 0.5 ! dont need to close the session
f1 = s.query(Foo).get(f1.id)
f2 = s.query(Foo).get(f2.id)
@@ -95,33 +96,29 @@ class VersioningTest(ORMTest):
s.delete(f1)
s.delete(f2)
- success = False
- try:
- s.commit()
- except exceptions.ConcurrentModificationError, e:
- #print e
- success = True
+
if testing.db.dialect.supports_sane_multi_rowcount:
- assert success
+ self.assertRaises(orm_exc.ConcurrentModificationError, s.commit)
+ else:
+ s.commit()
@engines.close_open_connections
def test_versioncheck(self):
"""test that query.with_lockmode performs a 'version check' on an already loaded instance"""
s1 = Session(scope=None)
- class Foo(object):pass
+ class Foo(object):
+ def __init__(self, _sa_session=None): pass
mapper(Foo, version_table, version_id_col=version_table.c.version_id)
- f1s1 =Foo(value='f1', _sa_session=s1)
+ f1s1 = Foo(_sa_session=s1)
+ f1s1.value = 'f1 value'
s1.commit()
s2 = Session()
f1s2 = s2.query(Foo).get(f1s1.id)
f1s2.value='f1 new value'
s2.commit()
- try:
- # load, version is wrong
- s1.query(Foo).with_lockmode('read').get(f1s1.id)
- assert False
- except exceptions.ConcurrentModificationError, e:
- assert True
+ # load, version is wrong
+ self.assertRaises(orm_exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id)
+
# reload it
s1.query(Foo).load(f1s1.id)
# now assert version OK
@@ -135,9 +132,11 @@ class VersioningTest(ORMTest):
def test_noversioncheck(self):
"""test that query.with_lockmode works OK when the mapper has no version id col"""
s1 = Session()
- class Foo(object):pass
+ class Foo(object):
+ def __init__(self, _sa_session=None): pass
mapper(Foo, version_table)
- f1s1 =Foo(value='f1', _sa_session=s1)
+ f1s1 =Foo(_sa_session=s1)
+ f1s1.value = 'foo'
f1s1.version_id=0
s1.commit()
s2 = Session()
@@ -271,9 +270,11 @@ class MutableTypesTest(ORMTest):
Session.commit()
Session.close()
f2 = Session.query(Foo).filter_by(id=f1.id).one()
+ assert 'data' in attributes.instance_state(f2).unmodified
assert f2.data == f1.data
f2.data.y = 19
assert f2 in Session.dirty
+ assert 'data' not in attributes.instance_state(f2).unmodified
Session.commit()
Session.close()
f3 = Session.query(Foo).filter_by(id=f1.id).one()
@@ -439,8 +440,11 @@ class PKTest(ORMTest):
e.multi_rev = 2
Session.commit()
Session.close()
- e2 = Query(Entry).get((e.multi_id, 2))
- self.assert_(e is not e2 and e._instance_key == e2._instance_key)
+ e2 = Session.query(Entry).get((e.multi_id, 2))
+ self.assert_(e is not e2)
+ state = attributes.instance_state(e)
+ state2 = attributes.instance_state(e2)
+ self.assert_(state.key == state2.key)
# this one works with sqlite since we are manually setting up pk values
def test_manualpk(self):
@@ -514,8 +518,7 @@ class ClauseAttributesTest(ORMTest):
Column('counter', Integer, default=1))
def test_update(self):
- class User(object):
- pass
+ class User(fixtures.Base): pass
mapper(User, users_table)
u = User(name='test')
sess = Session()
@@ -530,8 +533,7 @@ class ClauseAttributesTest(ORMTest):
self.assert_sql_count(testing.db, go, 1)
def test_multi_update(self):
- class User(object):
- pass
+ class User(fixtures.Base): pass
mapper(User, users_table)
u = User(name='test')
sess = Session()
@@ -553,8 +555,7 @@ class ClauseAttributesTest(ORMTest):
@testing.unsupported('mssql')
def test_insert(self):
- class User(object):
- pass
+ class User(fixtures.Base): pass
mapper(User, users_table)
u = User(name='test', counter=select([5]))
sess = Session()
@@ -641,7 +642,7 @@ class ExtraPassiveDeletesTest(ORMTest):
'children':relation(MyOtherClass, passive_deletes='all', cascade="all")
})
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e) == "Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade"
@testing.unsupported('sqlite')
@@ -669,7 +670,7 @@ class ExtraPassiveDeletesTest(ORMTest):
assert myothertable.count().scalar() == 4
mc = sess.query(MyClass).get(mc.id)
sess.delete(mc)
- self.assertRaises(exceptions.DBAPIError, sess.commit)
+ self.assertRaises(sa_exc.DBAPIError, sess.commit)
@testing.unsupported('sqlite')
def test_extra_passive_2(self):
@@ -694,7 +695,7 @@ class ExtraPassiveDeletesTest(ORMTest):
mc = sess.query(MyClass).get(mc.id)
sess.delete(mc)
mc.children[0].data = 'some new data'
- self.assertRaises(exceptions.DBAPIError, sess.commit)
+ self.assertRaises(sa_exc.DBAPIError, sess.commit)
class DefaultTest(ORMTest):
@@ -736,7 +737,7 @@ class DefaultTest(ORMTest):
secondary_table.append_column(Column('hoho', hohotype, ForeignKey('default_test.hoho')))
def test_insert(self):
- class Hoho(object):pass
+ class Hoho(fixtures.Base): pass
mapper(Hoho, default_table)
h1 = Hoho(hoho=althohoval)
@@ -790,7 +791,7 @@ class DefaultTest(ORMTest):
def test_insert_nopostfetch(self):
# populates the PassiveDefaults explicitly so there is no "post-update"
- class Hoho(object):pass
+ class Hoho(fixtures.Base): pass
mapper(Hoho, default_table)
h1 = Hoho(hoho="15", counter="15")
@@ -803,7 +804,7 @@ class DefaultTest(ORMTest):
self.assert_sql_count(testing.db, go, 0)
def test_update(self):
- class Hoho(object):pass
+ class Hoho(fixtures.Base): pass
mapper(Hoho, default_table)
h1 = Hoho()
Session.commit()
@@ -971,7 +972,7 @@ class OneToManyTest(ORMTest):
def test_o2m_delete_parent(self):
m = mapper(User, users, properties = dict(
- address = relation(mapper(Address, addresses), lazy = True, uselist = False, private = False)
+ address = relation(mapper(Address, addresses), lazy=True, uselist=False)
))
u = User()
a = Address()
@@ -981,7 +982,10 @@ class OneToManyTest(ORMTest):
Session.commit()
Session.delete(u)
Session.commit()
- self.assert_(a.address_id is not None and a.user_id is None and u._instance_key not in Session.identity_map and a._instance_key in Session.identity_map)
+ self.assert_(a.address_id is not None)
+ self.assert_(a.user_id is None)
+ self.assert_(attributes.instance_state(a).key in Session.identity_map)
+ self.assert_(attributes.instance_state(u).key not in Session.identity_map)
def test_onetoone(self):
m = mapper(User, users, properties = dict(
@@ -2029,7 +2033,7 @@ class TransactionTest(ORMTest):
orm_mapper(T2, t2)
def test_close_transaction_on_commit_fail(self):
- Session = sessionmaker(autoflush=False, transactional=False)
+ Session = sessionmaker(autoflush=False, autocommit=True)
sess = Session()
# with a deferred constraint, this fails at COMMIT time instead
diff --git a/test/orm/utils.py b/test/orm/utils.py
new file mode 100644
index 000000000..4bb2464b3
--- /dev/null
+++ b/test/orm/utils.py
@@ -0,0 +1,208 @@
+import testenv; testenv.configure_for_tests()
+from sqlalchemy.orm import interfaces, util
+from testlib import *
+from testlib import fixtures
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import Table
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm import mapper
+
+
+class ExtensionCarrierTest(TestBase):
+ def test_basic(self):
+ carrier = util.ExtensionCarrier()
+
+ assert 'translate_row' not in carrier.methods
+ assert carrier.translate_row() is interfaces.EXT_CONTINUE
+ assert 'translate_row' not in carrier.methods
+
+ self.assertRaises(AttributeError, lambda: carrier.snickysnack)
+
+ class Partial(object):
+ def __init__(self, marker):
+ self.marker = marker
+ def translate_row(self, row):
+ return self.marker
+
+ carrier.append(Partial('end'))
+ assert 'translate_row' in carrier.methods
+ assert carrier.translate_row(None) == 'end'
+
+ carrier.push(Partial('front'))
+ assert carrier.translate_row(None) == 'front'
+
+ assert 'populate_instance' not in carrier.methods
+ carrier.append(interfaces.MapperExtension)
+ assert 'populate_instance' in carrier.methods
+
+ assert carrier.interface
+ for m in carrier.interface:
+ assert getattr(interfaces.MapperExtension, m)
+
+class AliasedClassTest(TestBase):
+ def point_map(self, cls):
+ table = Table('point', MetaData(),
+ Column('id', Integer(), primary_key=True),
+ Column('x', Integer),
+ Column('y', Integer))
+ mapper(cls, table)
+ return table
+
+ def test_simple(self):
+ class Point(object):
+ pass
+ table = self.point_map(Point)
+
+ alias = aliased(Point)
+
+ assert alias.id
+ assert alias.x
+ assert alias.y
+
+ assert Point.id.__clause_element__().table is table
+ assert alias.id.__clause_element__().table is not table
+
+ def test_notcallable(self):
+ class Point(object):
+ pass
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ self.assertRaises(TypeError, alias)
+
+ def test_instancemethods(self):
+ class Point(object):
+ def zero(self):
+ self.x, self.y = 0, 0
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ assert Point.zero
+ assert not getattr(alias, 'zero')
+
+ def test_classmethods(self):
+ class Point(object):
+ @classmethod
+ def max_x(cls):
+ return 100
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ assert Point.max_x
+ assert alias.max_x
+ assert Point.max_x() == alias.max_x()
+
+ def test_simpleproperties(self):
+ class Point(object):
+ @property
+ def max_x(self):
+ return 100
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ assert Point.max_x
+ assert Point.max_x != 100
+ assert alias.max_x
+ assert Point.max_x is alias.max_x
+
+ def test_descriptors(self):
+ class descriptor(object):
+ """Tortured..."""
+ def __init__(self, fn):
+ self.fn = fn
+ def __get__(self, obj, owner):
+ if obj is not None:
+ return self.fn(obj, obj)
+ else:
+ return self
+ def method(self):
+ return 'method'
+
+ class Point(object):
+ center = (0, 0)
+ @descriptor
+ def thing(self, arg):
+ return arg.center
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ assert Point.thing != (0, 0)
+ assert Point().thing == (0, 0)
+ assert Point.thing.method() == 'method'
+
+ assert alias.thing != (0, 0)
+ assert alias.thing.method() == 'method'
+
+ def test_hybrid_descriptors(self):
+ from sqlalchemy import Column # override testlib's override
+ import new
+
+ class MethodDescriptor(object):
+ def __init__(self, func):
+ self.func = func
+ def __get__(self, instance, owner):
+ if instance is None:
+ args = (self.func, owner, owner.__class__)
+ else:
+ args = (self.func, instance, owner)
+ return new.instancemethod(*args)
+
+ class PropertyDescriptor(object):
+ def __init__(self, fget, fset, fdel):
+ self.fget = fget
+ self.fset = fset
+ self.fdel = fdel
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self.fget(owner)
+ else:
+ return self.fget(instance)
+ def __set__(self, instance, value):
+ self.fset(instance, value)
+ def __delete__(self, instance):
+ self.fdel(instance)
+ hybrid = MethodDescriptor
+ def hybrid_property(fget, fset=None, fdel=None):
+ return PropertyDescriptor(fget, fset, fdel)
+
+ def assert_table(expr, table):
+ for child in expr.get_children():
+ if isinstance(child, Column):
+ assert child.table is table
+
+ class Point(object):
+ def __init__(self, x, y):
+ self.x, self.y = x, y
+ @hybrid
+ def left_of(self, other):
+ return self.x < other.x
+
+ double_x = hybrid_property(lambda self: self.x * 2)
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+ alias_table = alias.x.__clause_element__().table
+ assert table is not alias_table
+
+ p1 = Point(-10, -10)
+ p2 = Point(20, 20)
+
+ assert p1.left_of(p2)
+ assert p1.double_x == -20
+
+ assert_table(Point.double_x, table)
+ assert_table(alias.double_x, alias_table)
+
+ assert_table(Point.left_of(p2), table)
+ assert_table(alias.left_of(p2), alias_table)
+
+
+if __name__ == '__main__':
+ testenv.main()
+
diff --git a/test/perf/masseagerload.py b/test/perf/masseagerload.py
index bc2834ff7..a848b866c 100644
--- a/test/perf/masseagerload.py
+++ b/test/perf/masseagerload.py
@@ -37,6 +37,7 @@ def load():
@profiling.profiled('masseagerload', always=True, sort=['cumulative'])
def masseagerload(session):
+ session.begin()
query = session.query(Item)
l = query.all()
print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems"
diff --git a/test/profiling/compiler.py b/test/profiling/compiler.py
index 4e1111aa2..cd0a29ee3 100644
--- a/test/profiling/compiler.py
+++ b/test/profiling/compiler.py
@@ -15,11 +15,11 @@ class CompileTest(TestBase, AssertsExecutionResults):
Column('c1', Integer, primary_key=True),
Column('c2', String(30)))
- @profiling.function_call_count(74, {'2.3': 44, '2.4': 42})
+ @profiling.function_call_count(67, {'2.3': 44, '2.4': 42})
def test_insert(self):
t1.insert().compile()
- @profiling.function_call_count(75, {'2.3': 47, '2.4': 42})
+ @profiling.function_call_count(68, {'2.3': 47, '2.4': 42})
def test_update(self):
t1.update().compile()
diff --git a/test/profiling/zoomark.py b/test/profiling/zoomark.py
index 0994b5d4b..cdf663a4e 100644
--- a/test/profiling/zoomark.py
+++ b/test/profiling/zoomark.py
@@ -332,7 +332,7 @@ class ZooMarkTest(TestBase):
def test_profile_2_insert(self):
self.test_baseline_2_insert()
- @profiling.function_call_count(4923, {'2.4': 2557})
+ @profiling.function_call_count(4662, {'2.4': 2557})
def test_profile_3_properties(self):
self.test_baseline_3_properties()
@@ -344,7 +344,7 @@ class ZooMarkTest(TestBase):
def test_profile_5_aggregates(self):
self.test_baseline_5_aggregates()
- @profiling.function_call_count(1988, {'2.4': 1048})
+ @profiling.function_call_count(1882, {'2.4': 1048})
def test_profile_6_editing(self):
self.test_baseline_6_editing()
diff --git a/test/sql/case_statement.py b/test/sql/case_statement.py
index 6aecefd3c..876f820b5 100644
--- a/test/sql/case_statement.py
+++ b/test/sql/case_statement.py
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
import sys
from sqlalchemy import *
from testlib import *
-from sqlalchemy import util, exceptions
+from sqlalchemy import util, exc
from sqlalchemy.sql import table, column
@@ -91,7 +91,7 @@ class CaseTest(TestBase, AssertsCompiledSQL):
def test_literal_interpretation(self):
t = table('test', column('col1'))
- self.assertRaises(exceptions.ArgumentError, case, [("x", "y")])
+ self.assertRaises(exc.ArgumentError, case, [("x", "y")])
self.assert_compile(case([("x", "y")], value=t.c.col1), "CASE test.col1 WHEN :param_1 THEN :param_2 END")
self.assert_compile(case([(t.c.col1==7, "y")], else_="z"), "CASE WHEN (test.col1 = :col1_1) THEN :param_1 ELSE :param_2 END")
diff --git a/test/sql/columns.py b/test/sql/columns.py
index 76bf9b389..661be891a 100644
--- a/test/sql/columns.py
+++ b/test/sql/columns.py
@@ -1,6 +1,6 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc, sql
from testlib import *
from sqlalchemy import Table, Column # don't use testlib's wrappers
@@ -37,7 +37,7 @@ class ColumnDefinitionTest(TestBase):
def test_incomplete(self):
c = self.columns()
- self.assertRaises(exceptions.ArgumentError, Table, 't', MetaData(), *c)
+ self.assertRaises(exc.ArgumentError, Table, 't', MetaData(), *c)
def test_incomplete_key(self):
c = Column(Integer)
@@ -52,8 +52,8 @@ class ColumnDefinitionTest(TestBase):
def test_bogus(self):
- self.assertRaises(exceptions.ArgumentError, Column, 'foo', name='bar')
- self.assertRaises(exceptions.ArgumentError, Column, 'foo', Integer,
+ self.assertRaises(exc.ArgumentError, Column, 'foo', name='bar')
+ self.assertRaises(exc.ArgumentError, Column, 'foo', Integer,
type_=Integer())
if __name__ == "__main__":
diff --git a/test/sql/constraints.py b/test/sql/constraints.py
index 2908e07da..966930ca9 100644
--- a/test/sql/constraints.py
+++ b/test/sql/constraints.py
@@ -1,6 +1,6 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc
from testlib import *
from testlib import config, engines
@@ -72,14 +72,14 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
try:
foo.insert().execute(id=2,x=5,y=9)
assert False
- except exceptions.SQLError:
+ except exc.SQLError:
assert True
bar.insert().execute(id=1,x=10)
try:
bar.insert().execute(id=2,x=5)
assert False
- except exceptions.SQLError:
+ except exc.SQLError:
assert True
def test_unique_constraint(self):
@@ -100,12 +100,12 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
try:
foo.insert().execute(id=3, value='value1')
assert False
- except exceptions.SQLError:
+ except exc.SQLError:
assert True
try:
bar.insert().execute(id=3, value='a', value2='b')
assert False
- except exceptions.SQLError:
+ except exc.SQLError:
assert True
def test_index_create(self):
diff --git a/test/sql/defaults.py b/test/sql/defaults.py
index 22660c060..e9ed21a65 100644
--- a/test/sql/defaults.py
+++ b/test/sql/defaults.py
@@ -1,7 +1,7 @@
import testenv; testenv.configure_for_tests()
import datetime
from sqlalchemy import *
-from sqlalchemy import exceptions, schema, util
+from sqlalchemy import exc, schema, util
from sqlalchemy.orm import mapper, create_session
from testlib import *
@@ -122,7 +122,7 @@ class DefaultTest(TestBase):
try:
c = ColumnDefault(fn)
assert False, str(fn)
- except exceptions.ArgumentError, e:
+ except exc.ArgumentError, e:
assert str(e) == ex_msg
def test_argsignature(self):
@@ -327,7 +327,7 @@ class AutoIncrementTest(TestBase):
nonai_table.insert().execute(data='row 1')
nonai_table.insert().execute(data='row 2')
assert False
- except exceptions.SQLError, e:
+ except exc.SQLError, e:
print "Got exception", str(e)
assert True
diff --git a/test/sql/functions.py b/test/sql/functions.py
index d1ce17c72..82814ef1b 100644
--- a/test/sql/functions.py
+++ b/test/sql/functions.py
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
import datetime
from sqlalchemy import *
from sqlalchemy.sql import table, column
-from sqlalchemy import databases, exceptions, sql, util
+from sqlalchemy import databases, sql, util
from sqlalchemy.sql.compiler import BIND_TEMPLATES
from sqlalchemy.engine import default
from sqlalchemy import types as sqltypes
diff --git a/test/sql/generative.py b/test/sql/generative.py
index 820474282..cf5ea8235 100644
--- a/test/sql/generative.py
+++ b/test/sql/generative.py
@@ -1,7 +1,7 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.sql import table, column, ClauseElement
-from sqlalchemy.sql.expression import _clone
+from sqlalchemy.sql.expression import _clone, _from_objects
from testlib import *
from sqlalchemy.sql.visitors import *
from sqlalchemy import util
@@ -82,14 +82,14 @@ class TraversalTest(TestBase, AssertsExecutionResults):
def test_clone(self):
struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_a(self, a):
pass
def visit_b(self, b):
pass
vis = Vis()
- s2 = vis.traverse(struct, clone=True)
+ s2 = vis.traverse(struct)
assert struct == s2
assert not struct.is_other(s2)
@@ -103,7 +103,7 @@ class TraversalTest(TestBase, AssertsExecutionResults):
pass
vis = Vis()
- s2 = vis.traverse(struct, clone=False)
+ s2 = vis.traverse(struct)
assert struct == s2
assert struct.is_other(s2)
@@ -112,7 +112,7 @@ class TraversalTest(TestBase, AssertsExecutionResults):
struct2 = B(A("expr1"), A("expr2modified"), B(A("expr1b"), A("expr2b")), A("expr3"))
struct3 = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3"))
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_a(self, a):
if a.expr == "expr2":
a.expr = "expr2modified"
@@ -120,12 +120,12 @@ class TraversalTest(TestBase, AssertsExecutionResults):
pass
vis = Vis()
- s2 = vis.traverse(struct, clone=True)
+ s2 = vis.traverse(struct)
assert struct != s2
assert not struct.is_other(s2)
assert struct2 == s2
- class Vis2(ClauseVisitor):
+ class Vis2(CloningVisitor):
def visit_a(self, a):
if a.expr == "expr2b":
a.expr = "expr2bmodified"
@@ -133,7 +133,7 @@ class TraversalTest(TestBase, AssertsExecutionResults):
pass
vis2 = Vis2()
- s3 = vis2.traverse(struct, clone=True)
+ s3 = vis2.traverse(struct)
assert struct != s3
assert struct3 == s3
@@ -156,7 +156,7 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
def test_binary(self):
clause = t1.c.col2 == t2.c.col2
- assert str(clause) == ClauseVisitor().traverse(clause, clone=True)
+ assert str(clause) == CloningVisitor().traverse(clause)
def test_binary_anon_label_quirk(self):
t = table('t1', column('col1'))
@@ -175,25 +175,25 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
def test_join(self):
clause = t1.join(t2, t1.c.col2==t2.c.col2)
c1 = str(clause)
- assert str(clause) == str(ClauseVisitor().traverse(clause, clone=True))
+ assert str(clause) == str(CloningVisitor().traverse(clause))
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_binary(self, binary):
binary.right = t2.c.col3
- clause2 = Vis().traverse(clause, clone=True)
+ clause2 = Vis().traverse(clause)
assert c1 == str(clause)
assert str(clause2) == str(t1.join(t2, t1.c.col2==t2.c.col3))
def test_text(self):
clause = text("select * from table where foo=:bar", bindparams=[bindparam('bar')])
c1 = str(clause)
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_textclause(self, text):
text.text = text.text + " SOME MODIFIER=:lala"
text.bindparams['lala'] = bindparam('lala')
- clause2 = Vis().traverse(clause, clone=True)
+ clause2 = Vis().traverse(clause)
assert c1 == str(clause)
assert str(clause2) == c1 + " SOME MODIFIER=:lala"
assert clause.bindparams.keys() == ['bar']
@@ -203,24 +203,27 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
s2 = select([t1])
s2_assert = str(s2)
s3_assert = str(select([t1], t1.c.col2==7))
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_select(self, select):
select.append_whereclause(t1.c.col2==7)
- s3 = Vis().traverse(s2, clone=True)
+ s3 = Vis().traverse(s2)
assert str(s3) == s3_assert
assert str(s2) == s2_assert
print str(s2)
print str(s3)
+ class Vis(ClauseVisitor):
+ def visit_select(self, select):
+ select.append_whereclause(t1.c.col2==7)
Vis().traverse(s2)
assert str(s2) == s3_assert
print "------------------"
s4_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col3==9)))
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_select(self, select):
select.append_whereclause(t1.c.col3==9)
- s4 = Vis().traverse(s3, clone=True)
+ s4 = Vis().traverse(s3)
print str(s3)
print str(s4)
assert str(s4) == s4_assert
@@ -228,12 +231,12 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
print "------------------"
s5_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col1==9)))
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_binary(self, binary):
if binary.left is t1.c.col3:
binary.left = t1.c.col1
binary.right = bindparam("col1", unique=True)
- s5 = Vis().traverse(s4, clone=True)
+ s5 = Vis().traverse(s4)
print str(s4)
print str(s5)
assert str(s5) == s5_assert
@@ -241,13 +244,13 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
def test_union(self):
u = union(t1.select(), t2.select())
- u2 = ClauseVisitor().traverse(u, clone=True)
+ u2 = CloningVisitor().traverse(u)
assert str(u) == str(u2)
assert [str(c) for c in u2.c] == [str(c) for c in u.c]
u = union(t1.select(), t2.select())
cols = [str(c) for c in u.c]
- u2 = ClauseVisitor().traverse(u, clone=True)
+ u2 = CloningVisitor().traverse(u)
assert str(u) == str(u2)
assert [str(c) for c in u2.c] == cols
@@ -265,7 +268,7 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
"""test that unique bindparams change their name upon clone() to prevent conflicts"""
s = select([t1], t1.c.col1==bindparam(None, unique=True)).alias()
- s2 = ClauseVisitor().traverse(s, clone=True).alias()
+ s2 = CloningVisitor().traverse(s).alias()
s3 = select([s], s.c.col2==s2.c.col2)
self.assert_compile(s3, "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "\
@@ -274,7 +277,7 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
"WHERE anon_1.col2 = anon_2.col2")
s = select([t1], t1.c.col1==4).alias()
- s2 = ClauseVisitor().traverse(s, clone=True).alias()
+ s2 = CloningVisitor().traverse(s).alias()
s3 = select([s], s.c.col2==s2.c.col2)
self.assert_compile(s3, "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "\
"table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1, "\
@@ -286,26 +289,51 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
subq = t2.select().alias('subq')
s = select([t1.c.col1, subq.c.col1], from_obj=[t1, subq, t1.join(subq, t1.c.col1==subq.c.col2)])
orig = str(s)
- s2 = ClauseVisitor().traverse(s, clone=True)
+ s2 = CloningVisitor().traverse(s)
assert orig == str(s) == str(s2)
- s4 = ClauseVisitor().traverse(s2, clone=True)
+ s4 = CloningVisitor().traverse(s2)
assert orig == str(s) == str(s2) == str(s4)
- s3 = sql_util.ClauseAdapter(table('foo')).traverse(s, clone=True)
+ s3 = sql_util.ClauseAdapter(table('foo')).traverse(s)
assert orig == str(s) == str(s3)
- s4 = sql_util.ClauseAdapter(table('foo')).traverse(s3, clone=True)
+ s4 = sql_util.ClauseAdapter(table('foo')).traverse(s3)
assert orig == str(s) == str(s3) == str(s4)
def test_correlated_select(self):
s = select(['*'], t1.c.col1==t2.c.col1, from_obj=[t1, t2]).correlate(t2)
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_select(self, select):
select.append_whereclause(t1.c.col2==7)
- self.assert_compile(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :col2_1")
-
+ self.assert_compile(Vis().traverse(s), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :col2_1")
+
+ def test_this_thing(self):
+ s = select([t1]).where(t1.c.col1=='foo').alias()
+ s2 = select([s.c.col1])
+
+ self.assert_compile(s2, "SELECT anon_1.col1 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1")
+ t1a = t1.alias()
+ s2 = sql_util.ClauseAdapter(t1a).traverse(s2)
+ self.assert_compile(s2, "SELECT anon_1.col1 FROM (SELECT table1_1.col1 AS col1, table1_1.col2 AS col2, table1_1.col3 AS col3 FROM table1 AS table1_1 WHERE table1_1.col1 = :col1_1) AS anon_1")
+
+ def test_select_fromtwice(self):
+ t1a = t1.alias()
+
+ s = select([1], t1.c.col1==t1a.c.col1, from_obj=t1a).correlate(t1)
+ self.assert_compile(s, "SELECT 1 FROM table1 AS table1_1 WHERE table1.col1 = table1_1.col1")
+
+ s = CloningVisitor().traverse(s)
+ self.assert_compile(s, "SELECT 1 FROM table1 AS table1_1 WHERE table1.col1 = table1_1.col1")
+
+ s = select([t1]).where(t1.c.col1=='foo').alias()
+
+ s2 = select([1], t1.c.col1==s.c.col1, from_obj=s).correlate(t1)
+ self.assert_compile(s2, "SELECT 1 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1 WHERE table1.col1 = anon_1.col1")
+ s2 = ReplacingCloningVisitor().traverse(s2)
+ self.assert_compile(s2, "SELECT 1 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1 WHERE table1.col1 = anon_1.col1")
+
class ClauseAdapterTest(TestBase, AssertsCompiledSQL):
def setUpAll(self):
global t1, t2
@@ -330,69 +358,88 @@ class ClauseAdapterTest(TestBase, AssertsCompiledSQL):
assert t1alias in s._froms
self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
- s = vis.traverse(s, clone=True)
+ s = vis.traverse(s)
+
assert t2alias not in s._froms # not present because it's been cloned
+
assert t1alias in s._froms # present because the adapter placed it there
+
# correlate list on "s" needs to take into account the full _cloned_set for each element in _froms when correlating
self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
s = select(['*'], from_obj=[t1alias, t2alias]).correlate(t2alias).as_scalar()
self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
- s = vis.traverse(s, clone=True)
+ s = vis.traverse(s)
self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
- s = ClauseVisitor().traverse(s, clone=True)
+ s = CloningVisitor().traverse(s)
self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
s = select(['*']).where(t1.c.col1==t2.c.col1).as_scalar()
self.assert_compile(select([t1.c.col1, s]), "SELECT table1.col1, (SELECT * FROM table2 WHERE table1.col1 = table2.col1) AS anon_1 FROM table1")
vis = sql_util.ClauseAdapter(t1alias)
- s = vis.traverse(s, clone=True)
+ s = vis.traverse(s)
self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias")
- s = ClauseVisitor().traverse(s, clone=True)
+ s = CloningVisitor().traverse(s)
self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias")
s = select(['*']).where(t1.c.col1==t2.c.col1).correlate(t1).as_scalar()
self.assert_compile(select([t1.c.col1, s]), "SELECT table1.col1, (SELECT * FROM table2 WHERE table1.col1 = table2.col1) AS anon_1 FROM table1")
vis = sql_util.ClauseAdapter(t1alias)
- s = vis.traverse(s, clone=True)
+ s = vis.traverse(s)
self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias")
- s = ClauseVisitor().traverse(s, clone=True)
+ s = CloningVisitor().traverse(s)
self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias")
-
+
+ @testing.fails_on_everything_except()
+ def test_joins_dont_adapt(self):
+ # adapting to a join, i.e. ClauseAdapter(t1.join(t2)), doesn't make much sense.
+ # ClauseAdapter doesn't make any changes if it's against a straight join.
+ users = table('users', column('id'))
+ addresses = table('addresses', column('id'), column('user_id'))
+
+ ualias = users.alias()
+
+ s = select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users) #.as_scalar().label(None)
+ s= sql_util.ClauseAdapter(ualias).traverse(s)
+
+ j1 = addresses.join(ualias, addresses.c.user_id==ualias.c.id)
+
+ self.assert_compile(sql_util.ClauseAdapter(j1).traverse(s), "SELECT count(addresses.id) AS count_1 FROM addresses WHERE users_1.id = addresses.user_id")
def test_table_to_alias(self):
t1alias = t1.alias('t1alias')
vis = sql_util.ClauseAdapter(t1alias)
- ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True)
- assert ff._get_from_objects() == [t1alias]
+ ff = vis.traverse(func.count(t1.c.col1).label('foo'))
+ assert list(_from_objects(ff)) == [t1alias]
- self.assert_compile(vis.traverse(select(['*'], from_obj=[t1]), clone=True), "SELECT * FROM table1 AS t1alias")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2")
+ self.assert_compile(vis.traverse(select(['*'], from_obj=[t1])), "SELECT * FROM table1 AS t1alias")
+ self.assert_compile(select(['*'], t1.c.col1==t2.c.col2), "SELECT * FROM table1, table2 WHERE table1.col1 = table2.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2)), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2])), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1)), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2)), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2")
s = select(['*'], from_obj=[t1]).alias('foo')
self.assert_compile(s.select(), "SELECT foo.* FROM (SELECT * FROM table1) AS foo")
- self.assert_compile(vis.traverse(s.select(), clone=True), "SELECT foo.* FROM (SELECT * FROM table1 AS t1alias) AS foo")
+ self.assert_compile(vis.traverse(s.select()), "SELECT foo.* FROM (SELECT * FROM table1 AS t1alias) AS foo")
self.assert_compile(s.select(), "SELECT foo.* FROM (SELECT * FROM table1) AS foo")
- ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True)
- self.assert_compile(ff, "count(t1alias.col1) AS foo")
- assert ff._get_from_objects() == [t1alias]
+ ff = vis.traverse(func.count(t1.c.col1).label('foo'))
+ self.assert_compile(select([ff]), "SELECT count(t1alias.col1) AS foo FROM table1 AS t1alias")
+ assert list(_from_objects(ff)) == [t1alias]
# TODO:
# self.assert_compile(vis.traverse(select([func.count(t1.c.col1).label('foo')]), clone=True), "SELECT count(t1alias.col1) AS foo FROM table1 AS t1alias")
t2alias = t2.alias('t2alias')
vis.chain(sql_util.ClauseAdapter(t2alias))
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2)), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2])), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1)), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2)), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
def test_include_exclude(self):
m = MetaData()
@@ -517,6 +564,65 @@ class ClauseAdapterTest(TestBase, AssertsCompiledSQL):
"WHERE c.bid = anon_1.b_aid"
)
+class SpliceJoinsTest(TestBase, AssertsCompiledSQL):
+ def setUpAll(self):
+ global table1, table2, table3, table4
+ def _table(name):
+ return table(name, column("col1"), column("col2"),column("col3"))
+
+ table1, table2, table3, table4 = [_table(name) for name in ("table1", "table2", "table3", "table4")]
+
+ def test_splice(self):
+ (t1, t2, t3, t4) = (table1, table2, table1.alias(), table2.alias())
+
+ j = t1.join(t2, t1.c.col1==t2.c.col1).join(t3, t2.c.col1==t3.c.col1).join(t4, t4.c.col1==t1.c.col1)
+
+ s = select([t1]).where(t1.c.col2<5).alias()
+
+ self.assert_compile(sql_util.splice_joins(s, j),
+ "(SELECT table1.col1 AS col1, table1.col2 AS col2, "\
+ "table1.col3 AS col3 FROM table1 WHERE table1.col2 < :col2_1) AS anon_1 "\
+ "JOIN table2 ON anon_1.col1 = table2.col1 JOIN table1 AS table1_1 ON table2.col1 = table1_1.col1 "\
+ "JOIN table2 AS table2_1 ON table2_1.col1 = anon_1.col1")
+
+ def test_stop_on(self):
+ (t1, t2, t3) = (table1, table2, table3)
+
+ j1= t1.join(t2, t1.c.col1==t2.c.col1)
+ j2 = j1.join(t3, t2.c.col1==t3.c.col1)
+
+ s = select([t1]).select_from(j1).alias()
+
+ self.assert_compile(sql_util.splice_joins(s, j2),
+ "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 JOIN table2 "\
+ "ON table1.col1 = table2.col1) AS anon_1 JOIN table2 ON anon_1.col1 = table2.col1 JOIN table3 "\
+ "ON table2.col1 = table3.col1"
+ )
+
+ self.assert_compile(sql_util.splice_joins(s, j2, j1),
+ "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 "\
+ "JOIN table2 ON table1.col1 = table2.col1) AS anon_1 JOIN table3 ON table2.col1 = table3.col1")
+
+ def test_splice_2(self):
+ t2a = table2.alias()
+ t3a = table3.alias()
+ j1 = table1.join(t2a, table1.c.col1==t2a.c.col1).join(t3a, t2a.c.col2==t3a.c.col2)
+
+ t2b = table4.alias()
+ j2 = table1.join(t2b, table1.c.col3==t2b.c.col3)
+
+ self.assert_compile(sql_util.splice_joins(table1, j1),
+ "table1 JOIN table2 AS table2_1 ON table1.col1 = table2_1.col1 "\
+ "JOIN table3 AS table3_1 ON table2_1.col2 = table3_1.col2")
+
+ self.assert_compile(sql_util.splice_joins(table1, j2), "table1 JOIN table4 AS table4_1 ON table1.col3 = table4_1.col3")
+
+ self.assert_compile(sql_util.splice_joins(sql_util.splice_joins(table1, j1), j2),
+ "table1 JOIN table2 AS table2_1 ON table1.col1 = table2_1.col1 "\
+ "JOIN table3 AS table3_1 ON table2_1.col2 = table3_1.col2 "\
+ "JOIN table4 AS table4_1 ON table1.col3 = table4_1.col3")
+
+
class SelectTest(TestBase, AssertsCompiledSQL):
"""tests the generative capability of Select"""
diff --git a/test/sql/query.py b/test/sql/query.py
index e6d6714c2..a305a5314 100644
--- a/test/sql/query.py
+++ b/test/sql/query.py
@@ -1,7 +1,7 @@
import testenv; testenv.configure_for_tests()
import datetime
from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc, sql
from sqlalchemy.engine import default
from testlib import *
@@ -426,7 +426,7 @@ class QueryTest(TestBase):
try:
print r['user_id']
assert False
- except exceptions.InvalidRequestError, e:
+ except exc.InvalidRequestError, e:
assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." or \
str(e) == "Ambiguous column name 'USER_ID' in result set! try 'use_labels' option on select statement."
@@ -466,7 +466,7 @@ class QueryTest(TestBase):
def test_cant_execute_join(self):
try:
users.join(addresses).execute()
- except exceptions.ArgumentError, e:
+ except exc.ArgumentError, e:
assert str(e).startswith('Not an executable clause: ')
diff --git a/test/sql/quote.py b/test/sql/quote.py
index 825e836ff..d137b44a3 100644
--- a/test/sql/quote.py
+++ b/test/sql/quote.py
@@ -4,7 +4,7 @@ from sqlalchemy import sql
from sqlalchemy.sql import compiler
from testlib import *
-class QuoteTest(TestBase):
+class QuoteTest(TestBase, AssertsCompiledSQL):
def setUpAll(self):
# TODO: figure out which databases/which identifiers allow special characters to be used,
# such as: spaces, quote characters, punctuation characters, set up tests for those as
@@ -67,7 +67,23 @@ class QuoteTest(TestBase):
res2 = select([table2.c.d123, table2.c.u123, table2.c.MixedCase], use_labels=True).execute().fetchall()
print res2
assert(res2==[(1,2,3),(2,2,3),(4,3,2)])
+
+ def test_quote_flag(self):
+ metadata = MetaData()
+ t1 = Table('TableOne', metadata,
+ Column('ColumnOne', Integer), schema="FooBar")
+ self.assert_compile(t1.select(), '''SELECT "FooBar"."TableOne"."ColumnOne" FROM "FooBar"."TableOne"''')
+
+ metadata = MetaData()
+ t1 = Table('t1', metadata,
+ Column('col1', Integer, quote=True), quote=True, schema="foo", quote_schema=True)
+ self.assert_compile(t1.select(), '''SELECT "foo"."t1"."col1" FROM "foo"."t1"''')
+ metadata = MetaData()
+ t1 = Table('TableOne', metadata,
+ Column('ColumnOne', Integer, quote=False), quote=False, schema="FooBar", quote_schema=False)
+ self.assert_compile(t1.select(), '''SELECT FooBar.TableOne.ColumnOne FROM FooBar.TableOne''')
+
@testing.unsupported('oracle')
def testlabels(self):
"""test the quoting of labels.
@@ -86,16 +102,19 @@ class QuoteTest(TestBase):
table = Table("ImATable", metadata,
Column("col1", Integer))
x = select([table.c.col1.label("ImATable_col1")]).alias("SomeAlias")
- assert str(select([x.c.ImATable_col1])) == '''SELECT "SomeAlias"."ImATable_col1" \nFROM (SELECT "ImATable".col1 AS "ImATable_col1" \nFROM "ImATable") AS "SomeAlias"'''
+ self.assert_compile(select([x.c.ImATable_col1]),
+ '''SELECT "SomeAlias"."ImATable_col1" FROM (SELECT "ImATable".col1 AS "ImATable_col1" FROM "ImATable") AS "SomeAlias"''')
# note that 'foo' and 'FooCol' are literals already quoted
x = select([sql.literal_column("'foo'").label("somelabel")], from_obj=[table]).alias("AnAlias")
x = x.select()
- assert str(x) == '''SELECT "AnAlias".somelabel \nFROM (SELECT 'foo' AS somelabel \nFROM "ImATable") AS "AnAlias"'''
+ self.assert_compile(x,
+ '''SELECT "AnAlias".somelabel FROM (SELECT 'foo' AS somelabel FROM "ImATable") AS "AnAlias"''')
x = select([sql.literal_column("'FooCol'").label("SomeLabel")], from_obj=[table])
x = x.select()
- assert str(x) == '''SELECT "SomeLabel" \nFROM (SELECT 'FooCol' AS "SomeLabel" \nFROM "ImATable")'''
+ self.assert_compile(x,
+ '''SELECT "SomeLabel" FROM (SELECT 'FooCol' AS "SomeLabel" FROM "ImATable")''')
class PreparerTest(TestBase):
diff --git a/test/sql/select.py b/test/sql/select.py
index bea862112..3ecf63d34 100644
--- a/test/sql/select.py
+++ b/test/sql/select.py
@@ -1,7 +1,7 @@
import testenv; testenv.configure_for_tests()
import datetime, re, operator
from sqlalchemy import *
-from sqlalchemy import exceptions, sql, util
+from sqlalchemy import exc, sql, util
from sqlalchemy.sql import table, column, compiler
from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
from testlib import *
@@ -154,7 +154,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
t2 = table('t2', column('c'), column('d'))
s = select([t.c.a]).where(t.c.a==t2.c.d).as_scalar()
s2 =select([t, t2, s])
- self.assertRaises(exceptions.InvalidRequestError, str, s2)
+ self.assertRaises(exc.InvalidRequestError, str, s2)
# intentional again
s = s.correlate(t, t2)
@@ -245,14 +245,14 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
try:
s = select([table1.c.myid, table1.c.name]).as_scalar()
assert False
- except exceptions.InvalidRequestError, err:
+ except exc.InvalidRequestError, err:
assert str(err) == "Scalar select can only be created from a Select object that has exactly one column expression.", str(err)
try:
# generic function which will look at the type of expression
func.coalesce(select([table1.c.myid]))
assert False
- except exceptions.InvalidRequestError, err:
+ except exc.InvalidRequestError, err:
assert str(err) == "Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.", str(err)
s = select([table1.c.myid], scalar=True, correlate=False)
@@ -278,12 +278,12 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
s = select([table1.c.myid]).as_scalar()
try:
s.c.foo
- except exceptions.InvalidRequestError, err:
+ except exc.InvalidRequestError, err:
assert str(err) == 'Scalar Select expression has no columns; use this object directly within a column-level expression.'
try:
s.columns.foo
- except exceptions.InvalidRequestError, err:
+ except exc.InvalidRequestError, err:
assert str(err) == 'Scalar Select expression has no columns; use this object directly within a column-level expression.'
zips = table('zips',
@@ -807,8 +807,8 @@ mytable.description FROM myothertable JOIN mytable ON mytable.myid = myothertabl
self.assert_compile(
select(
- [join(join(table1, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid == table3.c.userid)
- ]),
+ [join(join(table1, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid == table3.c.userid)]
+ ),
"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid JOIN thirdtable ON mytable.myid = thirdtable.userid"
)
@@ -854,7 +854,7 @@ EXISTS (select yay from foo where boo = lar)",
def test_compound_selects(self):
try:
union(table3.select(), table1.select())
- except exceptions.ArgumentError, err:
+ except exc.ArgumentError, err:
assert str(err) == "All selectables passed to CompoundSelect must have identical numbers of columns; select #1 has 2 columns, select #2 has 3"
x = union(
@@ -1048,10 +1048,10 @@ UNION SELECT mytable.myid FROM mytable"
# check that conflicts with "unique" params are caught
s = select([table1], or_(table1.c.myid==7, table1.c.myid==bindparam('myid_1')))
- self.assertRaisesMessage(exceptions.CompileError, "conflicts with unique bind parameter of the same name", str, s)
+ self.assertRaisesMessage(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
s = select([table1], or_(table1.c.myid==7, table1.c.myid==8, table1.c.myid==bindparam('myid_1')))
- self.assertRaisesMessage(exceptions.CompileError, "conflicts with unique bind parameter of the same name", str, s)
+ self.assertRaisesMessage(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
@@ -1153,20 +1153,6 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE
self.assert_compile(select([table1], table1.c.myid.in_([])),
"SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)")
- @testing.uses_deprecated('passing in_')
- def test_in_deprecated_api(self):
- self.assert_compile(select([table1], table1.c.myid.in_('abc')),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1)")
-
- self.assert_compile(select([table1], table1.c.myid.in_(1)),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1)")
-
- self.assert_compile(select([table1], table1.c.myid.in_(1,2)),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1, :myid_2)")
-
- self.assert_compile(select([table1], table1.c.myid.in_()),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)")
-
def test_cast(self):
tbl = table('casttest',
column('id', Integer),
diff --git a/test/sql/selectable.py b/test/sql/selectable.py
index b29ba8d5c..66793a25b 100755
--- a/test/sql/selectable.py
+++ b/test/sql/selectable.py
@@ -6,7 +6,7 @@ import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from testlib import *
from sqlalchemy.sql import util as sql_util
-from sqlalchemy import exceptions
+from sqlalchemy import exc
metadata = MetaData()
table = Table('table1', metadata,
@@ -164,7 +164,7 @@ class SelectableTest(TestBase, AssertsExecutionResults):
print str(j)
self.assert_(criterion.compare(j.onclause))
- def testcolumnlabels(self):
+ def test_column_labels(self):
a = select([table.c.col1.label('acol1'), table.c.col2.label('acol2'), table.c.col3.label('acol3')])
print str(a)
print [c for c in a.columns]
@@ -173,13 +173,13 @@ class SelectableTest(TestBase, AssertsExecutionResults):
criterion = a.c.acol1 == table2.c.col2
print str(j)
self.assert_(criterion.compare(j.onclause))
-
+
def test_labeled_select_correspoinding(self):
l1 = select([func.max(table.c.col1)]).label('foo')
s = select([l1])
assert s.corresponding_column(l1).name == s.c.foo
-
+
s = select([table.c.col1, l1])
assert s.corresponding_column(l1).name == s.c.foo
@@ -193,7 +193,7 @@ class SelectableTest(TestBase, AssertsExecutionResults):
print str(j.onclause)
self.assert_(criterion.compare(j.onclause))
- def testtablejoinedtoselectoftable(self):
+ def test_table_joined_to_select_of_table(self):
metadata = MetaData()
a = Table('a', metadata,
Column('id', Integer, primary_key=True))
@@ -242,7 +242,7 @@ class SelectableTest(TestBase, AssertsExecutionResults):
s = select([t2, t3], use_labels=True)
- self.assertRaises(exceptions.NoReferencedTableError, s.join, t1)
+ self.assertRaises(exc.NoReferencedTableError, s.join, t1)
class PrimaryKeyTest(TestBase, AssertsExecutionResults):
def test_join_pk_collapse_implicit(self):
diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py
index 09a3702ee..9cd6f9bdb 100644
--- a/test/sql/testtypes.py
+++ b/test/sql/testtypes.py
@@ -1,7 +1,7 @@
import testenv; testenv.configure_for_tests()
import datetime, os, pickleable, re
from sqlalchemy import *
-from sqlalchemy import exceptions, types, util
+from sqlalchemy import exc, types, util
from sqlalchemy.sql import operators
import sqlalchemy.engine.url as url
from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
@@ -40,17 +40,6 @@ class AdaptTest(TestBase):
assert isinstance(dialect_type, mssql.MSNVarchar)
assert dialect_type.get_col_spec() == 'NVARCHAR(10)'
- def testoracletext(self):
- dialect = oracle.OracleDialect()
- class MyDecoratedType(types.TypeDecorator):
- impl = String
- def copy(self):
- return MyDecoratedType()
-
- col = Column('', MyDecoratedType)
- dialect_type = col.type.dialect_impl(dialect)
- assert isinstance(dialect_type.impl, oracle.OracleText), repr(dialect_type.impl)
-
def testoracletimestamp(self):
dialect = oracle.OracleDialect()
@@ -77,29 +66,29 @@ class AdaptTest(TestBase):
firebird_dialect = firebird.FBDialect()
for dialect, start, test in [
- (oracle_dialect, String(), oracle.OracleText),
+ (oracle_dialect, String(), oracle.OracleString),
(oracle_dialect, VARCHAR(), oracle.OracleString),
(oracle_dialect, String(50), oracle.OracleString),
- (oracle_dialect, Unicode(), oracle.OracleText),
+ (oracle_dialect, Unicode(), oracle.OracleString),
(oracle_dialect, UnicodeText(), oracle.OracleText),
(oracle_dialect, NCHAR(), oracle.OracleString),
(oracle_dialect, oracle.OracleRaw(50), oracle.OracleRaw),
- (mysql_dialect, String(), mysql.MSText),
+ (mysql_dialect, String(), mysql.MSString),
(mysql_dialect, VARCHAR(), mysql.MSString),
(mysql_dialect, String(50), mysql.MSString),
- (mysql_dialect, Unicode(), mysql.MSText),
+ (mysql_dialect, Unicode(), mysql.MSString),
(mysql_dialect, UnicodeText(), mysql.MSText),
(mysql_dialect, NCHAR(), mysql.MSNChar),
- (postgres_dialect, String(), postgres.PGText),
+ (postgres_dialect, String(), postgres.PGString),
(postgres_dialect, VARCHAR(), postgres.PGString),
(postgres_dialect, String(50), postgres.PGString),
- (postgres_dialect, Unicode(), postgres.PGText),
+ (postgres_dialect, Unicode(), postgres.PGString),
(postgres_dialect, UnicodeText(), postgres.PGText),
(postgres_dialect, NCHAR(), postgres.PGString),
- (firebird_dialect, String(), firebird.FBText),
+ (firebird_dialect, String(), firebird.FBString),
(firebird_dialect, VARCHAR(), firebird.FBString),
(firebird_dialect, String(50), firebird.FBString),
- (firebird_dialect, Unicode(), firebird.FBText),
+ (firebird_dialect, Unicode(), firebird.FBString),
(firebird_dialect, UnicodeText(), firebird.FBText),
(firebird_dialect, NCHAR(), firebird.FBString),
]:
@@ -118,9 +107,9 @@ class UserDefinedTest(TestBase):
def testprocessing(self):
global users
- users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack', goofy8=12, goofy9=12)
- users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala', goofy8=15, goofy9=15)
- users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred', goofy8=9, goofy9=9)
+ users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack', goofy8=12, goofy9=12)
+ users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala', goofy8=15, goofy9=15)
+ users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred', goofy8=9, goofy9=9)
l = users.select().execute().fetchall()
for assertstr, assertint, assertint2, row in zip(
@@ -130,11 +119,11 @@ class UserDefinedTest(TestBase):
l
):
- for col in row[1:8]:
+ for col in row[1:7]:
self.assertEquals(col, assertstr)
- self.assertEquals(row[8], assertint)
- self.assertEquals(row[9], assertint2)
- for col in (row[4], row[5], row[7]):
+ self.assertEquals(row[7], assertint)
+ self.assertEquals(row[8], assertint2)
+ for col in (row[3], row[4], row[6]):
assert isinstance(col, unicode)
def setUpAll(self):
@@ -250,13 +239,10 @@ class UserDefinedTest(TestBase):
# decorated type with an argument, so its a String
Column('goofy2', MyDecoratedType(50), nullable = False),
- # decorated type without an argument, it will adapt_args to TEXT
- Column('goofy3', MyDecoratedType, nullable = False),
-
- Column('goofy4', MyUnicodeType, nullable = False),
- Column('goofy5', LegacyUnicodeType, nullable = False),
+ Column('goofy4', MyUnicodeType(50), nullable = False),
+ Column('goofy5', LegacyUnicodeType(50), nullable = False),
Column('goofy6', LegacyType, nullable = False),
- Column('goofy7', MyNewUnicodeType, nullable = False),
+ Column('goofy7', MyNewUnicodeType(50), nullable = False),
Column('goofy8', MyNewIntType, nullable = False),
Column('goofy9', MyNewIntSubClass, nullable = False),
@@ -344,7 +330,7 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
try:
unicode_table.insert().execute(unicode_varchar='not unicode')
assert False
- except exceptions.SAWarning, e:
+ except exc.SAWarning, e:
assert str(e) == "Unicode type received non-unicode bind param value 'not unicode'", str(e)
unicode_engine = engines.utf8_engine(options={'convert_unicode':True,
@@ -353,7 +339,7 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
try:
unicode_engine.execute(unicode_table.insert(), plain_varchar='im not unicode')
assert False
- except exceptions.InvalidRequestError, e:
+ except exc.InvalidRequestError, e:
assert str(e) == "Unicode type received non-unicode bind param value 'im not unicode'"
@testing.emits_warning('.*non-unicode bind')
@@ -664,33 +650,20 @@ class DateTest(TestBase, AssertsExecutionResults):
t.drop(checkfirst=True)
class StringTest(TestBase, AssertsExecutionResults):
- def test_nolen_string_deprecated(self):
+
+
+ def test_nolength_string(self):
+ # this tests what happens with String DDL with no length.
+ # seems like we need to decide amongst "VARCHAR" (sqlite, postgres), "TEXT" (mysql)
+ # i.e. theres some inconsisency here.
+
metadata = MetaData(testing.db)
foo =Table('foo', metadata,
Column('one', String))
-
- # no warning
- select([func.count("*")], bind=testing.db).execute()
-
- try:
- # warning during CREATE
- foo.create()
- assert False
- except exceptions.SADeprecationWarning, e:
- assert "Using String type with no length" in str(e)
- assert re.search(r'\bone\b', str(e))
-
- bar = Table('bar', metadata, Column('one', String(40)))
-
- try:
- # no warning
- bar.create()
-
- # no warning for non-lengthed string
- select([func.count("*")], from_obj=bar).execute()
- finally:
- bar.drop()
-
+
+ foo.create()
+ foo.drop()
+
def _missing_decimal():
"""Python implementation supports decimals"""
try:
diff --git a/test/testlib/__init__.py b/test/testlib/__init__.py
index 98552b0f3..67e56e3d8 100644
--- a/test/testlib/__init__.py
+++ b/test/testlib/__init__.py
@@ -3,14 +3,21 @@
Load after sqlalchemy imports to use instrumented stand-ins like Table.
"""
+import sys
import testlib.config
from testlib.schema import Table, Column
from testlib.orm import mapper
import testlib.testing as testing
-from testlib.testing import rowset
-from testlib.testing import TestBase, AssertsExecutionResults, ORMTest, AssertsCompiledSQL, ComparesTables
+from testlib.testing import \
+ AssertsCompiledSQL, \
+ AssertsExecutionResults, \
+ ComparesTables, \
+ ORMTest, \
+ TestBase, \
+ rowset
import testlib.profiling as profiling
import testlib.engines as engines
+import testlib.requires as requires
from testlib.compat import set, frozenset, sorted, _function_named
@@ -18,6 +25,15 @@ __all__ = ('testing',
'mapper',
'Table', 'Column',
'rowset',
- 'TestBase', 'AssertsExecutionResults', 'ORMTest', 'AssertsCompiledSQL', 'ComparesTables',
+ 'TestBase', 'AssertsExecutionResults', 'ORMTest',
+ 'AssertsCompiledSQL', 'ComparesTables',
'profiling', 'engines',
'set', 'frozenset', 'sorted', '_function_named')
+
+
+testing.requires = requires
+
+sys.modules['testlib.sa'] = sa = testing.CompositeModule(
+ 'testlib.sa', 'sqlalchemy', 'testlib.schema', orm=testing.CompositeModule(
+ 'testlib.sa.orm', 'sqlalchemy.orm', 'testlib.orm'))
+sys.modules['testlib.sa.orm'] = sa.orm
diff --git a/test/testlib/compat.py b/test/testlib/compat.py
index ba12b78ac..fcb7fa1e9 100644
--- a/test/testlib/compat.py
+++ b/test/testlib/compat.py
@@ -1,6 +1,6 @@
-import itertools, new, sys, warnings
+import new
-__all__ = 'set', 'frozenset', 'sorted', '_function_named', 'deque'
+__all__ = 'set', 'frozenset', 'sorted', '_function_named', 'deque', 'reversed'
try:
set = set
@@ -69,6 +69,16 @@ except NameError:
return l
try:
+ reversed = reversed
+except NameError:
+ def reversed(seq):
+ i = len(seq) - 1
+ while i >= 0:
+ yield seq[i]
+ i -= 1
+ raise StopIteration()
+
+try:
from collections import deque
except ImportError:
class deque(list):
@@ -77,9 +87,7 @@ except ImportError:
def popleft(self):
return self.pop(0)
def extendleft(self, iterable):
- items = list(iterable)
- items.reverse()
- for x in items:
+ for x in reversed(list(iterable)):
self.insert(0, x)
def _function_named(fn, newname):
diff --git a/test/testlib/engines.py b/test/testlib/engines.py
index f5694df57..5ad35a066 100644
--- a/test/testlib/engines.py
+++ b/test/testlib/engines.py
@@ -1,6 +1,6 @@
import sys, types, weakref
from testlib import config
-from testlib.compat import *
+from testlib.compat import set, _function_named, deque
class ConnectionKiller(object):
diff --git a/test/testlib/filters.py b/test/testlib/filters.py
index eb7eff279..2d559f53b 100644
--- a/test/testlib/filters.py
+++ b/test/testlib/filters.py
@@ -14,8 +14,8 @@ Includes::
"""
import sys
-from StringIO import StringIO
-from tokenize import *
+from tokenize import generate_tokens, INDENT, DEDENT, NAME, OP, NL, NEWLINE, \
+ NUMBER, STRING, COMMENT
__all__ = ['py23_decorators', 'py23']
diff --git a/test/testlib/fixtures.py b/test/testlib/fixtures.py
index e8d71179a..f56b865c6 100644
--- a/test/testlib/fixtures.py
+++ b/test/testlib/fixtures.py
@@ -1,14 +1,16 @@
-# can't be imported until the path is setup; be sure to configure
-# first if covering.
-from sqlalchemy import *
-from sqlalchemy import util
-from testlib import *
-
-__all__ = ['keywords', 'addresses', 'Base', 'Keyword', 'FixtureTest', 'Dingaling', 'item_keywords',
- 'dingalings', 'User', 'items', 'Fixtures', 'orders', 'install_fixture_data', 'Address', 'users',
+from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey
+from testlib.sa.orm import attributes
+from testlib import ORMTest
+from testlib.compat import set
+
+
+__all__ = ['keywords', 'addresses', 'Base', 'Keyword', 'FixtureTest',
+ 'Dingaling', 'item_keywords', 'dingalings', 'User', 'items',
+ 'Fixtures', 'orders', 'install_fixture_data', 'Address', 'users',
'order_items', 'Item', 'Order', 'fixtures']
-
-_recursion_stack = util.Set()
+
+
+_recursion_stack = set()
class Base(object):
def __init__(self, **kwargs):
for k in kwargs:
@@ -36,10 +38,15 @@ class Base(object):
_recursion_stack.add(self)
try:
# pick the entity thats not SA persisted as the source
+ try:
+ state = attributes.instance_state(self)
+ key = state.key
+ except (KeyError, AttributeError):
+ key = None
if other is None:
a = self
b = other
- elif hasattr(self, '_instance_key'):
+ elif key is not None:
a = other
b = self
else:
@@ -57,8 +64,9 @@ class Base(object):
battr = getattr(b, attr)
except AttributeError:
#print "b class does not have attribute named '%s'" % attr
+ #raise
return False
-
+
if list(value) == list(battr):
continue
else:
@@ -84,43 +92,60 @@ metadata = MetaData()
users = Table('users', metadata,
Column('id', Integer, primary_key=True),
- Column('name', String(30), nullable=False))
+ Column('name', String(30), nullable=False),
+ test_needs_acid=True,
+ test_needs_fk=True
+ )
orders = Table('orders', metadata,
Column('id', Integer, primary_key=True),
Column('user_id', None, ForeignKey('users.id')),
Column('address_id', None, ForeignKey('addresses.id')),
Column('description', String(30)),
- Column('isopen', Integer)
+ Column('isopen', Integer),
+ test_needs_acid=True,
+ test_needs_fk=True
)
addresses = Table('addresses', metadata,
Column('id', Integer, primary_key=True),
Column('user_id', None, ForeignKey('users.id')),
- Column('email_address', String(50), nullable=False))
+ Column('email_address', String(50), nullable=False),
+ test_needs_acid=True,
+ test_needs_fk=True)
dingalings = Table("dingalings", metadata,
Column('id', Integer, primary_key=True),
Column('address_id', None, ForeignKey('addresses.id')),
- Column('data', String(30))
+ Column('data', String(30)),
+ test_needs_acid=True,
+ test_needs_fk=True
)
items = Table('items', metadata,
Column('id', Integer, primary_key=True),
- Column('description', String(30), nullable=False)
+ Column('description', String(30), nullable=False),
+ test_needs_acid=True,
+ test_needs_fk=True
)
order_items = Table('order_items', metadata,
Column('item_id', None, ForeignKey('items.id')),
- Column('order_id', None, ForeignKey('orders.id')))
+ Column('order_id', None, ForeignKey('orders.id')),
+ test_needs_acid=True,
+ test_needs_fk=True)
item_keywords = Table('item_keywords', metadata,
Column('item_id', None, ForeignKey('items.id')),
- Column('keyword_id', None, ForeignKey('keywords.id')))
+ Column('keyword_id', None, ForeignKey('keywords.id')),
+ test_needs_acid=True,
+ test_needs_fk=True)
keywords = Table('keywords', metadata,
Column('id', Integer, primary_key=True),
- Column('name', String(30), nullable=False)
+ Column('name', String(30), nullable=False),
+ test_needs_acid=True,
+ test_needs_fk=True
)
def install_fixture_data():
@@ -203,14 +228,15 @@ def install_fixture_data():
class FixtureTest(ORMTest):
refresh_data = False
-
+ only_tables = False
+
def setUpAll(self):
super(FixtureTest, self).setUpAll()
- if self.keep_data:
+ if not self.only_tables and self.keep_data:
install_fixture_data()
def setUp(self):
- if self.refresh_data:
+ if not self.only_tables and self.refresh_data:
install_fixture_data()
def define_tables(self, meta):
diff --git a/test/testlib/profiling.py b/test/testlib/profiling.py
index b452d1fb8..e423b9904 100644
--- a/test/testlib/profiling.py
+++ b/test/testlib/profiling.py
@@ -1,8 +1,7 @@
"""Profiling support for unit and performance tests."""
import os, sys
-from testlib.config import parser, post_configure
-from testlib.compat import *
+from testlib.compat import set, _function_named
import testlib.config
__all__ = 'profiled', 'function_call_count', 'conditional_call_count'
@@ -26,8 +25,6 @@ def profiled(target=None, **target_opts):
configuration and command-line options.
"""
- import time, hotshot, hotshot.stats
-
# manual or automatic namespacing by module would remove conflict issues
if target is None:
target = 'anonymous_target'
diff --git a/test/testlib/requires.py b/test/testlib/requires.py
new file mode 100644
index 000000000..a4604ff7f
--- /dev/null
+++ b/test/testlib/requires.py
@@ -0,0 +1,32 @@
+"""Global database feature support policy.
+
+Provides decorators to mark tests requiring specific feature support from the
+target database.
+
+"""
+from testlib import testing
+
+def savepoints(fn):
+ """Target database must support savepoints."""
+ return (testing.unsupported(
+ 'access',
+ 'mssql',
+ 'sqlite',
+ 'sybase',
+ )
+ (testing.exclude('mysql', '<', (5, 0, 3))
+ (fn)))
+
+def two_phase_transactions(fn):
+ """Target database must support two-phase transactions."""
+ return (testing.unsupported(
+ 'access',
+ 'firebird',
+ 'maxdb',
+ 'mssql',
+ 'oracle',
+ 'sqlite',
+ 'sybase',
+ )
+ (testing.exclude('mysql', '<', (5, 0, 3))
+ (fn)))
diff --git a/test/testlib/schema.py b/test/testlib/schema.py
index 37f3591ad..9cedc02f0 100644
--- a/test/testlib/schema.py
+++ b/test/testlib/schema.py
@@ -1,5 +1,5 @@
from testlib import testing
-import itertools
+
schema = None
__all__ = 'Table', 'Column',
diff --git a/test/testlib/tables.py b/test/testlib/tables.py
index 33b1b20db..3399acaae 100644
--- a/test/testlib/tables.py
+++ b/test/testlib/tables.py
@@ -1,8 +1,9 @@
# can't be imported until the path is setup; be sure to configure
# first if covering.
-from sqlalchemy import *
+
from testlib import testing
-from testlib.schema import Table, Column
+from testlib.sa import MetaData, Table, Column, Integer, String, Sequence, \
+ ForeignKey, VARCHAR, INT
# these are older test fixtures, used primarily by test/orm/mapper.py and
diff --git a/test/testlib/testing.py b/test/testlib/testing.py
index cf0936e92..1e2ca62e9 100644
--- a/test/testlib/testing.py
+++ b/test/testlib/testing.py
@@ -2,15 +2,27 @@
# monkeypatches unittest.TestLoader.suiteClass at import time
-import itertools, os, operator, re, sys, unittest, warnings
+import itertools
+import operator
+import re
+import sys
+import types
+import unittest
+import warnings
from cStringIO import StringIO
+
import testlib.config as config
-from testlib.compat import *
+from testlib.compat import set, _function_named, reversed
-sql, sqltypes, schema, MetaData, clear_mappers, Session, util = None, None, None, None, None, None, None
-sa_exceptions = None
+# Delayed imports
+MetaData = None
+Session = None
+clear_mappers = None
+sa_exc = None
+schema = None
+sqltypes = None
+util = None
-__all__ = ('TestBase', 'AssertsExecutionResults', 'ComparesTables', 'ORMTest', 'AssertsCompiledSQL')
_ops = { '<': operator.lt,
'>': operator.gt,
@@ -25,6 +37,9 @@ _ops = { '<': operator.lt,
# sugar ('testing.db'); set here by config() at runtime
db = None
+# more sugar, installed by __init__
+requires = None
+
def fails_if(callable_):
"""Mark a test as expected to fail if callable_ returns True.
@@ -224,17 +239,17 @@ def emits_warning(*messages):
# - update: jython looks ok, it uses cpython's module
def decorate(fn):
def safe(*args, **kw):
- global sa_exceptions
- if sa_exceptions is None:
- import sqlalchemy.exceptions as sa_exceptions
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
if not messages:
filters = [dict(action='ignore',
- category=sa_exceptions.SAWarning)]
+ category=sa_exc.SAWarning)]
else:
filters = [dict(action='ignore',
message=message,
- category=sa_exceptions.SAWarning)
+ category=sa_exc.SAWarning)
for message in messages ]
for f in filters:
warnings.filterwarnings(**f)
@@ -259,17 +274,17 @@ def uses_deprecated(*messages):
def decorate(fn):
def safe(*args, **kw):
- global sa_exceptions
- if sa_exceptions is None:
- import sqlalchemy.exceptions as sa_exceptions
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
if not messages:
filters = [dict(action='ignore',
- category=sa_exceptions.SADeprecationWarning)]
+ category=sa_exc.SADeprecationWarning)]
else:
filters = [dict(action='ignore',
message=message,
- category=sa_exceptions.SADeprecationWarning)
+ category=sa_exc.SADeprecationWarning)
for message in
[ (m.startswith('//') and
('Call to deprecated function ' + m[2:]) or m)
@@ -287,13 +302,13 @@ def uses_deprecated(*messages):
def resetwarnings():
"""Reset warning behavior to testing defaults."""
- global sa_exceptions
- if sa_exceptions is None:
- import sqlalchemy.exceptions as sa_exceptions
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
warnings.resetwarnings()
- warnings.filterwarnings('error', category=sa_exceptions.SADeprecationWarning)
- warnings.filterwarnings('error', category=sa_exceptions.SAWarning)
+ warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings('error', category=sa_exc.SAWarning)
if sys.version_info < (2, 4):
warnings.filterwarnings('ignore', category=FutureWarning)
@@ -338,6 +353,23 @@ def rowset(results):
return set([tuple(row) for row in results])
+def eq_(a, b, msg=None):
+ """Assert a == b, with repr messaging on failure."""
+ assert a == b, msg or "%r != %r" % (a, b)
+
+def ne_(a, b, msg=None):
+ """Assert a != b, with repr messaging on failure."""
+ assert a != b, msg or "%r == %r" % (a, b)
+
+def is_(a, b, msg=None):
+ """Assert a is b, with repr messaging on failure."""
+ assert a is b, msg or "%r is not %r" % (a, b)
+
+def is_not_(a, b, msg=None):
+ """Assert a is not b, with repr messaging on failure."""
+ assert a is not b, msg or "%r is %r" % (a, b)
+
+
class TestData(object):
"""Tracks SQL expressions as they are executed via an instrumented ExecutionContext."""
@@ -360,10 +392,6 @@ class ExecutionContextWrapper(object):
can be tracked."""
def __init__(self, ctx):
- global sql
- if sql is None:
- from sqlalchemy import sql
-
self.__dict__['ctx'] = ctx
def __getattr__(self, key):
return getattr(self.ctx, key)
@@ -414,7 +442,7 @@ class ExecutionContextWrapper(object):
query = self.convert_statement(query)
equivalent = ( (statement == query)
- or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) )
+ or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) )
) \
and \
( (params is None) or (params == parameters)
@@ -422,7 +450,7 @@ class ExecutionContextWrapper(object):
for (k, v) in p.items()])
for p in parameters]
)
- testdata.unittest.assert_(equivalent,
+ testdata.unittest.assert_(equivalent,
"Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
testdata.sql_count += 1
self.ctx.post_execution()
@@ -445,6 +473,44 @@ class ExecutionContextWrapper(object):
query = re.sub(r':([\w_]+)', repl, query)
return query
+
+def _import_by_name(name):
+ submodule = name.split('.')[-1]
+ return __import__(name, globals(), locals(), [submodule])
+
+class CompositeModule(types.ModuleType):
+ """Merged attribute access for multiple modules."""
+
+ # break the habit
+ __all__ = ()
+
+ def __init__(self, name, *modules, **overrides):
+ """Construct a new lazy composite of modules.
+
+ Modules may be string names or module-like instances. Individual
+ attribute overrides may be specified as keyword arguments for
+ convenience.
+
+ The constructed module will resolve attribute access in reverse order:
+ overrides, then each member of reversed(modules). Modules specified
+ by name will be loaded lazily when encountered in attribute
+ resolution.
+
+ """
+ types.ModuleType.__init__(self, name)
+ self.__modules = list(reversed(modules))
+ for key, value in overrides.iteritems():
+ setattr(self, key, value)
+
+ def __getattr__(self, key):
+ for idx, mod in enumerate(self.__modules):
+ if isinstance(mod, basestring):
+ self.__modules[idx] = mod = _import_by_name(mod)
+ if hasattr(mod, key):
+ return getattr(mod, key)
+ raise AttributeError(key)
+
+
class TestBase(unittest.TestCase):
# A sequence of dialect names to exclude from the test class.
__unsupported_on__ = ()
@@ -469,14 +535,14 @@ class TestBase(unittest.TestCase):
def shortDescription(self):
"""overridden to not return docstrings"""
return None
-
+
def assertRaisesMessage(self, except_cls, msg, callable_, *args, **kwargs):
try:
callable_(*args, **kwargs)
assert False, "Callable did not raise expected exception"
except except_cls, e:
assert re.search(msg, str(e)), "Exception message did not match: '%s'" % str(e)
-
+
if not hasattr(unittest.TestCase, 'assertTrue'):
assertTrue = unittest.TestCase.failUnless
if not hasattr(unittest.TestCase, 'assertFalse'):
@@ -522,7 +588,7 @@ class ComparesTables(object):
set(type(c.type).__mro__).difference(base_mro)
)
) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
-
+
if isinstance(c.type, sqltypes.String):
self.assertEquals(c.type.length, reflected_c.type.length)
@@ -535,18 +601,18 @@ class ComparesTables(object):
elif not c.primary_key or not against('postgres'):
print repr(c)
assert reflected_c.default is None, reflected_c.default
-
+
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name]
-
+
class AssertsExecutionResults(object):
def assert_result(self, result, class_, *objects):
result = list(result)
print repr(result)
self.assert_list(result, class_, objects)
-
+
def assert_list(self, result, class_, list):
self.assert_(len(result) == len(list),
"result list is not the same size as test list, " +
@@ -675,10 +741,10 @@ class ORMTest(TestBase, AssertsExecutionResults):
def define_tables(self, _otest_metadata):
raise NotImplementedError()
-
+
def setup_mappers(self):
pass
-
+
def insert_data(self):
pass