summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-12-29 19:20:38 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-12-29 19:20:38 +0000
commita582fe3b2645f4c12221b0dc8940cefffe674a93 (patch)
tree018ccd85ce1675c3bbefd83d7af3955c1d93a7b6
parent90d38c7407e532462059d2e98cb8d3bab31f7a36 (diff)
downloadsqlalchemy-a582fe3b2645f4c12221b0dc8940cefffe674a93.tar.gz
- mapped classes which extend "object" and do not provide an
__init__() method will now raise TypeError if non-empty *args or **kwargs are present at instance construction time (and are not consumed by any extensions such as the scoped_session mapper), consistent with the behavior of normal Python classes [ticket:908]
-rw-r--r--CHANGES8
-rw-r--r--lib/sqlalchemy/orm/attributes.py19
-rw-r--r--lib/sqlalchemy/orm/scoping.py8
-rw-r--r--test/orm/inheritance/basic.py2
-rw-r--r--test/orm/inheritance/manytomany.py14
-rw-r--r--test/orm/mapper.py26
-rw-r--r--test/orm/session.py2
7 files changed, 59 insertions, 20 deletions
diff --git a/CHANGES b/CHANGES
index e8d066fb2..4f9cf26f9 100644
--- a/CHANGES
+++ b/CHANGES
@@ -107,7 +107,13 @@ CHANGES
- columns which are missing from a Query's select statement
now get automatically deferred during load.
-
+
+ - mapped classes which extend "object" and do not provide an
+ __init__() method will now raise TypeError if non-empty *args
+ or **kwargs are present at instance construction time (and are
+ not consumed by any extensions such as the scoped_session mapper),
+ consistent with the behavior of normal Python classes [ticket:908]
+
- fixed Query bug when filter_by() compares a relation against None
[ticket:899]
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index f18e54521..09406652a 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -1121,14 +1121,19 @@ def register_class(class_, extra_init=None, on_exception=None, deferred_scalar_l
if extra_init:
extra_init(class_, oldinit, instance, args, kwargs)
- if doinit:
- try:
+ try:
+ if doinit:
oldinit(instance, *args, **kwargs)
- except:
- if on_exception:
- on_exception(class_, oldinit, instance, args, kwargs)
- raise
-
+ elif args or kwargs:
+ # simulate error message raised by object(), but don't copy
+ # the text verbatim
+ raise TypeError("default constructor for object() takes no parameters")
+ except:
+ if on_exception:
+ on_exception(class_, oldinit, instance, args, kwargs)
+ raise
+
+
# override oldinit
oldinit = class_.__init__
if oldinit is None or not hasattr(oldinit, '_oldinit'):
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py
index 3f2f2f049..19cd44884 100644
--- a/lib/sqlalchemy/orm/scoping.py
+++ b/lib/sqlalchemy/orm/scoping.py
@@ -118,15 +118,19 @@ class _ScopedExt(MapperExtension):
class_.query = query()
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
+ if self.save_on_init:
+ entity_name = kwargs.pop('_sa_entity_name', None)
+ session = kwargs.pop('_sa_session', None)
if not isinstance(oldinit, types.MethodType):
for key, value in kwargs.items():
if self.validate:
if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
setattr(instance, key, value)
+ kwargs.clear()
if self.save_on_init:
- session = kwargs.pop('_sa_session', self.context.registry())
- session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
+ session = session or self.context.registry()
+ session._save_impl(instance, entity_name=entity_name)
return EXT_CONTINUE
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py
index 2ef76b6d8..f9c98ac1c 100644
--- a/test/orm/inheritance/basic.py
+++ b/test/orm/inheritance/basic.py
@@ -348,7 +348,7 @@ class FlushTest(ORMTest):
)
admin_mapper = mapper(Admin, admins, inherits=user_mapper)
sess = create_session()
- adminrole = Role('admin')
+ adminrole = Role()
sess.save(adminrole)
sess.flush()
diff --git a/test/orm/inheritance/manytomany.py b/test/orm/inheritance/manytomany.py
index 7886e90ad..d28ce8ada 100644
--- a/test/orm/inheritance/manytomany.py
+++ b/test/orm/inheritance/manytomany.py
@@ -83,9 +83,9 @@ class InheritTest2(ORMTest):
Column('bar_id', Integer, ForeignKey('bar.bid')))
def testget(self):
- class Foo(object):pass
- def __init__(self, data=None):
- self.data = data
+ class Foo(object):
+ def __init__(self, data=None):
+ self.data = data
class Bar(Foo):pass
mapper(Foo, foo)
@@ -128,7 +128,7 @@ class InheritTest2(ORMTest):
sess.flush()
sess.clear()
- l = sess.query(Bar).select()
+ l = sess.query(Bar).all()
print l[0]
print l[0].foos
self.assert_unordered_result(l, Bar,
@@ -191,7 +191,7 @@ class InheritTest3(ORMTest):
sess.flush()
compare = repr(b) + repr(sorted([repr(o) for o in b.foos]))
sess.clear()
- l = sess.query(Bar).select()
+ l = sess.query(Bar).all()
print repr(l[0]) + repr(l[0].foos)
found = repr(l[0]) + repr(sorted([repr(o) for o in l[0].foos]))
self.assertEqual(found, compare)
@@ -233,11 +233,11 @@ class InheritTest3(ORMTest):
blubid = bl1.id
sess.clear()
- l = sess.query(Blub).select()
+ l = sess.query(Blub).all()
print l
self.assert_(repr(l[0]) == compare)
sess.clear()
- x = sess.query(Blub).get_by(id=blubid)
+ x = sess.query(Blub).filter_by(id=blubid).one()
print x
self.assert_(repr(x) == compare)
diff --git a/test/orm/mapper.py b/test/orm/mapper.py
index e1aca1345..28885356c 100644
--- a/test/orm/mapper.py
+++ b/test/orm/mapper.py
@@ -140,6 +140,30 @@ class MapperTest(MapperSuperTest):
except Exception, e:
assert e is ex
+ clear_mappers()
+
+ # test that TypeError is raised for illegal constructor args,
+ # whether or not explicit __init__ is present [ticket:908]
+ class Foo(object):
+ def __init__(self):
+ pass
+ class Bar(object):
+ pass
+
+ mapper(Foo, users)
+ mapper(Bar, addresses)
+ try:
+ Foo(x=5)
+ assert False
+ except TypeError:
+ assert True
+
+ try:
+ Bar(x=5)
+ assert False
+ except TypeError:
+ assert True
+
def test_props(self):
m = mapper(User, users, properties = {
'addresses' : relation(mapper(Address, addresses))
@@ -1247,7 +1271,7 @@ class MapperExtensionTest(PersistTest):
sess = create_session()
i1 = Item()
- k1 = Keyword('blue')
+ k1 = Keyword()
sess.save(i1)
sess.save(k1)
sess.flush()
diff --git a/test/orm/session.py b/test/orm/session.py
index db9245c72..8baf75275 100644
--- a/test/orm/session.py
+++ b/test/orm/session.py
@@ -894,7 +894,7 @@ class ScopedMapperTest(PersistTest):
pass
Session.mapper(Baz, table2, extension=ext)
assert hasattr(Baz, 'query')
-
+
def test_validating_constructor(self):
s2 = SomeObject(someid=12)
s3 = SomeOtherObject(someid=123, bogus=345)