summaryrefslogtreecommitdiff
path: root/numpy/lib/polynomial.py
diff options
context:
space:
mode:
authorjaimefrio <jaime.frio@gmail.com>2014-09-25 10:45:08 -0700
committerjaimefrio <jaime.frio@gmail.com>2014-09-25 10:49:16 -0700
commit3a0587e545e959747d9b501dbf029a4cd6576547 (patch)
treee6cb41c3c6f67240a28690241b5a6d0e8e84aa2d /numpy/lib/polynomial.py
parentf4fa7bd2a67a577eaa72af83028adcfbc71b7fd4 (diff)
downloadnumpy-3a0587e545e959747d9b501dbf029a4cd6576547.tar.gz
ENH: Cast non-object arrays to float in np.poly
Closes #5096. Casts integer arrays to np.double, to prevent integer overflow. Object arrays are left unchanged, to allow use of arbitrary precision objects.
Diffstat (limited to 'numpy/lib/polynomial.py')
-rw-r--r--numpy/lib/polynomial.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/numpy/lib/polynomial.py b/numpy/lib/polynomial.py
index 7e4ec1485..2b867e244 100644
--- a/numpy/lib/polynomial.py
+++ b/numpy/lib/polynomial.py
@@ -12,10 +12,11 @@ import re
import warnings
import numpy.core.numeric as NX
-from numpy.core import isscalar, abs, finfo, atleast_1d, hstack, dot
+from numpy.core import (isscalar, abs, finfo, atleast_1d, hstack, dot, array,
+ ones)
from numpy.lib.twodim_base import diag, vander
from numpy.lib.function_base import trim_zeros, sort_complex
-from numpy.lib.type_check import iscomplex, real, imag
+from numpy.lib.type_check import iscomplex, real, imag, mintypecode
from numpy.linalg import eigvals, lstsq, inv
class RankWarning(UserWarning):
@@ -122,19 +123,24 @@ def poly(seq_of_zeros):
"""
seq_of_zeros = atleast_1d(seq_of_zeros)
sh = seq_of_zeros.shape
+
if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
seq_of_zeros = eigvals(seq_of_zeros)
elif len(sh) == 1:
- pass
+ dt = seq_of_zeros.dtype
+ # Let object arrays slip through, e.g. for arbitrary precision
+ if dt != object:
+ seq_of_zeros = seq_of_zeros.astype(mintypecode(dt.char))
else:
raise ValueError("input must be 1d or non-empty square 2d array.")
if len(seq_of_zeros) == 0:
return 1.0
-
- a = [1]
+ dt = seq_of_zeros.dtype
+ a = ones((1,), dtype=dt)
for k in range(len(seq_of_zeros)):
- a = NX.convolve(a, [1, -seq_of_zeros[k]], mode='full')
+ a = NX.convolve(a, array([1, -seq_of_zeros[k]], dtype=dt),
+ mode='full')
if issubclass(a.dtype.type, NX.complexfloating):
# if complex roots are all complex conjugates, the roots are real.