summaryrefslogtreecommitdiff
path: root/numpy/lib/recfunctions.py
diff options
context:
space:
mode:
authorAllan Haldane <allan.haldane@gmail.com>2019-08-20 16:42:04 -0400
committerAllan Haldane <allan.haldane@gmail.com>2019-08-23 13:15:40 -0400
commit63ecfb884d23a3edcddee55b2bc64582cf8db757 (patch)
treeed50d7bd0be55c79f9f39bc7441ed6306fc8feb8 /numpy/lib/recfunctions.py
parent950dd4e15ea0976bb671148440e036c3ae2dc11d (diff)
downloadnumpy-63ecfb884d23a3edcddee55b2bc64582cf8db757.tar.gz
MAINT: fix behavior of structured_to_unstructured on non-trivial dtypes
Fixes #13333
Diffstat (limited to 'numpy/lib/recfunctions.py')
-rw-r--r--numpy/lib/recfunctions.py46
1 files changed, 39 insertions, 7 deletions
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py
index 6e257bb3f..c17c39c8a 100644
--- a/numpy/lib/recfunctions.py
+++ b/numpy/lib/recfunctions.py
@@ -874,16 +874,35 @@ def _get_fields_and_offsets(dt, offset=0):
scalar fields in the dtype "dt", including nested fields, in left
to right order.
"""
+
+ # counts up elements in subarrays, including nested subarrays, and returns
+ # base dtype and count
+ def count_elem(dt):
+ count = 1
+ while dt.shape != ():
+ for size in dt.shape:
+ count *= size
+ dt = dt.base
+ return dt, count
+
fields = []
for name in dt.names:
field = dt.fields[name]
- if field[0].names is None:
- count = 1
- for size in field[0].shape:
- count *= size
- fields.append((field[0], count, field[1] + offset))
+ f_dt, f_offset = field[0], field[1]
+ f_dt, n = count_elem(f_dt)
+
+ if f_dt.names is None:
+ fields.append((np.dtype((f_dt, (n,))), n, f_offset + offset))
else:
- fields.extend(_get_fields_and_offsets(field[0], field[1] + offset))
+ subfields = _get_fields_and_offsets(f_dt, f_offset + offset)
+ size = f_dt.itemsize
+
+ for i in range(n):
+ if i == 0:
+ # optimization: avoid list comprehension if no subarray
+ fields.extend(subfields)
+ else:
+ fields.extend([(d, c, o + i*size) for d, c, o in subfields])
return fields
@@ -948,6 +967,12 @@ def structured_to_unstructured(arr, dtype=None, copy=False, casting='unsafe'):
fields = _get_fields_and_offsets(arr.dtype)
n_fields = len(fields)
+ if n_fields == 0 and dtype is None:
+ raise ValueError("arr has no fields. Unable to guess dtype")
+ elif n_fields == 0:
+ # too many bugs elsewhere for this to work now
+ raise NotImplementedError("arr with no fields is not supported")
+
dts, counts, offsets = zip(*fields)
names = ['f{}'.format(n) for n in range(n_fields)]
@@ -1039,6 +1064,9 @@ def unstructured_to_structured(arr, dtype=None, names=None, align=False,
if arr.shape == ():
raise ValueError('arr must have at least one dimension')
n_elem = arr.shape[-1]
+ if n_elem == 0:
+ # too many bugs elsewhere for this to work now
+ raise NotImplementedError("last axis with size 0 is not supported")
if dtype is None:
if names is None:
@@ -1051,7 +1079,11 @@ def unstructured_to_structured(arr, dtype=None, names=None, align=False,
raise ValueError("don't supply both dtype and names")
# sanity check of the input dtype
fields = _get_fields_and_offsets(dtype)
- dts, counts, offsets = zip(*fields)
+ if len(fields) == 0:
+ dts, counts, offsets = [], [], []
+ else:
+ dts, counts, offsets = zip(*fields)
+
if n_elem != sum(counts):
raise ValueError('The length of the last dimension of arr must '
'be equal to the number of fields in dtype')