summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-08-16 16:11:42 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2012-08-16 16:11:42 -0400
commitf327eaea478670198fbaa5b16047be73e9dd6aba (patch)
tree5627656802b9c67b0f0c1ba6bad26a058bca3e18
parent079f5a282b4b37ff2a2a7172aa289eff49509f17 (diff)
downloadsqlalchemy-f327eaea478670198fbaa5b16047be73e9dd6aba.tar.gz
_adapt_expression() moves fully to _DefaultColumnComparator which resumes
its original role as stateful, forms the basis of TypeEngine.Comparator. lots of code goes back mostly as it was just with cleaner typing behavior, such as simple flow in _binary_operate now.
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py7
-rw-r--r--lib/sqlalchemy/sql/expression.py81
-rw-r--r--lib/sqlalchemy/types.py143
-rw-r--r--test/sql/test_operators.py58
-rw-r--r--test/sql/test_types.py1
5 files changed, 117 insertions, 173 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index a9ff988e8..36da14d33 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -708,9 +708,10 @@ class PGCompiler(compiler.SQLCompiler):
affinity = None
casts = {
- sqltypes.Date:'date',
- sqltypes.DateTime:'timestamp',
- sqltypes.Interval:'interval', sqltypes.Time:'time'
+ sqltypes.Date: 'date',
+ sqltypes.DateTime: 'timestamp',
+ sqltypes.Interval: 'interval',
+ sqltypes.Time: 'time'
}
cast = casts.get(affinity, None)
if isinstance(extract.expr, sql.ColumnElement) and cast is not None:
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 844293c73..63fa23c15 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -1875,7 +1875,7 @@ class Immutable(object):
return self
-class _DefaultColumnComparator(object):
+class _DefaultColumnComparator(operators.ColumnOperators):
"""Defines comparison and math operations.
See :class:`.ColumnOperators` and :class:`.Operators` for descriptions
@@ -1883,6 +1883,45 @@ class _DefaultColumnComparator(object):
"""
+ @util.memoized_property
+ def type(self):
+ return self.expr.type
+
+ def operate(self, op, *other, **kwargs):
+ o = self.operators[op.__name__]
+ return o[0](self, self.expr, op, *(other + o[1:]), **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ o = self.operators[op.__name__]
+ return o[0](self, self.expr, op, other, reverse=True, *o[1:], **kwargs)
+
+ def _adapt_expression(self, op, other_comparator):
+ """evaluate the return type of <self> <op> <othertype>,
+ and apply any adaptations to the given operator.
+
+ This method determines the type of a resulting binary expression
+ given two source types and an operator. For example, two
+ :class:`.Column` objects, both of the type :class:`.Integer`, will
+ produce a :class:`.BinaryExpression` that also has the type
+ :class:`.Integer` when compared via the addition (``+``) operator.
+ However, using the addition operator with an :class:`.Integer`
+ and a :class:`.Date` object will produce a :class:`.Date`, assuming
+ "days delta" behavior by the database (in reality, most databases
+ other than Postgresql don't accept this particular operation).
+
+ The method returns a tuple of the form <operator>, <type>.
+ The resulting operator and type will be those applied to the
+ resulting :class:`.BinaryExpression` as the final operator and the
+ right-hand side of the expression.
+
+ Note that only a subset of operators make usage of
+ :meth:`._adapt_expression`,
+ including math operators and user-defined operators, but not
+ boolean comparison or special SQL keywords like MATCH or BETWEEN.
+
+ """
+ return op, other_comparator.type
+
def _boolean_compare(self, expr, op, obj, negate=None, reverse=False,
**kwargs
):
@@ -1912,7 +1951,7 @@ class _DefaultColumnComparator(object):
type_=sqltypes.BOOLEANTYPE,
negate=negate, modifiers=kwargs)
- def _binary_operate(self, expr, op, obj, result_type, reverse=False):
+ def _binary_operate(self, expr, op, obj, reverse=False):
obj = self._check_literal(expr, op, obj)
if reverse:
@@ -1920,6 +1959,8 @@ class _DefaultColumnComparator(object):
else:
left, right = expr, obj
+ op, result_type = left.comparator._adapt_expression(op, right.comparator)
+
return BinaryExpression(left, right, op, type_=result_type)
def _scalar(self, expr, op, fn, **kw):
@@ -1986,7 +2027,8 @@ class _DefaultColumnComparator(object):
expr,
operators.like_op,
literal_column("'%'", type_=sqltypes.String).__radd__(
- self._check_literal(expr, operators.like_op, other)
+ self._check_literal(expr,
+ operators.like_op, other)
),
escape=escape)
@@ -2068,21 +2110,16 @@ class _DefaultColumnComparator(object):
"neg": (_neg_impl,),
}
- def operate(self, expr, op, *other, **kwargs):
- o = self.operators[op.__name__]
- return o[0](self, expr, op, *(other + o[1:]), **kwargs)
-
- def reverse_operate(self, expr, op, other, **kwargs):
- o = self.operators[op.__name__]
- return o[0](self, expr, op, other, reverse=True, *o[1:], **kwargs)
def _check_literal(self, expr, operator, other):
- if isinstance(other, BindParameter) and \
- isinstance(other.type, sqltypes.NullType):
- # TODO: perhaps we should not mutate the incoming bindparam()
- # here and instead make a copy of it. this might
- # be the only place that we're mutating an incoming construct.
- other.type = expr.type
+ if isinstance(other, (ColumnElement, TextClause)):
+ if isinstance(other, BindParameter) and \
+ isinstance(other.type, sqltypes.NullType):
+ # TODO: perhaps we should not mutate the incoming
+ # bindparam() here and instead make a copy of it.
+ # this might be the only place that we're mutating
+ # an incoming construct.
+ other.type = expr.type
return other
elif hasattr(other, '__clause_element__'):
other = other.__clause_element__()
@@ -2096,8 +2133,6 @@ class _DefaultColumnComparator(object):
else:
return other
-_DEFAULT_COMPARATOR = _DefaultColumnComparator()
-
class ColumnElement(ClauseElement, ColumnOperators):
"""Represent an element that is usable within the "column clause" portion
@@ -2155,11 +2190,7 @@ class ColumnElement(ClauseElement, ColumnOperators):
def comparator(self):
return self.type.comparator_factory(self)
- #def _assert_comparator(self):
- # assert self.comparator.expr is self
-
def __getattr__(self, key):
- #self._assert_comparator()
try:
return getattr(self.comparator, key)
except AttributeError:
@@ -2171,11 +2202,9 @@ class ColumnElement(ClauseElement, ColumnOperators):
)
def operate(self, op, *other, **kwargs):
- #self._assert_comparator()
return op(self.comparator, *other, **kwargs)
def reverse_operate(self, op, other, **kwargs):
- #self._assert_comparator()
return op(other, self.comparator, **kwargs)
def _bind_param(self, operator, obj):
@@ -3090,6 +3119,10 @@ class TextClause(Executable, ClauseElement):
else:
return sqltypes.NULLTYPE
+ @property
+ def comparator(self):
+ return self.type.comparator_factory(self)
+
def self_group(self, against=None):
if against is operators.in_op:
return Grouping(self)
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index d4dbd648c..bbeebf5d3 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -11,21 +11,21 @@ types.
For more information see the SQLAlchemy documentation on types.
"""
-__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType',
- 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR','TEXT', 'Text',
+__all__ = ['TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType',
+ 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR', 'TEXT', 'Text',
'FLOAT', 'NUMERIC', 'REAL', 'DECIMAL', 'TIMESTAMP', 'DATETIME',
'CLOB', 'BLOB', 'BINARY', 'VARBINARY', 'BOOLEAN', 'BIGINT', 'SMALLINT',
'INTEGER', 'DATE', 'TIME', 'String', 'Integer', 'SmallInteger',
'BigInteger', 'Numeric', 'Float', 'DateTime', 'Date', 'Time',
'LargeBinary', 'Binary', 'Boolean', 'Unicode', 'Concatenable',
- 'UnicodeText','PickleType', 'Interval', 'Enum' ]
+ 'UnicodeText', 'PickleType', 'Interval', 'Enum']
import datetime as dt
import codecs
from . import exc, schema, util, processors, events, event
from .sql import operators
-from .sql.expression import _DEFAULT_COMPARATOR
+from .sql.expression import _DefaultColumnComparator
from .util import pickle
from .util.compat import decimal
from .sql.visitors import Visitable
@@ -42,7 +42,7 @@ class AbstractType(Visitable):
class TypeEngine(AbstractType):
"""Base for built-in types."""
- class Comparator(operators.ColumnOperators):
+ class Comparator(_DefaultColumnComparator):
"""Base class for custom comparison operations defined at the
type level. See :attr:`.TypeEngine.comparator_factory`.
@@ -54,24 +54,6 @@ class TypeEngine(AbstractType):
def __reduce__(self):
return _reconstitute_comparator, (self.expr, )
- def operate(self, op, *other, **kwargs):
- if len(other) == 1:
- obj = other[0]
- obj = _DEFAULT_COMPARATOR._check_literal(self.expr, op, obj)
- op, adapt_type = self.expr.type._adapt_expression(op,
- obj.type)
- kwargs['result_type'] = adapt_type
-
- return _DEFAULT_COMPARATOR.operate(self.expr, op, *other, **kwargs)
-
- def reverse_operate(self, op, other, **kwargs):
-
- obj = _DEFAULT_COMPARATOR._check_literal(self.expr, op, other)
- op, adapt_type = obj.type._adapt_expression(op, self.expr.type)
- kwargs['result_type'] = adapt_type
-
- return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, obj,
- **kwargs)
comparator_factory = Comparator
"""A :class:`.TypeEngine.Comparator` class which will apply
@@ -143,11 +125,6 @@ class TypeEngine(AbstractType):
>>> (c1 == c2).type
Boolean()
- The propagation of :class:`.TypeEngine.Comparator` throughout an expression
- will follow with how the :class:`.TypeEngine` itself is propagated. To
- customize the behavior of most operators in this regard, see the
- :meth:`._adapt_expression` method.
-
.. versionadded:: 0.8 The expression system was reworked to support
user-defined comparator objects specified at the type level.
@@ -247,34 +224,7 @@ class TypeEngine(AbstractType):
.. versionadded:: 0.7.2
"""
- return Variant(self, {dialect_name:type_})
-
- def _adapt_expression(self, op, othertype):
- """evaluate the return type of <self> <op> <othertype>,
- and apply any adaptations to the given operator.
-
- This method determines the type of a resulting binary expression
- given two source types and an operator. For example, two
- :class:`.Column` objects, both of the type :class:`.Integer`, will
- produce a :class:`.BinaryExpression` that also has the type
- :class:`.Integer` when compared via the addition (``+``) operator.
- However, using the addition operator with an :class:`.Integer`
- and a :class:`.Date` object will produce a :class:`.Date`, assuming
- "days delta" behavior by the database (in reality, most databases
- other than Postgresql don't accept this particular operation).
-
- The method returns a tuple of the form <operator>, <type>.
- The resulting operator and type will be those applied to the
- resulting :class:`.BinaryExpression` as the final operator and the
- right-hand side of the expression.
-
- Note that only a subset of operators make usage of
- :meth:`._adapt_expression`,
- including math operators and user-defined operators, but not
- boolean comparison or special SQL keywords like MATCH or BETWEEN.
-
- """
- return op, self
+ return Variant(self, {dialect_name: type_})
@util.memoized_property
def _type_affinity(self):
@@ -334,7 +284,7 @@ class TypeEngine(AbstractType):
impl = self.adapt(type(self))
# this can't be self, else we create a cycle
assert impl is not self
- dialect._type_memos[self] = d = {'impl':impl}
+ dialect._type_memos[self] = d = {'impl': impl}
return d
def _gen_dialect_impl(self, dialect):
@@ -461,22 +411,21 @@ class UserDefinedType(TypeEngine):
"""
__visit_name__ = "user_defined"
- def _adapt_expression(self, op, othertype):
- """evaluate the return type of <self> <op> <othertype>,
- and apply any adaptations to the given operator.
-
- """
- return self.adapt_operator(op), self
-
- def adapt_operator(self, op):
- """A hook which allows the given operator to be adapted
- to something new.
+ class Comparator(TypeEngine.Comparator):
+ def _adapt_expression(self, op, other_comparator):
+ if hasattr(self.type, 'adapt_operator'):
+ util.warn_deprecated(
+ "UserDefinedType.adapt_operator is deprecated. Create "
+ "a UserDefinedType.Comparator subclass instead which "
+ "generates the desired expression constructs, given a "
+ "particular operator."
+ )
+ return self.type.adapt_operator(op), self.type
+ else:
+ return op, self.type
- See also UserDefinedType._adapt_expression(), an as-yet-
- semi-public method with greater capability in this regard.
+ comparator_factory = Comparator
- """
- return op
class TypeDecorator(TypeEngine):
"""Allows the creation of types which add additional functionality
@@ -837,13 +786,6 @@ class TypeDecorator(TypeEngine):
"""
return self.impl.compare_values(x, y)
- def _adapt_expression(self, op, othertype):
- op, typ = self.impl._adapt_expression(op, othertype)
- typ = to_instance(typ)
- if typ._compare_type_affinity(self.impl):
- return op, self
- else:
- return op, typ
class Variant(TypeDecorator):
"""A wrapping type that selects among a variety of
@@ -926,8 +868,6 @@ def adapt_type(typeobj, colspecs):
return typeobj.adapt(impltype)
-
-
class NullType(TypeEngine):
"""An unknown type.
@@ -943,11 +883,14 @@ class NullType(TypeEngine):
"""
__visit_name__ = 'null'
- def _adapt_expression(self, op, othertype):
- if isinstance(othertype, NullType) or not operators.is_commutative(op):
- return op, self
- else:
- return othertype._adapt_expression(op, self)
+ class Comparator(TypeEngine.Comparator):
+ def _adapt_expression(self, op, other_comparator):
+ if isinstance(other_comparator, NullType.Comparator) or \
+ not operators.is_commutative(op):
+ return op, self.expr.type
+ else:
+ return other_comparator._adapt_expression(op, self)
+ comparator_factory = Comparator
NullTypeEngine = NullType
@@ -955,12 +898,16 @@ class Concatenable(object):
"""A mixin that marks a type as supporting 'concatenation',
typically strings."""
- def _adapt_expression(self, op, othertype):
- if op is operators.add and issubclass(othertype._type_affinity,
- (Concatenable, NullType)):
- return operators.concat_op, self
- else:
- return op, self
+ class Comparator(TypeEngine.Comparator):
+ def _adapt_expression(self, op, other_comparator):
+ if op is operators.add and isinstance(other_comparator,
+ (Concatenable.Comparator, NullType.Comparator)):
+ return operators.concat_op, self.expr.type
+ else:
+ return op, self.expr.type
+
+ comparator_factory = Comparator
+
class _DateAffinity(object):
"""Mixin date/time specific expression adaptations.
@@ -975,12 +922,14 @@ class _DateAffinity(object):
def _expression_adaptations(self):
raise NotImplementedError()
- _blank_dict = util.immutabledict()
- def _adapt_expression(self, op, othertype):
- othertype = othertype._type_affinity
- return op, \
- self._expression_adaptations.get(op, self._blank_dict).\
- get(othertype, NULLTYPE)
+ class Comparator(TypeEngine.Comparator):
+ _blank_dict = util.immutabledict()
+ def _adapt_expression(self, op, other_comparator):
+ othertype = other_comparator.type._type_affinity
+ return op, \
+ self.type._expression_adaptations.get(op, self._blank_dict).\
+ get(othertype, NULLTYPE)
+ comparator_factory = Comparator
class String(Concatenable, TypeEngine):
"""The base for all string and character types.
diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py
index c38f95a01..05de8c9ef 100644
--- a/test/sql/test_operators.py
+++ b/test/sql/test_operators.py
@@ -12,18 +12,16 @@ from sqlalchemy.types import Integer, TypeEngine, TypeDecorator
class DefaultColumnComparatorTest(fixtures.TestBase):
def _do_scalar_test(self, operator, compare_to):
- cc = _DefaultColumnComparator()
left = column('left')
- assert cc.operate(left, operator).compare(
+ assert left.comparator.operate(operator).compare(
compare_to(left)
)
def _do_operate_test(self, operator):
- cc = _DefaultColumnComparator()
left = column('left')
right = column('right')
- assert cc.operate(left, operator, right, result_type=Integer).compare(
+ assert left.comparator.operate(operator, right).compare(
BinaryExpression(left, right, operator)
)
@@ -37,9 +35,8 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
self._do_operate_test(operators.add)
def test_in(self):
- cc = _DefaultColumnComparator()
left = column('left')
- assert cc.operate(left, operators.in_op, [1, 2, 3]).compare(
+ assert left.comparator.operate(operators.in_op, [1, 2, 3]).compare(
BinaryExpression(
left,
Grouping(ClauseList(
@@ -50,10 +47,9 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
)
def test_collate(self):
- cc = _DefaultColumnComparator()
left = column('left')
right = "some collation"
- cc.operate(left, operators.collate, right).compare(
+ left.comparator.operate(operators.collate, right).compare(
collate(left, right)
)
@@ -144,12 +140,8 @@ class _CustomComparatorTests(object):
self._assert_add_override(6 - c1)
def test_binary_multi_propagate(self):
- c1 = Column('foo', self._add_override_factory(True))
- self._assert_add_override((c1 - 6) + 5)
-
- def test_no_binary_multi_propagate_wo_adapt(self):
c1 = Column('foo', self._add_override_factory())
- self._assert_not_add_override((c1 - 6) + 5)
+ self._assert_add_override((c1 - 6) + 5)
def test_no_boolean_propagate(self):
c1 = Column('foo', self._add_override_factory())
@@ -166,7 +158,7 @@ class _CustomComparatorTests(object):
)
class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase):
- def _add_override_factory(self, include_adapt=False):
+ def _add_override_factory(self):
class MyInteger(Integer):
class comparator_factory(TypeEngine.Comparator):
@@ -176,19 +168,12 @@ class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase):
def __add__(self, other):
return self.expr.op("goofy")(other)
- if include_adapt:
- def _adapt_expression(self, op, othertype):
- if op.__name__ == 'custom_op':
- return op, self
- else:
- return super(MyInteger, self)._adapt_expression(
- op, othertype)
return MyInteger
class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase):
- def _add_override_factory(self, include_adapt=False):
+ def _add_override_factory(self):
class MyInteger(TypeDecorator):
impl = Integer
@@ -200,19 +185,12 @@ class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase):
def __add__(self, other):
return self.expr.op("goofy")(other)
- if include_adapt:
- def _adapt_expression(self, op, othertype):
- if op.__name__ == 'custom_op':
- return op, self
- else:
- return super(MyInteger, self)._adapt_expression(
- op, othertype)
return MyInteger
class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBase):
- def _add_override_factory(self, include_adapt=False):
+ def _add_override_factory(self):
class MyInteger(Integer):
class comparator_factory(TypeEngine.Comparator):
def __init__(self, expr):
@@ -221,13 +199,6 @@ class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBas
def __add__(self, other):
return self.expr.op("goofy")(other)
- if include_adapt:
- def _adapt_expression(self, op, othertype):
- if op.__name__ == 'custom_op':
- return op, self
- else:
- return super(MyInteger, self)._adapt_expression(
- op, othertype)
class MyDecInteger(TypeDecorator):
impl = MyInteger
@@ -235,7 +206,7 @@ class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBas
return MyDecInteger
class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase):
- def _add_override_factory(self, include_adapt=False):
+ def _add_override_factory(self):
class MyInteger(Integer):
class comparator_factory(TypeEngine.Comparator):
def __init__(self, expr):
@@ -243,15 +214,6 @@ class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase):
def foob(self, other):
return self.expr.op("foob")(other)
-
- if include_adapt:
- def _adapt_expression(self, op, othertype):
- if op.__name__ == 'custom_op':
- return op, self
- else:
- return super(MyInteger, self)._adapt_expression(
- op, othertype)
-
return MyInteger
def _assert_add_override(self, expr):
@@ -262,5 +224,3 @@ class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase):
def _assert_not_add_override(self, expr):
assert not hasattr(expr, "foob")
- def test_no_binary_multi_propagate_wo_adapt(self):
- pass \ No newline at end of file
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 91bf17175..279ae36a0 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -1222,6 +1222,7 @@ class ExpressionTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiled
eq_(expr.right.type.__class__, CHAR)
+ @testing.uses_deprecated
@testing.fails_on('firebird', 'Data type unknown on the parameter')
@testing.fails_on('mssql', 'int is unsigned ? not clear')
def test_operator_adapt(self):