summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRobert Cimrman <cimrman3@ntc.zcu.cz>2009-07-20 11:59:40 +0000
committerRobert Cimrman <cimrman3@ntc.zcu.cz>2009-07-20 11:59:40 +0000
commit1319b2ca9c0a7491e5143f2608f444c2b1b7e346 (patch)
treeac7569257d268dd5f62c018247e32ab5c67434fc
parent520872500893f95f1676cdf6715c8d34b32a2449 (diff)
downloadnumpy-1319b2ca9c0a7491e5143f2608f444c2b1b7e346.tar.gz
Fix to setdiff1d (and masked version) + tests (#1133, by N.C.)
-rw-r--r--numpy/lib/arraysetops.py5
-rw-r--r--numpy/lib/tests/test_arraysetops.py2
-rw-r--r--numpy/ma/extras.py5
-rw-r--r--numpy/ma/tests/test_extras.py2
4 files changed, 10 insertions, 4 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py
index fe1326aa0..b8ae9a9f3 100644
--- a/numpy/lib/arraysetops.py
+++ b/numpy/lib/arraysetops.py
@@ -346,7 +346,10 @@ def setdiff1d(ar1, ar2, assume_unique=False):
performing set operations on arrays.
"""
- aux = in1d(ar1, ar2, assume_unique=assume_unique)
+ if not assume_unique:
+ ar1 = unique(ar1)
+ ar2 = unique(ar2)
+ aux = in1d(ar1, ar2, assume_unique=True)
if aux.size == 0:
return aux
else:
diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py
index a83ab1394..92305129a 100644
--- a/numpy/lib/tests/test_arraysetops.py
+++ b/numpy/lib/tests/test_arraysetops.py
@@ -210,7 +210,7 @@ class TestAso(TestCase):
assert_array_equal([], union1d([],[]))
def test_setdiff1d( self ):
- a = np.array( [6, 5, 4, 7, 1, 2] )
+ a = np.array( [6, 5, 4, 7, 1, 2, 7, 4] )
b = np.array( [2, 4, 3, 3, 2, 1, 5] )
ec = np.array( [6, 7] )
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 85cc834c1..d4b78f986 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -1002,7 +1002,10 @@ def setdiff1d(ar1, ar2, assume_unique=False):
numpy.setdiff1d : equivalent function for ndarrays
"""
- aux = in1d(ar1, ar2, assume_unique=assume_unique)
+ if not assume_unique:
+ ar1 = unique(ar1)
+ ar2 = unique(ar2)
+ aux = in1d(ar1, ar2, assume_unique=True)
if aux.size == 0:
return aux
else:
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index 687fc9b81..e40c56ed5 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -771,7 +771,7 @@ class TestArraySetOps(TestCase):
def test_setdiff1d( self ):
"Test setdiff1d"
- a = array([6, 5, 4, 7, 1, 2, 1], mask=[0, 0, 0, 0, 0, 0, 1])
+ a = array([6, 5, 4, 7, 7, 1, 2, 1], mask=[0, 0, 0, 0, 0, 0, 0, 1])
b = array([2, 4, 3, 3, 2, 1, 5])
test = setdiff1d(a, b)
assert_equal(test, array([6, 7, -1], mask=[0, 0, 1]))