summaryrefslogtreecommitdiff
path: root/numpy/core/_internal.py
diff options
context:
space:
mode:
authorahaldane <ealloc@gmail.com>2017-05-09 12:48:41 -0400
committerGitHub <noreply@github.com>2017-05-09 12:48:41 -0400
commit86180665a8ba05dc70144db50024f967ea9ccc78 (patch)
treeeb627a83645f799f82060da8a76c8e9f49b3e516 /numpy/core/_internal.py
parent5e78b887b3c1a553a9181f059c3369cc59d744e9 (diff)
parenta4f435c68c15bc43a9b09869aaedb94d6c1bee2a (diff)
downloadnumpy-86180665a8ba05dc70144db50024f967ea9ccc78.tar.gz
Merge pull request #9054 from eric-wieser/fix-pep3118
BUG: Various fixes to _dtype_from_pep3118
Diffstat (limited to 'numpy/core/_internal.py')
-rw-r--r--numpy/core/_internal.py240
1 files changed, 134 insertions, 106 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index c890eba17..f7f579c75 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -432,51 +432,83 @@ _pep3118_standard_map = {
}
_pep3118_standard_typechars = ''.join(_pep3118_standard_map.keys())
-def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
- fields = {}
+def _dtype_from_pep3118(spec):
+
+ class Stream(object):
+ def __init__(self, s):
+ self.s = s
+ self.byteorder = '@'
+
+ def advance(self, n):
+ res = self.s[:n]
+ self.s = self.s[n:]
+ return res
+
+ def consume(self, c):
+ if self.s[:len(c)] == c:
+ self.advance(len(c))
+ return True
+ return False
+
+ def consume_until(self, c):
+ if callable(c):
+ i = 0
+ while i < len(self.s) and not c(self.s[i]):
+ i = i + 1
+ return self.advance(i)
+ else:
+ i = self.s.index(c)
+ res = self.advance(i)
+ self.advance(len(c))
+ return res
+
+ @property
+ def next(self):
+ return self.s[0]
+
+ def __bool__(self):
+ return bool(self.s)
+ __nonzero__ = __bool__
+
+ stream = Stream(spec)
+
+ dtype, align = __dtype_from_pep3118(stream, is_subdtype=False)
+ return dtype
+
+def __dtype_from_pep3118(stream, is_subdtype):
+ field_spec = dict(
+ names=[],
+ formats=[],
+ offsets=[],
+ itemsize=0
+ )
offset = 0
- explicit_name = False
- this_explicit_name = False
common_alignment = 1
is_padding = False
- dummy_name_index = [0]
-
- def next_dummy_name():
- dummy_name_index[0] += 1
-
- def get_dummy_name():
- while True:
- name = 'f%d' % dummy_name_index[0]
- if name not in fields:
- return name
- next_dummy_name()
-
# Parse spec
- while spec:
+ while stream:
value = None
# End of structure, bail out to upper level
- if spec[0] == '}':
- spec = spec[1:]
+ if stream.consume('}'):
break
# Sub-arrays (1)
shape = None
- if spec[0] == '(':
- j = spec.index(')')
- shape = tuple(map(int, spec[1:j].split(',')))
- spec = spec[j+1:]
+ if stream.consume('('):
+ shape = stream.consume_until(')')
+ shape = tuple(map(int, shape.split(',')))
# Byte order
- if spec[0] in ('@', '=', '<', '>', '^', '!'):
- byteorder = spec[0]
+ if stream.next in ('@', '=', '<', '>', '^', '!'):
+ byteorder = stream.advance(1)
if byteorder == '!':
byteorder = '>'
- spec = spec[1:]
+ stream.byteorder = byteorder
# Byte order characters also control native vs. standard type sizes
- if byteorder in ('@', '^'):
+ if stream.byteorder in ('@', '^'):
type_map = _pep3118_native_map
type_map_chars = _pep3118_native_typechars
else:
@@ -484,39 +516,35 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
type_map_chars = _pep3118_standard_typechars
# Item sizes
- itemsize = 1
- if spec[0].isdigit():
- j = 1
- for j in range(1, len(spec)):
- if not spec[j].isdigit():
- break
- itemsize = int(spec[:j])
- spec = spec[j:]
+ itemsize_str = stream.consume_until(lambda c: not c.isdigit())
+ if itemsize_str:
+ itemsize = int(itemsize_str)
+ else:
+ itemsize = 1
# Data types
is_padding = False
- if spec[:2] == 'T{':
- value, spec, align, next_byteorder = _dtype_from_pep3118(
- spec[2:], byteorder=byteorder, is_subdtype=True)
- elif spec[0] in type_map_chars:
- next_byteorder = byteorder
- if spec[0] == 'Z':
- j = 2
+ if stream.consume('T{'):
+ value, align = __dtype_from_pep3118(
+ stream, is_subdtype=True)
+ elif stream.next in type_map_chars:
+ if stream.next == 'Z':
+ typechar = stream.advance(2)
else:
- j = 1
- typechar = spec[:j]
- spec = spec[j:]
+ typechar = stream.advance(1)
+
is_padding = (typechar == 'x')
dtypechar = type_map[typechar]
if dtypechar in 'USV':
dtypechar += '%d' % itemsize
itemsize = 1
- numpy_byteorder = {'@': '=', '^': '='}.get(byteorder, byteorder)
+ numpy_byteorder = {'@': '=', '^': '='}.get(
+ stream.byteorder, stream.byteorder)
value = dtype(numpy_byteorder + dtypechar)
align = value.alignment
else:
- raise ValueError("Unknown PEP 3118 data type specifier %r" % spec)
+ raise ValueError("Unknown PEP 3118 data type specifier %r" % stream.s)
#
# Native alignment may require padding
@@ -525,7 +553,7 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
# that the start of the array is *already* aligned.
#
extra_offset = 0
- if byteorder == '@':
+ if stream.byteorder == '@':
start_padding = (-offset) % align
intra_padding = (-value.itemsize) % align
@@ -541,8 +569,7 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
extra_offset += intra_padding
# Update common alignment
- common_alignment = (align*common_alignment
- / _gcd(align, common_alignment))
+ common_alignment = _lcm(align, common_alignment)
# Convert itemsize to sub-array
if itemsize != 1:
@@ -553,79 +580,77 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
value = dtype((value, shape))
# Field name
- this_explicit_name = False
- if spec and spec.startswith(':'):
- i = spec[1:].index(':') + 1
- name = spec[1:i]
- spec = spec[i+1:]
- explicit_name = True
- this_explicit_name = True
+ if stream.consume(':'):
+ name = stream.consume_until(':')
else:
- name = get_dummy_name()
+ name = None
- if not is_padding or this_explicit_name:
- if name in fields:
+ if not (is_padding and name is None):
+ if name is not None and name in field_spec['names']:
raise RuntimeError("Duplicate field name '%s' in PEP3118 format"
% name)
- fields[name] = (value, offset)
- if not this_explicit_name:
- next_dummy_name()
-
- byteorder = next_byteorder
+ field_spec['names'].append(name)
+ field_spec['formats'].append(value)
+ field_spec['offsets'].append(offset)
offset += value.itemsize
offset += extra_offset
- # Check if this was a simple 1-item type
- if (len(fields) == 1 and not explicit_name and
- fields['f0'][1] == 0 and not is_subdtype):
- ret = fields['f0'][0]
- else:
- ret = dtype(fields)
+ field_spec['itemsize'] = offset
- # Trailing padding must be explicitly added
- padding = offset - ret.itemsize
- if byteorder == '@':
- padding += (-offset) % common_alignment
- if is_padding and not this_explicit_name:
- ret = _add_trailing_padding(ret, padding)
+ # extra final padding for aligned types
+ if stream.byteorder == '@':
+ field_spec['itemsize'] += (-offset) % common_alignment
- # Finished
- if is_subdtype:
- return ret, spec, common_alignment, byteorder
+ # Check if this was a simple 1-item type, and unwrap it
+ if (field_spec['names'] == [None]
+ and field_spec['offsets'][0] == 0
+ and field_spec['itemsize'] == field_spec['formats'][0].itemsize
+ and not is_subdtype):
+ ret = field_spec['formats'][0]
else:
- return ret
+ _fix_names(field_spec)
+ ret = dtype(field_spec)
+
+ # Finished
+ return ret, common_alignment
+
+def _fix_names(field_spec):
+ """ Replace names which are None with the next unused f%d name """
+ names = field_spec['names']
+ for i, name in enumerate(names):
+ if name is not None:
+ continue
+
+ j = 0
+ while True:
+ name = 'f{}'.format(j)
+ if name not in names:
+ break
+ j = j + 1
+ names[i] = name
def _add_trailing_padding(value, padding):
"""Inject the specified number of padding bytes at the end of a dtype"""
if value.fields is None:
- vfields = {'f0': (value, 0)}
- else:
- vfields = dict(value.fields)
-
- if (value.names and value.names[-1] == '' and
- value[''].char == 'V'):
- # A trailing padding field is already present
- vfields[''] = ('V%d' % (vfields[''][0].itemsize + padding),
- vfields[''][1])
- value = dtype(vfields)
+ field_spec = dict(
+ names=['f0'],
+ formats=[value],
+ offsets=[0],
+ itemsize=value.itemsize
+ )
else:
- # Get a free name for the padding field
- j = 0
- while True:
- name = 'pad%d' % j
- if name not in vfields:
- vfields[name] = ('V%d' % padding, value.itemsize)
- break
- j += 1
+ fields = value.fields
+ names = value.names
+ field_spec = dict(
+ names=names,
+ formats=[fields[name][0] for name in names],
+ offsets=[fields[name][1] for name in names],
+ itemsize=value.itemsize
+ )
- value = dtype(vfields)
- if '' not in vfields:
- # Strip out the name of the padding field
- names = list(value.names)
- names[-1] = ''
- value.names = tuple(names)
- return value
+ field_spec['itemsize'] += padding
+ return dtype(field_spec)
def _prod(a):
p = 1
@@ -639,6 +664,9 @@ def _gcd(a, b):
a, b = b, a % b
return a
+def _lcm(a, b):
+ return a // _gcd(a, b) * b
+
# Exception used in shares_memory()
class TooHardError(RuntimeError):
pass