diff options
Diffstat (limited to 'numpy/lib')
| -rw-r--r-- | numpy/lib/index_tricks.py | 14 | ||||
| -rw-r--r-- | numpy/lib/tests/test_index_tricks.py | 27 |
2 files changed, 34 insertions, 7 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index cba713ede..2833e1072 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -154,15 +154,15 @@ class nd_grid: start = 0 if step is None: step = 1 - if isinstance(step, complex): + if isinstance(step, (_nx.complexfloating, complex)): size.append(int(abs(step))) typ = float else: size.append( int(math.ceil((key[k].stop - start)/(step*1.0)))) - if (isinstance(step, float) or - isinstance(start, float) or - isinstance(key[k].stop, float)): + if (isinstance(step, (_nx.floating, float)) or + isinstance(start, (_nx.floating, float)) or + isinstance(key[k].stop, (_nx.floating, float))): typ = float if self.sparse: nn = [_nx.arange(_x, dtype=_t) @@ -176,7 +176,7 @@ class nd_grid: start = 0 if step is None: step = 1 - if isinstance(step, complex): + if isinstance(step, (_nx.complexfloating, complex)): step = int(abs(step)) if step != 1: step = (key[k].stop - start)/float(step-1) @@ -194,7 +194,7 @@ class nd_grid: start = key.start if start is None: start = 0 - if isinstance(step, complex): + if isinstance(step, (_nx.complexfloating, complex)): step = abs(step) length = int(step) if step != 1: @@ -344,7 +344,7 @@ class AxisConcatenator: start = 0 if step is None: step = 1 - if isinstance(step, complex): + if isinstance(step, (_nx.complexfloating, complex)): size = int(abs(step)) newobj = linspace(start, stop, num=size) else: diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index 905165a99..843e27cef 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -249,6 +249,29 @@ class TestGrid: assert_equal(grid.size, expected[0]) assert_equal(grid_small.size, expected[1]) + def test_accepts_npfloating(self): + # regression test for #16466 + grid64 = mgrid[0.1:0.33:0.1, ] + grid32 = mgrid[np.float32(0.1):np.float32(0.33):np.float32(0.1), ] + assert_(grid32.dtype == np.float64) + assert_array_almost_equal(grid64, grid32) + + # different code path for single slice + grid64 = mgrid[0.1:0.33:0.1] + grid32 = mgrid[np.float32(0.1):np.float32(0.33):np.float32(0.1)] + assert_(grid32.dtype == np.float64) + assert_array_almost_equal(grid64, grid32) + + def test_accepts_npcomplexfloating(self): + # Related to #16466 + assert_array_almost_equal( + mgrid[0.1:0.3:3j, ], mgrid[0.1:0.3:np.complex64(3j), ] + ) + + # different code path for single slice + assert_array_almost_equal( + mgrid[0.1:0.3:3j], mgrid[0.1:0.3:np.complex64(3j)] + ) class TestConcatenator: def test_1d(self): @@ -270,6 +293,10 @@ class TestConcatenator: g = r_[0:36:100j] assert_(g.shape == (100,)) + # Related to #16466 + g = r_[0:36:np.complex64(100j)] + assert_(g.shape == (100,)) + def test_2d(self): b = np.random.rand(5, 5) c = np.random.rand(5, 5) |
