From ff12de3f3dead7f522ab5a1076fcef4fbbdf3314 Mon Sep 17 00:00:00 2001 From: Allan Haldane Date: Thu, 8 Feb 2018 00:03:41 -0500 Subject: BUG: infinite recursion in str of 0d subclasses Fixes #10360 --- numpy/core/arrayprint.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) (limited to 'numpy/core/arrayprint.py') diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index 472318098..0fdd6fc7f 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -468,14 +468,17 @@ def _recursive_guard(fillvalue='...'): # gracefully handle recursive calls, when object arrays contain themselves @_recursive_guard() def _array2string(a, options, separator=' ', prefix=""): - # The formatter __init__s cannot deal with subclasses yet - data = asarray(a) + # The formatter __init__s in _get_format_function cannot deal with + # subclasses yet, and we also need to avoid recursion issues in + # _formatArray with subclasses which return 0d arrays in place of scalars + a = asarray(a) if a.size > options['threshold']: summary_insert = "..." - data = _leading_trailing(data, options['edgeitems']) + data = _leading_trailing(a, options['edgeitems']) else: summary_insert = "" + data = a # find the right formatting function for the array format_function = _get_format_function(data, **options) @@ -501,7 +504,7 @@ def array2string(a, max_line_width=None, precision=None, Parameters ---------- - a : ndarray + a : array_like Input array. max_line_width : int, optional The maximum number of columns the string should span. Newline @@ -763,7 +766,7 @@ def _formatArray(a, format_function, line_width, next_line_prefix, if show_summary: if legacy == '1.13': - # trailing space, fixed number of newlines, and fixed separator + # trailing space, fixed nbr of newlines, and fixed separator s += hanging_indent + summary_insert + ", \n" else: s += hanging_indent + summary_insert + line_sep @@ -1413,6 +1416,8 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None): return arr_str + spacer + dtype_str +_guarded_str = _recursive_guard()(str) + def array_str(a, max_line_width=None, precision=None, suppress_small=None): """ Return a string representation of the data in an array. @@ -1455,7 +1460,10 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None): # so floats are not truncated by `precision`, and strings are not wrapped # in quotes. So we return the str of the scalar value. if a.shape == (): - return str(a[()]) + # obtain a scalar and call str on it, avoiding problems for subclasses + # for which indexing with () returns a 0d instead of a scalar by using + # ndarray's getindex. Also guard against recursive 0d object arrays. + return _guarded_str(np.ndarray.__getitem__(a, ())) return array2string(a, max_line_width, precision, suppress_small, ' ', "") -- cgit v1.2.1