diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-04-20 11:40:40 +0100 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-04-20 12:16:50 +0100 |
commit | a2aea7757aacaa140f22910de6b81f9196a4aecc (patch) | |
tree | 11b746555d6928e44ff58390729063cb773074a2 /numpy/core/arrayprint.py | |
parent | 3b2a7a761d5ceef3b9dcca3fff10380cb7a6f976 (diff) | |
download | numpy-a2aea7757aacaa140f22910de6b81f9196a4aecc.tar.gz |
BUG: Don't construct formatters until we're sure they're correct
Previously, formatters could incur errors from being run on object arrays, even
though the formatter was not used.
Diffstat (limited to 'numpy/core/arrayprint.py')
-rw-r--r-- | numpy/core/arrayprint.py | 59 |
1 files changed, 32 insertions, 27 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index 318ad5495..4e62a42fc 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -235,38 +235,44 @@ def repr_format(x): return repr(x) def _get_formatdict(data, precision, suppress_small, formatter): - formatdict = {'bool': _boolFormatter, - 'int': IntegerFormat(data), - 'float': FloatFormat(data, precision, suppress_small), - 'longfloat': LongFloatFormat(precision), - 'complexfloat': ComplexFormat(data, precision, + # wrapped in lambdas to avoid taking a code path with the wrong type of data + formatdict = {'bool': lambda: _boolFormatter, + 'int': lambda: IntegerFormat(data), + 'float': lambda: FloatFormat(data, precision, suppress_small), + 'longfloat': lambda: LongFloatFormat(precision), + 'complexfloat': lambda: ComplexFormat(data, precision, suppress_small), - 'longcomplexfloat': LongComplexFormat(precision), - 'datetime': DatetimeFormat(data), - 'timedelta': TimedeltaFormat(data), - 'numpystr': repr_format, - 'str': str} + 'longcomplexfloat': lambda: LongComplexFormat(precision), + 'datetime': lambda: DatetimeFormat(data), + 'timedelta': lambda: TimedeltaFormat(data), + 'numpystr': lambda: repr_format, + 'str': lambda: str} + + # we need to wrap values in `formatter` in a lambda, so that the interface + # is the same as the above values. + def indirect(x): + return lambda: x if formatter is not None: fkeys = [k for k in formatter.keys() if formatter[k] is not None] if 'all' in fkeys: for key in formatdict.keys(): - formatdict[key] = formatter['all'] + formatdict[key] = indirect(formatter['all']) if 'int_kind' in fkeys: for key in ['int']: - formatdict[key] = formatter['int_kind'] + formatdict[key] = indirect(formatter['int_kind']) if 'float_kind' in fkeys: for key in ['float', 'longfloat']: - formatdict[key] = formatter['float_kind'] + formatdict[key] = indirect(formatter['float_kind']) if 'complex_kind' in fkeys: for key in ['complexfloat', 'longcomplexfloat']: - formatdict[key] = formatter['complex_kind'] + formatdict[key] = indirect(formatter['complex_kind']) if 'str_kind' in fkeys: for key in ['numpystr', 'str']: - formatdict[key] = formatter['str_kind'] + formatdict[key] = indirect(formatter['str_kind']) for key in formatdict.keys(): if key in fkeys: - formatdict[key] = formatter[key] + formatdict[key] = indirect(formatter[key]) return formatdict @@ -289,28 +295,28 @@ def _get_format_function(data, precision, suppress_small, formatter): dtypeobj = dtype_.type formatdict = _get_formatdict(data, precision, suppress_small, formatter) if issubclass(dtypeobj, _nt.bool_): - return formatdict['bool'] + return formatdict['bool']() elif issubclass(dtypeobj, _nt.integer): if issubclass(dtypeobj, _nt.timedelta64): - return formatdict['timedelta'] + return formatdict['timedelta']() else: - return formatdict['int'] + return formatdict['int']() elif issubclass(dtypeobj, _nt.floating): if issubclass(dtypeobj, _nt.longfloat): - return formatdict['longfloat'] + return formatdict['longfloat']() else: - return formatdict['float'] + return formatdict['float']() elif issubclass(dtypeobj, _nt.complexfloating): if issubclass(dtypeobj, _nt.clongfloat): - return formatdict['longcomplexfloat'] + return formatdict['longcomplexfloat']() else: - return formatdict['complexfloat'] + return formatdict['complexfloat']() elif issubclass(dtypeobj, (_nt.unicode_, _nt.string_)): - return formatdict['numpystr'] + return formatdict['numpystr']() elif issubclass(dtypeobj, _nt.datetime64): - return formatdict['datetime'] + return formatdict['datetime']() else: - return formatdict['numpystr'] + return formatdict['numpystr']() def _array2string(a, max_line_width, precision, suppress_small, separator=' ', prefix="", formatter=None): @@ -336,7 +342,6 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ', _summaryEdgeItems, summary_insert)[:-1] return lst - def array2string(a, max_line_width=None, precision=None, suppress_small=None, separator=' ', prefix="", style=repr, formatter=None): |