diff options
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/array.py | 19 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/default_comparator.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 1 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_types.py | 21 |
4 files changed, 32 insertions, 12 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index b88f139de..8492f31a8 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -84,12 +84,19 @@ class array(expression.Tuple): super(array, self).__init__(*clauses, **kw) self.type = ARRAY(self.type) - def _bind_param(self, operator, obj): - return array([ - expression.BindParameter(None, o, _compared_to_operator=operator, - _compared_to_type=self.type, unique=True) - for o in obj - ]) + def _bind_param(self, operator, obj, _assume_scalar=False): + if _assume_scalar or operator is operators.getitem: + # if getitem->slice were called, Indexable produces + # a Slice object from that + assert isinstance(obj, int) + return expression.BindParameter( + None, obj, _compared_to_operator=operator, + _compared_to_type=self.type, unique=True) + + else: + return array([ + self._bind_param(operator, o, _assume_scalar=True) + for o in obj]) def self_group(self, against=None): if (against in ( diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 5504699bc..8b178d5ca 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -164,8 +164,7 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): def _getitem_impl(expr, op, other, **kw): if isinstance(expr.type, type_api.INDEXABLE): - other = _literal_as_binds( - other, name=expr.key, type_=type_api.INTEGERTYPE) + other = _check_literal(expr, op, other) return _binary_operate(expr, op, other, **kw) else: _unsupported_impl(expr, op, other, **kw) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index ca354a58a..552e23285 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1607,7 +1607,6 @@ class Array(Indexable, Concatenable, TypeEngine): index.stop + 1, index.step ) - index = Slice( _literal_as_binds( index.start, name=self.expr.key, diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 408dedad4..1ba73eea6 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -892,7 +892,7 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): type_=postgresql.ARRAY(Integer) )[3], "(array_cat(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s], " - "ARRAY[%(param_4)s, %(param_5)s, %(param_6)s]))[%(param_7)s]" + "ARRAY[%(param_4)s, %(param_5)s, %(param_6)s]))[%(array_cat_1)s]" ) def test_array_agg_generic(self): @@ -1811,7 +1811,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): def test_where_getitem(self): self._test_where( self.hashcol['bar'] == None, - "(test_table.hash -> %(hash_1)s) IS NULL" + "test_table.hash -> %(hash_1)s IS NULL" ) def test_cols_get(self): @@ -1864,7 +1864,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): hstore(postgresql.array(['1', '2']), postgresql.array(['3', None]))['1'], ("hstore(ARRAY[%(param_1)s, %(param_2)s], " - "ARRAY[%(param_3)s, NULL]) -> %(param_4)s AS anon_1"), + "ARRAY[%(param_3)s, NULL]) -> %(hstore_1)s AS anon_1"), False ) @@ -1980,6 +1980,21 @@ class HStoreRoundTripTest(fixtures.TablesTest): cols = insp.get_columns('data_table') assert isinstance(cols[2]['type'], HSTORE) + def test_literal_round_trip(self): + # in particular, this tests that the array index + # operator against the function is handled by PG; with some + # array functions it requires outer parenthezisation on the left and + # we may not be doing that here + expr = hstore( + postgresql.array(['1', '2']), + postgresql.array(['3', None]))['1'] + eq_( + testing.db.scalar( + select([expr]) + ), + "3" + ) + @testing.requires.psycopg2_native_hstore def test_insert_native(self): engine = testing.db |
