diff options
author | rgommers <ralf.gommers@googlemail.com> | 2010-07-31 04:41:08 +0000 |
---|---|---|
committer | rgommers <ralf.gommers@googlemail.com> | 2010-07-31 04:41:08 +0000 |
commit | d1a661df48625ce5544d83dc022c96cf8c5d41c7 (patch) | |
tree | 40d139668649832673e3f1803971183e937e7633 /numpy | |
parent | bff8bb580d6e915b9bdadea312482a169f7d1472 (diff) | |
download | numpy-d1a661df48625ce5544d83dc022c96cf8c5d41c7.tar.gz |
ENH: Make trapz work with ndarray subclasses. Thanks to Ryan May. Closes #1438.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/function_base.py | 16 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 26 |
2 files changed, 37 insertions, 5 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index b2ec9bb5a..7610c2db3 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -2924,7 +2924,7 @@ def percentile(a, q, axis=None, out=None, overwrite_input=False): ----- Given a vector V of length N, the qth percentile of V is the qth ranked value in a sorted copy of V. A weighted average of the two nearest neighbors - is used if the normalized ranking does not match q exactly. + is used if the normalized ranking does not match q exactly. The same as the median if q is 0.5; the same as the min if q is 0; and the same as the max if q is 1 @@ -2962,7 +2962,7 @@ def percentile(a, q, axis=None, out=None, overwrite_input=False): return a.min(axis=axis, out=out) elif q == 100: return a.max(axis=axis, out=out) - + if overwrite_input: if axis is None: sorted = a.ravel() @@ -3072,11 +3072,11 @@ def trapz(y, x=None, dx=1.0, axis=-1): array([ 2., 8.]) """ - y = asarray(y) + y = asanyarray(y) if x is None: d = dx else: - x = asarray(x) + x = asanyarray(x) if x.ndim == 1: d = diff(x) # reshape to correct shape @@ -3090,7 +3090,13 @@ def trapz(y, x=None, dx=1.0, axis=-1): slice2 = [slice(None)]*nd slice1[axis] = slice(1,None) slice2[axis] = slice(None,-1) - return add.reduce(d * (y[slice1]+y[slice2])/2.0,axis) + try: + ret = (d * (y[slice1] +y [slice2]) / 2.0).sum(axis) + except ValueError: # Operations didn't work, cast to ndarray + d = np.asarray(d) + y = np.asarray(y) + ret = add.reduce(d * (y[slice1]+y[slice2])/2.0, axis) + return ret #always succeed def add_newdoc(place, obj, doc): diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 037e8043a..1d0d61be3 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -491,6 +491,32 @@ class TestTrapz(TestCase): r = trapz(q, x=z, axis=2) assert_almost_equal(r, qz) + def test_masked(self): + #Testing that masked arrays behave as if the function is 0 where + #masked + x = arange(5) + y = x * x + mask = x == 2 + ym = np.ma.array(y, mask=mask) + r = 13.0 # sum(0.5 * (0 + 1) * 1.0 + 0.5 * (9 + 16)) + assert_almost_equal(trapz(ym, x), r) + + xm = np.ma.array(x, mask=mask) + assert_almost_equal(trapz(ym, xm), r) + + xm = np.ma.array(x, mask=mask) + assert_almost_equal(trapz(y, xm), r) + + def test_matrix(self): + #Test to make sure matrices give the same answer as ndarrays + x = linspace(0, 5) + y = x * x + r = trapz(y, x) + mx = matrix(x) + my = matrix(y) + mr = trapz(my, mx) + assert_almost_equal(mr, r) + class TestSinc(TestCase): def test_simple(self): |