summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@google.com>2018-10-18 11:01:09 -0700
committerStephan Hoyer <shoyer@google.com>2018-10-18 11:01:09 -0700
commitd9917024e55d1ad99f8b47c5314629ccd821a47c (patch)
treea8b1697da446b361ce548586f95875ebf1bb75bc
parent2c4c93af0b2d20d85a7432093f31318cbf3c457f (diff)
downloadnumpy-d9917024e55d1ad99f8b47c5314629ccd821a47c.tar.gz
MAINT: ndarray.__repr__ should not rely on __array_function__
``ndarray.__repr__`` and ``ndarray.__str__`` should not rely upon ``__array_function__`` internally, so they are still well defined on subclasses even if ``array_repr`` and ``array_str`` are not implemented. Fixes gh-12162
-rw-r--r--numpy/core/arrayprint.py129
-rw-r--r--numpy/core/tests/test_overrides.py15
2 files changed, 91 insertions, 53 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py
index 1b9fbbfa9..0c45989e1 100644
--- a/numpy/core/arrayprint.py
+++ b/numpy/core/arrayprint.py
@@ -1381,6 +1381,53 @@ def dtype_short_repr(dtype):
return typename
+def _array_repr_implementation(
+ arr, max_line_width=None, precision=None, suppress_small=None,
+ array2string=array2string):
+ """Internal version of array_repr() that allows overriding array2string."""
+ if max_line_width is None:
+ max_line_width = _format_options['linewidth']
+
+ if type(arr) is not ndarray:
+ class_name = type(arr).__name__
+ else:
+ class_name = "array"
+
+ skipdtype = dtype_is_implied(arr.dtype) and arr.size > 0
+
+ prefix = class_name + "("
+ suffix = ")" if skipdtype else ","
+
+ if (_format_options['legacy'] == '1.13' and
+ arr.shape == () and not arr.dtype.names):
+ lst = repr(arr.item())
+ elif arr.size > 0 or arr.shape == (0,):
+ lst = array2string(arr, max_line_width, precision, suppress_small,
+ ', ', prefix, suffix=suffix)
+ else: # show zero-length shape unless it is (0,)
+ lst = "[], shape=%s" % (repr(arr.shape),)
+
+ arr_str = prefix + lst + suffix
+
+ if skipdtype:
+ return arr_str
+
+ dtype_str = "dtype={})".format(dtype_short_repr(arr.dtype))
+
+ # compute whether we should put dtype on a new line: Do so if adding the
+ # dtype would extend the last line past max_line_width.
+ # Note: This line gives the correct result even when rfind returns -1.
+ last_line_len = len(arr_str) - (arr_str.rfind('\n') + 1)
+ spacer = " "
+ if _format_options['legacy'] == '1.13':
+ if issubclass(arr.dtype.type, flexible):
+ spacer = '\n' + ' '*len(class_name + "(")
+ elif last_line_len + len(dtype_str) + 1 > max_line_width:
+ spacer = '\n' + ' '*len(class_name + "(")
+
+ return arr_str + spacer + dtype_str
+
+
def _array_repr_dispatcher(
arr, max_line_width=None, precision=None, suppress_small=None):
return (arr,)
@@ -1429,50 +1476,31 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None):
'array([ 0.000001, 0. , 2. , 3. ])'
"""
- if max_line_width is None:
- max_line_width = _format_options['linewidth']
+ return _array_repr_implementation(
+ arr, max_line_width, precision, suppress_small)
- if type(arr) is not ndarray:
- class_name = type(arr).__name__
- else:
- class_name = "array"
- skipdtype = dtype_is_implied(arr.dtype) and arr.size > 0
+_guarded_str = _recursive_guard()(str)
- prefix = class_name + "("
- suffix = ")" if skipdtype else ","
+def _array_str_implementation(
+ a, max_line_width=None, precision=None, suppress_small=None,
+ array2string=array2string):
+ """Internal version of array_str() that allows overriding array2string."""
if (_format_options['legacy'] == '1.13' and
- arr.shape == () and not arr.dtype.names):
- lst = repr(arr.item())
- elif arr.size > 0 or arr.shape == (0,):
- lst = array2string(arr, max_line_width, precision, suppress_small,
- ', ', prefix, suffix=suffix)
- else: # show zero-length shape unless it is (0,)
- lst = "[], shape=%s" % (repr(arr.shape),)
-
- arr_str = prefix + lst + suffix
-
- if skipdtype:
- return arr_str
-
- dtype_str = "dtype={})".format(dtype_short_repr(arr.dtype))
-
- # compute whether we should put dtype on a new line: Do so if adding the
- # dtype would extend the last line past max_line_width.
- # Note: This line gives the correct result even when rfind returns -1.
- last_line_len = len(arr_str) - (arr_str.rfind('\n') + 1)
- spacer = " "
- if _format_options['legacy'] == '1.13':
- if issubclass(arr.dtype.type, flexible):
- spacer = '\n' + ' '*len(class_name + "(")
- elif last_line_len + len(dtype_str) + 1 > max_line_width:
- spacer = '\n' + ' '*len(class_name + "(")
-
- return arr_str + spacer + dtype_str
+ a.shape == () and not a.dtype.names):
+ return str(a.item())
+ # the str of 0d arrays is a special case: It should appear like a scalar,
+ # so floats are not truncated by `precision`, and strings are not wrapped
+ # in quotes. So we return the str of the scalar value.
+ if a.shape == ():
+ # obtain a scalar and call str on it, avoiding problems for subclasses
+ # for which indexing with () returns a 0d instead of a scalar by using
+ # ndarray's getindex. Also guard against recursive 0d object arrays.
+ return _guarded_str(np.ndarray.__getitem__(a, ()))
-_guarded_str = _recursive_guard()(str)
+ return array2string(a, max_line_width, precision, suppress_small, ' ', "")
def _array_str_dispatcher(
@@ -1515,20 +1543,15 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None):
'[0 1 2]'
"""
- if (_format_options['legacy'] == '1.13' and
- a.shape == () and not a.dtype.names):
- return str(a.item())
+ return _array_str_implementation(
+ a, max_line_width, precision, suppress_small)
- # the str of 0d arrays is a special case: It should appear like a scalar,
- # so floats are not truncated by `precision`, and strings are not wrapped
- # in quotes. So we return the str of the scalar value.
- if a.shape == ():
- # obtain a scalar and call str on it, avoiding problems for subclasses
- # for which indexing with () returns a 0d instead of a scalar by using
- # ndarray's getindex. Also guard against recursive 0d object arrays.
- return _guarded_str(np.ndarray.__getitem__(a, ()))
- return array2string(a, max_line_width, precision, suppress_small, ' ', "")
+_default_array_str = functools.partial(_array_str_implementation,
+ array2string=array2string.__wrapped__)
+_default_array_repr = functools.partial(_array_repr_implementation,
+ array2string=array2string.__wrapped__)
+
def set_string_function(f, repr=True):
"""
@@ -1583,11 +1606,11 @@ def set_string_function(f, repr=True):
"""
if f is None:
if repr:
- return multiarray.set_string_function(array_repr, 1)
+ return multiarray.set_string_function(_default_array_repr, 1)
else:
- return multiarray.set_string_function(array_str, 0)
+ return multiarray.set_string_function(_default_array_str, 0)
else:
return multiarray.set_string_function(f, repr)
-set_string_function(array_str, 0)
-set_string_function(array_repr, 1)
+set_string_function(_default_array_str, 0)
+set_string_function(_default_array_repr, 1)
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index 3f87a6afe..c959655a7 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -304,3 +304,18 @@ class TestArrayFunctionImplementation(object):
with assert_raises_regex(TypeError, 'no implementation found'):
func(MyArray())
+
+
+class TestNDArrayMethods(object):
+
+ def test_repr(self):
+ # gh-12162: should still be defined even if __array_function__ doesn't
+ # implement np.array_repr()
+
+ class MyArray(np.ndarray):
+ def __array_function__(*args, **kwargs):
+ return NotImplemented
+
+ array = np.array(1).view(MyArray)
+ assert_equal(repr(array), 'MyArray(1)')
+ assert_equal(str(array), '1')