summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2009-03-09 03:36:49 +0000
committerCharles Harris <charlesr.harris@gmail.com>2009-03-09 03:36:49 +0000
commiteefc1417d12af4a5b3bfa11b52ec67b9fff641ad (patch)
tree7590fef00fe3aba99983020373703c263301e465
parent4a632534604d686ff9ac5a9629ce06f7c895cd1e (diff)
downloadnumpy-eefc1417d12af4a5b3bfa11b52ec67b9fff641ad.tar.gz
Fix polyint to work correctly with float, complex, and int inputs.
Fix polydiv to work correctly with float, complex, and int inputs.
-rw-r--r--numpy/lib/polynomial.py20
-rw-r--r--numpy/lib/tests/test_regression.py18
2 files changed, 28 insertions, 10 deletions
diff --git a/numpy/lib/polynomial.py b/numpy/lib/polynomial.py
index ec9b8c1f6..10fd6dd6c 100644
--- a/numpy/lib/polynomial.py
+++ b/numpy/lib/polynomial.py
@@ -239,17 +239,20 @@ def polyint(p, m=1, k=None):
if len(k) < m:
raise ValueError, \
"k must be a scalar or a rank-1 array of length 1 or >m."
+
+ truepoly = isinstance(p, poly1d)
+ p = NX.asarray(p) + 0.0
if m == 0:
+ if truepoly:
+ return poly1d(p)
return p
else:
- truepoly = isinstance(p, poly1d)
- p = NX.asarray(p)
y = NX.zeros(len(p) + 1, p.dtype)
y[:-1] = p*1.0/NX.arange(len(p), 0, -1)
y[-1] = k[0]
val = polyint(y, m - 1, k=k[1:])
if truepoly:
- val = poly1d(val)
+ return poly1d(val)
return val
def polyder(p, m=1):
@@ -710,12 +713,14 @@ def polydiv(u, v):
"""
truepoly = (isinstance(u, poly1d) or isinstance(u, poly1d))
- u = atleast_1d(u)
- v = atleast_1d(v)
+ u = atleast_1d(u) + 0.0
+ v = atleast_1d(v) + 0.0
+ # w has the common type
+ w = u[0] + v[0]
m = len(u) - 1
n = len(v) - 1
scale = 1. / v[0]
- q = NX.zeros((max(m-n+1,1),), float)
+ q = NX.zeros((max(m - n + 1, 1),), w.dtype)
r = u.copy()
for k in range(0, m-n+1):
d = scale * r[k]
@@ -724,8 +729,7 @@ def polydiv(u, v):
while NX.allclose(r[0], 0, rtol=1e-14) and (r.shape[-1] > 1):
r = r[1:]
if truepoly:
- q = poly1d(q)
- r = poly1d(r)
+ return poly1d(q), poly1d(r)
return q, r
_poly_mat = re.compile(r"[*][*]([0-9]*)")
diff --git a/numpy/lib/tests/test_regression.py b/numpy/lib/tests/test_regression.py
index e465dc4b6..e58756b9b 100644
--- a/numpy/lib/tests/test_regression.py
+++ b/numpy/lib/tests/test_regression.py
@@ -27,8 +27,22 @@ class TestRegression(object):
def test_polyint_type(self) :
"""Ticket #944"""
msg = "Wrong type, should be complex"
- x = np.polyint(np.ones(3, dtype=np.complex))
- assert_(np.asarray(x).dtype == np.complex, msg)
+ x = np.ones(3, dtype=np.complex)
+ assert_(np.polyint(x).dtype == np.complex, msg)
+ msg = "Wrong type, should be float"
+ x = np.ones(3, dtype=np.int)
+ assert_(np.polyint(x).dtype == np.float, msg)
+
+ def test_polydiv_type(self) :
+ """Make polydiv work for complex types"""
+ msg = "Wrong type, should be complex"
+ x = np.ones(3, dtype=np.complex)
+ q,r = np.polydiv(x,x)
+ assert_(q.dtype == np.complex, msg)
+ msg = "Wrong type, should be float"
+ x = np.ones(3, dtype=np.int)
+ q,r = np.polydiv(x,x)
+ assert_(q.dtype == np.float, msg)
if __name__ == "__main__":