diff options
author | Joachim Hereth <joachim.hereth@numberfour.eu> | 2018-10-21 22:18:34 +0200 |
---|---|---|
committer | Joachim Hereth <joachim.hereth@numberfour.eu> | 2018-10-21 22:18:34 +0200 |
commit | af6cb03920f3ae62cb8a8c871edeccbcd8609955 (patch) | |
tree | 489a19b6a5e866a2ca262e0f5b89cdc4301fb8ad /numpy | |
parent | db5750f6cdc2715f1c65be31f985e2cd2699d2e0 (diff) | |
download | numpy-af6cb03920f3ae62cb8a8c871edeccbcd8609955.tar.gz |
BUG: polyval returned Non-Masked Arrays for Masked Input.
This fix will preserve subtypes of ndarray when given as input (x)
to the polyval function. In particular, the results for masked
values of a masked array will be masked.
Fixes #2477.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/polynomial.py | 4 | ||||
-rw-r--r-- | numpy/polynomial/tests/test_polynomial.py | 15 |
2 files changed, 17 insertions, 2 deletions
diff --git a/numpy/lib/polynomial.py b/numpy/lib/polynomial.py index 9f3b84732..165fd1b95 100644 --- a/numpy/lib/polynomial.py +++ b/numpy/lib/polynomial.py @@ -651,6 +651,8 @@ def polyval(p, x): for polynomials of high degree the values may be inaccurate due to rounding errors. Use carefully. + If `x` is a subtype of `ndarray` the return value will be of the same type. + References ---------- .. [1] I. N. Bronshtein, K. A. Semendyayev, and K. A. Hirsch (Eng. @@ -673,7 +675,7 @@ def polyval(p, x): if isinstance(x, poly1d): y = 0 else: - x = NX.asarray(x) + x = NX.asanyarray(x) y = NX.zeros_like(x) for i in range(len(p)): y = y * x + p[i] diff --git a/numpy/polynomial/tests/test_polynomial.py b/numpy/polynomial/tests/test_polynomial.py index 0c93be278..562aa904d 100644 --- a/numpy/polynomial/tests/test_polynomial.py +++ b/numpy/polynomial/tests/test_polynomial.py @@ -9,7 +9,7 @@ import numpy as np import numpy.polynomial.polynomial as poly from numpy.testing import ( assert_almost_equal, assert_raises, assert_equal, assert_, - ) + assert_array_equal) def trim(x): @@ -147,6 +147,19 @@ class TestEvaluation(object): assert_equal(poly.polyval(x, [1, 0]).shape, dims) assert_equal(poly.polyval(x, [1, 0, 0]).shape, dims) + #check masked arrays are processed correctly + mask = [False, True, False] + mx = np.ma.array([1, 2, 3], mask=mask) + res = np.polyval([7, 5, 3], mx) + assert_array_equal(res.mask, mask) + + #check subtypes of ndarray are preserved + class C(np.ndarray): + pass + + cx = np.array([1, 2, 3]).view(C) + assert_equal(type(np.polyval([2, 3, 4], cx)), C) + def test_polyvalfromroots(self): # check exception for broadcasting x values over root array with # too few dimensions |