summaryrefslogtreecommitdiff
path: root/numpy/core/arrayprint.py
diff options
context:
space:
mode:
authorShota Kawabuchi <shota.kawabuchi+Github@gmail.com>2016-10-22 12:26:46 +0900
committerShota Kawabuchi <shota.kawabuchi+Github@gmail.com>2016-10-22 13:25:41 +0900
commit2a4dd999c82276d00ef96d0d5839ff8b1f8a8871 (patch)
treea23b200f266ad5a906d6742685b363077a553cb4 /numpy/core/arrayprint.py
parentebc9910d1f1d84106c17174a7d3a87a651d62a93 (diff)
downloadnumpy-2a4dd999c82276d00ef96d0d5839ff8b1f8a8871.tar.gz
BUG: Fix subarray format changed in #8160
Preserving structured array element format, this commit fixes subarray format changed in PR #8160. This commit also changes iterator for field name from dtype_.descr to dtype_.names (Related to #8174).
Diffstat (limited to 'numpy/core/arrayprint.py')
-rw-r--r--numpy/core/arrayprint.py49
1 files changed, 31 insertions, 18 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py
index 7a84eb7c2..1d93a0c0b 100644
--- a/numpy/core/arrayprint.py
+++ b/numpy/core/arrayprint.py
@@ -234,24 +234,7 @@ def _boolFormatter(x):
def repr_format(x):
return repr(x)
-def _get_format_function(data, precision, suppress_small, formatter):
- """
- find the right formatting function for the dtype_
- """
- dtype_ = data.dtype
- if dtype_.fields is not None:
- format_functions = []
- for descr in dtype_.descr:
- field_name = descr[0]
- field_values = data[field_name]
- if len(field_values.shape) <= 1:
- format_function = _get_format_function(
- field_values, precision, suppress_small, formatter)
- else:
- format_function = repr_format
- format_functions.append(format_function)
- return StructureFormat(format_functions)
-
+def _get_formatdict(data, precision, suppress_small, formatter):
formatdict = {'bool': _boolFormatter,
'int': IntegerFormat(data),
'float': FloatFormat(data, precision, suppress_small),
@@ -285,7 +268,27 @@ def _get_format_function(data, precision, suppress_small, formatter):
if key in fkeys:
formatdict[key] = formatter[key]
+ return formatdict
+
+def _get_format_function(data, precision, suppress_small, formatter):
+ """
+ find the right formatting function for the dtype_
+ """
+ dtype_ = data.dtype
+ if dtype_.fields is not None:
+ format_functions = []
+ for field_name in dtype_.names:
+ field_values = data[field_name]
+ is_array_field = 1 < field_values.ndim
+ format_function = _get_format_function(
+ ravel(field_values), precision, suppress_small, formatter)
+ if is_array_field:
+ format_function = SubArrayFormat(format_function)
+ format_functions.append(format_function)
+ return StructureFormat(format_functions)
+
dtypeobj = dtype_.type
+ formatdict = _get_formatdict(data, precision, suppress_small, formatter)
if issubclass(dtypeobj, _nt.bool_):
return formatdict['bool']
elif issubclass(dtypeobj, _nt.integer):
@@ -781,6 +784,16 @@ class TimedeltaFormat(object):
return self.format % x.astype('i8')
+class SubArrayFormat(object):
+ def __init__(self, format_function):
+ self.format_function = format_function
+
+ def __call__(self, arr):
+ if arr.ndim <= 1:
+ return "[" + ", ".join(self.format_function(a) for a in arr) + "]"
+ return "[" + ", ".join(self.__call__(a) for a in arr) + "]"
+
+
class StructureFormat(object):
def __init__(self, format_functions):
self.format_functions = format_functions