summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2019-01-04 10:02:41 -0500
committerGitHub <noreply@github.com>2019-01-04 10:02:41 -0500
commitaef982e4482773e802cc0ef076bf5e76ff650cf9 (patch)
treed307bb118265aee51189674b3be14d5582e6e244
parentb60b58359eef967ce1e557fd8437a37b68330be9 (diff)
parentaf6cb03920f3ae62cb8a8c871edeccbcd8609955 (diff)
downloadnumpy-aef982e4482773e802cc0ef076bf5e76ff650cf9.tar.gz
Merge pull request #12239 from daten-kieker/polyval_2477
BUG: polyval returned non-masked arrays for masked input.
-rw-r--r--numpy/lib/polynomial.py4
-rw-r--r--numpy/polynomial/tests/test_polynomial.py15
2 files changed, 17 insertions, 2 deletions
diff --git a/numpy/lib/polynomial.py b/numpy/lib/polynomial.py
index 7904092ed..b55764b5d 100644
--- a/numpy/lib/polynomial.py
+++ b/numpy/lib/polynomial.py
@@ -704,6 +704,8 @@ def polyval(p, x):
for polynomials of high degree the values may be inaccurate due to
rounding errors. Use carefully.
+ If `x` is a subtype of `ndarray` the return value will be of the same type.
+
References
----------
.. [1] I. N. Bronshtein, K. A. Semendyayev, and K. A. Hirsch (Eng.
@@ -726,7 +728,7 @@ def polyval(p, x):
if isinstance(x, poly1d):
y = 0
else:
- x = NX.asarray(x)
+ x = NX.asanyarray(x)
y = NX.zeros_like(x)
for i in range(len(p)):
y = y * x + p[i]
diff --git a/numpy/polynomial/tests/test_polynomial.py b/numpy/polynomial/tests/test_polynomial.py
index 0c93be278..562aa904d 100644
--- a/numpy/polynomial/tests/test_polynomial.py
+++ b/numpy/polynomial/tests/test_polynomial.py
@@ -9,7 +9,7 @@ import numpy as np
import numpy.polynomial.polynomial as poly
from numpy.testing import (
assert_almost_equal, assert_raises, assert_equal, assert_,
- )
+ assert_array_equal)
def trim(x):
@@ -147,6 +147,19 @@ class TestEvaluation(object):
assert_equal(poly.polyval(x, [1, 0]).shape, dims)
assert_equal(poly.polyval(x, [1, 0, 0]).shape, dims)
+ #check masked arrays are processed correctly
+ mask = [False, True, False]
+ mx = np.ma.array([1, 2, 3], mask=mask)
+ res = np.polyval([7, 5, 3], mx)
+ assert_array_equal(res.mask, mask)
+
+ #check subtypes of ndarray are preserved
+ class C(np.ndarray):
+ pass
+
+ cx = np.array([1, 2, 3]).view(C)
+ assert_equal(type(np.polyval([2, 3, 4], cx)), C)
+
def test_polyvalfromroots(self):
# check exception for broadcasting x values over root array with
# too few dimensions