diff options
author | Shota Kawabuchi <shota.kawabuchi+Github@gmail.com> | 2016-10-15 21:56:24 +0900 |
---|---|---|
committer | Shota Kawabuchi <shota.kawabuchi+Github@gmail.com> | 2016-10-18 13:03:18 +0900 |
commit | 3d75c3f5165d958ff109a4f877aeb18d77ce279f (patch) | |
tree | cda6eed31b8c9e24ea551d64759590aaf20269da /numpy/core/arrayprint.py | |
parent | b8da06bdc44b3481f9e0e17d6ff24e79176eed7c (diff) | |
download | numpy-3d75c3f5165d958ff109a4f877aeb18d77ce279f.tar.gz |
BUG: fix _array2string for strustured array (issue #5692)
The cause of issue #5692 is that `_array2string` (in
numpy/core/arrayprint.py) doesn’t have format function for structured
arrays and it uses general purpose format function to format array
elements of structured arrays.
This commit adds `StructureFormat` class to format structured array
elements. `_get_format_function` instantiates `StructureFormat` by
instantiating format function classes for each field of structure
recursively and merge them.
Closes #5692.
Diffstat (limited to 'numpy/core/arrayprint.py')
-rw-r--r-- | numpy/core/arrayprint.py | 99 |
1 files changed, 66 insertions, 33 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index cd618d72a..7a84eb7c2 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -234,27 +234,23 @@ def _boolFormatter(x): def repr_format(x): return repr(x) -def _array2string(a, max_line_width, precision, suppress_small, separator=' ', - prefix="", formatter=None): - - if max_line_width is None: - max_line_width = _line_width - - if precision is None: - precision = _float_output_precision - - if suppress_small is None: - suppress_small = _float_output_suppress_small - - if formatter is None: - formatter = _formatter - - if a.size > _summaryThreshold: - summary_insert = "..., " - data = _leading_trailing(a) - else: - summary_insert = "" - data = ravel(asarray(a)) +def _get_format_function(data, precision, suppress_small, formatter): + """ + find the right formatting function for the dtype_ + """ + dtype_ = data.dtype + if dtype_.fields is not None: + format_functions = [] + for descr in dtype_.descr: + field_name = descr[0] + field_values = data[field_name] + if len(field_values.shape) <= 1: + format_function = _get_format_function( + field_values, precision, suppress_small, formatter) + else: + format_function = repr_format + format_functions.append(format_function) + return StructureFormat(format_functions) formatdict = {'bool': _boolFormatter, 'int': IntegerFormat(data), @@ -289,31 +285,56 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ', if key in fkeys: formatdict[key] = formatter[key] - # find the right formatting function for the array - dtypeobj = a.dtype.type + dtypeobj = dtype_.type if issubclass(dtypeobj, _nt.bool_): - format_function = formatdict['bool'] + return formatdict['bool'] elif issubclass(dtypeobj, _nt.integer): if issubclass(dtypeobj, _nt.timedelta64): - format_function = formatdict['timedelta'] + return formatdict['timedelta'] else: - format_function = formatdict['int'] + return formatdict['int'] elif issubclass(dtypeobj, _nt.floating): if issubclass(dtypeobj, _nt.longfloat): - format_function = formatdict['longfloat'] + return formatdict['longfloat'] else: - format_function = formatdict['float'] + return formatdict['float'] elif issubclass(dtypeobj, _nt.complexfloating): if issubclass(dtypeobj, _nt.clongfloat): - format_function = formatdict['longcomplexfloat'] + return formatdict['longcomplexfloat'] else: - format_function = formatdict['complexfloat'] + return formatdict['complexfloat'] elif issubclass(dtypeobj, (_nt.unicode_, _nt.string_)): - format_function = formatdict['numpystr'] + return formatdict['numpystr'] elif issubclass(dtypeobj, _nt.datetime64): - format_function = formatdict['datetime'] + return formatdict['datetime'] else: - format_function = formatdict['numpystr'] + return formatdict['numpystr'] + +def _array2string(a, max_line_width, precision, suppress_small, separator=' ', + prefix="", formatter=None): + + if max_line_width is None: + max_line_width = _line_width + + if precision is None: + precision = _float_output_precision + + if suppress_small is None: + suppress_small = _float_output_suppress_small + + if formatter is None: + formatter = _formatter + + if a.size > _summaryThreshold: + summary_insert = "..., " + data = _leading_trailing(a) + else: + summary_insert = "" + data = ravel(asarray(a)) + + # find the right formatting function for the array + format_function = _get_format_function(data, precision, + suppress_small, formatter) # skip over "[" next_line_prefix = " " @@ -758,3 +779,15 @@ class TimedeltaFormat(object): return self._nat else: return self.format % x.astype('i8') + + +class StructureFormat(object): + def __init__(self, format_functions): + self.format_functions = format_functions + self.num_fields = len(format_functions) + + def __call__(self, x): + s = "(" + for field, format_function in zip(x, self.format_functions): + s += format_function(field) + ", " + return (s[:-2] if 1 < self.num_fields else s[:-1]) + ")" |