summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_function_base.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-08-01 10:42:34 +0000
committerGitHub <noreply@github.com>2017-08-01 10:42:34 +0000
commit1821e585b3858ad81e4d5b4bdce5d7f2e720f73b (patch)
treeb9cccb5ea581f0a540aedb5908e16bcc588387eb /numpy/lib/tests/test_function_base.py
parent2bc9c0bb5935bdd6e7a4dd3fd54b4ee8ce394f98 (diff)
parent01fd010c6fd29a38b8013ba87c2fb89e8c4c7706 (diff)
downloadnumpy-1821e585b3858ad81e4d5b4bdce5d7f2e720f73b.tar.gz
Merge pull request #9482 from madphysicist/iterative-diff
MAINT: Make diff iterative instead of recursive
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r--numpy/lib/tests/test_function_base.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 099a2d407..4ecb02821 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -6,6 +6,7 @@ import sys
import decimal
import numpy as np
+from numpy import ma
from numpy.testing import (
run_module_suite, assert_, assert_equal, assert_array_equal,
assert_almost_equal, assert_array_almost_equal, assert_raises,
@@ -647,6 +648,19 @@ class TestDiff(object):
assert_array_equal(diff(x), out)
assert_array_equal(diff(x, n=2), out2)
+ def test_axis(self):
+ x = np.zeros((10, 20, 30))
+ x[:, 1::2, :] = 1
+ exp = np.ones((10, 19, 30))
+ exp[:, 1::2, :] = -1
+ assert_array_equal(diff(x), np.zeros((10, 20, 29)))
+ assert_array_equal(diff(x, axis=-1), np.zeros((10, 20, 29)))
+ assert_array_equal(diff(x, axis=0), np.zeros((9, 20, 30)))
+ assert_array_equal(diff(x, axis=1), exp)
+ assert_array_equal(diff(x, axis=-2), exp)
+ assert_raises(np.AxisError, diff, x, axis=3)
+ assert_raises(np.AxisError, diff, x, axis=-4)
+
def test_nd(self):
x = 20 * rand(10, 20, 30)
out1 = x[:, :, 1:] - x[:, :, :-1]
@@ -658,6 +672,45 @@ class TestDiff(object):
assert_array_equal(diff(x, axis=0), out3)
assert_array_equal(diff(x, n=2, axis=0), out4)
+ def test_n(self):
+ x = list(range(3))
+ assert_raises(ValueError, diff, x, n=-1)
+ output = [diff(x, n=n) for n in range(1, 5)]
+ expected = [[1, 1], [0], [], []]
+ assert_(diff(x, n=0) is x)
+ for n, (expected, out) in enumerate(zip(expected, output), start=1):
+ assert_(type(out) is np.ndarray)
+ assert_array_equal(out, expected)
+ assert_equal(out.dtype, np.int_)
+ assert_equal(len(out), max(0, len(x) - n))
+
+ def test_times(self):
+ x = np.arange('1066-10-13', '1066-10-16', dtype=np.datetime64)
+ expected = [
+ np.array([1, 1], dtype='timedelta64[D]'),
+ np.array([0], dtype='timedelta64[D]'),
+ ]
+ expected.extend([np.array([], dtype='timedelta64[D]')] * 3)
+ for n, exp in enumerate(expected, start=1):
+ out = diff(x, n=n)
+ assert_array_equal(out, exp)
+ assert_equal(out.dtype, exp.dtype)
+
+ def test_subclass(self):
+ x = ma.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]],
+ mask=[[False, False], [True, False],
+ [False, True], [True, True], [False, False]])
+ out = diff(x)
+ assert_array_equal(out.data, [[1], [1], [1], [1], [1]])
+ assert_array_equal(out.mask, [[False], [True],
+ [True], [True], [False]])
+ assert_(type(out) is type(x))
+
+ out3 = diff(x, n=3)
+ assert_array_equal(out3.data, [[], [], [], [], []])
+ assert_array_equal(out3.mask, [[], [], [], [], []])
+ assert_(type(out3) is type(x))
+
class TestDelete(object):