summaryrefslogtreecommitdiff
path: root/numpy/random/tests
diff options
context:
space:
mode:
authorKevin Sheppard <kevin.k.sheppard@gmail.com>2019-05-18 23:47:51 +0100
committermattip <matti.picus@gmail.com>2019-05-20 19:00:38 +0300
commitb42a5ca0a076b40c612014dc540ca5f9bcf10f41 (patch)
tree339ed4a67923d7f27357b2aa6d997b14990efe26 /numpy/random/tests
parent17e0070df93f4262908f884dca4b08cb7d0bba7f (diff)
downloadnumpy-b42a5ca0a076b40c612014dc540ca5f9bcf10f41.tar.gz
BUG: Ensure integer-type stream on 32bit
Ensure integer type is stream compatible on 32 bit Fix incorrect clause end Add integer-generator tests that check long streams
Diffstat (limited to 'numpy/random/tests')
-rw-r--r--numpy/random/tests/test_generator_mt19937.py10
-rw-r--r--numpy/random/tests/test_randomstate.py61
2 files changed, 65 insertions, 6 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 3bb3bd791..770c32a39 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -182,10 +182,10 @@ class TestIntegers(object):
lbnd = 0 if dt is bool else np.iinfo(dt).min
ubnd = 2 if dt is bool else np.iinfo(dt).max + (not endpoint)
- assert_raises(ValueError, self.rfunc, [
- lbnd - 1] * 2, [ubnd] * 2, endpoint=endpoint, dtype=dt)
- assert_raises(ValueError, self.rfunc, [
- lbnd] * 2, [ubnd + 1] * 2, endpoint=endpoint, dtype=dt)
+ assert_raises(ValueError, self.rfunc, [lbnd - 1] * 2, [ubnd] * 2,
+ endpoint=endpoint, dtype=dt)
+ assert_raises(ValueError, self.rfunc, [lbnd] * 2,
+ [ubnd + 1] * 2, endpoint=endpoint, dtype=dt)
assert_raises(ValueError, self.rfunc, ubnd, [lbnd] * 2,
endpoint=endpoint, dtype=dt)
assert_raises(ValueError, self.rfunc, [1] * 2, 0,
@@ -1895,7 +1895,7 @@ class TestBroadcast(object):
[4, 5, 1, 4, 3, 3]],
[[1, 1, 1, 0, 0, 2],
[2, 0, 4, 3, 7, 4]],
- [[1, 2, 0, 0, 2, 2],
+ [[1, 2, 0, 0, 2, 0],
[3, 2, 3, 4, 2, 6]]], dtype=np.int64)
assert_array_equal(actual, desired)
diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py
index 0c57b9aaf..75c35ef62 100644
--- a/numpy/random/tests/test_randomstate.py
+++ b/numpy/random/tests/test_randomstate.py
@@ -1,8 +1,10 @@
+import hashlib
import pickle
import sys
import warnings
import numpy as np
+import pytest
from numpy.testing import (
assert_, assert_raises, assert_equal, assert_warns,
assert_no_warnings, assert_array_equal, assert_array_almost_equal,
@@ -11,6 +13,44 @@ from numpy.testing import (
from numpy.random import MT19937, Xoshiro256, mtrand as random
+INT_FUNCS = {'binomial': (100.0, 0.6),
+ 'geometric': (.5,),
+ 'hypergeometric': (20, 20, 10),
+ 'logseries': (.5,),
+ 'multinomial': (20, np.ones(6) / 6.0),
+ 'negative_binomial': (100, .5),
+ 'poisson': (10.0,),
+ 'zipf': (2,),
+ }
+
+if np.iinfo(int).max < 2**32:
+ # Windows and some 32-bit platforms, e.g., ARM
+ INT_FUNC_HASHES = {'binomial': '670e1c04223ffdbab27e08fbbad7bdba',
+ 'logseries': '6bd0183d2f8030c61b0d6e11aaa60caf',
+ 'geometric': '6e9df886f3e1e15a643168568d5280c0',
+ 'hypergeometric': '7964aa611b046aecd33063b90f4dec06',
+ 'multinomial': '68a0b049c16411ed0aa4aff3572431e4',
+ 'negative_binomial': 'dc265219eec62b4338d39f849cd36d09',
+ 'poisson': '7b4dce8e43552fc82701c2fa8e94dc6e',
+ 'zipf': 'fcd2a2095f34578723ac45e43aca48c5',
+ }
+else:
+ INT_FUNC_HASHES = {'binomial': 'b5f8dcd74f172836536deb3547257b14',
+ 'geometric': '8814571f45c87c59699d62ccd3d6c350',
+ 'hypergeometric': 'bc64ae5976eac452115a16dad2dcf642',
+ 'logseries': '84be924b37485a27c4a98797bc88a7a4',
+ 'multinomial': 'ec3c7f9cf9664044bb0c6fb106934200',
+ 'negative_binomial': '210533b2234943591364d0117a552969',
+ 'poisson': '0536a8850c79da0c78defd742dccc3e0',
+ 'zipf': 'f2841f504dd2525cd67cdcad7561e532',
+ }
+
+
+@pytest.fixture(scope='module', params=INT_FUNCS)
+def int_func(request):
+ return (request.param, INT_FUNCS[request.param],
+ INT_FUNC_HASHES[request.param])
+
def assert_mt19937_state_equal(a, b):
assert_equal(a['bit_generator'], b['bit_generator'])
@@ -269,7 +309,6 @@ class TestRandint(object):
assert_(vals.min() >= 0)
def test_repeatability(self):
- import hashlib
# We use a md5 hash of generated sequences of 1000 samples
# in the range [0, 6) for all but bool, where the range
# is [0, 2). Hashes are for little endian numbers.
@@ -1862,3 +1901,23 @@ class TestSingleEltArrayInput(object):
out = func(self.argOne, self.argTwo[0], self.argThree)
assert_equal(out.shape, self.tgtShape)
+
+
+# Ensure returned array dtype is corect for platform
+def test_integer_dtype(int_func):
+ random.seed(123456789)
+ fname, args, md5 = int_func
+ f = getattr(random, fname)
+ actual = f(*args, size=2)
+ assert_(actual.dtype == np.dtype('l'))
+
+
+def test_integer_repeat(int_func):
+ random.seed(123456789)
+ fname, args, md5 = int_func
+ f = getattr(random, fname)
+ val = f(*args, size=1000000)
+ if sys.byteorder != 'little':
+ val = val.byteswap()
+ res = hashlib.md5(val.view(np.int8)).hexdigest()
+ assert_(res == md5)