summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRobert Kern <robert.kern@gmail.com>2008-07-07 16:47:18 +0000
committerRobert Kern <robert.kern@gmail.com>2008-07-07 16:47:18 +0000
commita4847b1de5fc8398c2589e04aee1b63ba0592ff4 (patch)
treedc7ecc8c6dcb51887903ff64392a44e456f7fedc
parent01320e0149a08299461ab730bf0bdce6c9ba3483 (diff)
downloadnumpy-a4847b1de5fc8398c2589e04aee1b63ba0592ff4.tar.gz
Return actual bools instead of 0 or 1.
-rw-r--r--numpy/core/numeric.py28
-rw-r--r--numpy/core/tests/test_numeric.py46
2 files changed, 60 insertions, 14 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 2945a892b..8fe1f9220 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -137,7 +137,7 @@ def asarray(a, dtype=None, order=None):
def asanyarray(a, dtype=None, order=None):
"""Returns a as an array, but will pass subclasses through.
"""
- return array(a, dtype, copy=False, order=order, subok=1)
+ return array(a, dtype, copy=False, order=order, subok=True)
def ascontiguousarray(a, dtype=None):
"""Return 'a' as an array contiguous in memory (C order).
@@ -182,9 +182,9 @@ def require(a, dtype=None, requirements=None):
return asanyarray(a, dtype=dtype)
if 'ENSUREARRAY' in requirements or 'E' in requirements:
- subok = 0
+ subok = False
else:
- subok = 1
+ subok = True
arr = array(a, dtype=dtype, copy=False, subok=subok)
@@ -342,12 +342,12 @@ def tensordot(a, b, axes=2):
nda = len(a.shape)
bs = b.shape
ndb = len(b.shape)
- equal = 1
- if (na != nb): equal = 0
+ equal = True
+ if (na != nb): equal = False
else:
for k in xrange(na):
if as_[axes_a[k]] != bs[axes_b[k]]:
- equal = 0
+ equal = False
break
if axes_a[k] < 0:
axes_a[k] += nda
@@ -391,10 +391,10 @@ def roll(a, shift, axis=None):
a = asanyarray(a)
if axis is None:
n = a.size
- reshape=1
+ reshape = True
else:
n = a.shape[axis]
- reshape=0
+ reshape = False
shift %= n
indexes = concatenate((arange(n-shift,n),arange(n-shift)))
res = a.take(indexes, axis)
@@ -728,10 +728,10 @@ def array_equal(a1, a2):
try:
a1, a2 = asarray(a1), asarray(a2)
except:
- return 0
+ return False
if a1.shape != a2.shape:
- return 0
- return logical_and.reduce(equal(a1,a2).ravel())
+ return False
+ return bool(logical_and.reduce(equal(a1,a2).ravel()))
def array_equiv(a1, a2):
"""Returns True if a1 and a2 are shape consistent
@@ -741,11 +741,11 @@ def array_equiv(a1, a2):
try:
a1, a2 = asarray(a1), asarray(a2)
except:
- return 0
+ return False
try:
- return logical_and.reduce(equal(a1,a2).ravel())
+ return bool(logical_and.reduce(equal(a1,a2).ravel()))
except ValueError:
- return 0
+ return False
_errdict = {"ignore":ERR_IGNORE,
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 23e6cf30a..24334dacd 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -256,6 +256,52 @@ class TestBinaryRepr(TestCase):
assert_equal(binary_repr(-1, width=8), '11111111')
+class TestArrayComparisons(TestCase):
+ def test_array_equal(self):
+ res = array_equal(array([1,2]), array([1,2]))
+ assert res
+ assert type(res) is bool
+ res = array_equal(array([1,2]), array([1,2,3]))
+ assert not res
+ assert type(res) is bool
+ res = array_equal(array([1,2]), array([3,4]))
+ assert not res
+ assert type(res) is bool
+ res = array_equal(array([1,2]), array([1,3]))
+ assert not res
+ assert type(res) is bool
+
+ def test_array_equiv(self):
+ res = array_equiv(array([1,2]), array([1,2]))
+ assert res
+ assert type(res) is bool
+ res = array_equiv(array([1,2]), array([1,2,3]))
+ assert not res
+ assert type(res) is bool
+ res = array_equiv(array([1,2]), array([3,4]))
+ assert not res
+ assert type(res) is bool
+ res = array_equiv(array([1,2]), array([1,3]))
+ assert not res
+ assert type(res) is bool
+
+ res = array_equiv(array([1,1]), array([1]))
+ assert res
+ assert type(res) is bool
+ res = array_equiv(array([1,1]), array([[1],[1]]))
+ assert res
+ assert type(res) is bool
+ res = array_equiv(array([1,2]), array([2]))
+ assert not res
+ assert type(res) is bool
+ res = array_equiv(array([1,2]), array([[1],[2]]))
+ assert not res
+ assert type(res) is bool
+ res = array_equiv(array([1,2]), array([[1,2,3],[4,5,6],[7,8,9]]))
+ assert not res
+ assert type(res) is bool
+
+
def assert_array_strict_equal(x, y):
assert_array_equal(x, y)
# Check flags