summaryrefslogtreecommitdiff
path: root/numpy/core/arrayprint.py
diff options
context:
space:
mode:
authorShota Kawabuchi <shota.kawabuchi+Github@gmail.com>2016-10-15 21:56:24 +0900
committerShota Kawabuchi <shota.kawabuchi+Github@gmail.com>2016-10-18 13:03:18 +0900
commit3d75c3f5165d958ff109a4f877aeb18d77ce279f (patch)
treecda6eed31b8c9e24ea551d64759590aaf20269da /numpy/core/arrayprint.py
parentb8da06bdc44b3481f9e0e17d6ff24e79176eed7c (diff)
downloadnumpy-3d75c3f5165d958ff109a4f877aeb18d77ce279f.tar.gz
BUG: fix _array2string for strustured array (issue #5692)
The cause of issue #5692 is that `_array2string` (in numpy/core/arrayprint.py) doesn’t have format function for structured arrays and it uses general purpose format function to format array elements of structured arrays. This commit adds `StructureFormat` class to format structured array elements. `_get_format_function` instantiates `StructureFormat` by instantiating format function classes for each field of structure recursively and merge them. Closes #5692.
Diffstat (limited to 'numpy/core/arrayprint.py')
-rw-r--r--numpy/core/arrayprint.py99
1 files changed, 66 insertions, 33 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py
index cd618d72a..7a84eb7c2 100644
--- a/numpy/core/arrayprint.py
+++ b/numpy/core/arrayprint.py
@@ -234,27 +234,23 @@ def _boolFormatter(x):
def repr_format(x):
return repr(x)
-def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
- prefix="", formatter=None):
-
- if max_line_width is None:
- max_line_width = _line_width
-
- if precision is None:
- precision = _float_output_precision
-
- if suppress_small is None:
- suppress_small = _float_output_suppress_small
-
- if formatter is None:
- formatter = _formatter
-
- if a.size > _summaryThreshold:
- summary_insert = "..., "
- data = _leading_trailing(a)
- else:
- summary_insert = ""
- data = ravel(asarray(a))
+def _get_format_function(data, precision, suppress_small, formatter):
+ """
+ find the right formatting function for the dtype_
+ """
+ dtype_ = data.dtype
+ if dtype_.fields is not None:
+ format_functions = []
+ for descr in dtype_.descr:
+ field_name = descr[0]
+ field_values = data[field_name]
+ if len(field_values.shape) <= 1:
+ format_function = _get_format_function(
+ field_values, precision, suppress_small, formatter)
+ else:
+ format_function = repr_format
+ format_functions.append(format_function)
+ return StructureFormat(format_functions)
formatdict = {'bool': _boolFormatter,
'int': IntegerFormat(data),
@@ -289,31 +285,56 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
if key in fkeys:
formatdict[key] = formatter[key]
- # find the right formatting function for the array
- dtypeobj = a.dtype.type
+ dtypeobj = dtype_.type
if issubclass(dtypeobj, _nt.bool_):
- format_function = formatdict['bool']
+ return formatdict['bool']
elif issubclass(dtypeobj, _nt.integer):
if issubclass(dtypeobj, _nt.timedelta64):
- format_function = formatdict['timedelta']
+ return formatdict['timedelta']
else:
- format_function = formatdict['int']
+ return formatdict['int']
elif issubclass(dtypeobj, _nt.floating):
if issubclass(dtypeobj, _nt.longfloat):
- format_function = formatdict['longfloat']
+ return formatdict['longfloat']
else:
- format_function = formatdict['float']
+ return formatdict['float']
elif issubclass(dtypeobj, _nt.complexfloating):
if issubclass(dtypeobj, _nt.clongfloat):
- format_function = formatdict['longcomplexfloat']
+ return formatdict['longcomplexfloat']
else:
- format_function = formatdict['complexfloat']
+ return formatdict['complexfloat']
elif issubclass(dtypeobj, (_nt.unicode_, _nt.string_)):
- format_function = formatdict['numpystr']
+ return formatdict['numpystr']
elif issubclass(dtypeobj, _nt.datetime64):
- format_function = formatdict['datetime']
+ return formatdict['datetime']
else:
- format_function = formatdict['numpystr']
+ return formatdict['numpystr']
+
+def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
+ prefix="", formatter=None):
+
+ if max_line_width is None:
+ max_line_width = _line_width
+
+ if precision is None:
+ precision = _float_output_precision
+
+ if suppress_small is None:
+ suppress_small = _float_output_suppress_small
+
+ if formatter is None:
+ formatter = _formatter
+
+ if a.size > _summaryThreshold:
+ summary_insert = "..., "
+ data = _leading_trailing(a)
+ else:
+ summary_insert = ""
+ data = ravel(asarray(a))
+
+ # find the right formatting function for the array
+ format_function = _get_format_function(data, precision,
+ suppress_small, formatter)
# skip over "["
next_line_prefix = " "
@@ -758,3 +779,15 @@ class TimedeltaFormat(object):
return self._nat
else:
return self.format % x.astype('i8')
+
+
+class StructureFormat(object):
+ def __init__(self, format_functions):
+ self.format_functions = format_functions
+ self.num_fields = len(format_functions)
+
+ def __call__(self, x):
+ s = "("
+ for field, format_function in zip(x, self.format_functions):
+ s += format_function(field) + ", "
+ return (s[:-2] if 1 < self.num_fields else s[:-1]) + ")"