diff options
author | Robert Kern <robert.kern@gmail.com> | 2008-07-07 16:47:18 +0000 |
---|---|---|
committer | Robert Kern <robert.kern@gmail.com> | 2008-07-07 16:47:18 +0000 |
commit | a4847b1de5fc8398c2589e04aee1b63ba0592ff4 (patch) | |
tree | dc7ecc8c6dcb51887903ff64392a44e456f7fedc | |
parent | 01320e0149a08299461ab730bf0bdce6c9ba3483 (diff) | |
download | numpy-a4847b1de5fc8398c2589e04aee1b63ba0592ff4.tar.gz |
Return actual bools instead of 0 or 1.
-rw-r--r-- | numpy/core/numeric.py | 28 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 46 |
2 files changed, 60 insertions, 14 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 2945a892b..8fe1f9220 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -137,7 +137,7 @@ def asarray(a, dtype=None, order=None): def asanyarray(a, dtype=None, order=None): """Returns a as an array, but will pass subclasses through. """ - return array(a, dtype, copy=False, order=order, subok=1) + return array(a, dtype, copy=False, order=order, subok=True) def ascontiguousarray(a, dtype=None): """Return 'a' as an array contiguous in memory (C order). @@ -182,9 +182,9 @@ def require(a, dtype=None, requirements=None): return asanyarray(a, dtype=dtype) if 'ENSUREARRAY' in requirements or 'E' in requirements: - subok = 0 + subok = False else: - subok = 1 + subok = True arr = array(a, dtype=dtype, copy=False, subok=subok) @@ -342,12 +342,12 @@ def tensordot(a, b, axes=2): nda = len(a.shape) bs = b.shape ndb = len(b.shape) - equal = 1 - if (na != nb): equal = 0 + equal = True + if (na != nb): equal = False else: for k in xrange(na): if as_[axes_a[k]] != bs[axes_b[k]]: - equal = 0 + equal = False break if axes_a[k] < 0: axes_a[k] += nda @@ -391,10 +391,10 @@ def roll(a, shift, axis=None): a = asanyarray(a) if axis is None: n = a.size - reshape=1 + reshape = True else: n = a.shape[axis] - reshape=0 + reshape = False shift %= n indexes = concatenate((arange(n-shift,n),arange(n-shift))) res = a.take(indexes, axis) @@ -728,10 +728,10 @@ def array_equal(a1, a2): try: a1, a2 = asarray(a1), asarray(a2) except: - return 0 + return False if a1.shape != a2.shape: - return 0 - return logical_and.reduce(equal(a1,a2).ravel()) + return False + return bool(logical_and.reduce(equal(a1,a2).ravel())) def array_equiv(a1, a2): """Returns True if a1 and a2 are shape consistent @@ -741,11 +741,11 @@ def array_equiv(a1, a2): try: a1, a2 = asarray(a1), asarray(a2) except: - return 0 + return False try: - return logical_and.reduce(equal(a1,a2).ravel()) + return bool(logical_and.reduce(equal(a1,a2).ravel())) except ValueError: - return 0 + return False _errdict = {"ignore":ERR_IGNORE, diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 23e6cf30a..24334dacd 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -256,6 +256,52 @@ class TestBinaryRepr(TestCase): assert_equal(binary_repr(-1, width=8), '11111111') +class TestArrayComparisons(TestCase): + def test_array_equal(self): + res = array_equal(array([1,2]), array([1,2])) + assert res + assert type(res) is bool + res = array_equal(array([1,2]), array([1,2,3])) + assert not res + assert type(res) is bool + res = array_equal(array([1,2]), array([3,4])) + assert not res + assert type(res) is bool + res = array_equal(array([1,2]), array([1,3])) + assert not res + assert type(res) is bool + + def test_array_equiv(self): + res = array_equiv(array([1,2]), array([1,2])) + assert res + assert type(res) is bool + res = array_equiv(array([1,2]), array([1,2,3])) + assert not res + assert type(res) is bool + res = array_equiv(array([1,2]), array([3,4])) + assert not res + assert type(res) is bool + res = array_equiv(array([1,2]), array([1,3])) + assert not res + assert type(res) is bool + + res = array_equiv(array([1,1]), array([1])) + assert res + assert type(res) is bool + res = array_equiv(array([1,1]), array([[1],[1]])) + assert res + assert type(res) is bool + res = array_equiv(array([1,2]), array([2])) + assert not res + assert type(res) is bool + res = array_equiv(array([1,2]), array([[1],[2]])) + assert not res + assert type(res) is bool + res = array_equiv(array([1,2]), array([[1,2,3],[4,5,6],[7,8,9]])) + assert not res + assert type(res) is bool + + def assert_array_strict_equal(x, y): assert_array_equal(x, y) # Check flags |