summaryrefslogtreecommitdiff
path: root/numpy/core/arrayprint.py
diff options
context:
space:
mode:
authorRalf Gommers <ralf.gommers@googlemail.com>2011-04-03 14:59:36 +0200
committerCharles Harris <charlesr.harris@gmail.com>2011-04-24 11:06:29 -0600
commitd8de71d14ce7ac08a51d06623437f1df35035a5e (patch)
tree9828a1271aa8ec9dafd94b1ac0a8c31fc993266c /numpy/core/arrayprint.py
parentbe364f74946c8fb4eb8486c51c96f70642175829 (diff)
downloadnumpy-d8de71d14ce7ac08a51d06623437f1df35035a5e.tar.gz
ENH: Ticket #1218, allow use of custom formatters in array2string and
set_printoptions. Add tests for the new functionality.
Diffstat (limited to 'numpy/core/arrayprint.py')
-rw-r--r--numpy/core/arrayprint.py217
1 files changed, 177 insertions, 40 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py
index ff6d0ae87..d0b899901 100644
--- a/numpy/core/arrayprint.py
+++ b/numpy/core/arrayprint.py
@@ -29,13 +29,14 @@ _float_output_suppress_small = False
_line_width = 75
_nan_str = 'nan'
_inf_str = 'inf'
+_formatter = None # formatting function for array elements
if sys.version_info[0] >= 3:
from functools import reduce
def set_printoptions(precision=None, threshold=None, edgeitems=None,
linewidth=None, suppress=None,
- nanstr=None, infstr=None):
+ nanstr=None, infstr=None, formatter=None):
"""
Set printing options.
@@ -62,10 +63,38 @@ def set_printoptions(precision=None, threshold=None, edgeitems=None,
String representation of floating point not-a-number (default nan).
infstr : str, optional
String representation of floating point infinity (default inf).
+ formatter : dict of callables, optional
+ If not None, the keys should indicate the type(s) that the respective
+ formatting function applies to. Callables should return a string.
+ Types that are not specified (by their corresponding keys) are handled
+ by the default formatters. Individual types for which a formatter
+ can be set are::
+
+ - 'bool'
+ - 'int'
+ - 'timeint' : a `numpy.timeinteger`
+ - 'float'
+ - 'longfloat' : 128-bit floats
+ - 'complexfloat'
+ - 'longcomplexfloat' : composed of two 128-bit floats
+ - 'numpy_str' : types `numpy.string_` and `numpy.unicode_`
+ - 'str' : all other strings
+
+ Other keys that can be used to set a group of types at once are::
+
+ - 'all' : sets all types
+ - 'int_kind' : sets 'int' and 'timeint'
+ - 'float_kind' : sets 'float' and 'longfloat'
+ - 'complex_kind' : sets 'complexfloat' and 'longcomplexfloat'
+ - 'str_kind' : sets 'str' and 'numpystr'
See Also
--------
- get_printoptions, set_string_function
+ get_printoptions, set_string_function, array2string
+
+ Notes
+ -----
+ `formatter` is always reset with a call to `set_printoptions`.
Examples
--------
@@ -91,15 +120,26 @@ def set_printoptions(precision=None, threshold=None, edgeitems=None,
>>> x**2 - (x + eps)**2
array([-0., -0., 0., 0.])
+ A custom formatter can be used to display array elements as desired:
+
+ >>> np.set_printoptions(formatter={'all':lambda x: 'int: '+str(-x)})
+ >>> x = np.arange(3)
+ >>> x
+ array([int: 0, int: -1, int: -2])
+ >>> np.set_printoptions() # formatter gets reset
+ >>> x
+ array([0, 1, 2])
+
To put back the default options, you can use:
- >>> np.set_printoptions(edgeitems=3,infstr='Inf',
- ... linewidth=75, nanstr='NaN', precision=8,
- ... suppress=False, threshold=1000)
+ >>> np.set_printoptions(edgeitems=3,infstr='inf',
+ ... linewidth=75, nanstr='nan', precision=8,
+ ... suppress=False, threshold=1000, formatter=None)
"""
global _summaryThreshold, _summaryEdgeItems, _float_output_precision, \
- _line_width, _float_output_suppress_small, _nan_str, _inf_str
+ _line_width, _float_output_suppress_small, _nan_str, _inf_str, \
+ _formatter
if linewidth is not None:
_line_width = linewidth
if threshold is not None:
@@ -114,6 +154,7 @@ def set_printoptions(precision=None, threshold=None, edgeitems=None,
_nan_str = nanstr
if infstr is not None:
_inf_str = infstr
+ _formatter = formatter
def get_printoptions():
"""
@@ -131,6 +172,7 @@ def get_printoptions():
- suppress : bool
- nanstr : str
- infstr : str
+ - formatter : dict of callables
For a full description of these options, see `set_printoptions`.
@@ -145,7 +187,8 @@ def get_printoptions():
linewidth=_line_width,
suppress=_float_output_suppress_small,
nanstr=_nan_str,
- infstr=_inf_str)
+ infstr=_inf_str,
+ formatter=_formatter)
return d
def _leading_trailing(a):
@@ -173,7 +216,7 @@ def _boolFormatter(x):
def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
- prefix=""):
+ prefix="", formatter=None):
if max_line_width is None:
max_line_width = _line_width
@@ -184,6 +227,9 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
if suppress_small is None:
suppress_small = _float_output_suppress_small
+ if formatter is None:
+ formatter = _formatter
+
if a.size > _summaryThreshold:
summary_insert = "..., "
data = _leading_trailing(a)
@@ -191,44 +237,76 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
summary_insert = ""
data = ravel(a)
+ formatdict = {'bool' : _boolFormatter,
+ 'int' : IntegerFormat(data),
+ 'timeint' : str,
+ 'float' : FloatFormat(data, precision, suppress_small),
+ 'longfloat' : LongFloatFormat(precision),
+ 'complexfloat' : ComplexFormat(data, precision,
+ suppress_small),
+ 'longcomplexfloat' : LongComplexFormat(precision),
+ 'numpystr' : repr,
+ 'str' : str}
+ if formatter is not None:
+ fkeys = [k for k in formatter.keys() if formatter[k] is not None]
+ if 'all' in fkeys:
+ for key in formatdict.keys():
+ formatdict[key] = formatter['all']
+ if 'int_kind' in fkeys:
+ for key in ['int', 'timeint']:
+ formatdict[key] = formatter['int_kind']
+ if 'float_kind' in fkeys:
+ for key in ['float', 'longfloat']:
+ formatdict[key] = formatter['float_kind']
+ if 'complex_kind' in fkeys:
+ for key in ['complexfloat', 'longcomplexfloat']:
+ formatdict[key] = formatter['complex_kind']
+ if 'str_kind' in fkeys:
+ for key in ['numpystr', 'str']:
+ formatdict[key] = formatter['str_kind']
+ for key in formatdict.keys():
+ if key in fkeys:
+ formatdict[key] = formatter[key]
+
try:
format_function = a._format
+ msg = "The `_format` attribute is deprecated in Numpy 2.0 and " \
+ "will be removed in 2.1. Use the `formatter` kw instead."
+ import warnings
+ warnings.warn(msg, DeprecationWarning)
except AttributeError:
+ # find the right formatting function for the array
dtypeobj = a.dtype.type
if issubclass(dtypeobj, _nt.bool_):
- # make sure True and False line up.
- format_function = _boolFormatter
+ format_function = formatdict['bool']
elif issubclass(dtypeobj, _nt.integer):
if issubclass(dtypeobj, _nt.timeinteger):
- format_function = str
+ format_function = formatdict['timeint']
else:
- max_str_len = max(len(str(maximum.reduce(data))),
- len(str(minimum.reduce(data))))
- format = '%' + str(max_str_len) + 'd'
- format_function = lambda x: _formatInteger(x, format)
+ format_function = formatdict['int']
elif issubclass(dtypeobj, _nt.floating):
if issubclass(dtypeobj, _nt.longfloat):
- format_function = LongFloatFormat(precision)
+ format_function = formatdict['longfloat']
else:
- format_function = FloatFormat(data, precision, suppress_small)
+ format_function = formatdict['float']
elif issubclass(dtypeobj, _nt.complexfloating):
if issubclass(dtypeobj, _nt.clongfloat):
- format_function = LongComplexFormat(precision)
+ format_function = formatdict['longcomplexfloat']
else:
- format_function = ComplexFormat(data, precision, suppress_small)
- elif issubclass(dtypeobj, _nt.unicode_) or \
- issubclass(dtypeobj, _nt.string_):
- format_function = repr
+ format_function = formatdict['complexfloat']
+ elif issubclass(dtypeobj, (_nt.unicode_, _nt.string_)):
+ format_function = formatdict['numpystr']
else:
- format_function = str
+ format_function = formatdict['str']
- next_line_prefix = " " # skip over "["
- next_line_prefix += " "*len(prefix) # skip over array(
+ # skip over "["
+ next_line_prefix = " "
+ # skip over array(
+ next_line_prefix += " "*len(prefix)
lst = _formatArray(a, format_function, len(a.shape), max_line_width,
next_line_prefix, separator,
_summaryEdgeItems, summary_insert)[:-1]
-
return lst
def _convert_arrays(obj):
@@ -243,9 +321,9 @@ def _convert_arrays(obj):
return tuple(newtup)
-def array2string(a, max_line_width = None, precision = None,
- suppress_small = None, separator=' ', prefix="",
- style=repr):
+def array2string(a, max_line_width=None, precision=None,
+ suppress_small=None, separator=' ', prefix="",
+ style=repr, formatter=None):
"""
Return a string representation of an array.
@@ -273,16 +351,49 @@ def array2string(a, max_line_width = None, precision = None,
output correctly.
style : function, optional
A function that accepts an ndarray and returns a string. Used only
- when the shape of `a` is equal to ().
+ when the shape of `a` is equal to ``()``, i.e. for 0-D arrays.
+ formatter : dict of callables, optional
+ If not None, the keys should indicate the type(s) that the respective
+ formatting function applies to. Callables should return a string.
+ Types that are not specified (by their corresponding keys) are handled
+ by the default formatters. Individual types for which a formatter
+ can be set are::
+
+ - 'bool'
+ - 'int'
+ - 'timeint' : a `numpy.timeinteger`
+ - 'float'
+ - 'longfloat' : 128-bit floats
+ - 'complexfloat'
+ - 'longcomplexfloat' : composed of two 128-bit floats
+ - 'numpy_str' : types `numpy.string_` and `numpy.unicode_`
+ - 'str' : all other strings
+
+ Other keys that can be used to set a group of types at once are::
+
+ - 'all' : sets all types
+ - 'int_kind' : sets 'int' and 'timeint'
+ - 'float_kind' : sets 'float' and 'longfloat'
+ - 'complex_kind' : sets 'complexfloat' and 'longcomplexfloat'
+ - 'str_kind' : sets 'str' and 'numpystr'
Returns
-------
array_str : str
String representation of the array.
+ Raises
+ ------
+ TypeError : if a callable in `formatter` does not return a string.
+
See Also
--------
- array_str, array_repr, set_printoptions
+ array_str, array_repr, set_printoptions, get_printoptions
+
+ Notes
+ -----
+ If a formatter is specified for a certain type, the `precision` keyword is
+ ignored for that type.
Examples
--------
@@ -291,12 +402,24 @@ def array2string(a, max_line_width = None, precision = None,
... suppress_small=True)
[ 0., 1., 2., 3.]
+ >>> x = np.arange(3.)
+ >>> np.array2string(x, formatter={'float_kind':lambda x: "%.2f" % x})
+ '[0.00 1.00 2.00]'
+
+ >>> x = np.arange(3)
+ >>> np.array2string(x, formatter={'int':lambda x: hex(x)})
+ '[0x0L 0x1L 0x2L]'
+
"""
if a.shape == ():
x = a.item()
try:
lst = a._format(x)
+ msg = "The `_format` attribute is deprecated in Numpy 2.0 and " \
+ "will be removed in 2.1. Use the `formatter` kw instead."
+ import warnings
+ warnings.warn(msg, DeprecationWarning)
except AttributeError:
if isinstance(x, tuple):
x = _convert_arrays(x)
@@ -306,7 +429,7 @@ def array2string(a, max_line_width = None, precision = None,
lst = "[]"
else:
lst = _array2string(a, max_line_width, precision, suppress_small,
- separator, prefix)
+ separator, prefix, formatter=formatter)
return lst
def _extendLine(s, line, word, max_line_len, next_line_prefix):
@@ -392,7 +515,12 @@ class FloatFormat(object):
self.exp_format = False
self.large_exponent = False
self.max_str_len = 0
- self.fillFormat(data)
+ try:
+ self.fillFormat(data)
+ except (TypeError, NotImplementedError):
+ # if reduce(data) fails, this instance will not be called, just
+ # instantiated in formatdict.
+ pass
def fillFormat(self, data):
import numeric as _nc
@@ -490,11 +618,22 @@ def _digits(x, precision, format):
_MAXINT = sys.maxint
_MININT = -sys.maxint-1
-def _formatInteger(x, format):
- if _MININT < x < _MAXINT:
- return format % x
- else:
- return "%s" % x
+class IntegerFormat(object):
+ def __init__(self, data):
+ try:
+ max_str_len = max(len(str(maximum.reduce(data))),
+ len(str(minimum.reduce(data))))
+ self.format = '%' + str(max_str_len) + 'd'
+ except TypeError, NotImplementedError:
+ # if reduce(data) fails, this instance will not be called, just
+ # instantiated in formatdict.
+ pass
+
+ def __call__(self, x):
+ if _MININT < x < _MAXINT:
+ return self.format % x
+ else:
+ return "%s" % x
class LongFloatFormat(object):
# XXX Have to add something to determine the width to use a la FloatFormat
@@ -552,5 +691,3 @@ class ComplexFormat(object):
else:
i = i + 'j'
return r + i
-
-## end