summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-07-10 11:37:31 -0600
committerCharles Harris <charlesr.harris@gmail.com>2016-07-10 14:06:35 -0600
commit2d24f8a41da440b53ffe953de14c2e1117df07f0 (patch)
treef0cd5abc634ad0d708bd10274586220b3f79b13b
parentd3e3d91d401f5e47a81e748cd7fb12620b45947a (diff)
downloadnumpy-2d24f8a41da440b53ffe953de14c2e1117df07f0.tar.gz
BUG: Make sure npy_mul_with_overflow_<type> detects overflow.
The previous check for overflow would fail with signed integers as it was designed to check for overflow of the larger corresponding unsigned type.
-rw-r--r--numpy/core/src/private/templ_common.h.src9
-rw-r--r--numpy/core/tests/test_multiarray.py15
2 files changed, 21 insertions, 3 deletions
diff --git a/numpy/core/src/private/templ_common.h.src b/numpy/core/src/private/templ_common.h.src
index dd6c7bf23..a65a00758 100644
--- a/numpy/core/src/private/templ_common.h.src
+++ b/numpy/core/src/private/templ_common.h.src
@@ -17,6 +17,10 @@
/*
* writes result of a * b into r
* returns 1 if a * b overflowed else returns 0
+ *
+ * These functions are not designed to work if either a or b is negative, but
+ * that is not checked. Could use absolute values and adjust the sign if that
+ * functionality was desired.
*/
static NPY_INLINE int
npy_mul_with_overflow_@name@(@type@ * r, @type@ a, @type@ b)
@@ -24,17 +28,16 @@ npy_mul_with_overflow_@name@(@type@ * r, @type@ a, @type@ b)
#ifdef HAVE___BUILTIN_MUL_OVERFLOW
return __builtin_mul_overflow(a, b, r);
#else
- const @type@ half_sz = (((@type@)1 << (sizeof(a) * 8 / 2)) - 1);
+ const @type@ half_sz = ((@type@)1 << ((sizeof(a) * 8 - 1 ) / 2));
*r = a * b;
/*
* avoid expensive division on common no overflow case
*/
if (NPY_UNLIKELY((a | b) >= half_sz) &&
- a != 0 && b > @MAX@ / a) {
+ a != 0 && b > @MAX@ / a) {
return 1;
}
-
return 0;
#endif
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index ca13a172f..542dfa843 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -564,6 +564,21 @@ class TestCreation(TestCase):
arr = np.array([], dtype='V')
assert_equal(arr.dtype.kind, 'V')
+ def test_too_big_error(self):
+ # 45341 is the smallest integer greater than sqrt(2**31 - 1).
+ # 3037000500 is the smallest integer greater than sqrt(2**63 - 1).
+ # We want to make sure that the square byte array with those dimensions
+ # is too big on 32 or 64 bit systems respectively.
+ if np.iinfo('intp').max == 2**31 - 1:
+ shape = (46341, 46341)
+ elif np.iinfo('intp').max == 2**63 - 1:
+ shape = (3037000500, 3037000500)
+ else:
+ return
+ assert_raises(ValueError, np.empty, shape, dtype=np.int8)
+ assert_raises(ValueError, np.zeros, shape, dtype=np.int8)
+ assert_raises(ValueError, np.ones, shape, dtype=np.int8)
+
def test_zeros(self):
types = np.typecodes['AllInteger'] + np.typecodes['AllFloat']
for dt in types: