diff options
| -rw-r--r-- | lib/sqlalchemy/ext/declarative.py | 152 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 6 | ||||
| -rw-r--r-- | test/ext/declarative.py | 41 | ||||
| -rw-r--r-- | test/orm/query.py | 4 |
4 files changed, 129 insertions, 74 deletions
diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py index 4778b9eba..b29f051b1 100644 --- a/lib/sqlalchemy/ext/declarative.py +++ b/lib/sqlalchemy/ext/declarative.py @@ -188,79 +188,93 @@ from sqlalchemy import util, exceptions from sqlalchemy.sql import util as sql_util -__all__ = 'declarative_base', 'synonym_for', 'comparable_using' +__all__ = 'declarative_base', 'synonym_for', 'comparable_using', 'instrument_declarative' +def instrument_declarative(cls, registry, metadata): + """Given a class, configure the class declaratively, + using the given registry (any dictionary) and MetaData object. + This operation does not assume any kind of class hierarchy. + + """ + if '_decl_class_registry' in cls.__dict__: + raise exceptions.InvalidRequestError("Class %r already has been instrumented declaratively" % cls) + cls._decl_class_registry = registry + cls.metadata = metadata + _as_declarative(cls, cls.__name__, cls.__dict__) + +def _as_declarative(cls, classname, dict_): + cls._decl_class_registry[classname] = cls + our_stuff = util.OrderedDict() + for k in dict_: + value = dict_[k] + if (isinstance(value, tuple) and len(value) == 1 and + isinstance(value[0], (Column, MapperProperty))): + util.warn("Ignoring declarative-like tuple value of attribute " + "%s: possibly a copy-and-paste error with a comma " + "left at the end of the line?" % k) + continue + if not isinstance(value, (Column, MapperProperty)): + continue + prop = _deferred_relation(cls, value) + our_stuff[k] = prop + + # set up attributes in the order they were created + our_stuff.sort(lambda x, y: cmp(our_stuff[x]._creation_order, + our_stuff[y]._creation_order)) + + table = None + if '__table__' not in cls.__dict__: + if '__tablename__' in cls.__dict__: + tablename = cls.__tablename__ + autoload = cls.__dict__.get('__autoload__') + if autoload: + table_kw = {'autoload': True} + else: + table_kw = {} + cols = [] + for key, c in our_stuff.iteritems(): + if isinstance(c, ColumnProperty): + for col in c.columns: + if isinstance(col, Column) and col.table is None: + _undefer_column_name(key, col) + cols.append(col) + elif isinstance(c, Column): + _undefer_column_name(key, c) + cols.append(c) + cls.__table__ = table = Table(tablename, cls.metadata, + *cols, **table_kw) + else: + table = cls.__table__ + + mapper_args = getattr(cls, '__mapper_args__', {}) + if 'inherits' not in mapper_args: + inherits = cls.__mro__[1] + inherits = cls._decl_class_registry.get(inherits.__name__, None) + if inherits: + mapper_args['inherits'] = inherits + if not mapper_args.get('concrete', False) and table: + # figure out the inherit condition with relaxed rules + # about nonexistent tables, to allow for ForeignKeys to + # not-yet-defined tables (since we know for sure that our + # parent table is defined within the same MetaData) + mapper_args['inherit_condition'] = sql_util.join_condition( + inherits.__table__, table, + ignore_nonexistent_tables=True) + + if hasattr(cls, '__mapper_cls__'): + mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__) + else: + mapper_cls = mapper + + cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, + **mapper_args) class DeclarativeMeta(type): def __init__(cls, classname, bases, dict_): if '_decl_class_registry' in cls.__dict__: return type.__init__(cls, classname, bases, dict_) - - cls._decl_class_registry[classname] = cls - our_stuff = util.OrderedDict() - for k in dict_: - value = dict_[k] - if (isinstance(value, tuple) and len(value) == 1 and - isinstance(value[0], (Column, MapperProperty))): - util.warn("Ignoring declarative-like tuple value of attribute " - "%s: possibly a copy-and-paste error with a comma " - "left at the end of the line?" % k) - continue - if not isinstance(value, (Column, MapperProperty)): - continue - prop = _deferred_relation(cls, value) - our_stuff[k] = prop - - # set up attributes in the order they were created - our_stuff.sort(lambda x, y: cmp(our_stuff[x]._creation_order, - our_stuff[y]._creation_order)) - - table = None - if '__table__' not in cls.__dict__: - if '__tablename__' in cls.__dict__: - tablename = cls.__tablename__ - autoload = cls.__dict__.get('__autoload__') - if autoload: - table_kw = {'autoload': True} - else: - table_kw = {} - cols = [] - for key, c in our_stuff.iteritems(): - if isinstance(c, ColumnProperty): - for col in c.columns: - if isinstance(col, Column) and col.table is None: - _undefer_column_name(key, col) - cols.append(col) - elif isinstance(c, Column): - _undefer_column_name(key, c) - cols.append(c) - cls.__table__ = table = Table(tablename, cls.metadata, - *cols, **table_kw) - else: - table = cls.__table__ - - mapper_args = getattr(cls, '__mapper_args__', {}) - if 'inherits' not in mapper_args: - inherits = cls.__mro__[1] - inherits = cls._decl_class_registry.get(inherits.__name__, None) - if inherits: - mapper_args['inherits'] = inherits - if not mapper_args.get('concrete', False) and table: - # figure out the inherit condition with relaxed rules - # about nonexistent tables, to allow for ForeignKeys to - # not-yet-defined tables (since we know for sure that our - # parent table is defined within the same MetaData) - mapper_args['inherit_condition'] = sql_util.join_condition( - inherits.__table__, table, - ignore_nonexistent_tables=True) - - if hasattr(cls, '__mapper_cls__'): - mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__) - else: - mapper_cls = mapper - - cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, - **mapper_args) + + _as_declarative(cls, classname, dict_) return type.__init__(cls, classname, bases, dict_) def __setattr__(cls, key, value): @@ -337,11 +351,11 @@ def comparable_using(comparator_factory): return comparable_property(comparator_factory, fn) return decorate -def declarative_base(engine=None, metadata=None, mapper=None): +def declarative_base(engine=None, metadata=None, mapper=None, cls=object): lcl_metadata = metadata or MetaData() if engine: lcl_metadata.bind = engine - class Base(object): + class Base(cls): __metaclass__ = DeclarativeMeta metadata = lcl_metadata if mapper: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index b7d6199b8..608f3f734 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -956,7 +956,7 @@ class Query(object): if start < 0 or stop < 0: return list(self)[item] else: - res = self.slice_(start, stop) + res = self.slice(start, stop) if step is not None: return list(res)[None:None:item.step] else: @@ -964,7 +964,7 @@ class Query(object): else: return list(self[item:item+1])[0] - def slice_(self, start, stop): + def slice(self, start, stop): """apply LIMIT/OFFSET to the ``Query`` based on a range and return the newly resulting ``Query``.""" if start is not None and stop is not None: @@ -974,7 +974,7 @@ class Query(object): self._limit = stop elif start is not None and stop is None: self._offset = (self._offset or 0) + start - slice_ = _generative(__no_statement_condition)(slice_) + slice = _generative(__no_statement_condition)(slice) def limit(self, limit): """Apply a ``LIMIT`` to the query and return the newly resulting diff --git a/test/ext/declarative.py b/test/ext/declarative.py index ca91f98fc..d5ea6df47 100644 --- a/test/ext/declarative.py +++ b/test/ext/declarative.py @@ -95,6 +95,14 @@ class DeclarativeTest(testing.TestBase, testing.AssertsExecutionResults): foo = sa.orm.column_property(User.id == 5) self.assertRaises(sa.exc.InvalidRequestError, go) + def test_custom_base(self): + class MyBase(object): + def foobar(self): + return "foobar" + Base = decl.declarative_base(cls=MyBase) + assert hasattr(Base, 'metadata') + assert Base().foobar() == "foobar" + def test_add_prop(self): class User(Base, ComparableEntity): __tablename__ = 'users' @@ -135,7 +143,40 @@ class DeclarativeTest(testing.TestBase, testing.AssertsExecutionResults): eq_(a1, Address(email='two')) eq_(a1.user, User(name='u1')) + def test_as_declarative(self): + class User(ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + addresses = relation("Address", backref="user") + + class Address(ComparableEntity): + __tablename__ = 'addresses' + + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey('users.id')) + + reg = {} + decl.instrument_declarative(User, reg, Base.metadata) + decl.instrument_declarative(Address, reg, Base.metadata) + Base.metadata.create_all() + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.save(u1) + sess.flush() + sess.clear() + eq_(sess.query(User).all(), [User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + def test_custom_mapper(self): class MyExt(sa.orm.MapperExtension): def create_instance(self): diff --git a/test/orm/query.py b/test/orm/query.py index 7a51c3f7e..eb7a0f3d3 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -489,7 +489,7 @@ class FromSelfTest(QueryTest): assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.id.in_([8,9]))._from_self().all() - assert [User(id=8), User(id=9)] == create_session().query(User).slice_(1,3)._from_self().all() + assert [User(id=8), User(id=9)] == create_session().query(User).slice(1,3)._from_self().all() assert [User(id=8)] == list(create_session().query(User).filter(User.id.in_([8,9]))._from_self()[0:1]) def test_join(self): @@ -1123,7 +1123,7 @@ class MixedEntitiesTest(QueryTest): q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address) self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) - q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address)).slice_(1, 3).values(User.name, Address.email_address) + q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address)).slice(1, 3).values(User.name, Address.email_address) self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')]) adalias = aliased(Address) |
