diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-02-11 19:33:06 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-02-11 19:33:06 +0000 |
commit | 85d335b01bf64a27e99cee915205afd99e7191b5 (patch) | |
tree | 24afce742247b27fe02da2ab32635cd7ff8590cc /lib/sqlalchemy/sql/expression.py | |
parent | 9d7335f934d3197f572017865220897763d4582b (diff) | |
download | sqlalchemy-85d335b01bf64a27e99cee915205afd99e7191b5.tar.gz |
- The type/expression system now does a more complete job
of determining the return type from an expression
as well as the adaptation of the Python operator into
a SQL operator, based on the full left/right/operator
of the given expression. In particular
the date/time/interval system created for Postgresql
EXTRACT in [ticket:1647] has now been generalized into
the type system. The previous behavior which often
occured of an expression "column + literal" forcing
the type of "literal" to be the same as that of "column"
will now usually not occur - the type of
"literal" is first derived from the Python type of the
literal, assuming standard native Python types + date
types, before falling back to that of the known type
on the other side of the expression. Also part
of [ticket:1683].
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 88 |
1 files changed, 43 insertions, 45 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 878b0d826..1ae706999 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -407,14 +407,7 @@ def between(ctest, cleft, cright): """ ctest = _literal_as_binds(ctest) - return _BinaryExpression( - ctest, - ClauseList( - _literal_as_binds(cleft, type_=ctest.type), - _literal_as_binds(cright, type_=ctest.type), - operator=operators.and_, - group=False), - operators.between_op) + return ctest.between(cleft, cright) def case(whens, value=None, else_=None): @@ -1453,19 +1446,35 @@ class _CompareMixin(ColumnOperators): obj = self._check_literal(obj) if reverse: - return _BinaryExpression(obj, self, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + return _BinaryExpression(obj, + self, + op, + type_=sqltypes.BOOLEANTYPE, + negate=negate, modifiers=kwargs) else: - return _BinaryExpression(self, obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + return _BinaryExpression(self, + obj, + op, + type_=sqltypes.BOOLEANTYPE, + negate=negate, modifiers=kwargs) def __operate(self, op, obj, reverse=False): obj = self._check_literal(obj) - - type_ = self._compare_type(obj) - + if reverse: - return _BinaryExpression(obj, self, type_.adapt_operator(op), type_=type_) + left, right = obj, self else: - return _BinaryExpression(self, obj, type_.adapt_operator(op), type_=type_) + left, right = self, obj + + if left.type is None: + op, result_type = sqltypes.NULLTYPE._adapt_expression(op, right.type) + elif right.type is None: + op, result_type = left.type._adapt_expression(op, sqltypes.NULLTYPE) + else: + op, result_type = left.type._adapt_expression(op, right.type) + + return _BinaryExpression(left, right, op, type_=result_type) + # a mapping of operators with the method they use, along with their negated # operator for comparison operators @@ -1643,7 +1652,7 @@ class _CompareMixin(ColumnOperators): return lambda other: self.__operate(operator, other) def _bind_param(self, obj): - return _BindParamClause(None, obj, type_=self.type, unique=True) + return _BindParamClause(None, obj, _fallback_type=self.type, unique=True) def _check_literal(self, other): if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType): @@ -1658,14 +1667,6 @@ class _CompareMixin(ColumnOperators): else: return other - def _compare_type(self, obj): - """Allow subclasses to override the type used in constructing - :class:`_BinaryExpression` objects. - - Default return value is the type of the given object. - - """ - return obj.type class ColumnElement(ClauseElement, _CompareMixin): """Represent an element that is usable within the "column clause" portion of a ``SELECT`` statement. @@ -2105,7 +2106,9 @@ class _BindParamClause(ColumnElement): __visit_name__ = 'bindparam' quote = None - def __init__(self, key, value, type_=None, unique=False, isoutparam=False, required=False): + def __init__(self, key, value, type_=None, unique=False, + isoutparam=False, required=False, + _fallback_type=None): """Construct a _BindParamClause. key @@ -2151,12 +2154,12 @@ class _BindParamClause(ColumnElement): self.required = required if type_ is None: - self.type = sqltypes.type_map.get(type(value), sqltypes.NullType)() + self.type = sqltypes.type_map.get(type(value), _fallback_type or sqltypes.NULLTYPE) elif isinstance(type_, type): self.type = type_() else: self.type = type_ - + def _clone(self): c = ClauseElement._clone(self) if self.unique: @@ -2171,12 +2174,6 @@ class _BindParamClause(ColumnElement): def bind_processor(self, dialect): return self.type.dialect_impl(dialect).bind_processor(dialect) - def _compare_type(self, obj): - if not isinstance(self.type, sqltypes.NullType): - return self.type - else: - return obj.type - def compare(self, other, **kw): """Compare this :class:`_BindParamClause` to the given clause.""" @@ -2342,7 +2339,14 @@ class ClauseList(ClauseElement): self.clauses = [ _literal_as_text(clause) for clause in clauses if clause is not None] - + + @util.memoized_property + def type(self): + if self.clauses: + return self.clauses[0].type + else: + return sqltypes.NULLTYPE + def __iter__(self): return iter(self.clauses) @@ -2419,7 +2423,7 @@ class _Tuple(ClauseList, ColumnElement): def _bind_param(self, obj): return _Tuple(*[ - _BindParamClause(None, o, type_=self.type, unique=True) + _BindParamClause(None, o, _fallback_type=self.type, unique=True) for o in obj ]).self_group() @@ -2518,11 +2522,8 @@ class FunctionElement(ColumnElement, FromClause): def execute(self): return select([self]).execute() - def _compare_type(self, obj): - return self.type - def _bind_param(self, obj): - return _BindParamClause(None, obj, type_=self.type, unique=True) + return _BindParamClause(None, obj, _fallback_type=self.type, unique=True) class Function(FunctionElement): @@ -2539,7 +2540,7 @@ class Function(FunctionElement): FunctionElement.__init__(self, *clauses, **kw) def _bind_param(self, obj): - return _BindParamClause(self.name, obj, type_=self.type, unique=True) + return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True) class _Cast(ColumnElement): @@ -2698,7 +2699,7 @@ class _BinaryExpression(ColumnElement): self.right, self.negate, negate=self.operator, - type_=self.type, + type_=sqltypes.BOOLEANTYPE, modifiers=self.modifiers) else: return super(_BinaryExpression, self)._negate() @@ -3149,7 +3150,7 @@ class ColumnClause(_Immutable, ColumnElement): return [] def _bind_param(self, obj): - return _BindParamClause(self.name, obj, type_=self.type, unique=True) + return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True) def _make_proxy(self, selectable, name=None, attach=True): # propagate the "is_literal" flag only if we are keeping our name, @@ -3166,9 +3167,6 @@ class ColumnClause(_Immutable, ColumnElement): selectable.columns[c.name] = c return c - def _compare_type(self, obj): - return self.type - class TableClause(_Immutable, FromClause): """Represents a "table" construct. |