diff options
Diffstat (limited to 'numpy/core/arrayprint.py')
-rw-r--r-- | numpy/core/arrayprint.py | 105 |
1 files changed, 77 insertions, 28 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index 318ad5495..dba9dffb3 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -16,7 +16,18 @@ __docformat__ = 'restructuredtext' # and by Travis Oliphant 2005-8-22 for numpy import sys -from functools import reduce +import functools +if sys.version_info[0] >= 3: + try: + from _thread import get_ident + except ImportError: + from _dummy_thread import get_ident +else: + try: + from thread import get_ident + except ImportError: + from dummy_thread import get_ident + from . import numerictypes as _nt from .umath import maximum, minimum, absolute, not_equal, isnan, isinf from .multiarray import (array, format_longfloat, datetime_as_string, @@ -235,38 +246,44 @@ def repr_format(x): return repr(x) def _get_formatdict(data, precision, suppress_small, formatter): - formatdict = {'bool': _boolFormatter, - 'int': IntegerFormat(data), - 'float': FloatFormat(data, precision, suppress_small), - 'longfloat': LongFloatFormat(precision), - 'complexfloat': ComplexFormat(data, precision, + # wrapped in lambdas to avoid taking a code path with the wrong type of data + formatdict = {'bool': lambda: _boolFormatter, + 'int': lambda: IntegerFormat(data), + 'float': lambda: FloatFormat(data, precision, suppress_small), + 'longfloat': lambda: LongFloatFormat(precision), + 'complexfloat': lambda: ComplexFormat(data, precision, suppress_small), - 'longcomplexfloat': LongComplexFormat(precision), - 'datetime': DatetimeFormat(data), - 'timedelta': TimedeltaFormat(data), - 'numpystr': repr_format, - 'str': str} + 'longcomplexfloat': lambda: LongComplexFormat(precision), + 'datetime': lambda: DatetimeFormat(data), + 'timedelta': lambda: TimedeltaFormat(data), + 'numpystr': lambda: repr_format, + 'str': lambda: str} + + # we need to wrap values in `formatter` in a lambda, so that the interface + # is the same as the above values. + def indirect(x): + return lambda: x if formatter is not None: fkeys = [k for k in formatter.keys() if formatter[k] is not None] if 'all' in fkeys: for key in formatdict.keys(): - formatdict[key] = formatter['all'] + formatdict[key] = indirect(formatter['all']) if 'int_kind' in fkeys: for key in ['int']: - formatdict[key] = formatter['int_kind'] + formatdict[key] = indirect(formatter['int_kind']) if 'float_kind' in fkeys: for key in ['float', 'longfloat']: - formatdict[key] = formatter['float_kind'] + formatdict[key] = indirect(formatter['float_kind']) if 'complex_kind' in fkeys: for key in ['complexfloat', 'longcomplexfloat']: - formatdict[key] = formatter['complex_kind'] + formatdict[key] = indirect(formatter['complex_kind']) if 'str_kind' in fkeys: for key in ['numpystr', 'str']: - formatdict[key] = formatter['str_kind'] + formatdict[key] = indirect(formatter['str_kind']) for key in formatdict.keys(): if key in fkeys: - formatdict[key] = formatter[key] + formatdict[key] = indirect(formatter[key]) return formatdict @@ -289,28 +306,28 @@ def _get_format_function(data, precision, suppress_small, formatter): dtypeobj = dtype_.type formatdict = _get_formatdict(data, precision, suppress_small, formatter) if issubclass(dtypeobj, _nt.bool_): - return formatdict['bool'] + return formatdict['bool']() elif issubclass(dtypeobj, _nt.integer): if issubclass(dtypeobj, _nt.timedelta64): - return formatdict['timedelta'] + return formatdict['timedelta']() else: - return formatdict['int'] + return formatdict['int']() elif issubclass(dtypeobj, _nt.floating): if issubclass(dtypeobj, _nt.longfloat): - return formatdict['longfloat'] + return formatdict['longfloat']() else: - return formatdict['float'] + return formatdict['float']() elif issubclass(dtypeobj, _nt.complexfloating): if issubclass(dtypeobj, _nt.clongfloat): - return formatdict['longcomplexfloat'] + return formatdict['longcomplexfloat']() else: - return formatdict['complexfloat'] + return formatdict['complexfloat']() elif issubclass(dtypeobj, (_nt.unicode_, _nt.string_)): - return formatdict['numpystr'] + return formatdict['numpystr']() elif issubclass(dtypeobj, _nt.datetime64): - return formatdict['datetime'] + return formatdict['datetime']() else: - return formatdict['numpystr'] + return formatdict['numpystr']() def _array2string(a, max_line_width, precision, suppress_small, separator=' ', prefix="", formatter=None): @@ -337,6 +354,38 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ', return lst +def _recursive_guard(fillvalue='...'): + """ + Like the python 3.2 reprlib.recursive_repr, but forwards *args and **kwargs + + Decorates a function such that if it calls itself with the same first + argument, it returns `fillvalue` instead of recursing. + + Largely copied from reprlib.recursive_repr + """ + + def decorating_function(f): + repr_running = set() + + @functools.wraps(f) + def wrapper(self, *args, **kwargs): + key = id(self), get_ident() + if key in repr_running: + return fillvalue + repr_running.add(key) + try: + return f(self, *args, **kwargs) + finally: + repr_running.discard(key) + + return wrapper + + return decorating_function + + +# gracefully handle recursive calls - this comes up when object arrays contain +# themselves +@_recursive_guard() def array2string(a, max_line_width=None, precision=None, suppress_small=None, separator=' ', prefix="", style=repr, formatter=None): @@ -455,7 +504,7 @@ def array2string(a, max_line_width=None, precision=None, lst = format_function(arr[0]) else: lst = style(x) - elif reduce(product, a.shape) == 0: + elif functools.reduce(product, a.shape) == 0: # treat as a null array if any of shape elements == 0 lst = "[]" else: |