summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2022-06-09 12:34:35 -0600
committerGitHub <noreply@github.com>2022-06-09 12:34:35 -0600
commit09dc8bcb8b729deb19e40fe9e9361690dc0113ec (patch)
tree34ead172fb598b43167d6596d0e503b1e5b282ae /numpy
parent5447192c8dc5d43d837b93ed5c5d47457e367669 (diff)
parent34ba7d7db0a43060f430e84d072f114faad80993 (diff)
downloadnumpy-09dc8bcb8b729deb19e40fe9e9361690dc0113ec.tar.gz
Merge pull request #16971 from BvB93/nd_grid
BUG: Fix three `complex`- & `float128`-related issues with `nd_grid`
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/index_tricks.py37
-rw-r--r--numpy/lib/tests/test_index_tricks.py34
2 files changed, 54 insertions, 17 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index b69226d48..4f414925d 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -148,24 +148,25 @@ class nd_grid:
def __getitem__(self, key):
try:
size = []
- typ = int
- for kk in key:
- step = kk.step
- start = kk.start
+ # Mimic the behavior of `np.arange` and use a data type
+ # which is at least as large as `np.int_`
+ num_list = [0]
+ for k in range(len(key)):
+ step = key[k].step
+ start = key[k].start
+ stop = key[k].stop
if start is None:
start = 0
if step is None:
step = 1
if isinstance(step, (_nx.complexfloating, complex)):
- size.append(int(abs(step)))
- typ = float
+ step = abs(step)
+ size.append(int(step))
else:
size.append(
- int(math.ceil((kk.stop - start) / (step * 1.0))))
- if (isinstance(step, (_nx.floating, float)) or
- isinstance(start, (_nx.floating, float)) or
- isinstance(kk.stop, (_nx.floating, float))):
- typ = float
+ int(math.ceil((stop - start) / (step*1.0))))
+ num_list += [start, stop, step]
+ typ = _nx.result_type(*num_list)
if self.sparse:
nn = [_nx.arange(_x, dtype=_t)
for _x, _t in zip(size, (typ,)*len(size))]
@@ -197,11 +198,13 @@ class nd_grid:
if start is None:
start = 0
if isinstance(step, (_nx.complexfloating, complex)):
- step = abs(step)
- length = int(step)
+ # Prevent the (potential) creation of integer arrays
+ step_float = abs(step)
+ step = length = int(step_float)
if step != 1:
step = (key.stop-start)/float(step-1)
- return _nx.arange(0, length, 1, float)*step + start
+ typ = _nx.result_type(start, stop, step_float)
+ return _nx.arange(0, length, 1, dtype=typ)*step + start
else:
return _nx.arange(start, stop, step)
@@ -621,7 +624,7 @@ class ndindex:
Parameters
----------
shape : ints, or a single tuple of ints
- The size of each dimension of the array can be passed as
+ The size of each dimension of the array can be passed as
individual parameters or as the elements of a tuple.
See Also
@@ -631,7 +634,7 @@ class ndindex:
Examples
--------
Dimensions as individual arguments
-
+
>>> for index in np.ndindex(3, 2, 1):
... print(index)
(0, 0, 0)
@@ -642,7 +645,7 @@ class ndindex:
(2, 1, 0)
Same dimensions - but in a tuple ``(3, 2, 1)``
-
+
>>> for index in np.ndindex((3, 2, 1)):
... print(index)
(0, 0, 0)
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index 26a34be7e..b599cb345 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -254,6 +254,28 @@ class TestGrid:
assert_(grid32.dtype == np.float64)
assert_array_almost_equal(grid64, grid32)
+ def test_accepts_longdouble(self):
+ # regression tests for #16945
+ grid64 = mgrid[0.1:0.33:0.1, ]
+ grid128 = mgrid[
+ np.longdouble(0.1):np.longdouble(0.33):np.longdouble(0.1),
+ ]
+ assert_(grid128.dtype == np.longdouble)
+ assert_array_almost_equal(grid64, grid128)
+
+ grid128c_a = mgrid[0:np.longdouble(1):3.4j]
+ grid128c_b = mgrid[0:np.longdouble(1):3.4j, ]
+ assert_(grid128c_a.dtype == grid128c_b.dtype == np.longdouble)
+ assert_array_equal(grid128c_a, grid128c_b[0])
+
+ # different code path for single slice
+ grid64 = mgrid[0.1:0.33:0.1]
+ grid128 = mgrid[
+ np.longdouble(0.1):np.longdouble(0.33):np.longdouble(0.1)
+ ]
+ assert_(grid128.dtype == np.longdouble)
+ assert_array_almost_equal(grid64, grid128)
+
def test_accepts_npcomplexfloating(self):
# Related to #16466
assert_array_almost_equal(
@@ -265,6 +287,18 @@ class TestGrid:
mgrid[0.1:0.3:3j], mgrid[0.1:0.3:np.complex64(3j)]
)
+ # Related to #16945
+ grid64_a = mgrid[0.1:0.3:3.3j]
+ grid64_b = mgrid[0.1:0.3:3.3j, ][0]
+ assert_(grid64_a.dtype == grid64_b.dtype == np.float64)
+ assert_array_equal(grid64_a, grid64_b)
+
+ grid128_a = mgrid[0.1:0.3:np.clongdouble(3.3j)]
+ grid128_b = mgrid[0.1:0.3:np.clongdouble(3.3j), ][0]
+ assert_(grid128_a.dtype == grid128_b.dtype == np.longdouble)
+ assert_array_equal(grid64_a, grid64_b)
+
+
class TestConcatenator:
def test_1d(self):
assert_array_equal(r_[1, 2, 3, 4, 5, 6], np.array([1, 2, 3, 4, 5, 6]))