diff options
| author | Allan Haldane <ealloc@gmail.com> | 2018-10-27 12:19:16 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2018-10-27 12:19:16 -0400 |
| commit | 45718fd73bc286e127772ee455d721d9a58665b3 (patch) | |
| tree | aa1a7464fc2973a6f4a64b164d9a4cc095599095 /numpy | |
| parent | 1dcb28da57b4dbfe5a1fe31bf3dce4d7a888c70c (diff) | |
| parent | fbc3ad69d2396fc5edbb2f145c82965756185f82 (diff) | |
| download | numpy-45718fd73bc286e127772ee455d721d9a58665b3.tar.gz | |
Merge pull request #12212 from shoyer/fix-overloaded-repr
MAINT: ndarray.__repr__ should not rely on __array_function__
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/arrayprint.py | 129 | ||||
| -rw-r--r-- | numpy/core/overrides.py | 4 | ||||
| -rw-r--r-- | numpy/core/tests/test_overrides.py | 15 |
3 files changed, 95 insertions, 53 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index 3201b2f78..ccc1468c4 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -1381,6 +1381,53 @@ def dtype_short_repr(dtype): return typename +def _array_repr_implementation( + arr, max_line_width=None, precision=None, suppress_small=None, + array2string=array2string): + """Internal version of array_repr() that allows overriding array2string.""" + if max_line_width is None: + max_line_width = _format_options['linewidth'] + + if type(arr) is not ndarray: + class_name = type(arr).__name__ + else: + class_name = "array" + + skipdtype = dtype_is_implied(arr.dtype) and arr.size > 0 + + prefix = class_name + "(" + suffix = ")" if skipdtype else "," + + if (_format_options['legacy'] == '1.13' and + arr.shape == () and not arr.dtype.names): + lst = repr(arr.item()) + elif arr.size > 0 or arr.shape == (0,): + lst = array2string(arr, max_line_width, precision, suppress_small, + ', ', prefix, suffix=suffix) + else: # show zero-length shape unless it is (0,) + lst = "[], shape=%s" % (repr(arr.shape),) + + arr_str = prefix + lst + suffix + + if skipdtype: + return arr_str + + dtype_str = "dtype={})".format(dtype_short_repr(arr.dtype)) + + # compute whether we should put dtype on a new line: Do so if adding the + # dtype would extend the last line past max_line_width. + # Note: This line gives the correct result even when rfind returns -1. + last_line_len = len(arr_str) - (arr_str.rfind('\n') + 1) + spacer = " " + if _format_options['legacy'] == '1.13': + if issubclass(arr.dtype.type, flexible): + spacer = '\n' + ' '*len(class_name + "(") + elif last_line_len + len(dtype_str) + 1 > max_line_width: + spacer = '\n' + ' '*len(class_name + "(") + + return arr_str + spacer + dtype_str + + def _array_repr_dispatcher( arr, max_line_width=None, precision=None, suppress_small=None): return (arr,) @@ -1429,50 +1476,31 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None): 'array([ 0.000001, 0. , 2. , 3. ])' """ - if max_line_width is None: - max_line_width = _format_options['linewidth'] + return _array_repr_implementation( + arr, max_line_width, precision, suppress_small) - if type(arr) is not ndarray: - class_name = type(arr).__name__ - else: - class_name = "array" - skipdtype = dtype_is_implied(arr.dtype) and arr.size > 0 +_guarded_str = _recursive_guard()(str) - prefix = class_name + "(" - suffix = ")" if skipdtype else "," +def _array_str_implementation( + a, max_line_width=None, precision=None, suppress_small=None, + array2string=array2string): + """Internal version of array_str() that allows overriding array2string.""" if (_format_options['legacy'] == '1.13' and - arr.shape == () and not arr.dtype.names): - lst = repr(arr.item()) - elif arr.size > 0 or arr.shape == (0,): - lst = array2string(arr, max_line_width, precision, suppress_small, - ', ', prefix, suffix=suffix) - else: # show zero-length shape unless it is (0,) - lst = "[], shape=%s" % (repr(arr.shape),) - - arr_str = prefix + lst + suffix - - if skipdtype: - return arr_str - - dtype_str = "dtype={})".format(dtype_short_repr(arr.dtype)) - - # compute whether we should put dtype on a new line: Do so if adding the - # dtype would extend the last line past max_line_width. - # Note: This line gives the correct result even when rfind returns -1. - last_line_len = len(arr_str) - (arr_str.rfind('\n') + 1) - spacer = " " - if _format_options['legacy'] == '1.13': - if issubclass(arr.dtype.type, flexible): - spacer = '\n' + ' '*len(class_name + "(") - elif last_line_len + len(dtype_str) + 1 > max_line_width: - spacer = '\n' + ' '*len(class_name + "(") - - return arr_str + spacer + dtype_str + a.shape == () and not a.dtype.names): + return str(a.item()) + # the str of 0d arrays is a special case: It should appear like a scalar, + # so floats are not truncated by `precision`, and strings are not wrapped + # in quotes. So we return the str of the scalar value. + if a.shape == (): + # obtain a scalar and call str on it, avoiding problems for subclasses + # for which indexing with () returns a 0d instead of a scalar by using + # ndarray's getindex. Also guard against recursive 0d object arrays. + return _guarded_str(np.ndarray.__getitem__(a, ())) -_guarded_str = _recursive_guard()(str) + return array2string(a, max_line_width, precision, suppress_small, ' ', "") def _array_str_dispatcher( @@ -1515,20 +1543,15 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None): '[0 1 2]' """ - if (_format_options['legacy'] == '1.13' and - a.shape == () and not a.dtype.names): - return str(a.item()) + return _array_str_implementation( + a, max_line_width, precision, suppress_small) - # the str of 0d arrays is a special case: It should appear like a scalar, - # so floats are not truncated by `precision`, and strings are not wrapped - # in quotes. So we return the str of the scalar value. - if a.shape == (): - # obtain a scalar and call str on it, avoiding problems for subclasses - # for which indexing with () returns a 0d instead of a scalar by using - # ndarray's getindex. Also guard against recursive 0d object arrays. - return _guarded_str(np.ndarray.__getitem__(a, ())) - return array2string(a, max_line_width, precision, suppress_small, ' ', "") +_default_array_str = functools.partial(_array_str_implementation, + array2string=array2string.__wrapped__) +_default_array_repr = functools.partial(_array_repr_implementation, + array2string=array2string.__wrapped__) + def set_string_function(f, repr=True): """ @@ -1583,11 +1606,11 @@ def set_string_function(f, repr=True): """ if f is None: if repr: - return multiarray.set_string_function(array_repr, 1) + return multiarray.set_string_function(_default_array_repr, 1) else: - return multiarray.set_string_function(array_str, 0) + return multiarray.set_string_function(_default_array_str, 0) else: return multiarray.set_string_function(f, repr) -set_string_function(array_str, 0) -set_string_function(array_repr, 1) +set_string_function(_default_array_str, 0) +set_string_function(_default_array_repr, 1) diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 4640efd31..85a8c32bb 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -154,6 +154,10 @@ def array_function_dispatch(dispatcher, module=None, verify=True): if module is not None: public_api.__module__ = module + # TODO: remove this when we drop Python 2 support (functools.wraps + # adds __wrapped__ automatically in later versions) + public_api.__wrapped__ = implementation + return public_api return decorator diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index 7b3472f96..ee6d5da4a 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -308,6 +308,21 @@ class TestArrayFunctionImplementation(object): func(MyArray()) +class TestNDArrayMethods(object): + + def test_repr(self): + # gh-12162: should still be defined even if __array_function__ doesn't + # implement np.array_repr() + + class MyArray(np.ndarray): + def __array_function__(*args, **kwargs): + return NotImplemented + + array = np.array(1).view(MyArray) + assert_equal(repr(array), 'MyArray(1)') + assert_equal(str(array), '1') + + class TestNumPyFunctions(object): def test_module(self): |
