summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2017-05-24 14:11:34 -0400
committerGitHub <noreply@github.com>2017-05-24 14:11:34 -0400
commit4f4b0a43c2c5c1441a507a214045a50a1bfc03ed (patch)
tree830ec571d6c0fbe8b1d9c3cdba72133e5d498514
parentd0c15fab8e3c0324914baede7496810485eb4e56 (diff)
parent3c0b4db0b51ec4f4b67f63d647d8f8a668b7e754 (diff)
downloadnumpy-4f4b0a43c2c5c1441a507a214045a50a1bfc03ed.tar.gz
Merge pull request #9164 from pitrou/as_strided_custom_dtype
BUG: have as_strided() keep custom dtypes
-rw-r--r--numpy/lib/stride_tricks.py7
-rw-r--r--numpy/lib/tests/test_stride_tricks.py8
2 files changed, 11 insertions, 4 deletions
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py
index 545623c38..6c240db7f 100644
--- a/numpy/lib/stride_tricks.py
+++ b/numpy/lib/stride_tricks.py
@@ -100,10 +100,9 @@ def as_strided(x, shape=None, strides=None, subok=False, writeable=True):
interface['strides'] = tuple(strides)
array = np.asarray(DummyArray(interface, base=x))
-
- if array.dtype.fields is None and x.dtype.fields is not None:
- # This should only happen if x.dtype is [('', 'Vx')]
- array.dtype = x.dtype
+ # The route via `__interface__` does not preserve structured
+ # dtypes. Since dtype should remain unchanged, we set it explicitly.
+ array.dtype = x.dtype
view = _maybe_view_as_subclass(x, array)
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py
index 7dc3c4d24..0599324d7 100644
--- a/numpy/lib/tests/test_stride_tricks.py
+++ b/numpy/lib/tests/test_stride_tricks.py
@@ -1,6 +1,7 @@
from __future__ import division, absolute_import, print_function
import numpy as np
+from numpy.core.test_rational import rational
from numpy.testing import (
run_module_suite, assert_equal, assert_array_equal,
assert_raises, assert_
@@ -317,6 +318,13 @@ def test_as_strided():
a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
assert_equal(a.dtype, a_view.dtype)
+ # Custom dtypes should not be lost (gh-9161)
+ r = [rational(i) for i in range(4)]
+ a = np.array(r, dtype=rational)
+ a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
+ assert_equal(a.dtype, a_view.dtype)
+ assert_array_equal([r] * 3, a_view)
+
def as_strided_writeable():
arr = np.ones(10)
view = as_strided(arr, writeable=False)