diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/arrayprint.py | 129 | ||||
-rw-r--r-- | numpy/core/tests/test_overrides.py | 15 |
2 files changed, 91 insertions, 53 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index 1b9fbbfa9..0c45989e1 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -1381,6 +1381,53 @@ def dtype_short_repr(dtype): return typename +def _array_repr_implementation( + arr, max_line_width=None, precision=None, suppress_small=None, + array2string=array2string): + """Internal version of array_repr() that allows overriding array2string.""" + if max_line_width is None: + max_line_width = _format_options['linewidth'] + + if type(arr) is not ndarray: + class_name = type(arr).__name__ + else: + class_name = "array" + + skipdtype = dtype_is_implied(arr.dtype) and arr.size > 0 + + prefix = class_name + "(" + suffix = ")" if skipdtype else "," + + if (_format_options['legacy'] == '1.13' and + arr.shape == () and not arr.dtype.names): + lst = repr(arr.item()) + elif arr.size > 0 or arr.shape == (0,): + lst = array2string(arr, max_line_width, precision, suppress_small, + ', ', prefix, suffix=suffix) + else: # show zero-length shape unless it is (0,) + lst = "[], shape=%s" % (repr(arr.shape),) + + arr_str = prefix + lst + suffix + + if skipdtype: + return arr_str + + dtype_str = "dtype={})".format(dtype_short_repr(arr.dtype)) + + # compute whether we should put dtype on a new line: Do so if adding the + # dtype would extend the last line past max_line_width. + # Note: This line gives the correct result even when rfind returns -1. + last_line_len = len(arr_str) - (arr_str.rfind('\n') + 1) + spacer = " " + if _format_options['legacy'] == '1.13': + if issubclass(arr.dtype.type, flexible): + spacer = '\n' + ' '*len(class_name + "(") + elif last_line_len + len(dtype_str) + 1 > max_line_width: + spacer = '\n' + ' '*len(class_name + "(") + + return arr_str + spacer + dtype_str + + def _array_repr_dispatcher( arr, max_line_width=None, precision=None, suppress_small=None): return (arr,) @@ -1429,50 +1476,31 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None): 'array([ 0.000001, 0. , 2. , 3. ])' """ - if max_line_width is None: - max_line_width = _format_options['linewidth'] + return _array_repr_implementation( + arr, max_line_width, precision, suppress_small) - if type(arr) is not ndarray: - class_name = type(arr).__name__ - else: - class_name = "array" - skipdtype = dtype_is_implied(arr.dtype) and arr.size > 0 +_guarded_str = _recursive_guard()(str) - prefix = class_name + "(" - suffix = ")" if skipdtype else "," +def _array_str_implementation( + a, max_line_width=None, precision=None, suppress_small=None, + array2string=array2string): + """Internal version of array_str() that allows overriding array2string.""" if (_format_options['legacy'] == '1.13' and - arr.shape == () and not arr.dtype.names): - lst = repr(arr.item()) - elif arr.size > 0 or arr.shape == (0,): - lst = array2string(arr, max_line_width, precision, suppress_small, - ', ', prefix, suffix=suffix) - else: # show zero-length shape unless it is (0,) - lst = "[], shape=%s" % (repr(arr.shape),) - - arr_str = prefix + lst + suffix - - if skipdtype: - return arr_str - - dtype_str = "dtype={})".format(dtype_short_repr(arr.dtype)) - - # compute whether we should put dtype on a new line: Do so if adding the - # dtype would extend the last line past max_line_width. - # Note: This line gives the correct result even when rfind returns -1. - last_line_len = len(arr_str) - (arr_str.rfind('\n') + 1) - spacer = " " - if _format_options['legacy'] == '1.13': - if issubclass(arr.dtype.type, flexible): - spacer = '\n' + ' '*len(class_name + "(") - elif last_line_len + len(dtype_str) + 1 > max_line_width: - spacer = '\n' + ' '*len(class_name + "(") - - return arr_str + spacer + dtype_str + a.shape == () and not a.dtype.names): + return str(a.item()) + # the str of 0d arrays is a special case: It should appear like a scalar, + # so floats are not truncated by `precision`, and strings are not wrapped + # in quotes. So we return the str of the scalar value. + if a.shape == (): + # obtain a scalar and call str on it, avoiding problems for subclasses + # for which indexing with () returns a 0d instead of a scalar by using + # ndarray's getindex. Also guard against recursive 0d object arrays. + return _guarded_str(np.ndarray.__getitem__(a, ())) -_guarded_str = _recursive_guard()(str) + return array2string(a, max_line_width, precision, suppress_small, ' ', "") def _array_str_dispatcher( @@ -1515,20 +1543,15 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None): '[0 1 2]' """ - if (_format_options['legacy'] == '1.13' and - a.shape == () and not a.dtype.names): - return str(a.item()) + return _array_str_implementation( + a, max_line_width, precision, suppress_small) - # the str of 0d arrays is a special case: It should appear like a scalar, - # so floats are not truncated by `precision`, and strings are not wrapped - # in quotes. So we return the str of the scalar value. - if a.shape == (): - # obtain a scalar and call str on it, avoiding problems for subclasses - # for which indexing with () returns a 0d instead of a scalar by using - # ndarray's getindex. Also guard against recursive 0d object arrays. - return _guarded_str(np.ndarray.__getitem__(a, ())) - return array2string(a, max_line_width, precision, suppress_small, ' ', "") +_default_array_str = functools.partial(_array_str_implementation, + array2string=array2string.__wrapped__) +_default_array_repr = functools.partial(_array_repr_implementation, + array2string=array2string.__wrapped__) + def set_string_function(f, repr=True): """ @@ -1583,11 +1606,11 @@ def set_string_function(f, repr=True): """ if f is None: if repr: - return multiarray.set_string_function(array_repr, 1) + return multiarray.set_string_function(_default_array_repr, 1) else: - return multiarray.set_string_function(array_str, 0) + return multiarray.set_string_function(_default_array_str, 0) else: return multiarray.set_string_function(f, repr) -set_string_function(array_str, 0) -set_string_function(array_repr, 1) +set_string_function(_default_array_str, 0) +set_string_function(_default_array_repr, 1) diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index 3f87a6afe..c959655a7 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -304,3 +304,18 @@ class TestArrayFunctionImplementation(object): with assert_raises_regex(TypeError, 'no implementation found'): func(MyArray()) + + +class TestNDArrayMethods(object): + + def test_repr(self): + # gh-12162: should still be defined even if __array_function__ doesn't + # implement np.array_repr() + + class MyArray(np.ndarray): + def __array_function__(*args, **kwargs): + return NotImplemented + + array = np.array(1).view(MyArray) + assert_equal(repr(array), 'MyArray(1)') + assert_equal(str(array), '1') |