summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/schema.py
diff options
context:
space:
mode:
authorJason Kirtland <jek@discorporate.us>2008-02-04 20:49:38 +0000
committerJason Kirtland <jek@discorporate.us>2008-02-04 20:49:38 +0000
commit0de289921c4d52798248cfacbacc04ccad12cec9 (patch)
treec3419490d745e18366a6b91310445b770875b3ca /lib/sqlalchemy/schema.py
parent66df4b4958c2cd8dbb699a0c1fe70d0fe97474db (diff)
downloadsqlalchemy-0de289921c4d52798248cfacbacc04ccad12cec9.tar.gz
- ColumnDefault callables can now be any kind of compliant callable, previously only actual functions were allowed.
Diffstat (limited to 'lib/sqlalchemy/schema.py')
-rw-r--r--lib/sqlalchemy/schema.py50
1 files changed, 34 insertions, 16 deletions
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 44dcb5755..98e375507 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -730,22 +730,40 @@ class ColumnDefault(DefaultGenerator):
def __init__(self, arg, **kwargs):
super(ColumnDefault, self).__init__(**kwargs)
if callable(arg):
- if not inspect.isfunction(arg):
- self.arg = lambda ctx: arg()
- else:
- argspec = inspect.getargspec(arg)
- if len(argspec[0]) == 0:
- self.arg = lambda ctx: arg()
- else:
- defaulted = argspec[3] is not None and len(argspec[3]) or 0
- if len(argspec[0]) - defaulted > 1:
- raise exceptions.ArgumentError(
- "ColumnDefault Python function takes zero or one "
- "positional arguments")
- else:
- self.arg = arg
+ arg = self._maybe_wrap_callable(arg)
+ self.arg = arg
+
+ def _maybe_wrap_callable(self, fn):
+ """Backward compat: Wrap callables that don't accept a context."""
+
+ if inspect.isfunction(fn):
+ inspectable = fn
+ elif inspect.isclass(fn):
+ inspectable = fn.__init__
+ elif hasattr(fn, '__call__'):
+ inspectable = fn.__call__
else:
- self.arg = arg
+ # probably not inspectable, try anyways.
+ inspectable = fn
+ try:
+ argspec = inspect.getargspec(inspectable)
+ except TypeError:
+ return lambda ctx: fn()
+
+ positionals = len(argspec[0])
+ if inspect.ismethod(inspectable):
+ positionals -= 1
+
+ if positionals == 0:
+ return lambda ctx: fn()
+
+ defaulted = argspec[3] is not None and len(argspec[3]) or 0
+ if positionals - defaulted > 1:
+ raise exceptions.ArgumentError(
+ "ColumnDefault Python function takes zero or one "
+ "positional arguments")
+ return fn
+
def _visit_name(self):
if self.for_update:
@@ -783,7 +801,7 @@ class Sequence(DefaultGenerator):
def create(self, bind=None, checkfirst=True):
"""Creates this sequence in the database."""
-
+
if bind is None:
bind = _bind_or_error(self)
bind.create(self, checkfirst=checkfirst)