summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2009-01-17 21:24:13 +0000
committerPauli Virtanen <pav@iki.fi>2009-01-17 21:24:13 +0000
commit9efaa09baca5ce2d64447bdcc91556227ab717c2 (patch)
treeea5d43b99c8cae7ea120e1234bb6e99b2e96bade /numpy/lib
parent4b9c0f2208305046275c74ec1c9d6fda8af1f5bb (diff)
downloadnumpy-9efaa09baca5ce2d64447bdcc91556227ab717c2.tar.gz
Make `trapz` accept 1-D `x` parameter for n-d `y`, even if axis != -1.
Additional tests included.
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py14
-rw-r--r--numpy/lib/tests/test_function_base.py38
2 files changed, 49 insertions, 3 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 425960639..269a97721 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -2818,9 +2818,9 @@ def trapz(y, x=None, dx=1.0, axis=-1):
y : array_like
Input array to integrate.
x : array_like, optional
- If `x` is None, then spacing between all `y` elements is 1.
+ If `x` is None, then spacing between all `y` elements is `dx`.
dx : scalar, optional
- If `x` is None, spacing given by `dx` is assumed.
+ If `x` is None, spacing given by `dx` is assumed. Default is 1.
axis : int, optional
Specify the axis.
@@ -2836,7 +2836,15 @@ def trapz(y, x=None, dx=1.0, axis=-1):
if x is None:
d = dx
else:
- d = diff(x,axis=axis)
+ x = asarray(x)
+ if x.ndim == 1:
+ d = diff(x)
+ # reshape to correct shape
+ shape = [1]*y.ndim
+ shape[axis] = d.shape[0]
+ d = d.reshape(shape)
+ else:
+ d = diff(x, axis=axis)
nd = len(y.shape)
slice1 = [slice(None)]*nd
slice2 = [slice(None)]*nd
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index ca8104b53..143e28ae5 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -430,6 +430,44 @@ class TestTrapz(TestCase):
#check integral of normal equals 1
assert_almost_equal(sum(r,axis=0),1,7)
+ def test_ndim(self):
+ x = linspace(0, 1, 3)
+ y = linspace(0, 2, 8)
+ z = linspace(0, 3, 13)
+
+ wx = ones_like(x) * (x[1]-x[0])
+ wx[0] /= 2
+ wx[-1] /= 2
+ wy = ones_like(y) * (y[1]-y[0])
+ wy[0] /= 2
+ wy[-1] /= 2
+ wz = ones_like(z) * (z[1]-z[0])
+ wz[0] /= 2
+ wz[-1] /= 2
+
+ q = x[:,None,None] + y[None,:,None] + z[None,None,:]
+
+ qx = (q*wx[:,None,None]).sum(axis=0)
+ qy = (q*wy[None,:,None]).sum(axis=1)
+ qz = (q*wz[None,None,:]).sum(axis=2)
+
+ # n-d `x`
+ r = trapz(q, x=x[:,None,None], axis=0)
+ assert_almost_equal(r, qx)
+ r = trapz(q, x=y[None,:,None], axis=1)
+ assert_almost_equal(r, qy)
+ r = trapz(q, x=z[None,None,:], axis=2)
+ assert_almost_equal(r, qz)
+
+ # 1-d `x`
+ r = trapz(q, x=x, axis=0)
+ assert_almost_equal(r, qx)
+ r = trapz(q, x=y, axis=1)
+ assert_almost_equal(r, qy)
+ r = trapz(q, x=z, axis=2)
+ assert_almost_equal(r, qz)
+
+
class TestSinc(TestCase):
def test_simple(self):
assert(sinc(0)==1)