diff options
| author | Jason Kirtland <jek@discorporate.us> | 2008-02-04 20:49:38 +0000 |
|---|---|---|
| committer | Jason Kirtland <jek@discorporate.us> | 2008-02-04 20:49:38 +0000 |
| commit | 0de289921c4d52798248cfacbacc04ccad12cec9 (patch) | |
| tree | c3419490d745e18366a6b91310445b770875b3ca /lib/sqlalchemy/schema.py | |
| parent | 66df4b4958c2cd8dbb699a0c1fe70d0fe97474db (diff) | |
| download | sqlalchemy-0de289921c4d52798248cfacbacc04ccad12cec9.tar.gz | |
- ColumnDefault callables can now be any kind of compliant callable, previously only actual functions were allowed.
Diffstat (limited to 'lib/sqlalchemy/schema.py')
| -rw-r--r-- | lib/sqlalchemy/schema.py | 50 |
1 files changed, 34 insertions, 16 deletions
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 44dcb5755..98e375507 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -730,22 +730,40 @@ class ColumnDefault(DefaultGenerator): def __init__(self, arg, **kwargs): super(ColumnDefault, self).__init__(**kwargs) if callable(arg): - if not inspect.isfunction(arg): - self.arg = lambda ctx: arg() - else: - argspec = inspect.getargspec(arg) - if len(argspec[0]) == 0: - self.arg = lambda ctx: arg() - else: - defaulted = argspec[3] is not None and len(argspec[3]) or 0 - if len(argspec[0]) - defaulted > 1: - raise exceptions.ArgumentError( - "ColumnDefault Python function takes zero or one " - "positional arguments") - else: - self.arg = arg + arg = self._maybe_wrap_callable(arg) + self.arg = arg + + def _maybe_wrap_callable(self, fn): + """Backward compat: Wrap callables that don't accept a context.""" + + if inspect.isfunction(fn): + inspectable = fn + elif inspect.isclass(fn): + inspectable = fn.__init__ + elif hasattr(fn, '__call__'): + inspectable = fn.__call__ else: - self.arg = arg + # probably not inspectable, try anyways. + inspectable = fn + try: + argspec = inspect.getargspec(inspectable) + except TypeError: + return lambda ctx: fn() + + positionals = len(argspec[0]) + if inspect.ismethod(inspectable): + positionals -= 1 + + if positionals == 0: + return lambda ctx: fn() + + defaulted = argspec[3] is not None and len(argspec[3]) or 0 + if positionals - defaulted > 1: + raise exceptions.ArgumentError( + "ColumnDefault Python function takes zero or one " + "positional arguments") + return fn + def _visit_name(self): if self.for_update: @@ -783,7 +801,7 @@ class Sequence(DefaultGenerator): def create(self, bind=None, checkfirst=True): """Creates this sequence in the database.""" - + if bind is None: bind = _bind_or_error(self) bind.create(self, checkfirst=checkfirst) |
