summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/ext/declarative.py152
-rw-r--r--lib/sqlalchemy/orm/query.py6
-rw-r--r--test/ext/declarative.py41
-rw-r--r--test/orm/query.py4
4 files changed, 129 insertions, 74 deletions
diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py
index 4778b9eba..b29f051b1 100644
--- a/lib/sqlalchemy/ext/declarative.py
+++ b/lib/sqlalchemy/ext/declarative.py
@@ -188,79 +188,93 @@ from sqlalchemy import util, exceptions
from sqlalchemy.sql import util as sql_util
-__all__ = 'declarative_base', 'synonym_for', 'comparable_using'
+__all__ = 'declarative_base', 'synonym_for', 'comparable_using', 'instrument_declarative'
+def instrument_declarative(cls, registry, metadata):
+ """Given a class, configure the class declaratively,
+ using the given registry (any dictionary) and MetaData object.
+ This operation does not assume any kind of class hierarchy.
+
+ """
+ if '_decl_class_registry' in cls.__dict__:
+ raise exceptions.InvalidRequestError("Class %r already has been instrumented declaratively" % cls)
+ cls._decl_class_registry = registry
+ cls.metadata = metadata
+ _as_declarative(cls, cls.__name__, cls.__dict__)
+
+def _as_declarative(cls, classname, dict_):
+ cls._decl_class_registry[classname] = cls
+ our_stuff = util.OrderedDict()
+ for k in dict_:
+ value = dict_[k]
+ if (isinstance(value, tuple) and len(value) == 1 and
+ isinstance(value[0], (Column, MapperProperty))):
+ util.warn("Ignoring declarative-like tuple value of attribute "
+ "%s: possibly a copy-and-paste error with a comma "
+ "left at the end of the line?" % k)
+ continue
+ if not isinstance(value, (Column, MapperProperty)):
+ continue
+ prop = _deferred_relation(cls, value)
+ our_stuff[k] = prop
+
+ # set up attributes in the order they were created
+ our_stuff.sort(lambda x, y: cmp(our_stuff[x]._creation_order,
+ our_stuff[y]._creation_order))
+
+ table = None
+ if '__table__' not in cls.__dict__:
+ if '__tablename__' in cls.__dict__:
+ tablename = cls.__tablename__
+ autoload = cls.__dict__.get('__autoload__')
+ if autoload:
+ table_kw = {'autoload': True}
+ else:
+ table_kw = {}
+ cols = []
+ for key, c in our_stuff.iteritems():
+ if isinstance(c, ColumnProperty):
+ for col in c.columns:
+ if isinstance(col, Column) and col.table is None:
+ _undefer_column_name(key, col)
+ cols.append(col)
+ elif isinstance(c, Column):
+ _undefer_column_name(key, c)
+ cols.append(c)
+ cls.__table__ = table = Table(tablename, cls.metadata,
+ *cols, **table_kw)
+ else:
+ table = cls.__table__
+
+ mapper_args = getattr(cls, '__mapper_args__', {})
+ if 'inherits' not in mapper_args:
+ inherits = cls.__mro__[1]
+ inherits = cls._decl_class_registry.get(inherits.__name__, None)
+ if inherits:
+ mapper_args['inherits'] = inherits
+ if not mapper_args.get('concrete', False) and table:
+ # figure out the inherit condition with relaxed rules
+ # about nonexistent tables, to allow for ForeignKeys to
+ # not-yet-defined tables (since we know for sure that our
+ # parent table is defined within the same MetaData)
+ mapper_args['inherit_condition'] = sql_util.join_condition(
+ inherits.__table__, table,
+ ignore_nonexistent_tables=True)
+
+ if hasattr(cls, '__mapper_cls__'):
+ mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
+ else:
+ mapper_cls = mapper
+
+ cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff,
+ **mapper_args)
class DeclarativeMeta(type):
def __init__(cls, classname, bases, dict_):
if '_decl_class_registry' in cls.__dict__:
return type.__init__(cls, classname, bases, dict_)
-
- cls._decl_class_registry[classname] = cls
- our_stuff = util.OrderedDict()
- for k in dict_:
- value = dict_[k]
- if (isinstance(value, tuple) and len(value) == 1 and
- isinstance(value[0], (Column, MapperProperty))):
- util.warn("Ignoring declarative-like tuple value of attribute "
- "%s: possibly a copy-and-paste error with a comma "
- "left at the end of the line?" % k)
- continue
- if not isinstance(value, (Column, MapperProperty)):
- continue
- prop = _deferred_relation(cls, value)
- our_stuff[k] = prop
-
- # set up attributes in the order they were created
- our_stuff.sort(lambda x, y: cmp(our_stuff[x]._creation_order,
- our_stuff[y]._creation_order))
-
- table = None
- if '__table__' not in cls.__dict__:
- if '__tablename__' in cls.__dict__:
- tablename = cls.__tablename__
- autoload = cls.__dict__.get('__autoload__')
- if autoload:
- table_kw = {'autoload': True}
- else:
- table_kw = {}
- cols = []
- for key, c in our_stuff.iteritems():
- if isinstance(c, ColumnProperty):
- for col in c.columns:
- if isinstance(col, Column) and col.table is None:
- _undefer_column_name(key, col)
- cols.append(col)
- elif isinstance(c, Column):
- _undefer_column_name(key, c)
- cols.append(c)
- cls.__table__ = table = Table(tablename, cls.metadata,
- *cols, **table_kw)
- else:
- table = cls.__table__
-
- mapper_args = getattr(cls, '__mapper_args__', {})
- if 'inherits' not in mapper_args:
- inherits = cls.__mro__[1]
- inherits = cls._decl_class_registry.get(inherits.__name__, None)
- if inherits:
- mapper_args['inherits'] = inherits
- if not mapper_args.get('concrete', False) and table:
- # figure out the inherit condition with relaxed rules
- # about nonexistent tables, to allow for ForeignKeys to
- # not-yet-defined tables (since we know for sure that our
- # parent table is defined within the same MetaData)
- mapper_args['inherit_condition'] = sql_util.join_condition(
- inherits.__table__, table,
- ignore_nonexistent_tables=True)
-
- if hasattr(cls, '__mapper_cls__'):
- mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
- else:
- mapper_cls = mapper
-
- cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff,
- **mapper_args)
+
+ _as_declarative(cls, classname, dict_)
return type.__init__(cls, classname, bases, dict_)
def __setattr__(cls, key, value):
@@ -337,11 +351,11 @@ def comparable_using(comparator_factory):
return comparable_property(comparator_factory, fn)
return decorate
-def declarative_base(engine=None, metadata=None, mapper=None):
+def declarative_base(engine=None, metadata=None, mapper=None, cls=object):
lcl_metadata = metadata or MetaData()
if engine:
lcl_metadata.bind = engine
- class Base(object):
+ class Base(cls):
__metaclass__ = DeclarativeMeta
metadata = lcl_metadata
if mapper:
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index b7d6199b8..608f3f734 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -956,7 +956,7 @@ class Query(object):
if start < 0 or stop < 0:
return list(self)[item]
else:
- res = self.slice_(start, stop)
+ res = self.slice(start, stop)
if step is not None:
return list(res)[None:None:item.step]
else:
@@ -964,7 +964,7 @@ class Query(object):
else:
return list(self[item:item+1])[0]
- def slice_(self, start, stop):
+ def slice(self, start, stop):
"""apply LIMIT/OFFSET to the ``Query`` based on a range and return the newly resulting ``Query``."""
if start is not None and stop is not None:
@@ -974,7 +974,7 @@ class Query(object):
self._limit = stop
elif start is not None and stop is None:
self._offset = (self._offset or 0) + start
- slice_ = _generative(__no_statement_condition)(slice_)
+ slice = _generative(__no_statement_condition)(slice)
def limit(self, limit):
"""Apply a ``LIMIT`` to the query and return the newly resulting
diff --git a/test/ext/declarative.py b/test/ext/declarative.py
index ca91f98fc..d5ea6df47 100644
--- a/test/ext/declarative.py
+++ b/test/ext/declarative.py
@@ -95,6 +95,14 @@ class DeclarativeTest(testing.TestBase, testing.AssertsExecutionResults):
foo = sa.orm.column_property(User.id == 5)
self.assertRaises(sa.exc.InvalidRequestError, go)
+ def test_custom_base(self):
+ class MyBase(object):
+ def foobar(self):
+ return "foobar"
+ Base = decl.declarative_base(cls=MyBase)
+ assert hasattr(Base, 'metadata')
+ assert Base().foobar() == "foobar"
+
def test_add_prop(self):
class User(Base, ComparableEntity):
__tablename__ = 'users'
@@ -135,7 +143,40 @@ class DeclarativeTest(testing.TestBase, testing.AssertsExecutionResults):
eq_(a1, Address(email='two'))
eq_(a1.user, User(name='u1'))
+ def test_as_declarative(self):
+ class User(ComparableEntity):
+ __tablename__ = 'users'
+
+ id = Column('id', Integer, primary_key=True)
+ name = Column('name', String(50))
+ addresses = relation("Address", backref="user")
+
+ class Address(ComparableEntity):
+ __tablename__ = 'addresses'
+
+ id = Column('id', Integer, primary_key=True)
+ email = Column('email', String(50))
+ user_id = Column('user_id', Integer, ForeignKey('users.id'))
+
+ reg = {}
+ decl.instrument_declarative(User, reg, Base.metadata)
+ decl.instrument_declarative(Address, reg, Base.metadata)
+ Base.metadata.create_all()
+
+ u1 = User(name='u1', addresses=[
+ Address(email='one'),
+ Address(email='two'),
+ ])
+ sess = create_session()
+ sess.save(u1)
+ sess.flush()
+ sess.clear()
+ eq_(sess.query(User).all(), [User(name='u1', addresses=[
+ Address(email='one'),
+ Address(email='two'),
+ ])])
+
def test_custom_mapper(self):
class MyExt(sa.orm.MapperExtension):
def create_instance(self):
diff --git a/test/orm/query.py b/test/orm/query.py
index 7a51c3f7e..eb7a0f3d3 100644
--- a/test/orm/query.py
+++ b/test/orm/query.py
@@ -489,7 +489,7 @@ class FromSelfTest(QueryTest):
assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.id.in_([8,9]))._from_self().all()
- assert [User(id=8), User(id=9)] == create_session().query(User).slice_(1,3)._from_self().all()
+ assert [User(id=8), User(id=9)] == create_session().query(User).slice(1,3)._from_self().all()
assert [User(id=8)] == list(create_session().query(User).filter(User.id.in_([8,9]))._from_self()[0:1])
def test_join(self):
@@ -1123,7 +1123,7 @@ class MixedEntitiesTest(QueryTest):
q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address)
self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')])
- q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address)).slice_(1, 3).values(User.name, Address.email_address)
+ q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address)).slice(1, 3).values(User.name, Address.email_address)
self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')])
adalias = aliased(Address)