summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
authorCameron Blocker <cameronjblocker@gmail.com>2020-07-12 03:38:37 -0400
committerCameron Blocker <cameronjblocker@gmail.com>2020-07-12 03:38:37 -0400
commit1cdc9a8f72e41c21c44187701d74133d8ee58412 (patch)
tree4f1d4cfb877fad507195967c35ba41c2691500cb /numpy/lib/index_tricks.py
parent9fd8b2db731d79fabdc40de3a7111381fb4aae5a (diff)
downloadnumpy-1cdc9a8f72e41c21c44187701d74133d8ee58412.tar.gz
BUG: fix mgrid output for lower precision float inputs
Floats besides float64 were being coerced to integers and complex step sizes for the index trick classes would fail for complex64 input. Fixes #16466
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index e86c7a7bb..8a73e35ed 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -154,15 +154,15 @@ class nd_grid:
start = 0
if step is None:
step = 1
- if isinstance(step, complex):
+ if isinstance(step, (_nx.complexfloating, complex)):
size.append(int(abs(step)))
typ = float
else:
size.append(
int(math.ceil((key[k].stop - start)/(step*1.0))))
- if (isinstance(step, float) or
- isinstance(start, float) or
- isinstance(key[k].stop, float)):
+ if (isinstance(step, (_nx.floating, float)) or
+ isinstance(start, (_nx.floating, float)) or
+ isinstance(key[k].stop, (_nx.floating, float))):
typ = float
if self.sparse:
nn = [_nx.arange(_x, dtype=_t)
@@ -176,7 +176,7 @@ class nd_grid:
start = 0
if step is None:
step = 1
- if isinstance(step, complex):
+ if isinstance(step, (_nx.complexfloating, complex)):
step = int(abs(step))
if step != 1:
step = (key[k].stop - start)/float(step-1)
@@ -194,7 +194,7 @@ class nd_grid:
start = key.start
if start is None:
start = 0
- if isinstance(step, complex):
+ if isinstance(step, (_nx.complexfloating, complex)):
step = abs(step)
length = int(step)
if step != 1:
@@ -344,7 +344,7 @@ class AxisConcatenator:
start = 0
if step is None:
step = 1
- if isinstance(step, complex):
+ if isinstance(step, (_nx.complexfloating, complex)):
size = int(abs(step))
newobj = linspace(start, stop, num=size)
else: