diff options
author | Ralf Gommers <ralf.gommers@googlemail.com> | 2011-04-03 14:59:36 +0200 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2011-04-24 11:06:29 -0600 |
commit | d8de71d14ce7ac08a51d06623437f1df35035a5e (patch) | |
tree | 9828a1271aa8ec9dafd94b1ac0a8c31fc993266c /numpy/core/arrayprint.py | |
parent | be364f74946c8fb4eb8486c51c96f70642175829 (diff) | |
download | numpy-d8de71d14ce7ac08a51d06623437f1df35035a5e.tar.gz |
ENH: Ticket #1218, allow use of custom formatters in array2string and
set_printoptions. Add tests for the new functionality.
Diffstat (limited to 'numpy/core/arrayprint.py')
-rw-r--r-- | numpy/core/arrayprint.py | 217 |
1 files changed, 177 insertions, 40 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index ff6d0ae87..d0b899901 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -29,13 +29,14 @@ _float_output_suppress_small = False _line_width = 75 _nan_str = 'nan' _inf_str = 'inf' +_formatter = None # formatting function for array elements if sys.version_info[0] >= 3: from functools import reduce def set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, suppress=None, - nanstr=None, infstr=None): + nanstr=None, infstr=None, formatter=None): """ Set printing options. @@ -62,10 +63,38 @@ def set_printoptions(precision=None, threshold=None, edgeitems=None, String representation of floating point not-a-number (default nan). infstr : str, optional String representation of floating point infinity (default inf). + formatter : dict of callables, optional + If not None, the keys should indicate the type(s) that the respective + formatting function applies to. Callables should return a string. + Types that are not specified (by their corresponding keys) are handled + by the default formatters. Individual types for which a formatter + can be set are:: + + - 'bool' + - 'int' + - 'timeint' : a `numpy.timeinteger` + - 'float' + - 'longfloat' : 128-bit floats + - 'complexfloat' + - 'longcomplexfloat' : composed of two 128-bit floats + - 'numpy_str' : types `numpy.string_` and `numpy.unicode_` + - 'str' : all other strings + + Other keys that can be used to set a group of types at once are:: + + - 'all' : sets all types + - 'int_kind' : sets 'int' and 'timeint' + - 'float_kind' : sets 'float' and 'longfloat' + - 'complex_kind' : sets 'complexfloat' and 'longcomplexfloat' + - 'str_kind' : sets 'str' and 'numpystr' See Also -------- - get_printoptions, set_string_function + get_printoptions, set_string_function, array2string + + Notes + ----- + `formatter` is always reset with a call to `set_printoptions`. Examples -------- @@ -91,15 +120,26 @@ def set_printoptions(precision=None, threshold=None, edgeitems=None, >>> x**2 - (x + eps)**2 array([-0., -0., 0., 0.]) + A custom formatter can be used to display array elements as desired: + + >>> np.set_printoptions(formatter={'all':lambda x: 'int: '+str(-x)}) + >>> x = np.arange(3) + >>> x + array([int: 0, int: -1, int: -2]) + >>> np.set_printoptions() # formatter gets reset + >>> x + array([0, 1, 2]) + To put back the default options, you can use: - >>> np.set_printoptions(edgeitems=3,infstr='Inf', - ... linewidth=75, nanstr='NaN', precision=8, - ... suppress=False, threshold=1000) + >>> np.set_printoptions(edgeitems=3,infstr='inf', + ... linewidth=75, nanstr='nan', precision=8, + ... suppress=False, threshold=1000, formatter=None) """ global _summaryThreshold, _summaryEdgeItems, _float_output_precision, \ - _line_width, _float_output_suppress_small, _nan_str, _inf_str + _line_width, _float_output_suppress_small, _nan_str, _inf_str, \ + _formatter if linewidth is not None: _line_width = linewidth if threshold is not None: @@ -114,6 +154,7 @@ def set_printoptions(precision=None, threshold=None, edgeitems=None, _nan_str = nanstr if infstr is not None: _inf_str = infstr + _formatter = formatter def get_printoptions(): """ @@ -131,6 +172,7 @@ def get_printoptions(): - suppress : bool - nanstr : str - infstr : str + - formatter : dict of callables For a full description of these options, see `set_printoptions`. @@ -145,7 +187,8 @@ def get_printoptions(): linewidth=_line_width, suppress=_float_output_suppress_small, nanstr=_nan_str, - infstr=_inf_str) + infstr=_inf_str, + formatter=_formatter) return d def _leading_trailing(a): @@ -173,7 +216,7 @@ def _boolFormatter(x): def _array2string(a, max_line_width, precision, suppress_small, separator=' ', - prefix=""): + prefix="", formatter=None): if max_line_width is None: max_line_width = _line_width @@ -184,6 +227,9 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ', if suppress_small is None: suppress_small = _float_output_suppress_small + if formatter is None: + formatter = _formatter + if a.size > _summaryThreshold: summary_insert = "..., " data = _leading_trailing(a) @@ -191,44 +237,76 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ', summary_insert = "" data = ravel(a) + formatdict = {'bool' : _boolFormatter, + 'int' : IntegerFormat(data), + 'timeint' : str, + 'float' : FloatFormat(data, precision, suppress_small), + 'longfloat' : LongFloatFormat(precision), + 'complexfloat' : ComplexFormat(data, precision, + suppress_small), + 'longcomplexfloat' : LongComplexFormat(precision), + 'numpystr' : repr, + 'str' : str} + if formatter is not None: + fkeys = [k for k in formatter.keys() if formatter[k] is not None] + if 'all' in fkeys: + for key in formatdict.keys(): + formatdict[key] = formatter['all'] + if 'int_kind' in fkeys: + for key in ['int', 'timeint']: + formatdict[key] = formatter['int_kind'] + if 'float_kind' in fkeys: + for key in ['float', 'longfloat']: + formatdict[key] = formatter['float_kind'] + if 'complex_kind' in fkeys: + for key in ['complexfloat', 'longcomplexfloat']: + formatdict[key] = formatter['complex_kind'] + if 'str_kind' in fkeys: + for key in ['numpystr', 'str']: + formatdict[key] = formatter['str_kind'] + for key in formatdict.keys(): + if key in fkeys: + formatdict[key] = formatter[key] + try: format_function = a._format + msg = "The `_format` attribute is deprecated in Numpy 2.0 and " \ + "will be removed in 2.1. Use the `formatter` kw instead." + import warnings + warnings.warn(msg, DeprecationWarning) except AttributeError: + # find the right formatting function for the array dtypeobj = a.dtype.type if issubclass(dtypeobj, _nt.bool_): - # make sure True and False line up. - format_function = _boolFormatter + format_function = formatdict['bool'] elif issubclass(dtypeobj, _nt.integer): if issubclass(dtypeobj, _nt.timeinteger): - format_function = str + format_function = formatdict['timeint'] else: - max_str_len = max(len(str(maximum.reduce(data))), - len(str(minimum.reduce(data)))) - format = '%' + str(max_str_len) + 'd' - format_function = lambda x: _formatInteger(x, format) + format_function = formatdict['int'] elif issubclass(dtypeobj, _nt.floating): if issubclass(dtypeobj, _nt.longfloat): - format_function = LongFloatFormat(precision) + format_function = formatdict['longfloat'] else: - format_function = FloatFormat(data, precision, suppress_small) + format_function = formatdict['float'] elif issubclass(dtypeobj, _nt.complexfloating): if issubclass(dtypeobj, _nt.clongfloat): - format_function = LongComplexFormat(precision) + format_function = formatdict['longcomplexfloat'] else: - format_function = ComplexFormat(data, precision, suppress_small) - elif issubclass(dtypeobj, _nt.unicode_) or \ - issubclass(dtypeobj, _nt.string_): - format_function = repr + format_function = formatdict['complexfloat'] + elif issubclass(dtypeobj, (_nt.unicode_, _nt.string_)): + format_function = formatdict['numpystr'] else: - format_function = str + format_function = formatdict['str'] - next_line_prefix = " " # skip over "[" - next_line_prefix += " "*len(prefix) # skip over array( + # skip over "[" + next_line_prefix = " " + # skip over array( + next_line_prefix += " "*len(prefix) lst = _formatArray(a, format_function, len(a.shape), max_line_width, next_line_prefix, separator, _summaryEdgeItems, summary_insert)[:-1] - return lst def _convert_arrays(obj): @@ -243,9 +321,9 @@ def _convert_arrays(obj): return tuple(newtup) -def array2string(a, max_line_width = None, precision = None, - suppress_small = None, separator=' ', prefix="", - style=repr): +def array2string(a, max_line_width=None, precision=None, + suppress_small=None, separator=' ', prefix="", + style=repr, formatter=None): """ Return a string representation of an array. @@ -273,16 +351,49 @@ def array2string(a, max_line_width = None, precision = None, output correctly. style : function, optional A function that accepts an ndarray and returns a string. Used only - when the shape of `a` is equal to (). + when the shape of `a` is equal to ``()``, i.e. for 0-D arrays. + formatter : dict of callables, optional + If not None, the keys should indicate the type(s) that the respective + formatting function applies to. Callables should return a string. + Types that are not specified (by their corresponding keys) are handled + by the default formatters. Individual types for which a formatter + can be set are:: + + - 'bool' + - 'int' + - 'timeint' : a `numpy.timeinteger` + - 'float' + - 'longfloat' : 128-bit floats + - 'complexfloat' + - 'longcomplexfloat' : composed of two 128-bit floats + - 'numpy_str' : types `numpy.string_` and `numpy.unicode_` + - 'str' : all other strings + + Other keys that can be used to set a group of types at once are:: + + - 'all' : sets all types + - 'int_kind' : sets 'int' and 'timeint' + - 'float_kind' : sets 'float' and 'longfloat' + - 'complex_kind' : sets 'complexfloat' and 'longcomplexfloat' + - 'str_kind' : sets 'str' and 'numpystr' Returns ------- array_str : str String representation of the array. + Raises + ------ + TypeError : if a callable in `formatter` does not return a string. + See Also -------- - array_str, array_repr, set_printoptions + array_str, array_repr, set_printoptions, get_printoptions + + Notes + ----- + If a formatter is specified for a certain type, the `precision` keyword is + ignored for that type. Examples -------- @@ -291,12 +402,24 @@ def array2string(a, max_line_width = None, precision = None, ... suppress_small=True) [ 0., 1., 2., 3.] + >>> x = np.arange(3.) + >>> np.array2string(x, formatter={'float_kind':lambda x: "%.2f" % x}) + '[0.00 1.00 2.00]' + + >>> x = np.arange(3) + >>> np.array2string(x, formatter={'int':lambda x: hex(x)}) + '[0x0L 0x1L 0x2L]' + """ if a.shape == (): x = a.item() try: lst = a._format(x) + msg = "The `_format` attribute is deprecated in Numpy 2.0 and " \ + "will be removed in 2.1. Use the `formatter` kw instead." + import warnings + warnings.warn(msg, DeprecationWarning) except AttributeError: if isinstance(x, tuple): x = _convert_arrays(x) @@ -306,7 +429,7 @@ def array2string(a, max_line_width = None, precision = None, lst = "[]" else: lst = _array2string(a, max_line_width, precision, suppress_small, - separator, prefix) + separator, prefix, formatter=formatter) return lst def _extendLine(s, line, word, max_line_len, next_line_prefix): @@ -392,7 +515,12 @@ class FloatFormat(object): self.exp_format = False self.large_exponent = False self.max_str_len = 0 - self.fillFormat(data) + try: + self.fillFormat(data) + except (TypeError, NotImplementedError): + # if reduce(data) fails, this instance will not be called, just + # instantiated in formatdict. + pass def fillFormat(self, data): import numeric as _nc @@ -490,11 +618,22 @@ def _digits(x, precision, format): _MAXINT = sys.maxint _MININT = -sys.maxint-1 -def _formatInteger(x, format): - if _MININT < x < _MAXINT: - return format % x - else: - return "%s" % x +class IntegerFormat(object): + def __init__(self, data): + try: + max_str_len = max(len(str(maximum.reduce(data))), + len(str(minimum.reduce(data)))) + self.format = '%' + str(max_str_len) + 'd' + except TypeError, NotImplementedError: + # if reduce(data) fails, this instance will not be called, just + # instantiated in formatdict. + pass + + def __call__(self, x): + if _MININT < x < _MAXINT: + return self.format % x + else: + return "%s" % x class LongFloatFormat(object): # XXX Have to add something to determine the width to use a la FloatFormat @@ -552,5 +691,3 @@ class ComplexFormat(object): else: i = i + 'j' return r + i - -## end |