diff options
author | Shota Kawabuchi <shota.kawabuchi+Github@gmail.com> | 2016-10-22 12:26:46 +0900 |
---|---|---|
committer | Shota Kawabuchi <shota.kawabuchi+Github@gmail.com> | 2016-10-22 13:25:41 +0900 |
commit | 2a4dd999c82276d00ef96d0d5839ff8b1f8a8871 (patch) | |
tree | a23b200f266ad5a906d6742685b363077a553cb4 /numpy/core/arrayprint.py | |
parent | ebc9910d1f1d84106c17174a7d3a87a651d62a93 (diff) | |
download | numpy-2a4dd999c82276d00ef96d0d5839ff8b1f8a8871.tar.gz |
BUG: Fix subarray format changed in #8160
Preserving structured array element format,
this commit fixes subarray format changed in PR #8160.
This commit also changes iterator for field name from dtype_.descr to
dtype_.names (Related to #8174).
Diffstat (limited to 'numpy/core/arrayprint.py')
-rw-r--r-- | numpy/core/arrayprint.py | 49 |
1 files changed, 31 insertions, 18 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index 7a84eb7c2..1d93a0c0b 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -234,24 +234,7 @@ def _boolFormatter(x): def repr_format(x): return repr(x) -def _get_format_function(data, precision, suppress_small, formatter): - """ - find the right formatting function for the dtype_ - """ - dtype_ = data.dtype - if dtype_.fields is not None: - format_functions = [] - for descr in dtype_.descr: - field_name = descr[0] - field_values = data[field_name] - if len(field_values.shape) <= 1: - format_function = _get_format_function( - field_values, precision, suppress_small, formatter) - else: - format_function = repr_format - format_functions.append(format_function) - return StructureFormat(format_functions) - +def _get_formatdict(data, precision, suppress_small, formatter): formatdict = {'bool': _boolFormatter, 'int': IntegerFormat(data), 'float': FloatFormat(data, precision, suppress_small), @@ -285,7 +268,27 @@ def _get_format_function(data, precision, suppress_small, formatter): if key in fkeys: formatdict[key] = formatter[key] + return formatdict + +def _get_format_function(data, precision, suppress_small, formatter): + """ + find the right formatting function for the dtype_ + """ + dtype_ = data.dtype + if dtype_.fields is not None: + format_functions = [] + for field_name in dtype_.names: + field_values = data[field_name] + is_array_field = 1 < field_values.ndim + format_function = _get_format_function( + ravel(field_values), precision, suppress_small, formatter) + if is_array_field: + format_function = SubArrayFormat(format_function) + format_functions.append(format_function) + return StructureFormat(format_functions) + dtypeobj = dtype_.type + formatdict = _get_formatdict(data, precision, suppress_small, formatter) if issubclass(dtypeobj, _nt.bool_): return formatdict['bool'] elif issubclass(dtypeobj, _nt.integer): @@ -781,6 +784,16 @@ class TimedeltaFormat(object): return self.format % x.astype('i8') +class SubArrayFormat(object): + def __init__(self, format_function): + self.format_function = format_function + + def __call__(self, arr): + if arr.ndim <= 1: + return "[" + ", ".join(self.format_function(a) for a in arr) + "]" + return "[" + ", ".join(self.__call__(a) for a in arr) + "]" + + class StructureFormat(object): def __init__(self, format_functions): self.format_functions = format_functions |