diff options
author | Pauli Virtanen <pav@iki.fi> | 2010-04-04 20:21:03 +0000 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2010-04-04 20:21:03 +0000 |
commit | 86d3b8181a44ab3394595c0323691d250c77d44a (patch) | |
tree | ece9fb3d9fa7a4e0a08c646b0a6df82ac906a523 | |
parent | b5967d0073b6eb2305ccfec07809e48d4134d7b4 (diff) | |
download | numpy-86d3b8181a44ab3394595c0323691d250c77d44a.tar.gz |
ENH: core: improve the way trailing padding is dealed with in PEP 3118 format strings
-rw-r--r-- | numpy/core/_internal.py | 69 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 39 |
2 files changed, 92 insertions, 16 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 9463b5099..8cb5a82ac 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -485,10 +485,12 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False): else: raise ValueError("Unknown PEP 3118 data type specifier %r" % spec) + # # Native alignment may require padding # - # XXX: here we assume that the presence of a '@' character implies - # that the start of the array is *also* aligned. + # Here we assume that the presence of a '@' character implicitly implies + # that the start of the array is *already* aligned. + # extra_offset = 0 if byteorder == '@': start_padding = (-offset) % align @@ -497,10 +499,12 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False): offset += start_padding if intra_padding != 0: - if itemsize > 1 or shape is not None: - value = dtype([('f0', value), - ('pad', '%dV' % intra_padding)]) + if itemsize > 1 or (shape is not None and _prod(shape) > 1): + # Inject internal padding to the end of the sub-item + value = _add_trailing_padding(value, intra_padding) else: + # We can postpone the injection of internal padding, + # as the item appears at most once extra_offset += intra_padding # Update common alignment @@ -531,31 +535,74 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False): raise RuntimeError("Duplicate field name '%s' in PEP3118 format" % name) fields[name] = (value, offset) + last_offset = offset if not this_explicit_name: next_dummy_name() - last_offset = offset byteorder = next_byteorder offset += value.itemsize offset += extra_offset - if is_padding and not this_explicit_name: - # Trailing padding must be made explicit - name = get_dummy_name() - fields[name] = ('V%d' % (offset - last_offset), last_offset) - + # Check if this was a simple 1-item type if len(fields.keys()) == 1 and not explicit_name and fields['f0'][1] == 0 \ and not is_subdtype: ret = fields['f0'][0] else: ret = dtype(fields) + # Trailing padding must be explicitly added + padding = offset - ret.itemsize + if byteorder == '@': + padding += (-offset) % common_alignment + if is_padding and not this_explicit_name: + ret = _add_trailing_padding(ret, padding) + + # Finished if is_subdtype: return ret, spec, common_alignment, byteorder else: return ret +def _add_trailing_padding(value, padding): + """Inject the specified number of padding bytes at the end of a dtype""" + from numpy.core.multiarray import dtype + + if value.fields is None: + vfields = {'f0': (value, 0)} + else: + vfields = dict(value.fields) + + if value.names and value.names[-1] == '' and \ + value[''].char == 'V': + # A trailing padding field is already present + vfields[''] = ('V%d' % (vfields[''][0].itemsize + padding), + vfields[''][1]) + value = dtype(vfields) + else: + # Get a free name for the padding field + j = 0 + while True: + name = 'pad%d' % j + if name not in vfields: + vfields[name] = ('V%d' % padding, value.itemsize) + break + j += 1 + + value = dtype(vfields) + if '' not in vfields: + # Strip out the name of the padding field + names = list(value.names) + names[-1] = '' + value.names = tuple(names) + return value + +def _prod(a): + p = 1 + for x in a: + p *= x + return p + def _gcd(a, b): """Calculate the greatest common divisor of a and b""" while b: diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index f2974a82d..f3b691b68 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1460,7 +1460,13 @@ if sys.version_info >= (2, 6): class TestPEP3118Dtype(object): def _check(self, spec, wanted): - assert_equal(_dtype_from_pep3118(spec), np.dtype(wanted), + dt = np.dtype(wanted) + if isinstance(wanted, list) and isinstance(wanted[-1], tuple): + if wanted[-1][0] == '': + names = list(dt.names) + names[-1] = '' + dt.names = tuple(names) + assert_equal(_dtype_from_pep3118(spec), dt, err_msg="spec %r != dtype %r" % (spec, wanted)) def test_native_padding(self): @@ -1481,14 +1487,37 @@ if sys.version_info >= (2, 6): self._check('^x3T{xi}', {'f0': (({'f0': ('i', 1)}, (3,)), 1)}) def test_trailing_padding(self): - # Trailing padding should be included - self._check('ix', [('f0', 'i'), ('f1', 'V1')]) + # Trailing padding should be included, *and*, the item size + # should match the alignment if in aligned mode + align = np.dtype('i').alignment + def VV(n): + return 'V%d' % (align*(1 + (n-1)//align)) + + self._check('ix', [('f0', 'i'), ('', VV(1))]) + self._check('ixx', [('f0', 'i'), ('', VV(2))]) + self._check('ixxx', [('f0', 'i'), ('', VV(3))]) + self._check('ixxxx', [('f0', 'i'), ('', VV(4))]) + self._check('i7x', [('f0', 'i'), ('', VV(7))]) + + self._check('^ix', [('f0', 'i'), ('', 'V1')]) + self._check('^ixx', [('f0', 'i'), ('', 'V2')]) + self._check('^ixxx', [('f0', 'i'), ('', 'V3')]) + self._check('^ixxxx', [('f0', 'i'), ('', 'V4')]) + self._check('^i7x', [('f0', 'i'), ('', 'V7')]) def test_byteorder_inside_struct(self): # The byte order after @T{=i} should be '=', not '@'. # Check this by noting the absence of native alignment. - self._check('@T{^i}xi', {'f0': ({'f0': (np.int32, 0)}, 0), - 'f1': (np.int32, 5)}) + self._check('@T{^i}xi', {'f0': ({'f0': ('i', 0)}, 0), + 'f1': ('i', 5)}) + + def test_intra_padding(self): + # Natively aligned sub-arrays may require some internal padding + align = np.dtype('i').alignment + def VV(n): + return 'V%d' % (align*(1 + (n-1)//align)) + + self._check('(3)T{ix}', ({'f0': ('i', 0), '': (VV(1), 4)}, (3,))) class TestNewBufferProtocol(object): def _check_roundtrip(self, obj): |