diff options
Diffstat (limited to 'numpy/lib/polynomial.py')
-rw-r--r-- | numpy/lib/polynomial.py | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/numpy/lib/polynomial.py b/numpy/lib/polynomial.py index 6a1adc773..de9376300 100644 --- a/numpy/lib/polynomial.py +++ b/numpy/lib/polynomial.py @@ -12,10 +12,11 @@ import re import warnings import numpy.core.numeric as NX -from numpy.core import isscalar, abs, finfo, atleast_1d, hstack, dot +from numpy.core import (isscalar, abs, finfo, atleast_1d, hstack, dot, array, + ones) from numpy.lib.twodim_base import diag, vander from numpy.lib.function_base import trim_zeros, sort_complex -from numpy.lib.type_check import iscomplex, real, imag +from numpy.lib.type_check import iscomplex, real, imag, mintypecode from numpy.linalg import eigvals, lstsq, inv class RankWarning(UserWarning): @@ -122,19 +123,24 @@ def poly(seq_of_zeros): """ seq_of_zeros = atleast_1d(seq_of_zeros) sh = seq_of_zeros.shape + if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0: seq_of_zeros = eigvals(seq_of_zeros) elif len(sh) == 1: - pass + dt = seq_of_zeros.dtype + # Let object arrays slip through, e.g. for arbitrary precision + if dt != object: + seq_of_zeros = seq_of_zeros.astype(mintypecode(dt.char)) else: raise ValueError("input must be 1d or non-empty square 2d array.") if len(seq_of_zeros) == 0: return 1.0 - - a = [1] + dt = seq_of_zeros.dtype + a = ones((1,), dtype=dt) for k in range(len(seq_of_zeros)): - a = NX.convolve(a, [1, -seq_of_zeros[k]], mode='full') + a = NX.convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), + mode='full') if issubclass(a.dtype.type, NX.complexfloating): # if complex roots are all complex conjugates, the roots are real. @@ -247,12 +253,12 @@ def polyint(p, m=1, k=None): Parameters ---------- - p : {array_like, poly1d} + p : array_like or poly1d Polynomial to differentiate. A sequence is interpreted as polynomial coefficients, see `poly1d`. m : int, optional Order of the antiderivative. (Default: 1) - k : {None, list of `m` scalars, scalar}, optional + k : list of `m` scalars or scalar, optional Integration constants. They are given in the order of integration: those corresponding to highest-order terms come first. @@ -671,7 +677,7 @@ def polyval(p, x): x = NX.asarray(x) y = NX.zeros_like(x) for i in range(len(p)): - y = x * y + p[i] + y = y * x + p[i] return y def polyadd(a1, a2): |