summaryrefslogtreecommitdiff
path: root/numpy/core/arrayprint.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/arrayprint.py')
-rw-r--r--numpy/core/arrayprint.py34
1 files changed, 25 insertions, 9 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py
index dcfb6e6a8..62cd52707 100644
--- a/numpy/core/arrayprint.py
+++ b/numpy/core/arrayprint.py
@@ -1096,7 +1096,7 @@ def format_float_scientific(x, precision=None, unique=True, trim='k',
identify the value may be printed and rounded unbiased.
-- versionadded:: 1.21.0
-
+
Returns
-------
rep : string
@@ -1181,7 +1181,7 @@ def format_float_positional(x, precision=None, unique=True,
Minimum number of digits to print. Only has an effect if `unique=True`
in which case additional digits past those necessary to uniquely
identify the value may be printed, rounding the last additional digit.
-
+
-- versionadded:: 1.21.0
Returns
@@ -1339,13 +1339,29 @@ class TimedeltaFormat(_TimelikeFormat):
class SubArrayFormat:
- def __init__(self, format_function):
+ def __init__(self, format_function, **options):
self.format_function = format_function
+ self.threshold = options['threshold']
+ self.edge_items = options['edgeitems']
+
+ def __call__(self, a):
+ self.summary_insert = "..." if a.size > self.threshold else ""
+ return self.format_array(a)
+
+ def format_array(self, a):
+ if np.ndim(a) == 0:
+ return self.format_function(a)
+
+ if self.summary_insert and a.shape[0] > 2*self.edge_items:
+ formatted = (
+ [self.format_array(a_) for a_ in a[:self.edge_items]]
+ + [self.summary_insert]
+ + [self.format_array(a_) for a_ in a[-self.edge_items:]]
+ )
+ else:
+ formatted = [self.format_array(a_) for a_ in a]
- def __call__(self, arr):
- if arr.ndim <= 1:
- return "[" + ", ".join(self.format_function(a) for a in arr) + "]"
- return "[" + ", ".join(self.__call__(a) for a in arr) + "]"
+ return "[" + ", ".join(formatted) + "]"
class StructuredVoidFormat:
@@ -1369,7 +1385,7 @@ class StructuredVoidFormat:
for field_name in data.dtype.names:
format_function = _get_format_function(data[field_name], **options)
if data.dtype[field_name].shape != ():
- format_function = SubArrayFormat(format_function)
+ format_function = SubArrayFormat(format_function, **options)
format_functions.append(format_function)
return cls(format_functions)
@@ -1428,7 +1444,7 @@ def dtype_is_implied(dtype):
# not just void types can be structured, and names are not part of the repr
if dtype.names is not None:
return False
-
+
# should care about endianness *unless size is 1* (e.g., int8, bool)
if not dtype.isnative:
return False