diff options
| author | Charles Harris <charlesr.harris@gmail.com> | 2009-03-09 03:36:49 +0000 |
|---|---|---|
| committer | Charles Harris <charlesr.harris@gmail.com> | 2009-03-09 03:36:49 +0000 |
| commit | eefc1417d12af4a5b3bfa11b52ec67b9fff641ad (patch) | |
| tree | 7590fef00fe3aba99983020373703c263301e465 | |
| parent | 4a632534604d686ff9ac5a9629ce06f7c895cd1e (diff) | |
| download | numpy-eefc1417d12af4a5b3bfa11b52ec67b9fff641ad.tar.gz | |
Fix polyint to work correctly with float, complex, and int inputs.
Fix polydiv to work correctly with float, complex, and int inputs.
| -rw-r--r-- | numpy/lib/polynomial.py | 20 | ||||
| -rw-r--r-- | numpy/lib/tests/test_regression.py | 18 |
2 files changed, 28 insertions, 10 deletions
diff --git a/numpy/lib/polynomial.py b/numpy/lib/polynomial.py index ec9b8c1f6..10fd6dd6c 100644 --- a/numpy/lib/polynomial.py +++ b/numpy/lib/polynomial.py @@ -239,17 +239,20 @@ def polyint(p, m=1, k=None): if len(k) < m: raise ValueError, \ "k must be a scalar or a rank-1 array of length 1 or >m." + + truepoly = isinstance(p, poly1d) + p = NX.asarray(p) + 0.0 if m == 0: + if truepoly: + return poly1d(p) return p else: - truepoly = isinstance(p, poly1d) - p = NX.asarray(p) y = NX.zeros(len(p) + 1, p.dtype) y[:-1] = p*1.0/NX.arange(len(p), 0, -1) y[-1] = k[0] val = polyint(y, m - 1, k=k[1:]) if truepoly: - val = poly1d(val) + return poly1d(val) return val def polyder(p, m=1): @@ -710,12 +713,14 @@ def polydiv(u, v): """ truepoly = (isinstance(u, poly1d) or isinstance(u, poly1d)) - u = atleast_1d(u) - v = atleast_1d(v) + u = atleast_1d(u) + 0.0 + v = atleast_1d(v) + 0.0 + # w has the common type + w = u[0] + v[0] m = len(u) - 1 n = len(v) - 1 scale = 1. / v[0] - q = NX.zeros((max(m-n+1,1),), float) + q = NX.zeros((max(m - n + 1, 1),), w.dtype) r = u.copy() for k in range(0, m-n+1): d = scale * r[k] @@ -724,8 +729,7 @@ def polydiv(u, v): while NX.allclose(r[0], 0, rtol=1e-14) and (r.shape[-1] > 1): r = r[1:] if truepoly: - q = poly1d(q) - r = poly1d(r) + return poly1d(q), poly1d(r) return q, r _poly_mat = re.compile(r"[*][*]([0-9]*)") diff --git a/numpy/lib/tests/test_regression.py b/numpy/lib/tests/test_regression.py index e465dc4b6..e58756b9b 100644 --- a/numpy/lib/tests/test_regression.py +++ b/numpy/lib/tests/test_regression.py @@ -27,8 +27,22 @@ class TestRegression(object): def test_polyint_type(self) : """Ticket #944""" msg = "Wrong type, should be complex" - x = np.polyint(np.ones(3, dtype=np.complex)) - assert_(np.asarray(x).dtype == np.complex, msg) + x = np.ones(3, dtype=np.complex) + assert_(np.polyint(x).dtype == np.complex, msg) + msg = "Wrong type, should be float" + x = np.ones(3, dtype=np.int) + assert_(np.polyint(x).dtype == np.float, msg) + + def test_polydiv_type(self) : + """Make polydiv work for complex types""" + msg = "Wrong type, should be complex" + x = np.ones(3, dtype=np.complex) + q,r = np.polydiv(x,x) + assert_(q.dtype == np.complex, msg) + msg = "Wrong type, should be float" + x = np.ones(3, dtype=np.int) + q,r = np.polydiv(x,x) + assert_(q.dtype == np.float, msg) if __name__ == "__main__": |
