diff options
author | jaimefrio <jaime.frio@gmail.com> | 2014-09-25 10:45:08 -0700 |
---|---|---|
committer | jaimefrio <jaime.frio@gmail.com> | 2014-09-25 10:49:16 -0700 |
commit | 3a0587e545e959747d9b501dbf029a4cd6576547 (patch) | |
tree | e6cb41c3c6f67240a28690241b5a6d0e8e84aa2d /numpy/lib/polynomial.py | |
parent | f4fa7bd2a67a577eaa72af83028adcfbc71b7fd4 (diff) | |
download | numpy-3a0587e545e959747d9b501dbf029a4cd6576547.tar.gz |
ENH: Cast non-object arrays to float in np.poly
Closes #5096. Casts integer arrays to np.double, to prevent
integer overflow. Object arrays are left unchanged, to allow
use of arbitrary precision objects.
Diffstat (limited to 'numpy/lib/polynomial.py')
-rw-r--r-- | numpy/lib/polynomial.py | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/numpy/lib/polynomial.py b/numpy/lib/polynomial.py index 7e4ec1485..2b867e244 100644 --- a/numpy/lib/polynomial.py +++ b/numpy/lib/polynomial.py @@ -12,10 +12,11 @@ import re import warnings import numpy.core.numeric as NX -from numpy.core import isscalar, abs, finfo, atleast_1d, hstack, dot +from numpy.core import (isscalar, abs, finfo, atleast_1d, hstack, dot, array, + ones) from numpy.lib.twodim_base import diag, vander from numpy.lib.function_base import trim_zeros, sort_complex -from numpy.lib.type_check import iscomplex, real, imag +from numpy.lib.type_check import iscomplex, real, imag, mintypecode from numpy.linalg import eigvals, lstsq, inv class RankWarning(UserWarning): @@ -122,19 +123,24 @@ def poly(seq_of_zeros): """ seq_of_zeros = atleast_1d(seq_of_zeros) sh = seq_of_zeros.shape + if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0: seq_of_zeros = eigvals(seq_of_zeros) elif len(sh) == 1: - pass + dt = seq_of_zeros.dtype + # Let object arrays slip through, e.g. for arbitrary precision + if dt != object: + seq_of_zeros = seq_of_zeros.astype(mintypecode(dt.char)) else: raise ValueError("input must be 1d or non-empty square 2d array.") if len(seq_of_zeros) == 0: return 1.0 - - a = [1] + dt = seq_of_zeros.dtype + a = ones((1,), dtype=dt) for k in range(len(seq_of_zeros)): - a = NX.convolve(a, [1, -seq_of_zeros[k]], mode='full') + a = NX.convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), + mode='full') if issubclass(a.dtype.type, NX.complexfloating): # if complex roots are all complex conjugates, the roots are real. |