summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2014-02-12 11:57:27 +0100
committerSebastian Berg <sebastian@sipsolutions.net>2014-02-16 00:00:22 +0100
commitab04e1ae0e8eca717bc7e42f3b0a60c9ff764289 (patch)
tree49ea02f820c4ee3eb484578abd0078f543ef4898
parent58e9e27c0c110f9be1558a53fb547dc1abc76fa4 (diff)
downloadnumpy-ab04e1ae0e8eca717bc7e42f3b0a60c9ff764289.tar.gz
BUG: Force allclose logic to use inexact type
Casting y to an inexact type fixes problems such as abs(MIN_INT) < 0, and generally makes sense since the allclose logic is inherently for float types.
-rw-r--r--numpy/core/numeric.py12
-rw-r--r--numpy/core/tests/test_numeric.py7
-rw-r--r--numpy/ma/core.py25
-rw-r--r--numpy/ma/tests/test_core.py4
-rw-r--r--numpy/testing/tests/test_utils.py8
-rw-r--r--numpy/testing/utils.py14
6 files changed, 47 insertions, 23 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 6b078ae31..3b8e52e71 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -2139,6 +2139,11 @@ def allclose(a, b, rtol=1.e-5, atol=1.e-8):
x = array(a, copy=False, ndmin=1)
y = array(b, copy=False, ndmin=1)
+ # make sure y is an inexact type to avoid abs(MIN_INT); will cause
+ # casting of x later.
+ dtype = multiarray.result_type(y, 1.)
+ y = array(y, dtype=dtype, copy=False)
+
xinf = isinf(x)
yinf = isinf(y)
if any(xinf) or any(yinf):
@@ -2154,12 +2159,7 @@ def allclose(a, b, rtol=1.e-5, atol=1.e-8):
# ignore invalid fpe's
with errstate(invalid='ignore'):
- if not x.dtype.kind == 'b' and not y.dtype.kind == 'b':
- diff = abs(x - y)
- else:
- diff = x ^ y
-
- r = all(less_equal(diff, atol + rtol * abs(y)))
+ r = all(less_equal(abs(x - y), atol + rtol * abs(y)))
return r
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index ac341468c..12a39a522 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -1420,6 +1420,13 @@ class TestAllclose(object):
assert_array_equal(y, array([0, inf]))
+ def test_min_int(self):
+ # Could make problems because of abs(min_int) == min_int
+ min_int = np.iinfo(np.int_).min
+ a = np.array([min_int], dtype=np.int_)
+ assert_(allclose(a, a))
+
+
class TestIsclose(object):
rtol = 1e-5
atol = 1e-8
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index c62e55c45..16df1ea76 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -6916,6 +6916,13 @@ def allclose (a, b, masked_equal=True, rtol=1e-5, atol=1e-8):
"""
x = masked_array(a, copy=False)
y = masked_array(b, copy=False)
+
+ # make sure y is an inexact type to avoid abs(MIN_INT); will cause
+ # casting of x later.
+ dtype = np.result_type(y, 1.)
+ if y.dtype != dtype:
+ y = masked_array(y, dtype=dtype, copy=False)
+
m = mask_or(getmask(x), getmask(y))
xinf = np.isinf(masked_array(x, copy=False, mask=m)).filled(False)
# If we have some infs, they should fall at the same place.
@@ -6923,26 +6930,20 @@ def allclose (a, b, masked_equal=True, rtol=1e-5, atol=1e-8):
return False
# No infs at all
if not np.any(xinf):
- if not x.dtype.kind == 'b' and not y.dtype.kind == 'b':
- diff = umath.absolute(x - y)
- else:
- diff = x ^ y
-
- d = filled(umath.less_equal(diff, atol + rtol * umath.absolute(y)),
+ d = filled(umath.less_equal(umath.absolute(x - y),
+ atol + rtol * umath.absolute(y)),
masked_equal)
return np.all(d)
+
if not np.all(filled(x[xinf] == y[xinf], masked_equal)):
return False
x = x[~xinf]
y = y[~xinf]
- if not x.dtype.kind == 'b' and not y.dtype.kind == 'b':
- diff = umath.absolute(x - y)
- else:
- diff = x ^ y
-
- d = filled(umath.less_equal(diff, atol + rtol * umath.absolute(y)),
+ d = filled(umath.less_equal(umath.absolute(x - y),
+ atol + rtol * umath.absolute(y)),
masked_equal)
+
return np.all(d)
#..............................................................................
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 8d8e1c947..19f13a8c4 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1995,6 +1995,10 @@ class TestMaskedArrayMethods(TestCase):
a[0] = 0
self.assertTrue(allclose(a, 0, masked_equal=True))
+ # Test that the function works for MIN_INT integer typed arrays
+ a = masked_array([np.iinfo(np.int_).min], dtype=np.int_)
+ self.assertTrue(allclose(a, a))
+
def test_allany(self):
# Checks the any/all methods/functions.
x = np.array([[0.13, 0.26, 0.90],
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 94fc4d655..5956a4294 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -53,6 +53,9 @@ class _GenericTest(object):
a = np.array([1, 1], dtype=np.object)
self._test_equal(a, 1)
+ def test_array_likes(self):
+ self._test_equal([1, 2, 3], (1, 2, 3))
+
class TestArrayEqual(_GenericTest, unittest.TestCase):
def setUp(self):
self._assert_func = assert_array_equal
@@ -373,6 +376,11 @@ class TestAssertAllclose(unittest.TestCase):
assert_allclose(6, 10, rtol=0.5)
self.assertRaises(AssertionError, assert_allclose, 10, 6, rtol=0.5)
+ def test_min_int(self):
+ a = np.array([np.iinfo(np.int_).min], dtype=np.int_)
+ # Should not raise:
+ assert_allclose(a, a)
+
class TestArrayAlmostEqualNulp(unittest.TestCase):
@dec.knownfailureif(True, "Github issue #347")
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 82aa1e39c..97908c7e8 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -793,7 +793,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
y: array([ 1. , 2.33333, 5. ])
"""
- from numpy.core import around, number, float_
+ from numpy.core import around, number, float_, result_type, array
from numpy.core.numerictypes import issubdtype
from numpy.core.fromnumeric import any as npany
def compare(x, y):
@@ -811,17 +811,21 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
except (TypeError, NotImplementedError):
pass
- if x.dtype.kind == 'b' and y.dtype.kind == 'b':
- z = x ^ y
- else:
- z = abs(x-y)
+ # make sure y is an inexact type to avoid abs(MIN_INT); will cause
+ # casting of x later.
+ dtype = result_type(y, 1.)
+ y = array(y, dtype=dtype, copy=False)
+ z = abs(x-y)
if not issubdtype(z.dtype, number):
z = z.astype(float_) # handle object arrays
+
return around(z, decimal) <= 10.0**(-decimal)
+
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
header=('Arrays are not almost equal to %d decimals' % decimal))
+
def assert_array_less(x, y, err_msg='', verbose=True):
"""
Raise an assertion if two array_like objects are not ordered by less than.