diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/access/base.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/maxdb/base.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 39 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 20 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 34 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/sqlsoup.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/dependency.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 17 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 110 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/unitofwork.py | 45 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/types.py | 164 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/__init__.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/_collections.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/langhelpers.py | 44 |
19 files changed, 259 insertions, 296 deletions
diff --git a/lib/sqlalchemy/dialects/access/base.py b/lib/sqlalchemy/dialects/access/base.py index 75ea91287..cf35b3e0a 100644 --- a/lib/sqlalchemy/dialects/access/base.py +++ b/lib/sqlalchemy/dialects/access/base.py @@ -50,15 +50,10 @@ class AcSmallInteger(types.SmallInteger): return "SMALLINT" class AcDateTime(types.DateTime): - def __init__(self, *a, **kw): - super(AcDateTime, self).__init__(False) - def get_col_spec(self): return "DATETIME" class AcDate(types.Date): - def __init__(self, *a, **kw): - super(AcDate, self).__init__(False) def get_col_spec(self): return "DATETIME" diff --git a/lib/sqlalchemy/dialects/maxdb/base.py b/lib/sqlalchemy/dialects/maxdb/base.py index 9a1e10f51..3d45bb670 100644 --- a/lib/sqlalchemy/dialects/maxdb/base.py +++ b/lib/sqlalchemy/dialects/maxdb/base.py @@ -116,15 +116,13 @@ class _StringType(sqltypes.String): class MaxString(_StringType): _type = 'VARCHAR' - def __init__(self, *a, **kw): - super(MaxString, self).__init__(*a, **kw) - class MaxUnicode(_StringType): _type = 'VARCHAR' def __init__(self, length=None, **kw): - super(MaxUnicode, self).__init__(length=length, encoding='unicode') + kw['encoding'] = 'unicode' + super(MaxUnicode, self).__init__(length=length, **kw) class MaxChar(_StringType): @@ -134,8 +132,8 @@ class MaxChar(_StringType): class MaxText(_StringType): _type = 'LONG' - def __init__(self, *a, **kw): - super(MaxText, self).__init__(*a, **kw) + def __init__(self, length=None, **kw): + super(MaxText, self).__init__(length, **kw) def get_col_spec(self): spec = 'LONG' diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index b1fb46041..028322677 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -175,8 +175,9 @@ class REAL(sqltypes.Float): __visit_name__ = 'REAL' - def __init__(self): - super(REAL, self).__init__(precision=24) + def __init__(self, **kw): + kw.setdefault('precision', 24) + super(REAL, self).__init__(**kw) class TINYINT(sqltypes.Integer): __visit_name__ = 'TINYINT' @@ -258,7 +259,8 @@ class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): class DATETIME2(_DateTimeBase, sqltypes.DateTime): __visit_name__ = 'DATETIME2' - def __init__(self, precision=None, **kwargs): + def __init__(self, precision=None, **kw): + super(DATETIME2, self).__init__(**kw) self.precision = precision @@ -278,16 +280,15 @@ class _StringType(object): class TEXT(_StringType, sqltypes.TEXT): """MSSQL TEXT type, for variable-length text up to 2^31 characters.""" - def __init__(self, *args, **kw): + def __init__(self, length=None, collation=None, **kw): """Construct a TEXT. :param collation: Optional, a column-level collation for this string value. Accepts a Windows Collation Name or a SQL Collation Name. """ - collation = kw.pop('collation', None) _StringType.__init__(self, collation) - sqltypes.Text.__init__(self, *args, **kw) + sqltypes.Text.__init__(self, length, **kw) class NTEXT(_StringType, sqltypes.UnicodeText): """MSSQL NTEXT type, for variable-length unicode text up to 2^30 @@ -295,24 +296,22 @@ class NTEXT(_StringType, sqltypes.UnicodeText): __visit_name__ = 'NTEXT' - def __init__(self, *args, **kwargs): + def __init__(self, length=None, collation=None, **kw): """Construct a NTEXT. :param collation: Optional, a column-level collation for this string value. Accepts a Windows Collation Name or a SQL Collation Name. """ - collation = kwargs.pop('collation', None) _StringType.__init__(self, collation) - length = kwargs.pop('length', None) - sqltypes.UnicodeText.__init__(self, length, **kwargs) + sqltypes.UnicodeText.__init__(self, length, **kw) class VARCHAR(_StringType, sqltypes.VARCHAR): """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum of 8,000 characters.""" - def __init__(self, *args, **kw): + def __init__(self, length=None, collation=None, **kw): """Construct a VARCHAR. :param length: Optinal, maximum data length, in characters. @@ -333,16 +332,15 @@ class VARCHAR(_StringType, sqltypes.VARCHAR): value. Accepts a Windows Collation Name or a SQL Collation Name. """ - collation = kw.pop('collation', None) _StringType.__init__(self, collation) - sqltypes.VARCHAR.__init__(self, *args, **kw) + sqltypes.VARCHAR.__init__(self, length, **kw) class NVARCHAR(_StringType, sqltypes.NVARCHAR): """MSSQL NVARCHAR type. For variable-length unicode character data up to 4,000 characters.""" - def __init__(self, *args, **kw): + def __init__(self, length=None, collation=None, **kw): """Construct a NVARCHAR. :param length: Optional, Maximum data length, in characters. @@ -351,15 +349,14 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR): value. Accepts a Windows Collation Name or a SQL Collation Name. """ - collation = kw.pop('collation', None) _StringType.__init__(self, collation) - sqltypes.NVARCHAR.__init__(self, *args, **kw) + sqltypes.NVARCHAR.__init__(self, length, **kw) class CHAR(_StringType, sqltypes.CHAR): """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum of 8,000 characters.""" - def __init__(self, *args, **kw): + def __init__(self, length=None, collation=None, **kw): """Construct a CHAR. :param length: Optinal, maximum data length, in characters. @@ -380,16 +377,15 @@ class CHAR(_StringType, sqltypes.CHAR): value. Accepts a Windows Collation Name or a SQL Collation Name. """ - collation = kw.pop('collation', None) _StringType.__init__(self, collation) - sqltypes.CHAR.__init__(self, *args, **kw) + sqltypes.CHAR.__init__(self, length, **kw) class NCHAR(_StringType, sqltypes.NCHAR): """MSSQL NCHAR type. For fixed-length unicode character data up to 4,000 characters.""" - def __init__(self, *args, **kw): + def __init__(self, length=None, collation=None, **kw): """Construct an NCHAR. :param length: Optional, Maximum data length, in characters. @@ -398,9 +394,8 @@ class NCHAR(_StringType, sqltypes.NCHAR): value. Accepts a Windows Collation Name or a SQL Collation Name. """ - collation = kw.pop('collation', None) _StringType.__init__(self, collation) - sqltypes.NCHAR.__init__(self, *args, **kw) + sqltypes.NCHAR.__init__(self, length, **kw) class IMAGE(sqltypes.LargeBinary): __visit_name__ = 'IMAGE' diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 42072699e..528e94965 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -233,11 +233,11 @@ SET_RE = re.compile( class _NumericType(object): """Base for MySQL numeric types.""" - def __init__(self, **kw): - self.unsigned = kw.pop('unsigned', False) - self.zerofill = kw.pop('zerofill', False) + def __init__(self, unsigned=False, zerofill=False, **kw): + self.unsigned = unsigned + self.zerofill = zerofill super(_NumericType, self).__init__(**kw) - + class _FloatType(_NumericType, sqltypes.Float): def __init__(self, precision=None, scale=None, asdecimal=True, **kw): if isinstance(self, (REAL, DOUBLE)) and \ @@ -276,7 +276,7 @@ class _StringType(sqltypes.String): self.binary = binary self.national = national super(_StringType, self).__init__(**kw) - + def __repr__(self): attributes = inspect.getargspec(self.__init__)[0][1:] attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) @@ -749,7 +749,7 @@ class CHAR(_StringType, sqltypes.CHAR): __visit_name__ = 'CHAR' - def __init__(self, length, **kwargs): + def __init__(self, length=None, **kwargs): """Construct a CHAR. :param length: Maximum data length, in characters. @@ -943,6 +943,10 @@ class ENUM(sqltypes.Enum, _StringType): else: return value return process + + def adapt(self, impltype, **kw): + kw['strict'] = self.strict + return sqltypes.Enum.adapt(self, impltype, **kw) class SET(_StringType): """MySQL SET type.""" @@ -990,8 +994,8 @@ class SET(_StringType): strip_values.append(a) self.values = strip_values - length = max([len(v) for v in strip_values] + [0]) - super(SET, self).__init__(length=length, **kw) + kw.setdefault('length', max([len(v) for v in strip_values] + [0])) + super(SET, self).__init__(**kw) def result_processor(self, dialect, coltype): def process(value): diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 256972696..3d97b504e 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -215,10 +215,6 @@ class INTERVAL(sqltypes.TypeEngine): return INTERVAL(day_precision=interval.day_precision, second_precision=interval.second_precision) - def adapt(self, impltype): - return impltype(day_precision=self.day_precision, - second_precision=self.second_precision) - @property def _type_affinity(self): return sqltypes.Interval diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 821ec5cfb..72b58a71c 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -132,12 +132,13 @@ class TIMESTAMP(sqltypes.TIMESTAMP): def __init__(self, timezone=False, precision=None): super(TIMESTAMP, self).__init__(timezone=timezone) self.precision = precision + class TIME(sqltypes.TIME): def __init__(self, timezone=False, precision=None): super(TIME, self).__init__(timezone=timezone) self.precision = precision - + class INTERVAL(sqltypes.TypeEngine): """Postgresql INTERVAL type. @@ -149,9 +150,6 @@ class INTERVAL(sqltypes.TypeEngine): def __init__(self, precision=None): self.precision = precision - def adapt(self, impltype): - return impltype(self.precision) - @classmethod def _adapt_from_generic_interval(cls, interval): return INTERVAL(precision=interval.second_precision) @@ -164,6 +162,9 @@ PGInterval = INTERVAL class BIT(sqltypes.TypeEngine): __visit_name__ = 'BIT' + def __init__(self, length=1): + self.length= length + PGBit = BIT class UUID(sqltypes.TypeEngine): @@ -213,7 +214,7 @@ class UUID(sqltypes.TypeEngine): return process else: return None - + PGUuid = UUID class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): @@ -285,23 +286,8 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): def is_mutable(self): return self.mutable - def dialect_impl(self, dialect, **kwargs): - impl = super(ARRAY, self).dialect_impl(dialect, **kwargs) - if impl is self: - impl = self.__class__.__new__(self.__class__) - impl.__dict__.update(self.__dict__) - impl.item_type = self.item_type.dialect_impl(dialect) - return impl - - def adapt(self, impltype): - return impltype( - self.item_type, - mutable=self.mutable, - as_tuple=self.as_tuple - ) - def bind_processor(self, dialect): - item_proc = self.item_type.bind_processor(dialect) + item_proc = self.item_type.dialect_impl(dialect).bind_processor(dialect) if item_proc: def convert_item(item): if isinstance(item, (list, tuple)): @@ -321,7 +307,7 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): return process def result_processor(self, dialect, coltype): - item_proc = self.item_type.result_processor(dialect, coltype) + item_proc = self.item_type.dialect_impl(dialect).result_processor(dialect, coltype) if item_proc: def convert_item(item): if isinstance(item, list): @@ -640,7 +626,7 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): return "INTERVAL" def visit_BIT(self, type_): - return "BIT" + return "BIT(%d)" % type_.length def visit_UUID(self, type_): return "UUID" @@ -1095,7 +1081,7 @@ class PGDialect(default.DefaultDialect): elif attype == 'double precision': args = (53, ) elif attype == 'integer': - args = (32, 0) + args = () elif attype in ('timestamp with time zone', 'time with time zone'): kwargs['timezone'] = True diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 707bc1630..1ab5173fe 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -72,7 +72,8 @@ class _DateTimeMixin(object): _reg = None _storage_format = None - def __init__(self, storage_format=None, regexp=None, **kwargs): + def __init__(self, storage_format=None, regexp=None, **kw): + super(_DateTimeMixin, self).__init__(**kw) if regexp is not None: self._reg = re.compile(regexp) if storage_format is not None: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 8647ba385..38ff76b89 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -16,6 +16,7 @@ import re, random from sqlalchemy.engine import base, reflection from sqlalchemy.sql import compiler, expression from sqlalchemy import exc, types as sqltypes, util, pool +import weakref AUTOCOMMIT_REGEXP = re.compile( r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', @@ -133,13 +134,17 @@ class DefaultDialect(base.Dialect): " maximum identifier length of %d" % (label_length, self.max_identifier_length)) self.label_length = label_length - + if not hasattr(self, 'description_encoding'): self.description_encoding = getattr( self, 'description_encoding', encoding) + @util.memoized_property + def _type_memos(self): + return weakref.WeakKeyDictionary() + @property def dialect_description(self): return self.name + "+" + self.driver @@ -398,7 +403,7 @@ class DefaultExecutionContext(base.ExecutionContext): self.postfetch_cols = self.compiled.postfetch self.prefetch_cols = self.compiled.prefetch - processors = compiled._get_bind_processors(dialect) + processors = compiled._bind_processors # Convert the dictionary of bind parameter values # into a dict or list to be sent to the DBAPI's diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index 9ff503dfa..6981919cf 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -359,7 +359,7 @@ from sqlalchemy import schema, sql, util from sqlalchemy.engine.base import Engine from sqlalchemy.orm import scoped_session, sessionmaker, mapper, \ class_mapper, relationship, session,\ - object_session + object_session, attributes from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE from sqlalchemy.exceptions import SQLAlchemyError, InvalidRequestError, ArgumentError from sqlalchemy.sql import expression @@ -384,7 +384,8 @@ class AutoAdd(MapperExtension): def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): session = self.scoped_session() - session._save_without_cascade(instance) + state = attributes.instance_state(instance) + session._save_impl(state) return EXT_CONTINUE def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index 39ea1db35..35b197f63 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -152,7 +152,7 @@ class DependencyProcessor(object): # detect if there's anything changed or loaded # by a preprocessor on this state/attribute. if not, # we should be able to skip it entirely. - sum_ = attributes.get_all_pending(state, state.dict, self.key) + sum_ = state.manager[self.key].impl.get_all_pending(state, state.dict) if not sum_: continue @@ -439,10 +439,10 @@ class OneToManyDP(DependencyProcessor): elif self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True, operation="delete", prop=self.prop) - for c, m in self.mapper.cascade_iterator( + for c, m, st_, dct_ in self.mapper.cascade_iterator( 'delete', child): uowcommit.register_object( - attributes.instance_state(c), + st_, isdelete=True) if pks_changed: @@ -661,10 +661,10 @@ class ManyToOneDP(DependencyProcessor): continue uowcommit.register_object(child, isdelete=True, operation="delete", prop=self.prop) - for c, m in self.mapper.cascade_iterator( + for c, m, st_, dct_ in self.mapper.cascade_iterator( 'delete', child): uowcommit.register_object( - attributes.instance_state(c), isdelete=True) + st_, isdelete=True) def presort_saves(self, uowcommit, states): for state in states: @@ -681,10 +681,10 @@ class ManyToOneDP(DependencyProcessor): uowcommit.register_object(child, isdelete=True, operation="delete", prop=self.prop) - for c, m in self.mapper.cascade_iterator( + for c, m, st_, dct_ in self.mapper.cascade_iterator( 'delete', child): uowcommit.register_object( - attributes.instance_state(c), + st_, isdelete=True) def process_deletes(self, uowcommit, states): @@ -939,11 +939,11 @@ class ManyToManyDP(DependencyProcessor): if self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True, operation="delete", prop=self.prop) - for c, m in self.mapper.cascade_iterator( + for c, m, st_, dct_ in self.mapper.cascade_iterator( 'delete', child): uowcommit.register_object( - attributes.instance_state(c), isdelete=True) + st_, isdelete=True) def process_deletes(self, uowcommit, states): secondary_delete = [] diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 48c37f80d..8bd8bf3c8 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1382,10 +1382,10 @@ class Mapper(object): reference so that they don't fall out of scope immediately. """ - visited_instances = util.IdentitySet() + visited_states = set() prp, mpp = object(), object() - visitables = [(deque(self._props.values()), prp, state, state.dict)] + visitables = deque([(deque(self._props.values()), prp, state, state.dict)]) while visitables: iterator, item_type, parent_state, parent_dict = visitables[-1] @@ -1398,13 +1398,13 @@ class Mapper(object): if type_ not in prop.cascade: continue queue = deque(prop.cascade_iterator(type_, parent_state, - parent_dict, visited_instances, halt_on)) + parent_dict, visited_states, halt_on)) if queue: visitables.append((queue,mpp, None, None)) elif item_type is mpp: instance, instance_mapper, corresponding_state, \ corresponding_dict = iterator.popleft() - yield (instance, instance_mapper) + yield instance, instance_mapper, corresponding_state, corresponding_dict visitables.append((deque(instance_mapper._props.values()), prp, corresponding_state, corresponding_dict)) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 239159f3e..da6d309e0 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -820,9 +820,8 @@ class RelationshipProperty(StrategizedProperty): dest_state.get_impl(self.key).set(dest_state, dest_dict, obj, None) - def cascade_iterator(self, type_, state, dict_, visited_instances, halt_on=None): - if not type_ in self.cascade: - return + def cascade_iterator(self, type_, state, dict_, visited_states, halt_on=None): + #assert type_ in self.cascade # only actively lazy load on the 'delete' cascade if type_ != 'delete' or self.passive_deletes: @@ -831,7 +830,7 @@ class RelationshipProperty(StrategizedProperty): passive = attributes.PASSIVE_OFF if type_ == 'save-update': - instances = attributes.get_all_pending(state, dict_, self.key) + instances = state.manager[self.key].impl.get_all_pending(state, dict_) else: instances = state.value_as_iterable(dict_, self.key, @@ -842,10 +841,12 @@ class RelationshipProperty(StrategizedProperty): if instances: for c in instances: if c is not None and \ - c is not attributes.PASSIVE_NO_RESULT and \ - c not in visited_instances: - + c is not attributes.PASSIVE_NO_RESULT: + instance_state = attributes.instance_state(c) + if instance_state in visited_states: + continue + instance_dict = attributes.instance_dict(c) if halt_on and halt_on(instance_state): @@ -865,7 +866,7 @@ class RelationshipProperty(StrategizedProperty): c.__class__ )) - visited_instances.add(c) + visited_states.add(instance_state) # cascade using the mapper local to this # object, so that its individual properties are located diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index e2c1308b8..1f704f502 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -981,16 +981,18 @@ class Session(object): else: # pre-fetch the full cascade since the expire is going to # remove associations - cascaded = list(_cascade_state_iterator('refresh-expire', state)) + cascaded = list(state.manager.mapper.cascade_iterator( + 'refresh-expire', state)) self._conditional_expire(state) - for (state, m, o) in cascaded: - self._conditional_expire(state) + for o, m, st_, dct_ in cascaded: + self._conditional_expire(st_) def _conditional_expire(self, state): """Expire a state if persistent, else expunge if pending""" if state.key: - _expire_state(state, state.dict, None, instance_dict=self.identity_map) + _expire_state(state, state.dict, None, + instance_dict=self.identity_map) elif state in self._new: self._new.pop(state) state.detach() @@ -1023,8 +1025,12 @@ class Session(object): raise sa_exc.InvalidRequestError( "Instance %s is not present in this Session" % mapperutil.state_str(state)) - for s, m, o in [(state, None, None)] + list(_cascade_state_iterator('expunge', state)): - self._expunge_state(s) + + cascaded = list(state.manager.mapper.cascade_iterator( + 'expunge', state)) + self._expunge_state(state) + for o, m, st_, dct_ in cascaded: + self._expunge_state(st_) def _expunge_state(self, state): if state in self._new: @@ -1078,12 +1084,6 @@ class Session(object): self._deleted.pop(state, None) state.deleted = True - def _save_without_cascade(self, instance): - """Used by scoping.py to save on init without cascade.""" - - state = _state_for_unsaved_instance(instance, create=True) - self._save_impl(state) - def add(self, instance): """Place an object in the ``Session``. @@ -1094,7 +1094,11 @@ class Session(object): is ``expunge()``. """ - state = _state_for_unknown_persistence_instance(instance) + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + self._save_or_update_state(state) def add_all(self, instances): @@ -1105,13 +1109,13 @@ class Session(object): def _save_or_update_state(self, state): self._save_or_update_impl(state) - self._cascade_save_or_update(state) - def _cascade_save_or_update(self, state): - for state, mapper in _cascade_unknown_state_iterator( - 'save-update', state, + mapper = _state_mapper(state) + for o, m, st_, dct_ in mapper.cascade_iterator( + 'save-update', + state, halt_on=self._contains_state): - self._save_or_update_impl(state) + self._save_or_update_impl(st_) def delete(self, instance): """Mark an instance as deleted. @@ -1139,16 +1143,19 @@ class Session(object): # grab the cascades before adding the item to the deleted list # so that autoflush does not delete the item - cascade_states = list(_cascade_state_iterator('delete', state)) + # the strong reference to the instance itself is significant here + cascade_states = list(state.manager.mapper.cascade_iterator( + 'delete', state)) self._deleted[state] = state.obj() self.identity_map.add(state) - for state, m, o in cascade_states: - self._delete_impl(state) + for o, m, st_, dct_ in cascade_states: + self._delete_impl(st_) def merge(self, instance, load=True, **kw): - """Copy the state an instance onto the persistent instance with the same identifier. + """Copy the state an instance onto the persistent instance with the + same identifier. If there is no persistent instance currently associated with the session, it will be loaded. Return the persistent instance. If the @@ -1164,7 +1171,8 @@ class Session(object): """ if 'dont_load' in kw: load = not kw['dont_load'] - util.warn_deprecated("dont_load=True has been renamed to load=False.") + util.warn_deprecated('dont_load=True has been renamed to ' + 'load=False.') _recursive = {} @@ -1241,7 +1249,9 @@ class Session(object): merged_state.load_options = state.load_options for prop in mapper.iterate_properties: - prop.merge(self, state, state_dict, merged_state, merged_dict, load, _recursive) + prop.merge(self, state, state_dict, + merged_state, merged_dict, + load, _recursive) if not load: # remove any history @@ -1319,10 +1329,10 @@ class Session(object): if state.key and \ state.key in self.identity_map and \ not self.identity_map.contains_state(state): - raise sa_exc.InvalidRequestError( - "Can't attach instance %s; another instance with key %s is already present in this session." % - (mapperutil.state_str(state), state.key) - ) + raise sa_exc.InvalidRequestError("Can't attach instance " + "%s; another instance with key %s is already " + "present in this session." + % (mapperutil.state_str(state), state.key)) if state.session_id and state.session_id is not self.hash_key: raise sa_exc.InvalidRequestError( @@ -1475,7 +1485,8 @@ class Session(object): #if not objects: # assert not self.identity_map._modified #else: - # assert self.identity_map._modified == self.identity_map._modified.difference(objects) + # assert self.identity_map._modified == \ + # self.identity_map._modified.difference(objects) #self.identity_map._modified.clear() self.dispatch.on_after_flush_postexec(self, flush_context) @@ -1598,47 +1609,6 @@ UOWEventHandler = unitofwork.UOWEventHandler _sessions = weakref.WeakValueDictionary() -def _cascade_state_iterator(cascade, state, **kwargs): - mapper = _state_mapper(state) - # yield the state, object, mapper. yielding the object - # allows the iterator's results to be held in a list without - # states being garbage collected - for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs): - yield attributes.instance_state(o), o, m - -def _cascade_unknown_state_iterator(cascade, state, **kwargs): - mapper = _state_mapper(state) - for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs): - yield _state_for_unknown_persistence_instance(o), m - -def _state_for_unsaved_instance(instance, create=False): - try: - state = attributes.instance_state(instance) - except AttributeError: - raise exc.UnmappedInstanceError(instance) - if state: - if state.key is not None: - raise sa_exc.InvalidRequestError( - "Instance '%s' is already persistent" % - mapperutil.state_str(state)) - elif create: - manager = attributes.manager_of_class(instance.__class__) - if manager is None: - raise exc.UnmappedInstanceError(instance) - state = manager.setup_instance(instance) - else: - raise exc.UnmappedInstanceError(instance) - - return state - -def _state_for_unknown_persistence_instance(instance): - try: - state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) - - return state - def make_transient(instance): """Make the given instance 'transient'. diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 1e1eda4a3..ba43b1359 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -16,7 +16,6 @@ from sqlalchemy import util from sqlalchemy.util import topological from sqlalchemy.orm import attributes, interfaces from sqlalchemy.orm import util as mapperutil -from sqlalchemy.orm.util import _state_mapper session = util.importlater("sqlalchemy.orm", "session") class UOWEventHandler(interfaces.AttributeExtension): @@ -28,28 +27,32 @@ class UOWEventHandler(interfaces.AttributeExtension): def __init__(self, key): self.key = key - + + # TODO: migrate these to unwrapped events + def append(self, state, item, initiator): # process "save_update" cascade rules for when # an instance is appended to the list of another instance sess = session._state_session(state) if sess: - prop = _state_mapper(state)._props[self.key] + prop = state.manager.mapper._props[self.key] + item_state = attributes.instance_state(item) if prop.cascade.save_update and \ (prop.cascade_backrefs or self.key == initiator.key) and \ - item not in sess: - sess.add(item) + not sess._contains_state(item_state): + sess._save_or_update_state(item_state) return item def remove(self, state, item, initiator): sess = session._state_session(state) if sess: - prop = _state_mapper(state)._props[self.key] + prop = state.manager.mapper._props[self.key] # expunge pending orphans + item_state = attributes.instance_state(item) if prop.cascade.delete_orphan and \ - item in sess.new and \ - prop.mapper._is_orphan(attributes.instance_state(item)): + item_state in sess._new and \ + prop.mapper._is_orphan(item_state): sess.expunge(item) def set(self, state, newvalue, oldvalue, initiator): @@ -60,16 +63,20 @@ class UOWEventHandler(interfaces.AttributeExtension): sess = session._state_session(state) if sess: - prop = _state_mapper(state)._props[self.key] - if newvalue is not None and \ - prop.cascade.save_update and \ - (prop.cascade_backrefs or self.key == initiator.key) and \ - newvalue not in sess: - sess.add(newvalue) - if prop.cascade.delete_orphan and \ - oldvalue in sess.new and \ - prop.mapper._is_orphan(attributes.instance_state(oldvalue)): - sess.expunge(oldvalue) + prop = state.manager.mapper._props[self.key] + if newvalue is not None: + newvalue_state = attributes.instance_state(newvalue) + if prop.cascade.save_update and \ + (prop.cascade_backrefs or self.key == initiator.key) and \ + not sess._contains_state(newvalue_state): + sess._save_or_update_state(newvalue_state) + + if oldvalue is not None and prop.cascade.delete_orphan: + oldvalue_state = attributes.instance_state(oldvalue) + + if oldvalue_state in sess._new and \ + prop.mapper._is_orphan(oldvalue_state): + sess.expunge(oldvalue) return newvalue @@ -196,7 +203,7 @@ class UOWTransaction(object): return False if state not in self.states: - mapper = _state_mapper(state) + mapper = state.manager.mapper if mapper not in self.mappers: mapper._per_mapper_flush_actions(self) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8bd728a7c..cf1e28f50 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -253,24 +253,16 @@ class SQLCompiler(engine.Compiled): # or dialect.max_identifier_length self.truncated_names = {} - # other memoized things - self._memos ={} - def _get_bind_processors(self, dialect): - key = 'bind_processors', dialect.__class__, \ - dialect.server_version_info - - if key not in self._memos: - self._memos[key] = processors = dict( + @util.memoized_property + def _bind_processors(self): + return dict( (key, value) for key, value in ( (self.bind_names[bindparam], - bindparam.type._cached_bind_processor(dialect)) + bindparam.type._cached_bind_processor(self.dialect)) for bindparam in self.bind_names ) if value is not None ) - return processors - else: - return self._memos[key] def is_subquery(self): return len(self.stack) > 1 diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index e31486eff..f5df02367 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -131,57 +131,58 @@ class TypeEngine(AbstractType): else: return self.__class__ - @util.memoized_property - def _impl_dict(self): - return {} - - def __getstate__(self): - d = self.__dict__.copy() - d.pop('_impl_dict', None) - return d - - def dialect_impl(self, dialect, **kwargs): - key = dialect.__class__, dialect.server_version_info + def dialect_impl(self, dialect): + """Return a dialect-specific implementation for this type.""" + try: - return self._impl_dict[key] + return dialect._type_memos[self]['impl'] except KeyError: - return self._impl_dict.setdefault(key, - dialect.type_descriptor(self)) + return self._dialect_info(dialect)['impl'] def _cached_bind_processor(self, dialect): - return self.dialect_impl(dialect).bind_processor(dialect) - - # TODO: can't do this until we find a way to link with the - # specific attributes of the dialect, i.e. convert_unicode, - # etc. might need to do a weakmap again. needs tests - # to ensure two dialects with different flags. use a mock - # dialect. - #key = "bind", dialect.__class__, dialect.server_version_info - #try: - # return self._impl_dict[key] - #except KeyError: - # self._impl_dict[key] = bp = \ - # self.dialect_impl(dialect).bind_processor(dialect) - # return bp + """Return a dialect-specific bind processor for this type.""" + try: + return dialect._type_memos[self]['bind'] + except KeyError: + d = self._dialect_info(dialect) + d['bind'] = bp = d['impl'].bind_processor(dialect) + return bp + def _cached_result_processor(self, dialect, coltype): - return self.dialect_impl(dialect).result_processor(dialect, coltype) + """Return a dialect-specific result processor for this type.""" + + try: + return dialect._type_memos[self][coltype] + except KeyError: + d = self._dialect_info(dialect) + # key assumption: DBAPI type codes are + # constants. Else this dictionary would + # grow unbounded. + d[coltype] = rp = d['impl'].result_processor(dialect, coltype) + return rp + + def _dialect_info(self, dialect): + """Return a dialect-specific registry which + caches a dialect-specific implementation, bind processing + function, and one or more result processing functions.""" + + if self in dialect._type_memos: + return dialect._type_memos[self] + else: + impl = self._gen_dialect_impl(dialect) + if impl is self: + impl = self.adapt(type(self)) + # this can't be self, else we create a cycle + assert impl is not self + dialect._type_memos[self] = d = {'impl':impl} + return d + + def _gen_dialect_impl(self, dialect): + return dialect.type_descriptor(self) - # TODO: can't do this until we find a way to link with the - # specific attributes of the dialect, i.e. convert_unicode, - # etc. might need to do a weakmap again. needs tests - # to ensure two dialects with different flags. use a mock - # dialect. - #key = "result", dialect.__class__, dialect.server_version_info, coltype - #try: - # return self._impl_dict[key] - #except KeyError: - # self._impl_dict[key] = rp = self.dialect_impl(dialect).\ - # result_processor(dialect, coltype) - # return rp - - def adapt(self, cls): - return cls() + def adapt(self, cls, **kw): + return util.constructor_copy(self, cls, **kw) def _coerce_compared_value(self, op, value): _coerced_type = _type_map.get(type(value), NULLTYPE) @@ -220,7 +221,7 @@ class TypeEngine(AbstractType): encode('ascii', 'backslashreplace') # end Py2K - def __init__(self, *args, **kwargs): + def __init__(self): # supports getargspec of the __init__ method # used by generic __repr__ pass @@ -376,17 +377,10 @@ class TypeDecorator(TypeEngine): "type being decorated") self.impl = to_instance(self.__class__.impl, *args, **kwargs) - def dialect_impl(self, dialect): - key = (dialect.__class__, dialect.server_version_info) - - try: - return self._impl_dict[key] - except KeyError: - pass - + + def _gen_dialect_impl(self, dialect): adapted = dialect.type_descriptor(self) if adapted is not self: - self._impl_dict[key] = adapted return adapted # otherwise adapt the impl type, link @@ -400,7 +394,6 @@ class TypeDecorator(TypeEngine): 'return an object of type %s' % (self, self.__class__)) tt.impl = typedesc - self._impl_dict[key] = tt return tt @util.memoized_property @@ -499,7 +492,6 @@ class TypeDecorator(TypeEngine): def copy(self): instance = self.__class__.__new__(self.__class__) instance.__dict__.update(self.__dict__) - instance._impl_dict = {} return instance def get_dbapi_type(self, dbapi): @@ -650,6 +642,9 @@ def adapt_type(typeobj, colspecs): return typeobj return typeobj.adapt(impltype) + + + class NullType(TypeEngine): """An unknown type. @@ -796,14 +791,6 @@ class String(Concatenable, TypeEngine): self.unicode_error = unicode_error self._warn_on_bytestring = _warn_on_bytestring - def adapt(self, impltype): - return impltype( - length=self.length, - convert_unicode=self.convert_unicode, - unicode_error=self.unicode_error, - _warn_on_bytestring=True, - ) - def bind_processor(self, dialect): if self.convert_unicode or dialect.convert_unicode: if dialect.supports_unicode_binds and \ @@ -1100,12 +1087,6 @@ class Numeric(_DateAffinity, TypeEngine): self.scale = scale self.asdecimal = asdecimal - def adapt(self, impltype): - return impltype( - precision=self.precision, - scale=self.scale, - asdecimal=self.asdecimal) - def get_dbapi_type(self, dbapi): return dbapi.NUMBER @@ -1179,7 +1160,9 @@ class Float(Numeric): """ __visit_name__ = 'float' - + + scale = None + def __init__(self, precision=None, asdecimal=False, **kwargs): """ Construct a Float. @@ -1195,9 +1178,6 @@ class Float(Numeric): self.precision = precision self.asdecimal = asdecimal - def adapt(self, impltype): - return impltype(precision=self.precision, asdecimal=self.asdecimal) - def result_processor(self, dialect, coltype): if self.asdecimal: return processors.to_decimal_processor_factory(decimal.Decimal) @@ -1244,9 +1224,6 @@ class DateTime(_DateAffinity, TypeEngine): def __init__(self, timezone=False): self.timezone = timezone - def adapt(self, impltype): - return impltype(timezone=self.timezone) - def get_dbapi_type(self, dbapi): return dbapi.DATETIME @@ -1304,9 +1281,6 @@ class Time(_DateAffinity,TypeEngine): def __init__(self, timezone=False): self.timezone = timezone - def adapt(self, impltype): - return impltype(timezone=self.timezone) - def get_dbapi_type(self, dbapi): return dbapi.DATETIME @@ -1366,9 +1340,6 @@ class _Binary(TypeEngine): else: return super(_Binary, self)._coerce_compared_value(op, value) - def adapt(self, impltype): - return impltype(length=self.length) - def get_dbapi_type(self, dbapi): return dbapi.BINARY @@ -1453,7 +1424,7 @@ class SchemaType(object): if bind is None: bind = schema._bind_or_error(self) t = self.dialect_impl(bind.dialect) - if t is not self and isinstance(t, SchemaType): + if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t.create(bind=bind, checkfirst=checkfirst) def drop(self, bind=None, checkfirst=False): @@ -1462,27 +1433,27 @@ class SchemaType(object): if bind is None: bind = schema._bind_or_error(self) t = self.dialect_impl(bind.dialect) - if t is not self and isinstance(t, SchemaType): + if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t.drop(bind=bind, checkfirst=checkfirst) def _on_table_create(self, event, target, bind, **kw): t = self.dialect_impl(bind.dialect) - if t is not self and isinstance(t, SchemaType): + if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t._on_table_create(event, target, bind, **kw) def _on_table_drop(self, event, target, bind, **kw): t = self.dialect_impl(bind.dialect) - if t is not self and isinstance(t, SchemaType): + if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t._on_table_drop(event, target, bind, **kw) def _on_metadata_create(self, event, target, bind, **kw): t = self.dialect_impl(bind.dialect) - if t is not self and isinstance(t, SchemaType): + if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t._on_metadata_create(event, target, bind, **kw) def _on_metadata_drop(self, event, target, bind, **kw): t = self.dialect_impl(bind.dialect) - if t is not self and isinstance(t, SchemaType): + if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t._on_metadata_drop(event, target, bind, **kw) class Enum(String, SchemaType): @@ -1579,7 +1550,7 @@ class Enum(String, SchemaType): ) table.append_constraint(e) - def adapt(self, impltype): + def adapt(self, impltype, **kw): if issubclass(impltype, Enum): return impltype(name=self.name, quote=self.quote, @@ -1587,10 +1558,11 @@ class Enum(String, SchemaType): metadata=self.metadata, convert_unicode=self.convert_unicode, native_enum=self.native_enum, - *self.enums + *self.enums, + **kw ) else: - return super(Enum, self).adapt(impltype) + return super(Enum, self).adapt(impltype, **kw) class PickleType(MutableType, TypeDecorator): """Holds Python objects, which are serialized using pickle. @@ -1792,11 +1764,11 @@ class Interval(_DateAffinity, TypeDecorator): self.second_precision = second_precision self.day_precision = day_precision - def adapt(self, cls): - if self.native: - return cls._adapt_from_generic_interval(self) + def adapt(self, cls, **kw): + if self.native and hasattr(cls, '_adapt_from_generic_interval'): + return cls._adapt_from_generic_interval(self, **kw) else: - return self + return cls(**kw) def bind_processor(self, dialect): impl_processor = self.impl.bind_processor(dialect) diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 9119e35b7..ae1eb3ac5 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -24,7 +24,8 @@ from langhelpers import iterate_attributes, class_hierarchy, \ reset_memoized, group_expirable_memoized_property, importlater, \ monkeypatch_proxied_specials, asbool, bool_or_str, coerce_kw_type,\ duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\ - classproperty, set_creation_order, warn_exception, warn, NoneType + classproperty, set_creation_order, warn_exception, warn, NoneType,\ + constructor_copy from deprecations import warn_deprecated, warn_pending_deprecation, \ deprecated, pending_deprecation diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 022da2de8..57ce02a1e 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -875,7 +875,6 @@ class ThreadLocalRegistry(ScopedRegistry): except AttributeError: pass - def _iter_id(iterable): """Generator: ((id(o), o) for o in iterable).""" diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index d85793ee0..2b9e890fc 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -54,6 +54,9 @@ def get_cls_kwargs(cls): pass along unrecognized keywords to it's base classes, and the collection process is repeated recursively on each of the bases. + Uses a subset of inspect.getargspec() to cut down on method overhead. + No anonymous tuple arguments please ! + """ for c in cls.__mro__: @@ -70,15 +73,39 @@ def get_cls_kwargs(cls): if not ctr or not isinstance(ctr, types.FunctionType): stack.update(class_.__bases__) continue - names, _, has_kw, _ = inspect.getargspec(ctr) + + # this is shorthand for + # names, _, has_kw, _ = inspect.getargspec(ctr) + + names, has_kw = inspect_func_args(ctr) args.update(names) if has_kw: stack.update(class_.__bases__) args.discard('self') return args +try: + from inspect import CO_VARKEYWORDS + def inspect_func_args(fn): + co = fn.func_code + nargs = co.co_argcount + names = co.co_varnames + args = list(names[:nargs]) + has_kw = bool(co.co_flags & CO_VARKEYWORDS) + return args, has_kw +except ImportError: + def inspect_func_args(fn): + names, _, has_kw, _ = inspect.getargspec(fn) + return names, bool(has_kw) + def get_func_kwargs(func): - """Return the full set of legal kwargs for the given `func`.""" + """Return the set of legal kwargs for the given `func`. + + Uses getargspec so is safe to call for methods, functions, + etc. + + """ + return inspect.getargspec(func)[0] def format_argspec_plus(fn, grouped=True): @@ -516,6 +543,19 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True): else: kw[key] = type_(kw[key]) + +def constructor_copy(obj, cls, **kw): + """Instantiate cls using the __dict__ of obj as constructor arguments. + + Uses inspect to match the named arguments of ``cls``. + + """ + + names = get_cls_kwargs(cls) + kw.update((k, obj.__dict__[k]) for k in names if k in obj.__dict__) + return cls(**kw) + + def duck_type_collection(specimen, default=None): """Given an instance or class, guess if it is or is acting as one of the basic collection types: list, set and dict. If the __emulates__ |
