diff options
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 425960639..269a97721 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -2818,9 +2818,9 @@ def trapz(y, x=None, dx=1.0, axis=-1): y : array_like Input array to integrate. x : array_like, optional - If `x` is None, then spacing between all `y` elements is 1. + If `x` is None, then spacing between all `y` elements is `dx`. dx : scalar, optional - If `x` is None, spacing given by `dx` is assumed. + If `x` is None, spacing given by `dx` is assumed. Default is 1. axis : int, optional Specify the axis. @@ -2836,7 +2836,15 @@ def trapz(y, x=None, dx=1.0, axis=-1): if x is None: d = dx else: - d = diff(x,axis=axis) + x = asarray(x) + if x.ndim == 1: + d = diff(x) + # reshape to correct shape + shape = [1]*y.ndim + shape[axis] = d.shape[0] + d = d.reshape(shape) + else: + d = diff(x, axis=axis) nd = len(y.shape) slice1 = [slice(None)]*nd slice2 = [slice(None)]*nd |