summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/recfunctions.py78
-rw-r--r--numpy/lib/tests/test_recfunctions.py16
2 files changed, 92 insertions, 2 deletions
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py
index c455bd93f..b6453d5a2 100644
--- a/numpy/lib/recfunctions.py
+++ b/numpy/lib/recfunctions.py
@@ -732,6 +732,84 @@ def rec_append_fields(base, names, data, dtypes=None):
return append_fields(base, names, data=data, dtypes=dtypes,
asrecarray=True, usemask=False)
+def repack_fields(a, align=False, recurse=False):
+ """
+ Re-pack the fields of a structured array or dtype in memory.
+
+ The memory layout of structured datatypes allows fields at arbitrary
+ byte offsets. This means the fields can be separated by padding bytes,
+ their offsets can be non-monotonically increasing, and they can overlap.
+
+ This method removes any overlaps and reorders the fields in memory so they
+ have increasing byte offsets, and adds or removes padding bytes depending
+ on the `align` option, which behaves like the `align` option to `np.dtype`.
+
+ If `align=False`, this method produces a "packed" memory layout in which
+ each field starts at the byte the previous field ended, and any padding
+ bytes are removed.
+
+ If `align=True`, this methods produces an "aligned" memory layout in which
+ each field's offset is a multiple of its alignment, and the total itemsize
+ is a multiple of the largest alignment, by adding padding bytes as needed.
+
+ Parameters
+ ----------
+ a : ndarray or dtype
+ array or dtype for which to repack the fields.
+ align : boolean
+ If true, use an "aligned" memory layout, otherwise use a "packed" layout.
+ recurse : boolean
+ If True, also repack nested structures.
+
+ Returns
+ -------
+ repacked : ndarray or dtype
+ Copy of `a` with fields repacked, or `a` itself if no repacking was
+ needed.
+
+ Examples
+ --------
+
+ >>> def print_offsets(d):
+ ... print("offsets:", [d.fields[name][1] for name in d.names])
+ ... print("itemsize:", d.itemsize)
+ ...
+ >>> dt = np.dtype('u1,i4,f4', align=True)
+ >>> dt
+ dtype({'names':['f0','f1','f2'], 'formats':['u1','<i4','<f8'], 'offsets':[0,4,8], 'itemsize':16}, align=True)
+ >>> print_offsets(dt)
+ offsets: [0, 4, 8]
+ itemsize: 16
+ >>> packed_dt = repack_fields(dt)
+ >>> packed_dt
+ dtype([('f0', 'u1'), ('f1', '<i4'), ('f2', '<f8')])
+ >>> print_offsets(packed_dt)
+ offsets: [0, 1, 5]
+ itemsize: 13
+
+ """
+ if not isinstance(a, np.dtype):
+ dt = repack_fields(a.dtype, align=align, recurse=recurse)
+ return a.astype(dt, copy=False)
+
+ if a.names is None:
+ return a
+
+ fieldinfo = []
+ for name in a.names:
+ tup = a.fields[name]
+ if recurse:
+ fmt = repack_fields(tup[0], align=align, recurse=True)
+ else:
+ fmt = tup[0]
+
+ if len(tup) == 3:
+ name = (tup[2], name)
+
+ fieldinfo.append((name, fmt))
+
+ dt = np.dtype(fieldinfo, align=align)
+ return np.dtype((a.type, dt))
def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False,
autoconvert=False):
diff --git a/numpy/lib/tests/test_recfunctions.py b/numpy/lib/tests/test_recfunctions.py
index 219ae24fa..d4828bc1f 100644
--- a/numpy/lib/tests/test_recfunctions.py
+++ b/numpy/lib/tests/test_recfunctions.py
@@ -9,8 +9,8 @@ from numpy.ma.testutils import assert_equal
from numpy.testing import assert_, assert_raises
from numpy.lib.recfunctions import (
drop_fields, rename_fields, get_fieldstructure, recursive_fill_fields,
- find_duplicates, merge_arrays, append_fields, stack_arrays, join_by
- )
+ find_duplicates, merge_arrays, append_fields, stack_arrays, join_by,
+ repack_fields)
get_names = np.lib.recfunctions.get_names
get_names_flat = np.lib.recfunctions.get_names_flat
zip_descr = np.lib.recfunctions.zip_descr
@@ -192,6 +192,18 @@ class TestRecFunctions(object):
assert_equal(sorted(test[-1]), control)
assert_equal(test[0], a[test[-1]])
+ def test_repack_fields(self):
+ dt = np.dtype('u1,f4,i8', align=True)
+ a = np.zeros(2, dtype=dt)
+
+ assert_equal(repack_fields(dt), np.dtype('u1,f4,i8'))
+ assert_equal(repack_fields(a).itemsize, 13)
+ assert_equal(repack_fields(repack_fields(dt), align=True), dt)
+
+ # make sure type is preserved
+ dt = np.dtype((np.record, dt))
+ assert_(repack_fields(dt).type is np.record)
+
class TestRecursiveFillFields(object):
# Test recursive_fill_fields.