diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2015-05-05 09:46:22 -0400 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2015-05-05 09:46:22 -0400 |
commit | 09e38f9e2f3a81a9a4c98f36780e791f60c665b1 (patch) | |
tree | cbf65a059d4e923201a20cca0780b63ba7f215d2 | |
parent | 7b0b7cc601d1839657aca03a01056101beb64a68 (diff) | |
parent | c01165f43068fea96722c172eb23efed4ca99763 (diff) | |
download | numpy-09e38f9e2f3a81a9a4c98f36780e791f60c665b1.tar.gz |
Merge pull request #5805 from jaimefrio/ix_intp
Ix intp
-rw-r--r-- | numpy/lib/index_tricks.py | 20 | ||||
-rw-r--r-- | numpy/lib/tests/test_index_tricks.py | 30 |
2 files changed, 39 insertions, 11 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index e97338106..752407f18 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -7,7 +7,7 @@ import numpy.core.numeric as _nx from numpy.core.numeric import ( asarray, ScalarType, array, alltrue, cumprod, arange ) -from numpy.core.numerictypes import find_common_type +from numpy.core.numerictypes import find_common_type, issubdtype from . import function_base import numpy.matrixlib as matrix @@ -71,17 +71,17 @@ def ix_(*args): """ out = [] nd = len(args) - baseshape = [1]*nd - for k in range(nd): - new = _nx.asarray(args[k]) - if (new.ndim != 1): + for k, new in enumerate(args): + new = asarray(new) + if new.ndim != 1: raise ValueError("Cross index must be 1 dimensional") - if issubclass(new.dtype.type, _nx.bool_): - new = new.nonzero()[0] - baseshape[k] = len(new) - new = new.reshape(tuple(baseshape)) + if new.size == 0: + # Explicitly type empty arrays to avoid float default + new = new.astype(_nx.intp) + if issubdtype(new.dtype, _nx.bool_): + new, = new.nonzero() + new.shape = (1,)*k + (new.size,) + (1,)*(nd-k-1) out.append(new) - baseshape[k] = 1 return tuple(out) class nd_grid(object): diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index 97047c53a..0e3c98ee1 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -7,7 +7,7 @@ from numpy.testing import ( ) from numpy.lib.index_tricks import ( mgrid, ndenumerate, fill_diagonal, diag_indices, diag_indices_from, - index_exp, ndindex, r_, s_ + index_exp, ndindex, r_, s_, ix_ ) @@ -169,6 +169,34 @@ class TestIndexExpression(TestCase): assert_equal(a[:, :3, [1, 2]], a[s_[:, :3, [1, 2]]]) +class TestIx_(TestCase): + def test_regression_1(self): + # Test empty inputs create ouputs of indexing type, gh-5804 + # Test both lists and arrays + for func in (range, np.arange): + a, = np.ix_(func(0)) + assert_equal(a.dtype, np.intp) + + def test_shape_and_dtype(self): + sizes = (4, 5, 3, 2) + # Test both lists and arrays + for func in (range, np.arange): + arrays = np.ix_(*[func(sz) for sz in sizes]) + for k, (a, sz) in enumerate(zip(arrays, sizes)): + assert_equal(a.shape[k], sz) + assert_(all(sh == 1 for j, sh in enumerate(a.shape) if j != k)) + assert_(np.issubdtype(a.dtype, int)) + + def test_bool(self): + bool_a = [True, False, True, True] + int_a, = np.nonzero(bool_a) + assert_equal(np.ix_(bool_a)[0], int_a) + + def test_1d_only(self): + idx2d = [[1, 2, 3], [4, 5, 6]] + assert_raises(ValueError, np.ix_, idx2d) + + def test_c_(): a = np.c_[np.array([[1, 2, 3]]), 0, 0, np.array([[4, 5, 6]])] assert_equal(a, [[1, 2, 3, 0, 0, 4, 5, 6]]) |