diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-07-31 18:19:04 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-07-31 18:19:04 -0400 |
commit | 6dcf65efa455612133ee667751327947b3a20ecc (patch) | |
tree | 61907d18e057dd3cee37da1792107393153d551c | |
parent | 196775b351c796498f393072cffaaf0d8205e9a3 (diff) | |
download | sqlalchemy-6dcf65efa455612133ee667751327947b3a20ecc.tar.gz |
- start refactoring the workings of index operators
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/hstore.py | 23 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/json.py | 62 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/operators.py | 12 |
4 files changed, 72 insertions, 30 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py index 9f369cb5b..cf9d79ee7 100644 --- a/lib/sqlalchemy/dialects/postgresql/hstore.py +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -10,6 +10,8 @@ import re from .base import ARRAY, ischema_names from ... import types as sqltypes from ...sql import functions as sqlfunc +from ...sql import elements +from ...sql import default_comparator from ...sql.operators import custom_op from ... import util @@ -226,10 +228,9 @@ class HSTORE(sqltypes.Concatenable, sqltypes.TypeEngine): return self.expr.op('<@')(other) def __getitem__(self, other): - """Text expression. Get the value at a given key. Note that the - key may be a SQLA expression. - """ - return self.expr.op('->', precedence=5)(other) + """Get the value at a given key.""" + + return HStoreElement(self.expr, other) def delete(self, key): """HStore expression. Returns the contents of this hstore with the @@ -341,6 +342,20 @@ class hstore(sqlfunc.GenericFunction): name = 'hstore' +class HStoreElement(elements.IndexExpression): + INDEX = custom_op( + "->", precedence=5, natural_self_precedent=True + ) + + def __init__(self, left, right, astext=False): + self._astext = astext + operator = self.INDEX + right = default_comparator._check_literal( + left, operator, right) + super(HStoreElement, self).__init__( + left, right, operator, type_=sqltypes.String()) + + class _HStoreDefinedFunction(sqlfunc.GenericFunction): type = sqltypes.Boolean name = 'defined' diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 3e30b8287..6f4ac4ac9 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -18,7 +18,7 @@ from ... import util __all__ = ('JSON', 'JSONElement', 'JSONB') -class JSONElement(elements.BinaryExpression): +class JSONElement(elements.IndexExpression): """Represents accessing an element of a :class:`.JSON` value. The :class:`.JSONElement` is produced whenever using the Python index @@ -32,20 +32,37 @@ class JSONElement(elements.BinaryExpression): """ - def __init__(self, left, right, astext=False, - opstring=None, result_type=None): - self._astext = astext - if opstring is None: - if hasattr(right, '__iter__') and \ - not isinstance(right, util.string_types): - opstring = "#>" - right = "{%s}" % ( - ", ".join(util.text_type(elem) for elem in right)) - else: - opstring = "->" - - self._json_opstring = opstring - operator = custom_op(opstring, precedence=5) + INDEX = custom_op( + "->", precedence=5, natural_self_precedent=True + ) + ARRAYIDX = custom_op( + "#>", precedence=5, natural_self_precedent=True + ) + ASTEXT = custom_op( + "->>", precedence=5, natural_self_precedent=True + ) + ASTEXT_ARRAYIDX = custom_op( + "#>>", precedence=5, natural_self_precedent=True + ) + + _ASTEXT_OPS = set([ASTEXT, ASTEXT_ARRAYIDX]) + _ARRIDX_OPS = set([ARRAYIDX, ASTEXT_ARRAYIDX]) + + def __init__(self, left, right, operator, result_type=None): + if hasattr(right, '__iter__') and \ + not isinstance(right, util.string_types): + right = "{%s}" % ( + ", ".join(util.text_type(elem) for elem in right)) + + if operator is self.INDEX: + operator = self.ARRAYIDX + elif operator is self.ASTEXT: + operator = self.ASTEXT_ARRAYIDX + + self._json_opstring = operator.opstring + self._astext = operator in self._ASTEXT_OPS + self._isarrayidx = operator in self._ARRIDX_OPS + right = default_comparator._check_literal( left, operator, right) super(JSONElement, self).__init__( @@ -71,8 +88,8 @@ class JSONElement(elements.BinaryExpression): return JSONElement( self.left, self.right, - astext=True, - opstring=self._json_opstring + ">", + self.ASTEXT_ARRAYIDX if self.operator is self.ARRAYIDX + else self.ASTEXT, result_type=sqltypes.String(convert_unicode=True) ) @@ -190,7 +207,9 @@ class JSON(sqltypes.TypeEngine): def __getitem__(self, other): """Get the value at a given key.""" - return JSONElement(self.expr, other) + return JSONElement( + self.expr, other, JSONElement.INDEX, + result_type=self.expr.type) def _adapt_expression(self, op, other_comparator): if isinstance(op, custom_op): @@ -309,14 +328,9 @@ class JSONB(JSON): __visit_name__ = 'JSONB' - class comparator_factory(sqltypes.Concatenable.Comparator): + class comparator_factory(JSON.comparator_factory): """Define comparison operations for :class:`.JSON`.""" - def __getitem__(self, other): - """Get the value at a given key.""" - - return JSONElement(self.expr, other) - def _adapt_expression(self, op, other_comparator): # How does one do equality?? jsonb also has "=" eg. # '[1,2,3]'::jsonb = '[1,2,3]'::jsonb diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index a44c308eb..046905ac7 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -2799,6 +2799,11 @@ class BinaryExpression(ColumnElement): return super(BinaryExpression, self)._negate() +class IndexExpression(BinaryExpression): + """Represent the class of expressions that are like an "index" operation. + """ + + class Grouping(ColumnElement): """Represent a grouping within a column expression""" diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 17a9d3086..a2778c7c4 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -214,10 +214,13 @@ class custom_op(object): """ __name__ = 'custom_op' - def __init__(self, opstring, precedence=0, is_comparison=False): + def __init__( + self, opstring, precedence=0, is_comparison=False, + natural_self_precedent=False): self.opstring = opstring self.precedence = precedence self.is_comparison = is_comparison + self.natural_self_precedent = natural_self_precedent def __eq__(self, other): return isinstance(other, custom_op) and \ @@ -826,6 +829,11 @@ def is_ordering_modifier(op): return op in (asc_op, desc_op, nullsfirst_op, nullslast_op) + +def is_natural_self_precedent(op): + return op in _natural_self_precedent or \ + isinstance(op, custom_op) and op.natural_self_precedent + _associative = _commutative.union([concat_op, and_, or_]) _natural_self_precedent = _associative.union([getitem]) @@ -893,7 +901,7 @@ _PRECEDENCE = { def is_precedent(operator, against): - if operator is against and operator in _natural_self_precedent: + if operator is against and is_natural_self_precedent(operator): return False else: return (_PRECEDENCE.get(operator, |