summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_function_base.py
diff options
context:
space:
mode:
authornjsmith <njs@pobox.com>2013-04-11 12:31:22 -0700
committernjsmith <njs@pobox.com>2013-04-11 12:31:22 -0700
commitb9232f34bfb53bf7c574bfa350dd981f58d6a2d4 (patch)
tree6208488b45ea2341df2151dbb09d685857cf703c /numpy/lib/tests/test_function_base.py
parentb7053e8d065f820a6c2b3db7e8df3feaf5adbd71 (diff)
parent1675ad9e5b95605a851337f407e1fad33cf10c9c (diff)
downloadnumpy-b9232f34bfb53bf7c574bfa350dd981f58d6a2d4.tar.gz
Merge pull request #452 from seberg/enhdel
ENH: delete and insert generalization and speed improvements
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r--numpy/lib/tests/test_function_base.py114
1 files changed, 111 insertions, 3 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 28786dc3e..ae68be41f 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -5,7 +5,7 @@ import numpy as np
from numpy.testing import (
run_module_suite, TestCase, assert_, assert_equal,
assert_array_equal, assert_almost_equal, assert_array_almost_equal,
- assert_raises, assert_allclose, assert_array_max_ulp
+ assert_raises, assert_allclose, assert_array_max_ulp, assert_warns
)
from numpy.random import rand
from numpy.lib import *
@@ -174,17 +174,64 @@ class TestInsert(TestCase):
assert_equal(insert(a, 3, 1), [1, 2, 3, 1])
assert_equal(insert(a, [1, 1, 1], [1, 2, 3]), [1, 1, 2, 3, 2, 3])
assert_equal(insert(a, 1,[1,2,3]), [1, 1, 2, 3, 2, 3])
- assert_equal(insert(a,[1,2,3],9),[1,9,2,9,3,9])
+ assert_equal(insert(a,[1,-1,3],9),[1,9,2,9,3,9])
+ assert_equal(insert(a,slice(-1,None,-1), 9),[9,1,9,2,9,3])
+ assert_equal(insert(a,[-1,1,3], [7,8,9]),[1,8,2,7,3,9])
b = np.array([0, 1], dtype=np.float64)
assert_equal(insert(b, 0, b[0]), [0., 0., 1.])
+ assert_equal(insert(b, [], []), b)
+ # Bools will be treated differently in the future:
+ #assert_equal(insert(a, np.array([True]*4), 9), [9,1,9,2,9,3,9])
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', FutureWarning)
+ assert_equal(insert(a, np.array([True]*4), 9), [1,9,9,9,9,2,3])
+ assert_(w[0].category is FutureWarning)
+
def test_multidim(self):
a = [[1, 1, 1]]
r = [[2, 2, 2],
[1, 1, 1]]
+ assert_equal(insert(a, 0, [1]), [1,1,1,1])
assert_equal(insert(a, 0, [2, 2, 2], axis=0), r)
assert_equal(insert(a, 0, 2, axis=0), r)
assert_equal(insert(a, 2, 2, axis=1), [[1, 1, 2, 1]])
+ a = np.array([[1, 1], [2, 2], [3, 3]])
+ b = np.arange(1,4).repeat(3).reshape(3,3)
+ c = np.concatenate((a[:,0:1], np.arange(1,4).repeat(3).reshape(3,3).T,
+ a[:,1:2]), axis=1)
+ assert_equal(insert(a, [1], [[1],[2],[3]], axis=1), b)
+ assert_equal(insert(a, [1], [1, 2, 3], axis=1), c)
+ # scalars behave differently, in this case exactly opposite:
+ assert_equal(insert(a, 1, [1, 2, 3], axis=1), b)
+ assert_equal(insert(a, 1, [[1],[2],[3]], axis=1), c)
+
+ a = np.arange(4).reshape(2,2)
+ assert_equal(insert(a[:,:1], 1, a[:,1], axis=1), a)
+ assert_equal(insert(a[:1,:], 1, a[1,:], axis=0), a)
+
+ def test_0d(self):
+ # This is an error in the future
+ a = np.array(1)
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', DeprecationWarning)
+ assert_equal(insert(a, [], 2, axis=0), np.array(2))
+ assert_(w[0].category is DeprecationWarning)
+
+ def test_subclass(self):
+ class SubClass(np.ndarray):
+ pass
+ a = np.arange(10).view(SubClass)
+ assert_(isinstance(np.insert(a, 0, [0]), SubClass))
+ assert_(isinstance(np.insert(a, [], []), SubClass))
+ assert_(isinstance(np.insert(a, [0,1], [1,2]), SubClass))
+ assert_(isinstance(np.insert(a, slice(1,2), [1,2]), SubClass))
+ assert_(isinstance(np.insert(a, slice(1,-2), []), SubClass))
+ # This is an error in the future:
+ a = np.array(1).view(SubClass)
+ assert_(isinstance(np.insert(a, 0, [0]), SubClass))
+
+
class TestAmax(TestCase):
def test_basic(self):
a = [3, 4, 5, 10, -3, -5, 6.0]
@@ -302,6 +349,67 @@ class TestDiff(TestCase):
assert_array_equal(diff(x, n=2, axis=0), out4)
+class TestDelete(TestCase):
+ def setUp(self):
+ self.a = np.arange(5)
+ self.nd_a = np.arange(5).repeat(2).reshape(1,5,2)
+
+ def _check_inverse_of_slicing(self, indices):
+ a_del = delete(self.a, indices)
+ nd_a_del = delete(self.nd_a, indices, axis=1)
+ msg = 'Delete failed for obj: %r' % indices
+ # NOTE: The cast should be removed after warning phase for bools
+ if not isinstance(indices, (slice, int, long, np.integer)):
+ indices = np.asarray(indices, dtype=np.intp)
+ indices = indices[(indices >= 0) & (indices < 5)]
+ assert_array_equal(setxor1d(a_del, self.a[indices,]), self.a,
+ err_msg=msg)
+ xor = setxor1d(nd_a_del[0,:,0], self.nd_a[0,indices,0])
+ assert_array_equal(xor, self.nd_a[0,:,0], err_msg=msg)
+
+ def test_slices(self):
+ lims = [-6, -2, 0, 1, 2, 4, 5]
+ steps = [-3, -1, 1, 3]
+ for start in lims:
+ for stop in lims:
+ for step in steps:
+ s = slice(start, stop, step)
+ self._check_inverse_of_slicing(s)
+
+ def test_fancy(self):
+ # Deprecation/FutureWarning tests should be kept after change.
+ self._check_inverse_of_slicing(np.array([[0,1],[2,1]]))
+ assert_raises(DeprecationWarning, delete, self.a, [100])
+ assert_raises(DeprecationWarning, delete, self.a, [-100])
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', FutureWarning)
+ self._check_inverse_of_slicing([0, -1, 2, 2])
+ obj = np.array([True, False, False], dtype=bool)
+ self._check_inverse_of_slicing(obj)
+ assert_(w[0].category is FutureWarning)
+ assert_(w[1].category is FutureWarning)
+
+ def test_single(self):
+ self._check_inverse_of_slicing(0)
+ self._check_inverse_of_slicing(-4)
+
+ def test_0d(self):
+ a = np.array(1)
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', DeprecationWarning)
+ assert_equal(delete(a, [], axis=0), a)
+ assert_(w[0].category is DeprecationWarning)
+
+ def test_subclass(self):
+ class SubClass(np.ndarray):
+ pass
+ a = self.a.view(SubClass)
+ assert_(isinstance(delete(a, 0), SubClass))
+ assert_(isinstance(delete(a, []), SubClass))
+ assert_(isinstance(delete(a, [0,1]), SubClass))
+ assert_(isinstance(delete(a, slice(1,2)), SubClass))
+ assert_(isinstance(delete(a, slice(1,-2)), SubClass))
+
class TestGradient(TestCase):
def test_basic(self):
v = [[1, 1], [3, 4]]
@@ -531,7 +639,7 @@ class TestVectorize(TestCase):
res2a = f2(np.arange(3))
assert_equal(res1a, res2a)
assert_equal(res1b, res2b)
-
+
def test_string_ticket_1892(self):
"""Test vectorization over strings: issue 1892."""
f = np.vectorize(lambda x:x)