summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2010-04-04 20:21:03 +0000
committerPauli Virtanen <pav@iki.fi>2010-04-04 20:21:03 +0000
commit86d3b8181a44ab3394595c0323691d250c77d44a (patch)
treeece9fb3d9fa7a4e0a08c646b0a6df82ac906a523
parentb5967d0073b6eb2305ccfec07809e48d4134d7b4 (diff)
downloadnumpy-86d3b8181a44ab3394595c0323691d250c77d44a.tar.gz
ENH: core: improve the way trailing padding is dealed with in PEP 3118 format strings
-rw-r--r--numpy/core/_internal.py69
-rw-r--r--numpy/core/tests/test_multiarray.py39
2 files changed, 92 insertions, 16 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index 9463b5099..8cb5a82ac 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -485,10 +485,12 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
else:
raise ValueError("Unknown PEP 3118 data type specifier %r" % spec)
+ #
# Native alignment may require padding
#
- # XXX: here we assume that the presence of a '@' character implies
- # that the start of the array is *also* aligned.
+ # Here we assume that the presence of a '@' character implicitly implies
+ # that the start of the array is *already* aligned.
+ #
extra_offset = 0
if byteorder == '@':
start_padding = (-offset) % align
@@ -497,10 +499,12 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
offset += start_padding
if intra_padding != 0:
- if itemsize > 1 or shape is not None:
- value = dtype([('f0', value),
- ('pad', '%dV' % intra_padding)])
+ if itemsize > 1 or (shape is not None and _prod(shape) > 1):
+ # Inject internal padding to the end of the sub-item
+ value = _add_trailing_padding(value, intra_padding)
else:
+ # We can postpone the injection of internal padding,
+ # as the item appears at most once
extra_offset += intra_padding
# Update common alignment
@@ -531,31 +535,74 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
raise RuntimeError("Duplicate field name '%s' in PEP3118 format"
% name)
fields[name] = (value, offset)
+ last_offset = offset
if not this_explicit_name:
next_dummy_name()
- last_offset = offset
byteorder = next_byteorder
offset += value.itemsize
offset += extra_offset
- if is_padding and not this_explicit_name:
- # Trailing padding must be made explicit
- name = get_dummy_name()
- fields[name] = ('V%d' % (offset - last_offset), last_offset)
-
+ # Check if this was a simple 1-item type
if len(fields.keys()) == 1 and not explicit_name and fields['f0'][1] == 0 \
and not is_subdtype:
ret = fields['f0'][0]
else:
ret = dtype(fields)
+ # Trailing padding must be explicitly added
+ padding = offset - ret.itemsize
+ if byteorder == '@':
+ padding += (-offset) % common_alignment
+ if is_padding and not this_explicit_name:
+ ret = _add_trailing_padding(ret, padding)
+
+ # Finished
if is_subdtype:
return ret, spec, common_alignment, byteorder
else:
return ret
+def _add_trailing_padding(value, padding):
+ """Inject the specified number of padding bytes at the end of a dtype"""
+ from numpy.core.multiarray import dtype
+
+ if value.fields is None:
+ vfields = {'f0': (value, 0)}
+ else:
+ vfields = dict(value.fields)
+
+ if value.names and value.names[-1] == '' and \
+ value[''].char == 'V':
+ # A trailing padding field is already present
+ vfields[''] = ('V%d' % (vfields[''][0].itemsize + padding),
+ vfields[''][1])
+ value = dtype(vfields)
+ else:
+ # Get a free name for the padding field
+ j = 0
+ while True:
+ name = 'pad%d' % j
+ if name not in vfields:
+ vfields[name] = ('V%d' % padding, value.itemsize)
+ break
+ j += 1
+
+ value = dtype(vfields)
+ if '' not in vfields:
+ # Strip out the name of the padding field
+ names = list(value.names)
+ names[-1] = ''
+ value.names = tuple(names)
+ return value
+
+def _prod(a):
+ p = 1
+ for x in a:
+ p *= x
+ return p
+
def _gcd(a, b):
"""Calculate the greatest common divisor of a and b"""
while b:
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index f2974a82d..f3b691b68 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1460,7 +1460,13 @@ if sys.version_info >= (2, 6):
class TestPEP3118Dtype(object):
def _check(self, spec, wanted):
- assert_equal(_dtype_from_pep3118(spec), np.dtype(wanted),
+ dt = np.dtype(wanted)
+ if isinstance(wanted, list) and isinstance(wanted[-1], tuple):
+ if wanted[-1][0] == '':
+ names = list(dt.names)
+ names[-1] = ''
+ dt.names = tuple(names)
+ assert_equal(_dtype_from_pep3118(spec), dt,
err_msg="spec %r != dtype %r" % (spec, wanted))
def test_native_padding(self):
@@ -1481,14 +1487,37 @@ if sys.version_info >= (2, 6):
self._check('^x3T{xi}', {'f0': (({'f0': ('i', 1)}, (3,)), 1)})
def test_trailing_padding(self):
- # Trailing padding should be included
- self._check('ix', [('f0', 'i'), ('f1', 'V1')])
+ # Trailing padding should be included, *and*, the item size
+ # should match the alignment if in aligned mode
+ align = np.dtype('i').alignment
+ def VV(n):
+ return 'V%d' % (align*(1 + (n-1)//align))
+
+ self._check('ix', [('f0', 'i'), ('', VV(1))])
+ self._check('ixx', [('f0', 'i'), ('', VV(2))])
+ self._check('ixxx', [('f0', 'i'), ('', VV(3))])
+ self._check('ixxxx', [('f0', 'i'), ('', VV(4))])
+ self._check('i7x', [('f0', 'i'), ('', VV(7))])
+
+ self._check('^ix', [('f0', 'i'), ('', 'V1')])
+ self._check('^ixx', [('f0', 'i'), ('', 'V2')])
+ self._check('^ixxx', [('f0', 'i'), ('', 'V3')])
+ self._check('^ixxxx', [('f0', 'i'), ('', 'V4')])
+ self._check('^i7x', [('f0', 'i'), ('', 'V7')])
def test_byteorder_inside_struct(self):
# The byte order after @T{=i} should be '=', not '@'.
# Check this by noting the absence of native alignment.
- self._check('@T{^i}xi', {'f0': ({'f0': (np.int32, 0)}, 0),
- 'f1': (np.int32, 5)})
+ self._check('@T{^i}xi', {'f0': ({'f0': ('i', 0)}, 0),
+ 'f1': ('i', 5)})
+
+ def test_intra_padding(self):
+ # Natively aligned sub-arrays may require some internal padding
+ align = np.dtype('i').alignment
+ def VV(n):
+ return 'V%d' % (align*(1 + (n-1)//align))
+
+ self._check('(3)T{ix}', ({'f0': ('i', 0), '': (VV(1), 4)}, (3,)))
class TestNewBufferProtocol(object):
def _check_roundtrip(self, obj):