diff options
author | ahaldane <ealloc@gmail.com> | 2017-05-09 12:48:41 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-09 12:48:41 -0400 |
commit | 86180665a8ba05dc70144db50024f967ea9ccc78 (patch) | |
tree | eb627a83645f799f82060da8a76c8e9f49b3e516 /numpy/core/_internal.py | |
parent | 5e78b887b3c1a553a9181f059c3369cc59d744e9 (diff) | |
parent | a4f435c68c15bc43a9b09869aaedb94d6c1bee2a (diff) | |
download | numpy-86180665a8ba05dc70144db50024f967ea9ccc78.tar.gz |
Merge pull request #9054 from eric-wieser/fix-pep3118
BUG: Various fixes to _dtype_from_pep3118
Diffstat (limited to 'numpy/core/_internal.py')
-rw-r--r-- | numpy/core/_internal.py | 240 |
1 files changed, 134 insertions, 106 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index c890eba17..f7f579c75 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -432,51 +432,83 @@ _pep3118_standard_map = { } _pep3118_standard_typechars = ''.join(_pep3118_standard_map.keys()) -def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False): - fields = {} +def _dtype_from_pep3118(spec): + + class Stream(object): + def __init__(self, s): + self.s = s + self.byteorder = '@' + + def advance(self, n): + res = self.s[:n] + self.s = self.s[n:] + return res + + def consume(self, c): + if self.s[:len(c)] == c: + self.advance(len(c)) + return True + return False + + def consume_until(self, c): + if callable(c): + i = 0 + while i < len(self.s) and not c(self.s[i]): + i = i + 1 + return self.advance(i) + else: + i = self.s.index(c) + res = self.advance(i) + self.advance(len(c)) + return res + + @property + def next(self): + return self.s[0] + + def __bool__(self): + return bool(self.s) + __nonzero__ = __bool__ + + stream = Stream(spec) + + dtype, align = __dtype_from_pep3118(stream, is_subdtype=False) + return dtype + +def __dtype_from_pep3118(stream, is_subdtype): + field_spec = dict( + names=[], + formats=[], + offsets=[], + itemsize=0 + ) offset = 0 - explicit_name = False - this_explicit_name = False common_alignment = 1 is_padding = False - dummy_name_index = [0] - - def next_dummy_name(): - dummy_name_index[0] += 1 - - def get_dummy_name(): - while True: - name = 'f%d' % dummy_name_index[0] - if name not in fields: - return name - next_dummy_name() - # Parse spec - while spec: + while stream: value = None # End of structure, bail out to upper level - if spec[0] == '}': - spec = spec[1:] + if stream.consume('}'): break # Sub-arrays (1) shape = None - if spec[0] == '(': - j = spec.index(')') - shape = tuple(map(int, spec[1:j].split(','))) - spec = spec[j+1:] + if stream.consume('('): + shape = stream.consume_until(')') + shape = tuple(map(int, shape.split(','))) # Byte order - if spec[0] in ('@', '=', '<', '>', '^', '!'): - byteorder = spec[0] + if stream.next in ('@', '=', '<', '>', '^', '!'): + byteorder = stream.advance(1) if byteorder == '!': byteorder = '>' - spec = spec[1:] + stream.byteorder = byteorder # Byte order characters also control native vs. standard type sizes - if byteorder in ('@', '^'): + if stream.byteorder in ('@', '^'): type_map = _pep3118_native_map type_map_chars = _pep3118_native_typechars else: @@ -484,39 +516,35 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False): type_map_chars = _pep3118_standard_typechars # Item sizes - itemsize = 1 - if spec[0].isdigit(): - j = 1 - for j in range(1, len(spec)): - if not spec[j].isdigit(): - break - itemsize = int(spec[:j]) - spec = spec[j:] + itemsize_str = stream.consume_until(lambda c: not c.isdigit()) + if itemsize_str: + itemsize = int(itemsize_str) + else: + itemsize = 1 # Data types is_padding = False - if spec[:2] == 'T{': - value, spec, align, next_byteorder = _dtype_from_pep3118( - spec[2:], byteorder=byteorder, is_subdtype=True) - elif spec[0] in type_map_chars: - next_byteorder = byteorder - if spec[0] == 'Z': - j = 2 + if stream.consume('T{'): + value, align = __dtype_from_pep3118( + stream, is_subdtype=True) + elif stream.next in type_map_chars: + if stream.next == 'Z': + typechar = stream.advance(2) else: - j = 1 - typechar = spec[:j] - spec = spec[j:] + typechar = stream.advance(1) + is_padding = (typechar == 'x') dtypechar = type_map[typechar] if dtypechar in 'USV': dtypechar += '%d' % itemsize itemsize = 1 - numpy_byteorder = {'@': '=', '^': '='}.get(byteorder, byteorder) + numpy_byteorder = {'@': '=', '^': '='}.get( + stream.byteorder, stream.byteorder) value = dtype(numpy_byteorder + dtypechar) align = value.alignment else: - raise ValueError("Unknown PEP 3118 data type specifier %r" % spec) + raise ValueError("Unknown PEP 3118 data type specifier %r" % stream.s) # # Native alignment may require padding @@ -525,7 +553,7 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False): # that the start of the array is *already* aligned. # extra_offset = 0 - if byteorder == '@': + if stream.byteorder == '@': start_padding = (-offset) % align intra_padding = (-value.itemsize) % align @@ -541,8 +569,7 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False): extra_offset += intra_padding # Update common alignment - common_alignment = (align*common_alignment - / _gcd(align, common_alignment)) + common_alignment = _lcm(align, common_alignment) # Convert itemsize to sub-array if itemsize != 1: @@ -553,79 +580,77 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False): value = dtype((value, shape)) # Field name - this_explicit_name = False - if spec and spec.startswith(':'): - i = spec[1:].index(':') + 1 - name = spec[1:i] - spec = spec[i+1:] - explicit_name = True - this_explicit_name = True + if stream.consume(':'): + name = stream.consume_until(':') else: - name = get_dummy_name() + name = None - if not is_padding or this_explicit_name: - if name in fields: + if not (is_padding and name is None): + if name is not None and name in field_spec['names']: raise RuntimeError("Duplicate field name '%s' in PEP3118 format" % name) - fields[name] = (value, offset) - if not this_explicit_name: - next_dummy_name() - - byteorder = next_byteorder + field_spec['names'].append(name) + field_spec['formats'].append(value) + field_spec['offsets'].append(offset) offset += value.itemsize offset += extra_offset - # Check if this was a simple 1-item type - if (len(fields) == 1 and not explicit_name and - fields['f0'][1] == 0 and not is_subdtype): - ret = fields['f0'][0] - else: - ret = dtype(fields) + field_spec['itemsize'] = offset - # Trailing padding must be explicitly added - padding = offset - ret.itemsize - if byteorder == '@': - padding += (-offset) % common_alignment - if is_padding and not this_explicit_name: - ret = _add_trailing_padding(ret, padding) + # extra final padding for aligned types + if stream.byteorder == '@': + field_spec['itemsize'] += (-offset) % common_alignment - # Finished - if is_subdtype: - return ret, spec, common_alignment, byteorder + # Check if this was a simple 1-item type, and unwrap it + if (field_spec['names'] == [None] + and field_spec['offsets'][0] == 0 + and field_spec['itemsize'] == field_spec['formats'][0].itemsize + and not is_subdtype): + ret = field_spec['formats'][0] else: - return ret + _fix_names(field_spec) + ret = dtype(field_spec) + + # Finished + return ret, common_alignment + +def _fix_names(field_spec): + """ Replace names which are None with the next unused f%d name """ + names = field_spec['names'] + for i, name in enumerate(names): + if name is not None: + continue + + j = 0 + while True: + name = 'f{}'.format(j) + if name not in names: + break + j = j + 1 + names[i] = name def _add_trailing_padding(value, padding): """Inject the specified number of padding bytes at the end of a dtype""" if value.fields is None: - vfields = {'f0': (value, 0)} - else: - vfields = dict(value.fields) - - if (value.names and value.names[-1] == '' and - value[''].char == 'V'): - # A trailing padding field is already present - vfields[''] = ('V%d' % (vfields[''][0].itemsize + padding), - vfields[''][1]) - value = dtype(vfields) + field_spec = dict( + names=['f0'], + formats=[value], + offsets=[0], + itemsize=value.itemsize + ) else: - # Get a free name for the padding field - j = 0 - while True: - name = 'pad%d' % j - if name not in vfields: - vfields[name] = ('V%d' % padding, value.itemsize) - break - j += 1 + fields = value.fields + names = value.names + field_spec = dict( + names=names, + formats=[fields[name][0] for name in names], + offsets=[fields[name][1] for name in names], + itemsize=value.itemsize + ) - value = dtype(vfields) - if '' not in vfields: - # Strip out the name of the padding field - names = list(value.names) - names[-1] = '' - value.names = tuple(names) - return value + field_spec['itemsize'] += padding + return dtype(field_spec) def _prod(a): p = 1 @@ -639,6 +664,9 @@ def _gcd(a, b): a, b = b, a % b return a +def _lcm(a, b): + return a // _gcd(a, b) * b + # Exception used in shares_memory() class TooHardError(RuntimeError): pass |