summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMichael Trier <mtrier@gmail.com>2010-12-19 19:38:03 -0500
committerMichael Trier <mtrier@gmail.com>2010-12-19 19:38:03 -0500
commit650bbcc8fd404b2122f1f5ab10eadb4fe3837274 (patch)
tree52755cef4e4cf4274681dc385c6d96e4113dadf2 /lib/sqlalchemy
parent15ea17d7f882fec3f892a22612da4827780c8dae (diff)
parent0a46523a92dbf5229575cd75bb1be989024676ec (diff)
downloadsqlalchemy-650bbcc8fd404b2122f1f5ab10eadb4fe3837274.tar.gz
merge tip
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/access/base.py5
-rw-r--r--lib/sqlalchemy/dialects/maxdb/base.py10
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py39
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py20
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py4
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py34
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py3
-rw-r--r--lib/sqlalchemy/engine/default.py9
-rw-r--r--lib/sqlalchemy/ext/sqlsoup.py5
-rw-r--r--lib/sqlalchemy/orm/dependency.py18
-rw-r--r--lib/sqlalchemy/orm/mapper.py8
-rw-r--r--lib/sqlalchemy/orm/properties.py17
-rw-r--r--lib/sqlalchemy/orm/session.py110
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py45
-rw-r--r--lib/sqlalchemy/sql/compiler.py16
-rw-r--r--lib/sqlalchemy/types.py164
-rw-r--r--lib/sqlalchemy/util/__init__.py3
-rw-r--r--lib/sqlalchemy/util/_collections.py1
-rw-r--r--lib/sqlalchemy/util/langhelpers.py44
19 files changed, 259 insertions, 296 deletions
diff --git a/lib/sqlalchemy/dialects/access/base.py b/lib/sqlalchemy/dialects/access/base.py
index 75ea91287..cf35b3e0a 100644
--- a/lib/sqlalchemy/dialects/access/base.py
+++ b/lib/sqlalchemy/dialects/access/base.py
@@ -50,15 +50,10 @@ class AcSmallInteger(types.SmallInteger):
return "SMALLINT"
class AcDateTime(types.DateTime):
- def __init__(self, *a, **kw):
- super(AcDateTime, self).__init__(False)
-
def get_col_spec(self):
return "DATETIME"
class AcDate(types.Date):
- def __init__(self, *a, **kw):
- super(AcDate, self).__init__(False)
def get_col_spec(self):
return "DATETIME"
diff --git a/lib/sqlalchemy/dialects/maxdb/base.py b/lib/sqlalchemy/dialects/maxdb/base.py
index 9a1e10f51..3d45bb670 100644
--- a/lib/sqlalchemy/dialects/maxdb/base.py
+++ b/lib/sqlalchemy/dialects/maxdb/base.py
@@ -116,15 +116,13 @@ class _StringType(sqltypes.String):
class MaxString(_StringType):
_type = 'VARCHAR'
- def __init__(self, *a, **kw):
- super(MaxString, self).__init__(*a, **kw)
-
class MaxUnicode(_StringType):
_type = 'VARCHAR'
def __init__(self, length=None, **kw):
- super(MaxUnicode, self).__init__(length=length, encoding='unicode')
+ kw['encoding'] = 'unicode'
+ super(MaxUnicode, self).__init__(length=length, **kw)
class MaxChar(_StringType):
@@ -134,8 +132,8 @@ class MaxChar(_StringType):
class MaxText(_StringType):
_type = 'LONG'
- def __init__(self, *a, **kw):
- super(MaxText, self).__init__(*a, **kw)
+ def __init__(self, length=None, **kw):
+ super(MaxText, self).__init__(length, **kw)
def get_col_spec(self):
spec = 'LONG'
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index b1fb46041..028322677 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -175,8 +175,9 @@ class REAL(sqltypes.Float):
__visit_name__ = 'REAL'
- def __init__(self):
- super(REAL, self).__init__(precision=24)
+ def __init__(self, **kw):
+ kw.setdefault('precision', 24)
+ super(REAL, self).__init__(**kw)
class TINYINT(sqltypes.Integer):
__visit_name__ = 'TINYINT'
@@ -258,7 +259,8 @@ class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime):
class DATETIME2(_DateTimeBase, sqltypes.DateTime):
__visit_name__ = 'DATETIME2'
- def __init__(self, precision=None, **kwargs):
+ def __init__(self, precision=None, **kw):
+ super(DATETIME2, self).__init__(**kw)
self.precision = precision
@@ -278,16 +280,15 @@ class _StringType(object):
class TEXT(_StringType, sqltypes.TEXT):
"""MSSQL TEXT type, for variable-length text up to 2^31 characters."""
- def __init__(self, *args, **kw):
+ def __init__(self, length=None, collation=None, **kw):
"""Construct a TEXT.
:param collation: Optional, a column-level collation for this string
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- collation = kw.pop('collation', None)
_StringType.__init__(self, collation)
- sqltypes.Text.__init__(self, *args, **kw)
+ sqltypes.Text.__init__(self, length, **kw)
class NTEXT(_StringType, sqltypes.UnicodeText):
"""MSSQL NTEXT type, for variable-length unicode text up to 2^30
@@ -295,24 +296,22 @@ class NTEXT(_StringType, sqltypes.UnicodeText):
__visit_name__ = 'NTEXT'
- def __init__(self, *args, **kwargs):
+ def __init__(self, length=None, collation=None, **kw):
"""Construct a NTEXT.
:param collation: Optional, a column-level collation for this string
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- collation = kwargs.pop('collation', None)
_StringType.__init__(self, collation)
- length = kwargs.pop('length', None)
- sqltypes.UnicodeText.__init__(self, length, **kwargs)
+ sqltypes.UnicodeText.__init__(self, length, **kw)
class VARCHAR(_StringType, sqltypes.VARCHAR):
"""MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum
of 8,000 characters."""
- def __init__(self, *args, **kw):
+ def __init__(self, length=None, collation=None, **kw):
"""Construct a VARCHAR.
:param length: Optinal, maximum data length, in characters.
@@ -333,16 +332,15 @@ class VARCHAR(_StringType, sqltypes.VARCHAR):
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- collation = kw.pop('collation', None)
_StringType.__init__(self, collation)
- sqltypes.VARCHAR.__init__(self, *args, **kw)
+ sqltypes.VARCHAR.__init__(self, length, **kw)
class NVARCHAR(_StringType, sqltypes.NVARCHAR):
"""MSSQL NVARCHAR type.
For variable-length unicode character data up to 4,000 characters."""
- def __init__(self, *args, **kw):
+ def __init__(self, length=None, collation=None, **kw):
"""Construct a NVARCHAR.
:param length: Optional, Maximum data length, in characters.
@@ -351,15 +349,14 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR):
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- collation = kw.pop('collation', None)
_StringType.__init__(self, collation)
- sqltypes.NVARCHAR.__init__(self, *args, **kw)
+ sqltypes.NVARCHAR.__init__(self, length, **kw)
class CHAR(_StringType, sqltypes.CHAR):
"""MSSQL CHAR type, for fixed-length non-Unicode data with a maximum
of 8,000 characters."""
- def __init__(self, *args, **kw):
+ def __init__(self, length=None, collation=None, **kw):
"""Construct a CHAR.
:param length: Optinal, maximum data length, in characters.
@@ -380,16 +377,15 @@ class CHAR(_StringType, sqltypes.CHAR):
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- collation = kw.pop('collation', None)
_StringType.__init__(self, collation)
- sqltypes.CHAR.__init__(self, *args, **kw)
+ sqltypes.CHAR.__init__(self, length, **kw)
class NCHAR(_StringType, sqltypes.NCHAR):
"""MSSQL NCHAR type.
For fixed-length unicode character data up to 4,000 characters."""
- def __init__(self, *args, **kw):
+ def __init__(self, length=None, collation=None, **kw):
"""Construct an NCHAR.
:param length: Optional, Maximum data length, in characters.
@@ -398,9 +394,8 @@ class NCHAR(_StringType, sqltypes.NCHAR):
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- collation = kw.pop('collation', None)
_StringType.__init__(self, collation)
- sqltypes.NCHAR.__init__(self, *args, **kw)
+ sqltypes.NCHAR.__init__(self, length, **kw)
class IMAGE(sqltypes.LargeBinary):
__visit_name__ = 'IMAGE'
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 42072699e..528e94965 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -233,11 +233,11 @@ SET_RE = re.compile(
class _NumericType(object):
"""Base for MySQL numeric types."""
- def __init__(self, **kw):
- self.unsigned = kw.pop('unsigned', False)
- self.zerofill = kw.pop('zerofill', False)
+ def __init__(self, unsigned=False, zerofill=False, **kw):
+ self.unsigned = unsigned
+ self.zerofill = zerofill
super(_NumericType, self).__init__(**kw)
-
+
class _FloatType(_NumericType, sqltypes.Float):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
if isinstance(self, (REAL, DOUBLE)) and \
@@ -276,7 +276,7 @@ class _StringType(sqltypes.String):
self.binary = binary
self.national = national
super(_StringType, self).__init__(**kw)
-
+
def __repr__(self):
attributes = inspect.getargspec(self.__init__)[0][1:]
attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
@@ -749,7 +749,7 @@ class CHAR(_StringType, sqltypes.CHAR):
__visit_name__ = 'CHAR'
- def __init__(self, length, **kwargs):
+ def __init__(self, length=None, **kwargs):
"""Construct a CHAR.
:param length: Maximum data length, in characters.
@@ -943,6 +943,10 @@ class ENUM(sqltypes.Enum, _StringType):
else:
return value
return process
+
+ def adapt(self, impltype, **kw):
+ kw['strict'] = self.strict
+ return sqltypes.Enum.adapt(self, impltype, **kw)
class SET(_StringType):
"""MySQL SET type."""
@@ -990,8 +994,8 @@ class SET(_StringType):
strip_values.append(a)
self.values = strip_values
- length = max([len(v) for v in strip_values] + [0])
- super(SET, self).__init__(length=length, **kw)
+ kw.setdefault('length', max([len(v) for v in strip_values] + [0]))
+ super(SET, self).__init__(**kw)
def result_processor(self, dialect, coltype):
def process(value):
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 256972696..3d97b504e 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -215,10 +215,6 @@ class INTERVAL(sqltypes.TypeEngine):
return INTERVAL(day_precision=interval.day_precision,
second_precision=interval.second_precision)
- def adapt(self, impltype):
- return impltype(day_precision=self.day_precision,
- second_precision=self.second_precision)
-
@property
def _type_affinity(self):
return sqltypes.Interval
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 821ec5cfb..72b58a71c 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -132,12 +132,13 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
def __init__(self, timezone=False, precision=None):
super(TIMESTAMP, self).__init__(timezone=timezone)
self.precision = precision
+
class TIME(sqltypes.TIME):
def __init__(self, timezone=False, precision=None):
super(TIME, self).__init__(timezone=timezone)
self.precision = precision
-
+
class INTERVAL(sqltypes.TypeEngine):
"""Postgresql INTERVAL type.
@@ -149,9 +150,6 @@ class INTERVAL(sqltypes.TypeEngine):
def __init__(self, precision=None):
self.precision = precision
- def adapt(self, impltype):
- return impltype(self.precision)
-
@classmethod
def _adapt_from_generic_interval(cls, interval):
return INTERVAL(precision=interval.second_precision)
@@ -164,6 +162,9 @@ PGInterval = INTERVAL
class BIT(sqltypes.TypeEngine):
__visit_name__ = 'BIT'
+ def __init__(self, length=1):
+ self.length= length
+
PGBit = BIT
class UUID(sqltypes.TypeEngine):
@@ -213,7 +214,7 @@ class UUID(sqltypes.TypeEngine):
return process
else:
return None
-
+
PGUuid = UUID
class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
@@ -285,23 +286,8 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
def is_mutable(self):
return self.mutable
- def dialect_impl(self, dialect, **kwargs):
- impl = super(ARRAY, self).dialect_impl(dialect, **kwargs)
- if impl is self:
- impl = self.__class__.__new__(self.__class__)
- impl.__dict__.update(self.__dict__)
- impl.item_type = self.item_type.dialect_impl(dialect)
- return impl
-
- def adapt(self, impltype):
- return impltype(
- self.item_type,
- mutable=self.mutable,
- as_tuple=self.as_tuple
- )
-
def bind_processor(self, dialect):
- item_proc = self.item_type.bind_processor(dialect)
+ item_proc = self.item_type.dialect_impl(dialect).bind_processor(dialect)
if item_proc:
def convert_item(item):
if isinstance(item, (list, tuple)):
@@ -321,7 +307,7 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
return process
def result_processor(self, dialect, coltype):
- item_proc = self.item_type.result_processor(dialect, coltype)
+ item_proc = self.item_type.dialect_impl(dialect).result_processor(dialect, coltype)
if item_proc:
def convert_item(item):
if isinstance(item, list):
@@ -640,7 +626,7 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
return "INTERVAL"
def visit_BIT(self, type_):
- return "BIT"
+ return "BIT(%d)" % type_.length
def visit_UUID(self, type_):
return "UUID"
@@ -1095,7 +1081,7 @@ class PGDialect(default.DefaultDialect):
elif attype == 'double precision':
args = (53, )
elif attype == 'integer':
- args = (32, 0)
+ args = ()
elif attype in ('timestamp with time zone',
'time with time zone'):
kwargs['timezone'] = True
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index 707bc1630..1ab5173fe 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -72,7 +72,8 @@ class _DateTimeMixin(object):
_reg = None
_storage_format = None
- def __init__(self, storage_format=None, regexp=None, **kwargs):
+ def __init__(self, storage_format=None, regexp=None, **kw):
+ super(_DateTimeMixin, self).__init__(**kw)
if regexp is not None:
self._reg = re.compile(regexp)
if storage_format is not None:
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 8647ba385..38ff76b89 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -16,6 +16,7 @@ import re, random
from sqlalchemy.engine import base, reflection
from sqlalchemy.sql import compiler, expression
from sqlalchemy import exc, types as sqltypes, util, pool
+import weakref
AUTOCOMMIT_REGEXP = re.compile(
r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
@@ -133,13 +134,17 @@ class DefaultDialect(base.Dialect):
" maximum identifier length of %d" %
(label_length, self.max_identifier_length))
self.label_length = label_length
-
+
if not hasattr(self, 'description_encoding'):
self.description_encoding = getattr(
self,
'description_encoding',
encoding)
+ @util.memoized_property
+ def _type_memos(self):
+ return weakref.WeakKeyDictionary()
+
@property
def dialect_description(self):
return self.name + "+" + self.driver
@@ -398,7 +403,7 @@ class DefaultExecutionContext(base.ExecutionContext):
self.postfetch_cols = self.compiled.postfetch
self.prefetch_cols = self.compiled.prefetch
- processors = compiled._get_bind_processors(dialect)
+ processors = compiled._bind_processors
# Convert the dictionary of bind parameter values
# into a dict or list to be sent to the DBAPI's
diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py
index 9ff503dfa..6981919cf 100644
--- a/lib/sqlalchemy/ext/sqlsoup.py
+++ b/lib/sqlalchemy/ext/sqlsoup.py
@@ -359,7 +359,7 @@ from sqlalchemy import schema, sql, util
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import scoped_session, sessionmaker, mapper, \
class_mapper, relationship, session,\
- object_session
+ object_session, attributes
from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE
from sqlalchemy.exceptions import SQLAlchemyError, InvalidRequestError, ArgumentError
from sqlalchemy.sql import expression
@@ -384,7 +384,8 @@ class AutoAdd(MapperExtension):
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
session = self.scoped_session()
- session._save_without_cascade(instance)
+ state = attributes.instance_state(instance)
+ session._save_impl(state)
return EXT_CONTINUE
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py
index 39ea1db35..35b197f63 100644
--- a/lib/sqlalchemy/orm/dependency.py
+++ b/lib/sqlalchemy/orm/dependency.py
@@ -152,7 +152,7 @@ class DependencyProcessor(object):
# detect if there's anything changed or loaded
# by a preprocessor on this state/attribute. if not,
# we should be able to skip it entirely.
- sum_ = attributes.get_all_pending(state, state.dict, self.key)
+ sum_ = state.manager[self.key].impl.get_all_pending(state, state.dict)
if not sum_:
continue
@@ -439,10 +439,10 @@ class OneToManyDP(DependencyProcessor):
elif self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True,
operation="delete", prop=self.prop)
- for c, m in self.mapper.cascade_iterator(
+ for c, m, st_, dct_ in self.mapper.cascade_iterator(
'delete', child):
uowcommit.register_object(
- attributes.instance_state(c),
+ st_,
isdelete=True)
if pks_changed:
@@ -661,10 +661,10 @@ class ManyToOneDP(DependencyProcessor):
continue
uowcommit.register_object(child, isdelete=True,
operation="delete", prop=self.prop)
- for c, m in self.mapper.cascade_iterator(
+ for c, m, st_, dct_ in self.mapper.cascade_iterator(
'delete', child):
uowcommit.register_object(
- attributes.instance_state(c), isdelete=True)
+ st_, isdelete=True)
def presort_saves(self, uowcommit, states):
for state in states:
@@ -681,10 +681,10 @@ class ManyToOneDP(DependencyProcessor):
uowcommit.register_object(child, isdelete=True,
operation="delete", prop=self.prop)
- for c, m in self.mapper.cascade_iterator(
+ for c, m, st_, dct_ in self.mapper.cascade_iterator(
'delete', child):
uowcommit.register_object(
- attributes.instance_state(c),
+ st_,
isdelete=True)
def process_deletes(self, uowcommit, states):
@@ -939,11 +939,11 @@ class ManyToManyDP(DependencyProcessor):
if self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True,
operation="delete", prop=self.prop)
- for c, m in self.mapper.cascade_iterator(
+ for c, m, st_, dct_ in self.mapper.cascade_iterator(
'delete',
child):
uowcommit.register_object(
- attributes.instance_state(c), isdelete=True)
+ st_, isdelete=True)
def process_deletes(self, uowcommit, states):
secondary_delete = []
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 48c37f80d..8bd8bf3c8 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1382,10 +1382,10 @@ class Mapper(object):
reference so that they don't fall out of scope immediately.
"""
- visited_instances = util.IdentitySet()
+ visited_states = set()
prp, mpp = object(), object()
- visitables = [(deque(self._props.values()), prp, state, state.dict)]
+ visitables = deque([(deque(self._props.values()), prp, state, state.dict)])
while visitables:
iterator, item_type, parent_state, parent_dict = visitables[-1]
@@ -1398,13 +1398,13 @@ class Mapper(object):
if type_ not in prop.cascade:
continue
queue = deque(prop.cascade_iterator(type_, parent_state,
- parent_dict, visited_instances, halt_on))
+ parent_dict, visited_states, halt_on))
if queue:
visitables.append((queue,mpp, None, None))
elif item_type is mpp:
instance, instance_mapper, corresponding_state, \
corresponding_dict = iterator.popleft()
- yield (instance, instance_mapper)
+ yield instance, instance_mapper, corresponding_state, corresponding_dict
visitables.append((deque(instance_mapper._props.values()),
prp, corresponding_state, corresponding_dict))
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 239159f3e..da6d309e0 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -820,9 +820,8 @@ class RelationshipProperty(StrategizedProperty):
dest_state.get_impl(self.key).set(dest_state,
dest_dict, obj, None)
- def cascade_iterator(self, type_, state, dict_, visited_instances, halt_on=None):
- if not type_ in self.cascade:
- return
+ def cascade_iterator(self, type_, state, dict_, visited_states, halt_on=None):
+ #assert type_ in self.cascade
# only actively lazy load on the 'delete' cascade
if type_ != 'delete' or self.passive_deletes:
@@ -831,7 +830,7 @@ class RelationshipProperty(StrategizedProperty):
passive = attributes.PASSIVE_OFF
if type_ == 'save-update':
- instances = attributes.get_all_pending(state, dict_, self.key)
+ instances = state.manager[self.key].impl.get_all_pending(state, dict_)
else:
instances = state.value_as_iterable(dict_, self.key,
@@ -842,10 +841,12 @@ class RelationshipProperty(StrategizedProperty):
if instances:
for c in instances:
if c is not None and \
- c is not attributes.PASSIVE_NO_RESULT and \
- c not in visited_instances:
-
+ c is not attributes.PASSIVE_NO_RESULT:
+
instance_state = attributes.instance_state(c)
+ if instance_state in visited_states:
+ continue
+
instance_dict = attributes.instance_dict(c)
if halt_on and halt_on(instance_state):
@@ -865,7 +866,7 @@ class RelationshipProperty(StrategizedProperty):
c.__class__
))
- visited_instances.add(c)
+ visited_states.add(instance_state)
# cascade using the mapper local to this
# object, so that its individual properties are located
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index e2c1308b8..1f704f502 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -981,16 +981,18 @@ class Session(object):
else:
# pre-fetch the full cascade since the expire is going to
# remove associations
- cascaded = list(_cascade_state_iterator('refresh-expire', state))
+ cascaded = list(state.manager.mapper.cascade_iterator(
+ 'refresh-expire', state))
self._conditional_expire(state)
- for (state, m, o) in cascaded:
- self._conditional_expire(state)
+ for o, m, st_, dct_ in cascaded:
+ self._conditional_expire(st_)
def _conditional_expire(self, state):
"""Expire a state if persistent, else expunge if pending"""
if state.key:
- _expire_state(state, state.dict, None, instance_dict=self.identity_map)
+ _expire_state(state, state.dict, None,
+ instance_dict=self.identity_map)
elif state in self._new:
self._new.pop(state)
state.detach()
@@ -1023,8 +1025,12 @@ class Session(object):
raise sa_exc.InvalidRequestError(
"Instance %s is not present in this Session" %
mapperutil.state_str(state))
- for s, m, o in [(state, None, None)] + list(_cascade_state_iterator('expunge', state)):
- self._expunge_state(s)
+
+ cascaded = list(state.manager.mapper.cascade_iterator(
+ 'expunge', state))
+ self._expunge_state(state)
+ for o, m, st_, dct_ in cascaded:
+ self._expunge_state(st_)
def _expunge_state(self, state):
if state in self._new:
@@ -1078,12 +1084,6 @@ class Session(object):
self._deleted.pop(state, None)
state.deleted = True
- def _save_without_cascade(self, instance):
- """Used by scoping.py to save on init without cascade."""
-
- state = _state_for_unsaved_instance(instance, create=True)
- self._save_impl(state)
-
def add(self, instance):
"""Place an object in the ``Session``.
@@ -1094,7 +1094,11 @@ class Session(object):
is ``expunge()``.
"""
- state = _state_for_unknown_persistence_instance(instance)
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE:
+ raise exc.UnmappedInstanceError(instance)
+
self._save_or_update_state(state)
def add_all(self, instances):
@@ -1105,13 +1109,13 @@ class Session(object):
def _save_or_update_state(self, state):
self._save_or_update_impl(state)
- self._cascade_save_or_update(state)
- def _cascade_save_or_update(self, state):
- for state, mapper in _cascade_unknown_state_iterator(
- 'save-update', state,
+ mapper = _state_mapper(state)
+ for o, m, st_, dct_ in mapper.cascade_iterator(
+ 'save-update',
+ state,
halt_on=self._contains_state):
- self._save_or_update_impl(state)
+ self._save_or_update_impl(st_)
def delete(self, instance):
"""Mark an instance as deleted.
@@ -1139,16 +1143,19 @@ class Session(object):
# grab the cascades before adding the item to the deleted list
# so that autoflush does not delete the item
- cascade_states = list(_cascade_state_iterator('delete', state))
+ # the strong reference to the instance itself is significant here
+ cascade_states = list(state.manager.mapper.cascade_iterator(
+ 'delete', state))
self._deleted[state] = state.obj()
self.identity_map.add(state)
- for state, m, o in cascade_states:
- self._delete_impl(state)
+ for o, m, st_, dct_ in cascade_states:
+ self._delete_impl(st_)
def merge(self, instance, load=True, **kw):
- """Copy the state an instance onto the persistent instance with the same identifier.
+ """Copy the state an instance onto the persistent instance with the
+ same identifier.
If there is no persistent instance currently associated with the
session, it will be loaded. Return the persistent instance. If the
@@ -1164,7 +1171,8 @@ class Session(object):
"""
if 'dont_load' in kw:
load = not kw['dont_load']
- util.warn_deprecated("dont_load=True has been renamed to load=False.")
+ util.warn_deprecated('dont_load=True has been renamed to '
+ 'load=False.')
_recursive = {}
@@ -1241,7 +1249,9 @@ class Session(object):
merged_state.load_options = state.load_options
for prop in mapper.iterate_properties:
- prop.merge(self, state, state_dict, merged_state, merged_dict, load, _recursive)
+ prop.merge(self, state, state_dict,
+ merged_state, merged_dict,
+ load, _recursive)
if not load:
# remove any history
@@ -1319,10 +1329,10 @@ class Session(object):
if state.key and \
state.key in self.identity_map and \
not self.identity_map.contains_state(state):
- raise sa_exc.InvalidRequestError(
- "Can't attach instance %s; another instance with key %s is already present in this session." %
- (mapperutil.state_str(state), state.key)
- )
+ raise sa_exc.InvalidRequestError("Can't attach instance "
+ "%s; another instance with key %s is already "
+ "present in this session."
+ % (mapperutil.state_str(state), state.key))
if state.session_id and state.session_id is not self.hash_key:
raise sa_exc.InvalidRequestError(
@@ -1475,7 +1485,8 @@ class Session(object):
#if not objects:
# assert not self.identity_map._modified
#else:
- # assert self.identity_map._modified == self.identity_map._modified.difference(objects)
+ # assert self.identity_map._modified == \
+ # self.identity_map._modified.difference(objects)
#self.identity_map._modified.clear()
self.dispatch.on_after_flush_postexec(self, flush_context)
@@ -1598,47 +1609,6 @@ UOWEventHandler = unitofwork.UOWEventHandler
_sessions = weakref.WeakValueDictionary()
-def _cascade_state_iterator(cascade, state, **kwargs):
- mapper = _state_mapper(state)
- # yield the state, object, mapper. yielding the object
- # allows the iterator's results to be held in a list without
- # states being garbage collected
- for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs):
- yield attributes.instance_state(o), o, m
-
-def _cascade_unknown_state_iterator(cascade, state, **kwargs):
- mapper = _state_mapper(state)
- for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs):
- yield _state_for_unknown_persistence_instance(o), m
-
-def _state_for_unsaved_instance(instance, create=False):
- try:
- state = attributes.instance_state(instance)
- except AttributeError:
- raise exc.UnmappedInstanceError(instance)
- if state:
- if state.key is not None:
- raise sa_exc.InvalidRequestError(
- "Instance '%s' is already persistent" %
- mapperutil.state_str(state))
- elif create:
- manager = attributes.manager_of_class(instance.__class__)
- if manager is None:
- raise exc.UnmappedInstanceError(instance)
- state = manager.setup_instance(instance)
- else:
- raise exc.UnmappedInstanceError(instance)
-
- return state
-
-def _state_for_unknown_persistence_instance(instance):
- try:
- state = attributes.instance_state(instance)
- except exc.NO_STATE:
- raise exc.UnmappedInstanceError(instance)
-
- return state
-
def make_transient(instance):
"""Make the given instance 'transient'.
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index 1e1eda4a3..ba43b1359 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -16,7 +16,6 @@ from sqlalchemy import util
from sqlalchemy.util import topological
from sqlalchemy.orm import attributes, interfaces
from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm.util import _state_mapper
session = util.importlater("sqlalchemy.orm", "session")
class UOWEventHandler(interfaces.AttributeExtension):
@@ -28,28 +27,32 @@ class UOWEventHandler(interfaces.AttributeExtension):
def __init__(self, key):
self.key = key
-
+
+ # TODO: migrate these to unwrapped events
+
def append(self, state, item, initiator):
# process "save_update" cascade rules for when
# an instance is appended to the list of another instance
sess = session._state_session(state)
if sess:
- prop = _state_mapper(state)._props[self.key]
+ prop = state.manager.mapper._props[self.key]
+ item_state = attributes.instance_state(item)
if prop.cascade.save_update and \
(prop.cascade_backrefs or self.key == initiator.key) and \
- item not in sess:
- sess.add(item)
+ not sess._contains_state(item_state):
+ sess._save_or_update_state(item_state)
return item
def remove(self, state, item, initiator):
sess = session._state_session(state)
if sess:
- prop = _state_mapper(state)._props[self.key]
+ prop = state.manager.mapper._props[self.key]
# expunge pending orphans
+ item_state = attributes.instance_state(item)
if prop.cascade.delete_orphan and \
- item in sess.new and \
- prop.mapper._is_orphan(attributes.instance_state(item)):
+ item_state in sess._new and \
+ prop.mapper._is_orphan(item_state):
sess.expunge(item)
def set(self, state, newvalue, oldvalue, initiator):
@@ -60,16 +63,20 @@ class UOWEventHandler(interfaces.AttributeExtension):
sess = session._state_session(state)
if sess:
- prop = _state_mapper(state)._props[self.key]
- if newvalue is not None and \
- prop.cascade.save_update and \
- (prop.cascade_backrefs or self.key == initiator.key) and \
- newvalue not in sess:
- sess.add(newvalue)
- if prop.cascade.delete_orphan and \
- oldvalue in sess.new and \
- prop.mapper._is_orphan(attributes.instance_state(oldvalue)):
- sess.expunge(oldvalue)
+ prop = state.manager.mapper._props[self.key]
+ if newvalue is not None:
+ newvalue_state = attributes.instance_state(newvalue)
+ if prop.cascade.save_update and \
+ (prop.cascade_backrefs or self.key == initiator.key) and \
+ not sess._contains_state(newvalue_state):
+ sess._save_or_update_state(newvalue_state)
+
+ if oldvalue is not None and prop.cascade.delete_orphan:
+ oldvalue_state = attributes.instance_state(oldvalue)
+
+ if oldvalue_state in sess._new and \
+ prop.mapper._is_orphan(oldvalue_state):
+ sess.expunge(oldvalue)
return newvalue
@@ -196,7 +203,7 @@ class UOWTransaction(object):
return False
if state not in self.states:
- mapper = _state_mapper(state)
+ mapper = state.manager.mapper
if mapper not in self.mappers:
mapper._per_mapper_flush_actions(self)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 8bd728a7c..cf1e28f50 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -253,24 +253,16 @@ class SQLCompiler(engine.Compiled):
# or dialect.max_identifier_length
self.truncated_names = {}
- # other memoized things
- self._memos ={}
- def _get_bind_processors(self, dialect):
- key = 'bind_processors', dialect.__class__, \
- dialect.server_version_info
-
- if key not in self._memos:
- self._memos[key] = processors = dict(
+ @util.memoized_property
+ def _bind_processors(self):
+ return dict(
(key, value) for key, value in
( (self.bind_names[bindparam],
- bindparam.type._cached_bind_processor(dialect))
+ bindparam.type._cached_bind_processor(self.dialect))
for bindparam in self.bind_names )
if value is not None
)
- return processors
- else:
- return self._memos[key]
def is_subquery(self):
return len(self.stack) > 1
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index e31486eff..f5df02367 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -131,57 +131,58 @@ class TypeEngine(AbstractType):
else:
return self.__class__
- @util.memoized_property
- def _impl_dict(self):
- return {}
-
- def __getstate__(self):
- d = self.__dict__.copy()
- d.pop('_impl_dict', None)
- return d
-
- def dialect_impl(self, dialect, **kwargs):
- key = dialect.__class__, dialect.server_version_info
+ def dialect_impl(self, dialect):
+ """Return a dialect-specific implementation for this type."""
+
try:
- return self._impl_dict[key]
+ return dialect._type_memos[self]['impl']
except KeyError:
- return self._impl_dict.setdefault(key,
- dialect.type_descriptor(self))
+ return self._dialect_info(dialect)['impl']
def _cached_bind_processor(self, dialect):
- return self.dialect_impl(dialect).bind_processor(dialect)
-
- # TODO: can't do this until we find a way to link with the
- # specific attributes of the dialect, i.e. convert_unicode,
- # etc. might need to do a weakmap again. needs tests
- # to ensure two dialects with different flags. use a mock
- # dialect.
- #key = "bind", dialect.__class__, dialect.server_version_info
- #try:
- # return self._impl_dict[key]
- #except KeyError:
- # self._impl_dict[key] = bp = \
- # self.dialect_impl(dialect).bind_processor(dialect)
- # return bp
+ """Return a dialect-specific bind processor for this type."""
+ try:
+ return dialect._type_memos[self]['bind']
+ except KeyError:
+ d = self._dialect_info(dialect)
+ d['bind'] = bp = d['impl'].bind_processor(dialect)
+ return bp
+
def _cached_result_processor(self, dialect, coltype):
- return self.dialect_impl(dialect).result_processor(dialect, coltype)
+ """Return a dialect-specific result processor for this type."""
+
+ try:
+ return dialect._type_memos[self][coltype]
+ except KeyError:
+ d = self._dialect_info(dialect)
+ # key assumption: DBAPI type codes are
+ # constants. Else this dictionary would
+ # grow unbounded.
+ d[coltype] = rp = d['impl'].result_processor(dialect, coltype)
+ return rp
+
+ def _dialect_info(self, dialect):
+ """Return a dialect-specific registry which
+ caches a dialect-specific implementation, bind processing
+ function, and one or more result processing functions."""
+
+ if self in dialect._type_memos:
+ return dialect._type_memos[self]
+ else:
+ impl = self._gen_dialect_impl(dialect)
+ if impl is self:
+ impl = self.adapt(type(self))
+ # this can't be self, else we create a cycle
+ assert impl is not self
+ dialect._type_memos[self] = d = {'impl':impl}
+ return d
+
+ def _gen_dialect_impl(self, dialect):
+ return dialect.type_descriptor(self)
- # TODO: can't do this until we find a way to link with the
- # specific attributes of the dialect, i.e. convert_unicode,
- # etc. might need to do a weakmap again. needs tests
- # to ensure two dialects with different flags. use a mock
- # dialect.
- #key = "result", dialect.__class__, dialect.server_version_info, coltype
- #try:
- # return self._impl_dict[key]
- #except KeyError:
- # self._impl_dict[key] = rp = self.dialect_impl(dialect).\
- # result_processor(dialect, coltype)
- # return rp
-
- def adapt(self, cls):
- return cls()
+ def adapt(self, cls, **kw):
+ return util.constructor_copy(self, cls, **kw)
def _coerce_compared_value(self, op, value):
_coerced_type = _type_map.get(type(value), NULLTYPE)
@@ -220,7 +221,7 @@ class TypeEngine(AbstractType):
encode('ascii', 'backslashreplace')
# end Py2K
- def __init__(self, *args, **kwargs):
+ def __init__(self):
# supports getargspec of the __init__ method
# used by generic __repr__
pass
@@ -376,17 +377,10 @@ class TypeDecorator(TypeEngine):
"type being decorated")
self.impl = to_instance(self.__class__.impl, *args, **kwargs)
- def dialect_impl(self, dialect):
- key = (dialect.__class__, dialect.server_version_info)
-
- try:
- return self._impl_dict[key]
- except KeyError:
- pass
-
+
+ def _gen_dialect_impl(self, dialect):
adapted = dialect.type_descriptor(self)
if adapted is not self:
- self._impl_dict[key] = adapted
return adapted
# otherwise adapt the impl type, link
@@ -400,7 +394,6 @@ class TypeDecorator(TypeEngine):
'return an object of type %s' % (self,
self.__class__))
tt.impl = typedesc
- self._impl_dict[key] = tt
return tt
@util.memoized_property
@@ -499,7 +492,6 @@ class TypeDecorator(TypeEngine):
def copy(self):
instance = self.__class__.__new__(self.__class__)
instance.__dict__.update(self.__dict__)
- instance._impl_dict = {}
return instance
def get_dbapi_type(self, dbapi):
@@ -650,6 +642,9 @@ def adapt_type(typeobj, colspecs):
return typeobj
return typeobj.adapt(impltype)
+
+
+
class NullType(TypeEngine):
"""An unknown type.
@@ -796,14 +791,6 @@ class String(Concatenable, TypeEngine):
self.unicode_error = unicode_error
self._warn_on_bytestring = _warn_on_bytestring
- def adapt(self, impltype):
- return impltype(
- length=self.length,
- convert_unicode=self.convert_unicode,
- unicode_error=self.unicode_error,
- _warn_on_bytestring=True,
- )
-
def bind_processor(self, dialect):
if self.convert_unicode or dialect.convert_unicode:
if dialect.supports_unicode_binds and \
@@ -1100,12 +1087,6 @@ class Numeric(_DateAffinity, TypeEngine):
self.scale = scale
self.asdecimal = asdecimal
- def adapt(self, impltype):
- return impltype(
- precision=self.precision,
- scale=self.scale,
- asdecimal=self.asdecimal)
-
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
@@ -1179,7 +1160,9 @@ class Float(Numeric):
"""
__visit_name__ = 'float'
-
+
+ scale = None
+
def __init__(self, precision=None, asdecimal=False, **kwargs):
"""
Construct a Float.
@@ -1195,9 +1178,6 @@ class Float(Numeric):
self.precision = precision
self.asdecimal = asdecimal
- def adapt(self, impltype):
- return impltype(precision=self.precision, asdecimal=self.asdecimal)
-
def result_processor(self, dialect, coltype):
if self.asdecimal:
return processors.to_decimal_processor_factory(decimal.Decimal)
@@ -1244,9 +1224,6 @@ class DateTime(_DateAffinity, TypeEngine):
def __init__(self, timezone=False):
self.timezone = timezone
- def adapt(self, impltype):
- return impltype(timezone=self.timezone)
-
def get_dbapi_type(self, dbapi):
return dbapi.DATETIME
@@ -1304,9 +1281,6 @@ class Time(_DateAffinity,TypeEngine):
def __init__(self, timezone=False):
self.timezone = timezone
- def adapt(self, impltype):
- return impltype(timezone=self.timezone)
-
def get_dbapi_type(self, dbapi):
return dbapi.DATETIME
@@ -1366,9 +1340,6 @@ class _Binary(TypeEngine):
else:
return super(_Binary, self)._coerce_compared_value(op, value)
- def adapt(self, impltype):
- return impltype(length=self.length)
-
def get_dbapi_type(self, dbapi):
return dbapi.BINARY
@@ -1453,7 +1424,7 @@ class SchemaType(object):
if bind is None:
bind = schema._bind_or_error(self)
t = self.dialect_impl(bind.dialect)
- if t is not self and isinstance(t, SchemaType):
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
t.create(bind=bind, checkfirst=checkfirst)
def drop(self, bind=None, checkfirst=False):
@@ -1462,27 +1433,27 @@ class SchemaType(object):
if bind is None:
bind = schema._bind_or_error(self)
t = self.dialect_impl(bind.dialect)
- if t is not self and isinstance(t, SchemaType):
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
t.drop(bind=bind, checkfirst=checkfirst)
def _on_table_create(self, event, target, bind, **kw):
t = self.dialect_impl(bind.dialect)
- if t is not self and isinstance(t, SchemaType):
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
t._on_table_create(event, target, bind, **kw)
def _on_table_drop(self, event, target, bind, **kw):
t = self.dialect_impl(bind.dialect)
- if t is not self and isinstance(t, SchemaType):
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
t._on_table_drop(event, target, bind, **kw)
def _on_metadata_create(self, event, target, bind, **kw):
t = self.dialect_impl(bind.dialect)
- if t is not self and isinstance(t, SchemaType):
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
t._on_metadata_create(event, target, bind, **kw)
def _on_metadata_drop(self, event, target, bind, **kw):
t = self.dialect_impl(bind.dialect)
- if t is not self and isinstance(t, SchemaType):
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
t._on_metadata_drop(event, target, bind, **kw)
class Enum(String, SchemaType):
@@ -1579,7 +1550,7 @@ class Enum(String, SchemaType):
)
table.append_constraint(e)
- def adapt(self, impltype):
+ def adapt(self, impltype, **kw):
if issubclass(impltype, Enum):
return impltype(name=self.name,
quote=self.quote,
@@ -1587,10 +1558,11 @@ class Enum(String, SchemaType):
metadata=self.metadata,
convert_unicode=self.convert_unicode,
native_enum=self.native_enum,
- *self.enums
+ *self.enums,
+ **kw
)
else:
- return super(Enum, self).adapt(impltype)
+ return super(Enum, self).adapt(impltype, **kw)
class PickleType(MutableType, TypeDecorator):
"""Holds Python objects, which are serialized using pickle.
@@ -1792,11 +1764,11 @@ class Interval(_DateAffinity, TypeDecorator):
self.second_precision = second_precision
self.day_precision = day_precision
- def adapt(self, cls):
- if self.native:
- return cls._adapt_from_generic_interval(self)
+ def adapt(self, cls, **kw):
+ if self.native and hasattr(cls, '_adapt_from_generic_interval'):
+ return cls._adapt_from_generic_interval(self, **kw)
else:
- return self
+ return cls(**kw)
def bind_processor(self, dialect):
impl_processor = self.impl.bind_processor(dialect)
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index 9119e35b7..ae1eb3ac5 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -24,7 +24,8 @@ from langhelpers import iterate_attributes, class_hierarchy, \
reset_memoized, group_expirable_memoized_property, importlater, \
monkeypatch_proxied_specials, asbool, bool_or_str, coerce_kw_type,\
duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\
- classproperty, set_creation_order, warn_exception, warn, NoneType
+ classproperty, set_creation_order, warn_exception, warn, NoneType,\
+ constructor_copy
from deprecations import warn_deprecated, warn_pending_deprecation, \
deprecated, pending_deprecation
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index 022da2de8..57ce02a1e 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -875,7 +875,6 @@ class ThreadLocalRegistry(ScopedRegistry):
except AttributeError:
pass
-
def _iter_id(iterable):
"""Generator: ((id(o), o) for o in iterable)."""
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index d85793ee0..2b9e890fc 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -54,6 +54,9 @@ def get_cls_kwargs(cls):
pass along unrecognized keywords to it's base classes, and the collection
process is repeated recursively on each of the bases.
+ Uses a subset of inspect.getargspec() to cut down on method overhead.
+ No anonymous tuple arguments please !
+
"""
for c in cls.__mro__:
@@ -70,15 +73,39 @@ def get_cls_kwargs(cls):
if not ctr or not isinstance(ctr, types.FunctionType):
stack.update(class_.__bases__)
continue
- names, _, has_kw, _ = inspect.getargspec(ctr)
+
+ # this is shorthand for
+ # names, _, has_kw, _ = inspect.getargspec(ctr)
+
+ names, has_kw = inspect_func_args(ctr)
args.update(names)
if has_kw:
stack.update(class_.__bases__)
args.discard('self')
return args
+try:
+ from inspect import CO_VARKEYWORDS
+ def inspect_func_args(fn):
+ co = fn.func_code
+ nargs = co.co_argcount
+ names = co.co_varnames
+ args = list(names[:nargs])
+ has_kw = bool(co.co_flags & CO_VARKEYWORDS)
+ return args, has_kw
+except ImportError:
+ def inspect_func_args(fn):
+ names, _, has_kw, _ = inspect.getargspec(fn)
+ return names, bool(has_kw)
+
def get_func_kwargs(func):
- """Return the full set of legal kwargs for the given `func`."""
+ """Return the set of legal kwargs for the given `func`.
+
+ Uses getargspec so is safe to call for methods, functions,
+ etc.
+
+ """
+
return inspect.getargspec(func)[0]
def format_argspec_plus(fn, grouped=True):
@@ -516,6 +543,19 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True):
else:
kw[key] = type_(kw[key])
+
+def constructor_copy(obj, cls, **kw):
+ """Instantiate cls using the __dict__ of obj as constructor arguments.
+
+ Uses inspect to match the named arguments of ``cls``.
+
+ """
+
+ names = get_cls_kwargs(cls)
+ kw.update((k, obj.__dict__[k]) for k in names if k in obj.__dict__)
+ return cls(**kw)
+
+
def duck_type_collection(specimen, default=None):
"""Given an instance or class, guess if it is or is acting as one of
the basic collection types: list, set and dict. If the __emulates__