diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-08-01 10:42:34 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-08-01 10:42:34 +0000 |
commit | 1821e585b3858ad81e4d5b4bdce5d7f2e720f73b (patch) | |
tree | b9cccb5ea581f0a540aedb5908e16bcc588387eb /numpy/lib/tests/test_function_base.py | |
parent | 2bc9c0bb5935bdd6e7a4dd3fd54b4ee8ce394f98 (diff) | |
parent | 01fd010c6fd29a38b8013ba87c2fb89e8c4c7706 (diff) | |
download | numpy-1821e585b3858ad81e4d5b4bdce5d7f2e720f73b.tar.gz |
Merge pull request #9482 from madphysicist/iterative-diff
MAINT: Make diff iterative instead of recursive
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 099a2d407..4ecb02821 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -6,6 +6,7 @@ import sys import decimal import numpy as np +from numpy import ma from numpy.testing import ( run_module_suite, assert_, assert_equal, assert_array_equal, assert_almost_equal, assert_array_almost_equal, assert_raises, @@ -647,6 +648,19 @@ class TestDiff(object): assert_array_equal(diff(x), out) assert_array_equal(diff(x, n=2), out2) + def test_axis(self): + x = np.zeros((10, 20, 30)) + x[:, 1::2, :] = 1 + exp = np.ones((10, 19, 30)) + exp[:, 1::2, :] = -1 + assert_array_equal(diff(x), np.zeros((10, 20, 29))) + assert_array_equal(diff(x, axis=-1), np.zeros((10, 20, 29))) + assert_array_equal(diff(x, axis=0), np.zeros((9, 20, 30))) + assert_array_equal(diff(x, axis=1), exp) + assert_array_equal(diff(x, axis=-2), exp) + assert_raises(np.AxisError, diff, x, axis=3) + assert_raises(np.AxisError, diff, x, axis=-4) + def test_nd(self): x = 20 * rand(10, 20, 30) out1 = x[:, :, 1:] - x[:, :, :-1] @@ -658,6 +672,45 @@ class TestDiff(object): assert_array_equal(diff(x, axis=0), out3) assert_array_equal(diff(x, n=2, axis=0), out4) + def test_n(self): + x = list(range(3)) + assert_raises(ValueError, diff, x, n=-1) + output = [diff(x, n=n) for n in range(1, 5)] + expected = [[1, 1], [0], [], []] + assert_(diff(x, n=0) is x) + for n, (expected, out) in enumerate(zip(expected, output), start=1): + assert_(type(out) is np.ndarray) + assert_array_equal(out, expected) + assert_equal(out.dtype, np.int_) + assert_equal(len(out), max(0, len(x) - n)) + + def test_times(self): + x = np.arange('1066-10-13', '1066-10-16', dtype=np.datetime64) + expected = [ + np.array([1, 1], dtype='timedelta64[D]'), + np.array([0], dtype='timedelta64[D]'), + ] + expected.extend([np.array([], dtype='timedelta64[D]')] * 3) + for n, exp in enumerate(expected, start=1): + out = diff(x, n=n) + assert_array_equal(out, exp) + assert_equal(out.dtype, exp.dtype) + + def test_subclass(self): + x = ma.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]], + mask=[[False, False], [True, False], + [False, True], [True, True], [False, False]]) + out = diff(x) + assert_array_equal(out.data, [[1], [1], [1], [1], [1]]) + assert_array_equal(out.mask, [[False], [True], + [True], [True], [False]]) + assert_(type(out) is type(x)) + + out3 = diff(x, n=3) + assert_array_equal(out3.data, [[], [], [], [], []]) + assert_array_equal(out3.mask, [[], [], [], [], []]) + assert_(type(out3) is type(x)) + class TestDelete(object): |