summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-12-01 23:00:05 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-12-01 23:00:05 +0000
commitc6d01a56e168f1c70461c2684c70a2c5967c4814 (patch)
tree513ee62c991b99542dd6c6cf846dbed2c1503b71
parentebb4b02c2136b3a71989867a52665e34bdfd4236 (diff)
downloadsqlalchemy-c6d01a56e168f1c70461c2684c70a2c5967c4814.tar.gz
- several ORM attributes have been removed or made private:
mapper.get_attr_by_column(), mapper.set_attr_by_column(), mapper.pks_by_table, mapper.cascade_callable(), MapperProperty.cascade_callable(), mapper.canload() - refinements to mapper PK/table column organization, session cascading, some naming convention work
-rw-r--r--CHANGES5
-rw-r--r--lib/sqlalchemy/orm/collections.py4
-rw-r--r--lib/sqlalchemy/orm/dependency.py10
-rw-r--r--lib/sqlalchemy/orm/interfaces.py15
-rw-r--r--lib/sqlalchemy/orm/mapper.py225
-rw-r--r--lib/sqlalchemy/orm/properties.py43
-rw-r--r--lib/sqlalchemy/orm/session.py238
-rw-r--r--lib/sqlalchemy/orm/strategies.py6
-rw-r--r--lib/sqlalchemy/orm/sync.py8
-rw-r--r--lib/sqlalchemy/sql/util.py5
-rw-r--r--test/orm/inheritance/basic.py2
-rw-r--r--test/orm/manytomany.py6
-rw-r--r--test/orm/mapper.py4
13 files changed, 256 insertions, 315 deletions
diff --git a/CHANGES b/CHANGES
index 1dc797d77..6e275254f 100644
--- a/CHANGES
+++ b/CHANGES
@@ -34,6 +34,11 @@ CHANGES
relationship where it takes effect for all inheriting mappers.
[ticket:883]
+ - several ORM attributes have been removed or made private:
+ mapper.get_attr_by_column(), mapper.set_attr_by_column(),
+ mapper.pks_by_table, mapper.cascade_callable(),
+ MapperProperty.cascade_callable(), mapper.canload()
+
- fixed endless loop issue when using lazy="dynamic" on both
sides of a bi-directional relationship [ticket:872]
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
index 942b880c9..c2cd4cf09 100644
--- a/lib/sqlalchemy/orm/collections.py
+++ b/lib/sqlalchemy/orm/collections.py
@@ -122,7 +122,7 @@ def column_mapped_collection(mapping_spec):
if isinstance(mapping_spec, schema.Column):
def keyfunc(value):
m = object_mapper(value)
- return m.get_attr_by_column(value, mapping_spec)
+ return m._get_attr_by_column(value, mapping_spec)
else:
cols = []
for c in mapping_spec:
@@ -133,7 +133,7 @@ def column_mapped_collection(mapping_spec):
mapping_spec = tuple(cols)
def keyfunc(value):
m = object_mapper(value)
- return tuple([m.get_attr_by_column(value, c) for c in mapping_spec])
+ return tuple([m._get_attr_by_column(value, c) for c in mapping_spec])
return lambda: MappedCollection(keyfunc)
def attribute_mapped_collection(attr_name):
diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py
index 968899916..9220c5743 100644
--- a/lib/sqlalchemy/orm/dependency.py
+++ b/lib/sqlalchemy/orm/dependency.py
@@ -111,7 +111,7 @@ class DependencyProcessor(object):
def _verify_canload(self, child):
if not self.enable_typechecks:
return
- if child is not None and not self.mapper.canload(child):
+ if child is not None and not self.mapper._canload(child):
raise exceptions.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ? Set 'enable_typechecks=False' on the relation() to disable this exception. Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (child.__class__, self.prop, self.mapper))
def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
@@ -240,7 +240,7 @@ class OneToManyDP(DependencyProcessor):
uowcommit.register_object(child, isdelete=False)
elif self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
- for c in self.mapper.cascade_iterator('delete', child):
+ for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(c, isdelete=True)
def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
@@ -294,7 +294,7 @@ class ManyToOneDP(DependencyProcessor):
for child in childlist.deleted_items() + childlist.unchanged_items():
if child is not None and self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
- for c in self.mapper.cascade_iterator('delete', child):
+ for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(c, isdelete=True)
else:
for obj in deplist:
@@ -305,7 +305,7 @@ class ManyToOneDP(DependencyProcessor):
for child in childlist.deleted_items():
if self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
- for c in self.mapper.cascade_iterator('delete', child):
+ for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(c, isdelete=True)
def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
@@ -391,7 +391,7 @@ class ManyToManyDP(DependencyProcessor):
for child in childlist.deleted_items():
if self.cascade.delete_orphan and self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
- for c in self.mapper.cascade_iterator('delete', child):
+ for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(c, isdelete=True)
def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index aa0b2dcc2..815ea8ceb 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -319,18 +319,15 @@ class MapperProperty(object):
"""
raise NotImplementedError()
-
+
def cascade_iterator(self, type, object, recursive=None, halt_on=None):
- """return an iterator of objects which are child objects of the given object,
- as attached to the attribute corresponding to this MapperProperty."""
+ """iterate through instances related to the given instance along
+ a particular 'cascade' path, starting with this MapperProperty.
- return []
-
- def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None):
- """run the given callable across all objects which are child objects of
- the given object, as attached to the attribute corresponding to this MapperProperty."""
+ see PropertyLoader for the related instance implementation.
+ """
- return []
+ return iter([])
def get_criterion(self, query, key, value):
"""Return a ``WHERE`` clause suitable for this
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 5673be44c..67087c570 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -4,15 +4,15 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import weakref, warnings, operator
+import weakref, warnings
+from itertools import chain
from sqlalchemy import sql, util, exceptions, logging
-from sqlalchemy.sql import expression, visitors
+from sqlalchemy.sql import expression, visitors, operators
from sqlalchemy.sql import util as sqlutil
from sqlalchemy.orm import util as mapperutil
from sqlalchemy.orm.util import ExtensionCarrier, create_row_adapter
from sqlalchemy.orm import sync, attributes
from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, PropComparator
-deferred_load = None
__all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry']
@@ -22,7 +22,7 @@ mapper_registry = weakref.WeakKeyDictionary()
# a list of MapperExtensions that will be installed in all mappers by default
global_extensions = []
-# a constant returned by get_attr_by_column to indicate
+# a constant returned by _get_attr_by_column to indicate
# this mapper is not handling an attribute for a particular
# column
NO_ATTRIBUTE = object()
@@ -152,6 +152,7 @@ class Mapper(object):
self._compile_inheritance()
self._compile_tables()
self._compile_properties()
+ self._compile_pks()
self._compile_selectable()
self.__log("constructed")
@@ -376,12 +377,6 @@ class Mapper(object):
self.polymorphic_map[key] = class_or_mapper
def _compile_tables(self):
- """After the inheritance relationships have been reconciled,
- set up some more table-based instance variables and determine
- the *primary key* columns for all tables represented by this
- ``Mapper``.
- """
-
# summary of the various Selectable units:
# mapped_table - the Selectable that represents a join of the underlying Tables to be saved (or just the Table)
# local_table - the Selectable that was passed to this Mapper's constructor, if any
@@ -401,27 +396,25 @@ class Mapper(object):
if not self.tables:
raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table))
- # TODO: move the "figure pks" step down into compile_properties; after
- # all columns have been mapped, assemble PK columns and their
- # proxied parents into the pks_by_table collection, then get rid
- # of the _has_pks method
-
- # determine primary key columns
- self.pks_by_table = {}
+ def _compile_pks(self):
- # go through all of our represented tables
- # and assemble primary key columns
- for t in self.tables + [self.mapped_table]:
+ self._pks_by_table = {}
+ self._cols_by_table = {}
+
+ all_cols = util.Set(chain(*[c2 for c2 in [col.proxy_set for col in [c for c in self._columntoproperty]]]))
+ pk_cols = util.Set([c for c in all_cols if c.primary_key])
+
+ for t in util.Set(self.tables + [self.mapped_table]):
self._all_tables.add(t)
- if t not in self.pks_by_table:
- self.pks_by_table[t] = util.OrderedSet()
- self.pks_by_table[t].update(t.primary_key)
-
- if self.primary_key_argument is not None:
+ if t.primary_key and pk_cols.issuperset(t.primary_key):
+ self._pks_by_table[t] = util.Set(t.primary_key).intersection(pk_cols)
+ self._cols_by_table[t] = util.Set(t.c).intersection(all_cols)
+
+ if self.primary_key_argument:
for k in self.primary_key_argument:
- self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k)
+ self._pks_by_table.setdefault(k.table, util.Set()).add(k)
- if len(self.pks_by_table[self.mapped_table]) == 0:
+ if len(self._pks_by_table[self.mapped_table]) == 0:
raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
if self.inherits is not None and not self.concrete and not self.primary_key_argument:
@@ -437,7 +430,7 @@ class Mapper(object):
primary_key = expression.ColumnSet()
- for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
+ for col in (self.primary_key_argument or self._pks_by_table[self.mapped_table]):
c = self.mapped_table.corresponding_column(col, raiseerr=False)
if c is None:
for cc in self._equivalent_columns[col]:
@@ -474,7 +467,10 @@ class Mapper(object):
self.primary_key = primary_key
self.__log("Identified primary key columns: " + str(primary_key))
-
+
+ # create a "get clause" based on the primary key. this is used
+ # by query.get() and many-to-one lazyloads to load this item
+ # by primary key.
_get_clause = sql.and_()
_get_params = {}
for primary_key in self.primary_key:
@@ -510,7 +506,7 @@ class Mapper(object):
result = {}
def visit_binary(binary):
- if binary.operator == operator.eq:
+ if binary.operator == operators.eq:
if binary.left in result:
result[binary.left].add(binary.right)
else:
@@ -533,13 +529,17 @@ class Mapper(object):
return
recursive.add(col)
for fk in col.foreign_keys:
- result.setdefault(fk.column, util.Set()).add(equiv)
+ if fk.column not in result:
+ result[fk.column] = util.Set()
+ result[fk.column].add(equiv)
equivs(fk.column, recursive, col)
- for column in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
+ for column in (self.primary_key_argument or self._pks_by_table[self.mapped_table]):
for col in column.proxy_set:
if not col.foreign_keys:
- result.setdefault(col, util.Set()).add(col)
+ if col not in result:
+ result[col] = util.Set()
+ result[col].add(col)
else:
equivs(col, util.Set(), col)
@@ -571,11 +571,6 @@ class Mapper(object):
return getattr(getattr(cls, clskey), key)
def _compile_properties(self):
- """Inspect the properties dictionary sent to the Mapper's
- constructor as well as the mapped_table, and create
- ``MapperProperty`` objects corresponding to each mapped column
- and relation.
- """
# object attribute names mapped to MapperProperty objects
self.__props = util.OrderedDict()
@@ -637,9 +632,6 @@ class Mapper(object):
# TODO: the "property already exists" case is still not well defined here.
# assuming single-column, etc.
- if column in self.primary_key and prop.columns[-1] in self.primary_key:
- warnings.warn(RuntimeWarning("On mapper %s, primary key column '%s' is being combined with distinct primary key column '%s' in attribute '%s'. Use explicit properties to give each column its own mapped attribute name." % (str(self), str(column), str(prop.columns[-1]), key)))
-
if prop.parent is not self:
# existing ColumnProperty from an inheriting mapper.
# make a copy and append our column to it
@@ -896,38 +888,27 @@ class Mapper(object):
instance.
"""
- return [self.get_attr_by_column(instance, column) for column in self.primary_key]
+ return [self._get_attr_by_column(instance, column) for column in self.primary_key]
- def canload(self, instance):
+ def _canload(self, instance):
"""return true if this mapper is capable of loading the given instance"""
if self.polymorphic_on is not None:
return isinstance(instance, self.class_)
else:
return instance.__class__ is self.class_
- def _getpropbycolumn(self, column, raiseerror=True):
+ def _get_attr_by_column(self, obj, column):
+ """Return an instance attribute using a Column as the key."""
try:
- return self._columntoproperty[column]
+ return self._columntoproperty[column].getattr(obj, column)
except KeyError:
- try:
- prop = self.__props[column.key]
- if not raiseerror:
- return None
+ prop = self.__props.get(column.key, None)
+ if prop:
raise exceptions.InvalidRequestError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop)))
- except KeyError:
- if not raiseerror:
- return None
+ else:
raise exceptions.InvalidRequestError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self)))
-
- def get_attr_by_column(self, obj, column, raiseerror=True):
- """Return an instance attribute using a Column as the key."""
-
- prop = self._getpropbycolumn(column, raiseerror)
- if prop is None:
- return NO_ATTRIBUTE
- return prop.getattr(obj, column)
-
- def set_attr_by_column(self, obj, column, value):
+
+ def _set_attr_by_column(self, obj, column, value):
"""Set the value of an instance attribute using a Column as the key."""
self._columntoproperty[column].setattr(obj, value, column)
@@ -996,18 +977,18 @@ class Mapper(object):
table_to_mapper = {}
for mapper in self.base_mapper.polymorphic_iterator():
for t in mapper.tables:
- table_to_mapper.setdefault(t, mapper)
+ table_to_mapper[t] = mapper
- for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=False):
+ for table in sqlutil.sort_tables(table_to_mapper.keys()):
# two lists to store parameters for each table/object pair located
insert = []
update = []
for obj, connection in tups:
mapper = object_mapper(obj)
- if table not in mapper.tables or not mapper._has_pks(table):
+ if table not in mapper._pks_by_table:
continue
- pks = mapper.pks_by_table[table]
+ pks = mapper._pks_by_table[table]
instance_key = mapper.identity_key_from_instance(obj)
if self.__should_log_debug:
@@ -1019,11 +1000,11 @@ class Mapper(object):
hasdata = False
if isinsert:
- for col in table.columns:
+ for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col.key] = 1
elif col in pks:
- value = mapper.get_attr_by_column(obj, col)
+ value = mapper._get_attr_by_column(obj, col)
if value is not None:
params[col.key] = value
elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col):
@@ -1033,9 +1014,7 @@ class Mapper(object):
if col.default is None or value is not None:
params[col.key] = value
else:
- value = mapper.get_attr_by_column(obj, col, False)
- if value is NO_ATTRIBUTE:
- continue
+ value = mapper._get_attr_by_column(obj, col)
if col.default is None or value is not None:
if isinstance(value, sql.ClauseElement):
value_params[col] = value
@@ -1043,24 +1022,22 @@ class Mapper(object):
params[col.key] = value
insert.append((obj, params, mapper, connection, value_params))
else:
- for col in table.columns:
+ for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
- params[col._label] = mapper.get_attr_by_column(obj, col)
+ params[col._label] = mapper._get_attr_by_column(obj, col)
params[col.key] = params[col._label] + 1
for prop in mapper._columntoproperty.values():
history = attributes.get_history(obj, prop.key, passive=True)
if history and history.added_items():
hasdata = True
elif col in pks:
- params[col._label] = mapper.get_attr_by_column(obj, col)
+ params[col._label] = mapper._get_attr_by_column(obj, col)
elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col):
pass
else:
if post_update_cols is not None and col not in post_update_cols:
continue
- prop = mapper._getpropbycolumn(col, False)
- if prop is None:
- continue
+ prop = mapper._columntoproperty[col]
history = attributes.get_history(obj, prop.key, passive=True)
if history:
a = history.added_items()
@@ -1076,14 +1053,14 @@ class Mapper(object):
if update:
mapper = table_to_mapper[table]
clause = sql.and_()
- for col in mapper.pks_by_table[table]:
+ for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True))
if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True))
statement = table.update(clause)
rows = 0
supports_sane_rowcount = True
- pks = mapper.pks_by_table[table]
+ pks = mapper._pks_by_table[table]
def comparator(a, b):
for col in pks:
x = cmp(a[1][col._label],b[1][col._label])
@@ -1115,23 +1092,22 @@ class Mapper(object):
if primary_key is not None:
i = 0
- for col in mapper.pks_by_table[table]:
- if mapper.get_attr_by_column(obj, col) is None and len(primary_key) > i:
- mapper.set_attr_by_column(obj, col, primary_key[i])
+ for col in mapper._pks_by_table[table]:
+ if mapper._get_attr_by_column(obj, col) is None and len(primary_key) > i:
+ mapper._set_attr_by_column(obj, col, primary_key[i])
i+=1
mapper._postfetch(connection, table, obj, c, c.last_inserted_params(), value_params)
# synchronize newly inserted ids from one table to the next
# TODO: this fires off more than needed, try to organize syncrules
# per table
- mappers = list(mapper.iterate_to_root())
- mappers.reverse()
- for m in mappers:
+ for m in util.reversed(list(mapper.iterate_to_root())):
if m._synchronizer is not None:
m._synchronizer.execute(obj, obj)
# testlib.pragma exempt:__hash__
inserted_objects.add((id(obj), obj, connection))
+
if not postupdate:
for id_, obj, connection in inserted_objects:
for mapper in object_mapper(obj).iterate_to_root():
@@ -1141,7 +1117,7 @@ class Mapper(object):
for mapper in object_mapper(obj).iterate_to_root():
if 'after_update' in mapper.extension.methods:
mapper.extension.after_update(mapper, connection, obj)
-
+
def _postfetch(self, connection, table, obj, resultproxy, params, value_params):
"""After an ``INSERT`` or ``UPDATE``, assemble newly generated
values on an instance. For columns which are marked as being generated
@@ -1152,20 +1128,15 @@ class Mapper(object):
postfetch_cols = resultproxy.postfetch_cols().union(util.Set(value_params.keys()))
deferred_props = []
- for c in table.c:
+ for c in self._cols_by_table[table]:
if c in postfetch_cols and (not c.key in params or c in value_params):
- prop = self._getpropbycolumn(c, raiseerror=False)
- if prop is None:
- continue
+ prop = self._columntoproperty[c]
deferred_props.append(prop.key)
continue
if c.primary_key or not c.key in params:
continue
- v = self.get_attr_by_column(obj, c, False)
- if v is NO_ATTRIBUTE:
- continue
- elif v != params[c.key]:
- self.set_attr_by_column(obj, c, params[c.key])
+ if self._get_attr_by_column(obj, c) != params[c.key]:
+ self._set_attr_by_column(obj, c, params[c.key])
if deferred_props:
expire_instance(obj, deferred_props)
@@ -1196,13 +1167,13 @@ class Mapper(object):
table_to_mapper = {}
for mapper in self.base_mapper.polymorphic_iterator():
for t in mapper.tables:
- table_to_mapper.setdefault(t, mapper)
+ table_to_mapper[t] = mapper
for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=True):
delete = {}
for (obj, connection) in tups:
mapper = object_mapper(obj)
- if table not in mapper.tables or not mapper._has_pks(table):
+ if table not in mapper._pks_by_table:
continue
params = {}
@@ -1210,23 +1181,23 @@ class Mapper(object):
continue
else:
delete.setdefault(connection, []).append(params)
- for col in mapper.pks_by_table[table]:
- params[col.key] = mapper.get_attr_by_column(obj, col)
+ for col in mapper._pks_by_table[table]:
+ params[col.key] = mapper._get_attr_by_column(obj, col)
if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
- params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col)
+ params[mapper.version_id_col.key] = mapper._get_attr_by_column(obj, mapper.version_id_col)
# testlib.pragma exempt:__hash__
deleted_objects.add((id(obj), obj, connection))
for connection, del_objects in delete.iteritems():
mapper = table_to_mapper[table]
def comparator(a, b):
- for col in mapper.pks_by_table[table]:
+ for col in mapper._pks_by_table[table]:
x = cmp(a[col.key],b[col.key])
if x != 0:
return x
return 0
del_objects.sort(comparator)
clause = sql.and_()
- for col in mapper.pks_by_table[table]:
+ for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True))
if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True))
@@ -1240,17 +1211,6 @@ class Mapper(object):
if 'after_delete' in mapper.extension.methods:
mapper.extension.after_delete(mapper, connection, obj)
- def _has_pks(self, table):
- # TODO: determine this beforehand
- if self.pks_by_table.get(table, None):
- for k in self.pks_by_table[table]:
- if k not in self._columntoproperty:
- return False
- else:
- return True
- else:
- return False
-
def register_dependencies(self, uowcommit, *args, **kwargs):
"""Register ``DependencyProcessor`` instances with a
``unitofwork.UOWTransaction``.
@@ -1263,8 +1223,8 @@ class Mapper(object):
prop.register_dependencies(uowcommit, *args, **kwargs)
def cascade_iterator(self, type, object, recursive=None, halt_on=None):
- """Iterate each element in an object graph, for all relations
- taht meet the given cascade rule.
+ """Iterate each element and its mapper in an object graph,
+ for all relations that meet the given cascade rule.
type
The name of the cascade rule (i.e. save-update, delete,
@@ -1282,33 +1242,8 @@ class Mapper(object):
if recursive is None:
recursive=util.IdentitySet()
for prop in self.__props.values():
- for c in prop.cascade_iterator(type, object, recursive, halt_on=halt_on):
- yield c
-
- def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None):
- """Execute a callable for each element in an object graph, for
- all relations that meet the given cascade rule.
-
- type
- The name of the cascade rule (i.e. save-update, delete, etc.)
-
- object
- The lead object instance. child items will be processed per
- the relations defined for this object's mapper.
-
- callable\_
- The callable function.
-
- recursive
- Used by the function for internal context during recursive
- calls, leave as None.
-
- """
-
- if recursive is None:
- recursive=util.IdentitySet()
- for prop in self.__props.values():
- prop.cascade_callable(type, object, callable_, recursive, halt_on=halt_on)
+ for (c, m) in prop.cascade_iterator(type, object, recursive, halt_on=halt_on):
+ yield (c, m)
def get_select_mapper(self):
"""Return the mapper used for issuing selects.
@@ -1365,8 +1300,8 @@ class Mapper(object):
isnew = False
- if context.version_check and self.version_id_col is not None and self.get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]:
- raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self.get_attr_by_column(instance, self.version_id_col), row[self.version_id_col]))
+ if context.version_check and self.version_id_col is not None and self._get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]:
+ raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._get_attr_by_column(instance, self.version_id_col), row[self.version_id_col]))
if context.populate_existing or self.always_refresh or instance._state.trigger is not None:
instance._state.trigger = None
@@ -1541,7 +1476,7 @@ class Mapper(object):
params = {}
for c in param_names:
- params[c.name] = self.get_attr_by_column(instance, c)
+ params[c.name] = self._get_attr_by_column(instance, c)
row = selectcontext.session.connection(self).execute(statement, params).fetchone()
self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True)
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 2c50ec92f..806c91a2b 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -12,13 +12,13 @@ to handle flush-time dependency sorting and processing.
"""
from sqlalchemy import sql, schema, util, exceptions, logging
-from sqlalchemy.sql import util as sql_util, visitors
+from sqlalchemy.sql import util as sql_util, visitors, operators
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
-import operator
from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
from sqlalchemy.exceptions import ArgumentError
+import warnings
__all__ = ['ColumnProperty', 'CompositeProperty', 'SynonymProperty', 'PropertyLoader', 'BackRef']
@@ -48,7 +48,12 @@ class ColumnProperty(StrategizedProperty):
return strategies.DeferredColumnLoader(self)
else:
return strategies.ColumnLoader(self)
-
+
+ def do_init(self):
+ super(ColumnProperty, self).do_init()
+ if len(self.columns) > 1 and self.parent.primary_key.issuperset(self.columns):
+ warnings.warn(RuntimeWarning("On mapper %s, primary key column '%s' is being combined with distinct primary key column '%s' in attribute '%s'. Use explicit properties to give each column its own mapped attribute name." % (str(self.parent), str(self.columns[1]), str(self.columns[0]), self.key)))
+
def copy(self):
return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns)
@@ -75,10 +80,8 @@ class ColumnProperty(StrategizedProperty):
col = self.prop.columns[0]
return op(col._bind_param(other), col)
-
ColumnProperty.logger = logging.class_logger(ColumnProperty)
-
class CompositeProperty(ColumnProperty):
"""subclasses ColumnProperty to provide composite type support."""
@@ -86,6 +89,10 @@ class CompositeProperty(ColumnProperty):
super(CompositeProperty, self).__init__(*columns, **kwargs)
self.composite_class = class_
self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self)
+
+ def do_init(self):
+ super(ColumnProperty, self).do_init()
+ # TODO: similar PK check as ColumnProperty does ?
def copy(self):
return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns)
@@ -283,7 +290,7 @@ class PropertyLoader(StrategizedProperty):
return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
def compare(self, op, value, value_is_parent=False):
- if op == operator.eq:
+ if op == operators.eq:
if value is None:
return ~sql.exists([1], self.prop.mapper.mapped_table, self.prop.primaryjoin)
else:
@@ -347,23 +354,9 @@ class PropertyLoader(StrategizedProperty):
if not isinstance(c, self.mapper.class_):
raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
recursive.add(c)
- yield c
- for c2 in mapper.cascade_iterator(type, c, recursive):
- yield c2
-
- def cascade_callable(self, type, object, callable_, recursive, halt_on=None):
- if not type in self.cascade:
- return
-
- mapper = self.mapper.primary_mapper()
- passive = type != 'delete' or self.passive_deletes
- for c in attributes.get_as_list(object, self.key, passive=passive):
- if c is not None and c not in recursive and (halt_on is None or not halt_on(c)):
- if not isinstance(c, self.mapper.class_):
- raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
- recursive.add(c)
- callable_(c, mapper.entity_name)
- mapper.cascade_callable(type, c, callable_, recursive)
+ yield (c, mapper)
+ for (c2, m) in mapper.cascade_iterator(type, c, recursive):
+ yield (c2, m)
def _get_target_class(self):
"""Return the target class of the relation, even if the
@@ -464,7 +457,7 @@ class PropertyLoader(StrategizedProperty):
if self.foreign_keys:
self._opposite_side = util.Set()
def visit_binary(binary):
- if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
if binary.left in self.foreign_keys:
self._opposite_side.add(binary.right)
@@ -477,7 +470,7 @@ class PropertyLoader(StrategizedProperty):
self.foreign_keys = util.Set()
self._opposite_side = util.Set()
def visit_binary(binary):
- if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
# this check is for when the user put the "view_only" flag on and has tables that have nothing
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 7097273f5..28ef39aba 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -1,4 +1,4 @@
-# objectstore.py
+# session.py
# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
@@ -103,10 +103,10 @@ class SessionExtension(object):
Note that this may not be per-flush if a longer running transaction is ongoing."""
- def before_flush(self, session, flush_context, objects):
+ def before_flush(self, session, flush_context, instances):
"""execute before flush process has started.
- 'objects' is an optional list of objects which were passed to the ``flush()``
+ 'instances' is an optional list of objects which were passed to the ``flush()``
method.
"""
@@ -719,7 +719,7 @@ class Session(object):
entity_name = kwargs.pop('entity_name', None)
return self.query(class_, entity_name=entity_name).load(ident, **kwargs)
- def refresh(self, obj, attribute_names=None):
+ def refresh(self, instance, attribute_names=None):
"""Refresh the attributes on the given instance.
When called, a query will be issued
@@ -738,12 +738,12 @@ class Session(object):
refreshed.
"""
- self._validate_persistent(obj)
+ self._validate_persistent(instance)
- if self.query(obj.__class__)._get(obj._instance_key, refresh_instance=obj, only_load_props=attribute_names) is None:
- raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(obj))
+ if self.query(instance.__class__)._get(instance._instance_key, refresh_instance=instance, only_load_props=attribute_names) is None:
+ raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
- def expire(self, obj, attribute_names=None):
+ def expire(self, instance, attribute_names=None):
"""Expire the attributes on the given instance.
The instance's attributes are instrumented such that
@@ -764,11 +764,16 @@ class Session(object):
"""
if attribute_names:
- self._validate_persistent(obj)
- expire_instance(obj, attribute_names=attribute_names)
+ self._validate_persistent(instance)
+ expire_instance(instance, attribute_names=attribute_names)
else:
- for c in [obj] + list(_object_mapper(obj).cascade_iterator('refresh-expire', obj)):
- self._validate_persistent(obj)
+ # pre-fetch the full cascade since the expire is going to
+ # remove associations
+ cascaded = list(_cascade_iterator('refresh-expire', instance))
+ self._validate_persistent(instance)
+ expire_instance(instance, None)
+ for (c, m) in cascaded:
+ self._validate_persistent(c)
expire_instance(c, None)
def prune(self):
@@ -784,20 +789,20 @@ class Session(object):
return self.uow.prune_identity_map()
- def expunge(self, object):
- """Remove the given `object` from this ``Session``.
+ def expunge(self, instance):
+ """Remove the given `instance` from this ``Session``.
- This will free all internal references to the object.
+ This will free all internal references to the instance.
Cascading will be applied according to the *expunge* cascade
rule.
"""
- self._validate_persistent(object)
- for c in [object] + list(_object_mapper(object).cascade_iterator('expunge', object)):
+ self._validate_persistent(instance)
+ for c, m in [(instance, None)] + list(_cascade_iterator('expunge', instance)):
if c in self:
self.uow._remove_deleted(c)
self._unattach(c)
- def save(self, object, entity_name=None):
+ def save(self, instance, entity_name=None):
"""Add a transient (unsaved) instance to this ``Session``.
This operation cascades the `save_or_update` method to
@@ -808,12 +813,10 @@ class Session(object):
specific ``Mapper`` used to handle this instance.
"""
- self._save_impl(object, entity_name=entity_name)
- _object_mapper(object).cascade_callable('save-update', object,
- lambda c, e:self._save_or_update_impl(c, e),
- halt_on=lambda c:c in self)
+ self._save_impl(instance, entity_name=entity_name)
+ self._cascade_save_or_update(instance)
- def update(self, object, entity_name=None):
+ def update(self, instance, entity_name=None):
"""Bring the given detached (saved) instance into this
``Session``.
@@ -826,37 +829,37 @@ class Session(object):
``cascade="save-update"``.
"""
- self._update_impl(object, entity_name=entity_name)
- _object_mapper(object).cascade_callable('save-update', object,
- lambda c, e:self._save_or_update_impl(c, e),
- halt_on=lambda c:c in self)
+ self._update_impl(instance, entity_name=entity_name)
+ self._cascade_save_or_update(instance)
- def save_or_update(self, object, entity_name=None):
- """Save or update the given object into this ``Session``.
+ def save_or_update(self, instance, entity_name=None):
+ """Save or update the given instance into this ``Session``.
The presence of an `_instance_key` attribute on the instance
determines whether to ``save()`` or ``update()`` the instance.
"""
- self._save_or_update_impl(object, entity_name=entity_name)
- _object_mapper(object).cascade_callable('save-update', object,
- lambda c, e:self._save_or_update_impl(c, e),
- halt_on=lambda c:c in self)
+ self._save_or_update_impl(instance, entity_name=entity_name)
+ self._cascade_save_or_update(instance)
+
+ def _cascade_save_or_update(self, instance):
+ for obj, mapper in _cascade_iterator('save-update', instance, halt_on=lambda c:c in self):
+ self._save_or_update_impl(obj, mapper.entity_name)
- def delete(self, object):
+ def delete(self, instance):
"""Mark the given instance as deleted.
The delete operation occurs upon ``flush()``.
"""
- self._delete_impl(object)
- for c in list(_object_mapper(object).cascade_iterator('delete', object)):
+ self._delete_impl(instance)
+ for c, m in _cascade_iterator('delete', instance):
self._delete_impl(c, ignore_transient=True)
- def merge(self, object, entity_name=None, dont_load=False, _recursive=None):
- """Copy the state of the given `object` onto the persistent
- object with the same identifier.
+ def merge(self, instance, entity_name=None, dont_load=False, _recursive=None):
+ """Copy the state of the given `instance` onto the persistent
+ instance with the same identifier.
If there is no persistent instance currently associated with
the session, it will be loaded. Return the persistent
@@ -871,20 +874,20 @@ class Session(object):
if _recursive is None:
_recursive = {} #TODO: this should be an IdentityDict
if entity_name is not None:
- mapper = _class_mapper(object.__class__, entity_name=entity_name)
+ mapper = _class_mapper(instance.__class__, entity_name=entity_name)
else:
- mapper = _object_mapper(object)
- if object in _recursive:
- return _recursive[object]
+ mapper = _object_mapper(instance)
+ if instance in _recursive:
+ return _recursive[instance]
- key = getattr(object, '_instance_key', None)
+ key = getattr(instance, '_instance_key', None)
if key is None:
merged = attributes.new_instance(mapper.class_)
else:
if key in self.identity_map:
merged = self.identity_map[key]
elif dont_load:
- if object._state.modified:
+ if instance._state.modified:
raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects marked as 'dirty'. flush() all changes on mapped instances before merging with dont_load=True.")
merged = attributes.new_instance(mapper.class_)
@@ -894,10 +897,10 @@ class Session(object):
else:
merged = self.get(mapper.class_, key[1])
if merged is None:
- raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(object))
- _recursive[object] = merged
+ raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(instance))
+ _recursive[instance] = merged
for prop in mapper.iterate_properties:
- prop.merge(self, object, merged, dont_load, _recursive)
+ prop.merge(self, instance, merged, dont_load, _recursive)
if key is None:
self.save(merged, entity_name=mapper.entity_name)
elif dont_load:
@@ -968,96 +971,96 @@ class Session(object):
return mapper.identity_key_from_instance(instance)
identity_key = classmethod(identity_key)
- def object_session(cls, obj):
+ def object_session(cls, instance):
"""return the ``Session`` to which the given object belongs."""
- return object_session(obj)
+ return object_session(instance)
object_session = classmethod(object_session)
- def _save_impl(self, obj, **kwargs):
- if hasattr(obj, '_instance_key'):
- raise exceptions.InvalidRequestError("Instance '%s' is already persistent" % mapperutil.instance_str(obj))
+ def _save_impl(self, instance, **kwargs):
+ if hasattr(instance, '_instance_key'):
+ raise exceptions.InvalidRequestError("Instance '%s' is already persistent" % mapperutil.instance_str(instance))
else:
# TODO: consolidate the steps here
- attributes.manage(obj)
- obj._entity_name = kwargs.get('entity_name', None)
- self._attach(obj)
- self.uow.register_new(obj)
+ attributes.manage(instance)
+ instance._entity_name = kwargs.get('entity_name', None)
+ self._attach(instance)
+ self.uow.register_new(instance)
- def _update_impl(self, obj, **kwargs):
- if obj in self and obj not in self.deleted:
+ def _update_impl(self, instance, **kwargs):
+ if instance in self and instance not in self.deleted:
return
- if not hasattr(obj, '_instance_key'):
- raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(obj))
- elif self.identity_map.get(obj._instance_key, obj) is not obj:
+ if not hasattr(instance, '_instance_key'):
+ raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance))
+ elif self.identity_map.get(instance._instance_key, instance) is not instance:
raise exceptions.InvalidRequestError("Could not update instance '%s', identity key %s; a different instance with the same identity key already exists in this session." % (mapperutil.instance_str(obj), obj._instance_key))
- self._attach(obj)
+ self._attach(instance)
- def _save_or_update_impl(self, object, entity_name=None):
- key = getattr(object, '_instance_key', None)
+ def _save_or_update_impl(self, instance, entity_name=None):
+ key = getattr(instance, '_instance_key', None)
if key is None:
- self._save_impl(object, entity_name=entity_name)
+ self._save_impl(instance, entity_name=entity_name)
else:
- self._update_impl(object, entity_name=entity_name)
+ self._update_impl(instance, entity_name=entity_name)
- def _delete_impl(self, obj, ignore_transient=False):
- if obj in self and obj in self.deleted:
+ def _delete_impl(self, instance, ignore_transient=False):
+ if instance in self and instance in self.deleted:
return
- if not hasattr(obj, '_instance_key'):
+ if not hasattr(instance, '_instance_key'):
if ignore_transient:
return
else:
- raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(obj))
- if self.identity_map.get(obj._instance_key, obj) is not obj:
- raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(obj), obj._instance_key))
- self._attach(obj)
- self.uow.register_deleted(obj)
-
- def _register_persistent(self, obj):
- obj._sa_session_id = self.hash_key
- self.identity_map[obj._instance_key] = obj
- obj._state.commit_all()
-
- def _attach(self, obj):
- old_id = getattr(obj, '_sa_session_id', None)
+ raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance))
+ if self.identity_map.get(instance._instance_key, instance) is not instance:
+ raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(instance), instance._instance_key))
+ self._attach(instance)
+ self.uow.register_deleted(instance)
+
+ def _register_persistent(self, instance):
+ instance._sa_session_id = self.hash_key
+ self.identity_map[instance._instance_key] = instance
+ instance._state.commit_all()
+
+ def _attach(self, instance):
+ old_id = getattr(instance, '_sa_session_id', None)
if old_id != self.hash_key:
- if old_id is not None and old_id in _sessions and obj in _sessions[old_id]:
+ if old_id is not None and old_id in _sessions and instance in _sessions[old_id]:
raise exceptions.InvalidRequestError("Object '%s' is already attached "
"to session '%s' (this is '%s')" %
- (mapperutil.instance_str(obj), old_id, id(self)))
+ (mapperutil.instance_str(instance), old_id, id(self)))
- key = getattr(obj, '_instance_key', None)
+ key = getattr(instance, '_instance_key', None)
if key is not None:
- self.identity_map[key] = obj
- obj._sa_session_id = self.hash_key
+ self.identity_map[key] = instance
+ instance._sa_session_id = self.hash_key
- def _unattach(self, obj):
- if obj._sa_session_id == self.hash_key:
- del obj._sa_session_id
+ def _unattach(self, instance):
+ if instance._sa_session_id == self.hash_key:
+ del instance._sa_session_id
- def _validate_persistent(self, obj):
- """Validate that the given object is persistent within this
+ def _validate_persistent(self, instance):
+ """Validate that the given instance is persistent within this
``Session``.
"""
- return obj in self
+ return instance in self
- def __contains__(self, obj):
- """return True if the given object is associated with this session.
+ def __contains__(self, instance):
+ """return True if the given instance is associated with this session.
The instance may be pending or persistent within the Session for a
result of True.
"""
- return obj in self.uow.new or (hasattr(obj, '_instance_key') and self.identity_map.get(obj._instance_key) is obj)
+ return instance in self.uow.new or (hasattr(instance, '_instance_key') and self.identity_map.get(instance._instance_key) is instance)
def __iter__(self):
- """return an iterator of all objects which are pending or persistent within this Session."""
+ """return an iterator of all instances which are pending or persistent within this Session."""
return iter(list(self.uow.new) + self.uow.identity_map.values())
- def is_modified(self, obj, include_collections=True, passive=False):
- """return True if the given object has modified attributes.
+ def is_modified(self, instance, include_collections=True, passive=False):
+ """return True if the given instance has modified attributes.
This method retrieves a history instance for each instrumented attribute
on the instance and performs a comparison of the current value to its
@@ -1073,15 +1076,15 @@ class Session(object):
not be loaded in the course of performing this test.
"""
- for attr in attributes.managed_attributes(obj.__class__):
+ for attr in attributes.managed_attributes(instance.__class__):
if not include_collections and hasattr(attr.impl, 'get_collection'):
continue
- if attr.get_history(obj).is_modified():
+ if attr.get_history(instance).is_modified():
return True
return False
dirty = property(lambda s:s.uow.locate_dirty(),
- doc="""A ``Set`` of all objects marked as 'dirty' within this ``Session``.
+ doc="""A ``Set`` of all instances marked as 'dirty' within this ``Session``.
Note that the 'dirty' state here is 'optimistic'; most attribute-setting or collection
modification operations will mark an instance as 'dirty' and place it in this set,
@@ -1095,12 +1098,12 @@ class Session(object):
""")
deleted = property(lambda s:s.uow.deleted,
- doc="A ``Set`` of all objects marked as 'deleted' within this ``Session``")
+ doc="A ``Set`` of all instances marked as 'deleted' within this ``Session``")
new = property(lambda s:s.uow.new,
- doc="A ``Set`` of all objects marked as 'new' within this ``Session``.")
+ doc="A ``Set`` of all instances marked as 'new' within this ``Session``.")
-def expire_instance(obj, attribute_names):
+def expire_instance(instance, attribute_names):
"""standalone expire instance function.
installs a callable with the given instance's _state
@@ -1110,29 +1113,30 @@ def expire_instance(obj, attribute_names):
If the list is None or blank, the entire instance is expired.
"""
- if obj._state.trigger is None:
+ if instance._state.trigger is None:
def load_attributes(instance, attribute_names):
if object_session(instance).query(instance.__class__)._get(instance._instance_key, refresh_instance=instance, only_load_props=attribute_names) is None:
raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
- obj._state.trigger = load_attributes
+ instance._state.trigger = load_attributes
- obj._state.expire_attributes(attribute_names)
+ instance._state.expire_attributes(attribute_names)
register_attribute = unitofwork.register_attribute
-# this dictionary maps the hash key of a Session to the Session itself, and
-# acts as a Registry with which to locate Sessions. this is to enable
-# object instances to be associated with Sessions without having to attach the
-# actual Session object directly to the object instance.
_sessions = weakref.WeakValueDictionary()
-def object_session(obj):
- """Return the ``Session`` to which the given object is bound, or ``None`` if none."""
+def _cascade_iterator(cascade, instance, **kwargs):
+ mapper = _object_mapper(instance)
+ for (o, m) in mapper.cascade_iterator(cascade, instance, **kwargs):
+ yield o, m
+
+def object_session(instance):
+ """Return the ``Session`` to which the given instance is bound, or ``None`` if none."""
- hashkey = getattr(obj, '_sa_session_id', None)
+ hashkey = getattr(instance, '_sa_session_id', None)
if hashkey is not None:
sess = _sessions.get(hashkey)
- if obj in sess:
+ if instance in sess:
return sess
return None
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 096a42bb7..3c647ac60 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -109,7 +109,7 @@ class ColumnLoader(LoaderStrategy):
def create_statement(instance):
params = {}
for c in param_names:
- params[c.name] = mapper.get_attr_by_column(instance, c)
+ params[c.name] = mapper._get_attr_by_column(instance, c)
return (statement, params)
def new_execute(instance, row, isnew, **flags):
@@ -301,7 +301,7 @@ class LazyLoader(AbstractRelationLoader):
def visit_bindparam(s, bindparam):
mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
if bindparam.key in bind_to_col:
- bindparam.value = mapper.get_attr_by_column(instance, bind_to_col[bindparam.key])
+ bindparam.value = mapper._get_attr_by_column(instance, bind_to_col[bindparam.key])
return Visitor().traverse(criterion, clone=True)
def setup_loader(self, instance, options=None, path=None):
@@ -338,7 +338,7 @@ class LazyLoader(AbstractRelationLoader):
if self.use_get:
params = {}
for col, bind in self.lazybinds.iteritems():
- params[bind.key] = self.parent.get_attr_by_column(instance, col)
+ params[bind.key] = self.parent._get_attr_by_column(instance, col)
ident = []
nonnulls = False
for primary_key in self.select_mapper.primary_key:
diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py
index 9575aa958..8132c7e4a 100644
--- a/lib/sqlalchemy/orm/sync.py
+++ b/lib/sqlalchemy/orm/sync.py
@@ -115,10 +115,12 @@ class SyncRule(object):
#print "SyncRule", source_mapper, source_column, dest_column, dest_mapper
def dest_primary_key(self):
+ # late-evaluating boolean since some syncs are created
+ # before the mapper has assembled pks
try:
return self._dest_primary_key
except AttributeError:
- self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper.pks_by_table[self.dest_column.table] and not self.dest_mapper.allow_null_pks
+ self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper._pks_by_table[self.dest_column.table] and not self.dest_mapper.allow_null_pks
return self._dest_primary_key
def execute(self, source, dest, obj, child, clearkeys):
@@ -131,7 +133,7 @@ class SyncRule(object):
value = None
clearkeys = True
else:
- value = self.source_mapper.get_attr_by_column(source, self.source_column)
+ value = self.source_mapper._get_attr_by_column(source, self.source_column)
if isinstance(dest, dict):
dest[self.dest_column.key] = value
else:
@@ -140,7 +142,7 @@ class SyncRule(object):
if logging.is_debug_enabled(self.logger):
self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.instance_str(source), str(self.source_column), mapperutil.instance_str(dest), str(self.dest_column), value))
- self.dest_mapper.set_attr_by_column(dest, self.dest_column, value)
+ self.dest_mapper._set_attr_by_column(dest, self.dest_column, value)
SyncRule.logger = logging.class_logger(SyncRule)
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index d3e89d57e..1cf0cb1b0 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -18,8 +18,9 @@ def sort_tables(tables, reverse=False):
vis.traverse(table)
sequence = topological.QueueDependencySorter( tuples, tables).sort(create_tree=False)
if reverse:
- sequence.reverse()
- return sequence
+ return util.reversed(sequence)
+ else:
+ return sequence
def find_tables(clause, check_columns=False, include_aliases=False):
tables = []
diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py
index 0301d01c9..5affaa238 100644
--- a/test/orm/inheritance/basic.py
+++ b/test/orm/inheritance/basic.py
@@ -489,7 +489,7 @@ class DistinctPKTest(ORMTest):
self._do_test(True)
assert False
except RuntimeWarning, e:
- assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name."
+ assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name.", str(e)
def test_explicit_pk(self):
person_mapper = mapper(Person, person_table)
diff --git a/test/orm/manytomany.py b/test/orm/manytomany.py
index 5608ae67d..a94e9bbc4 100644
--- a/test/orm/manytomany.py
+++ b/test/orm/manytomany.py
@@ -79,7 +79,11 @@ class M2MTest(ORMTest):
compile_mappers()
assert False
except exceptions.ArgumentError, e:
- assert str(e) == "Error creating backref 'transitions' on relation 'Transition.places (Place)': property of that name exists on mapper 'Mapper|Place|place'"
+ assert str(e) in [
+ "Error creating backref 'transitions' on relation 'Transition.places (Place)': property of that name exists on mapper 'Mapper|Place|place'",
+ "Error creating backref 'places' on relation 'Place.transitions (Transition)': property of that name exists on mapper 'Mapper|Transition|transition'"
+ ]
+
def testcircular(self):
"""tests a many-to-many relationship from a table to itself."""
diff --git a/test/orm/mapper.py b/test/orm/mapper.py
index 9cd07b8fe..f0d553630 100644
--- a/test/orm/mapper.py
+++ b/test/orm/mapper.py
@@ -269,8 +269,8 @@ class MapperTest(MapperSuperTest):
class A(object):pass
m = mapper(A, account_ids_table.join(account_stuff_table))
m.compile()
- assert m._has_pks(account_ids_table)
- assert not m._has_pks(account_stuff_table)
+ assert account_ids_table in m._pks_by_table
+ assert account_stuff_table not in m._pks_by_table
metadata.create_all(testbase.db)
try:
sess = create_session(bind=testbase.db)