summaryrefslogtreecommitdiff
path: root/numpy/random/tests
diff options
context:
space:
mode:
authorKevin Sheppard <kevin.k.sheppard@gmail.com>2019-04-13 15:35:52 +0100
committermattip <matti.picus@gmail.com>2019-05-20 18:45:27 +0300
commitb2f9bea1aaa1dfe7c1b007902160ee434762351c (patch)
tree4c7b4b0ff678059a26dbb4835285c3d99ff170d7 /numpy/random/tests
parent0f931b344b2723a21921b927d6bf5548ddd26d02 (diff)
downloadnumpy-b2f9bea1aaa1dfe7c1b007902160ee434762351c.tar.gz
ENH: Improvce choice without replacement
Improve performance in all cases Large improvement with size is small xref numpy/numpy#5299 xref numpy/numpy#2764 xref numpy/numpy#9855 xref numpy/numpy#7810
Diffstat (limited to 'numpy/random/tests')
-rw-r--r--numpy/random/tests/test_against_numpy.py1
-rw-r--r--numpy/random/tests/test_generator_mt19937.py32
2 files changed, 28 insertions, 5 deletions
diff --git a/numpy/random/tests/test_against_numpy.py b/numpy/random/tests/test_against_numpy.py
index 9846ba38f..3c83932e8 100644
--- a/numpy/random/tests/test_against_numpy.py
+++ b/numpy/random/tests/test_against_numpy.py
@@ -328,6 +328,7 @@ class TestAgainstNumPy(object):
g(100, np.array(p), size=(7, 23)))
self._is_state_common()
+ @pytest.mark.xfail(reason='Stream broken for performance')
def test_choice(self):
self._set_common_state()
self._is_state_common()
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index d76291d1a..d591ca339 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -542,25 +542,25 @@ class TestRandomDist(object):
def test_choice_uniform_replace(self):
random.brng.seed(self.seed)
actual = random.choice(4, 4)
- desired = np.array([2, 3, 2, 3])
+ desired = np.array([2, 3, 2, 3], dtype=np.int64)
assert_array_equal(actual, desired)
def test_choice_nonuniform_replace(self):
random.brng.seed(self.seed)
actual = random.choice(4, 4, p=[0.4, 0.4, 0.1, 0.1])
- desired = np.array([1, 1, 2, 2])
+ desired = np.array([1, 1, 2, 2], dtype=np.int64)
assert_array_equal(actual, desired)
def test_choice_uniform_noreplace(self):
random.brng.seed(self.seed)
actual = random.choice(4, 3, replace=False)
- desired = np.array([0, 1, 3])
+ desired = np.array([0, 2, 3], dtype=np.int64)
assert_array_equal(actual, desired)
def test_choice_nonuniform_noreplace(self):
random.brng.seed(self.seed)
actual = random.choice(4, 3, replace=False, p=[0.1, 0.3, 0.5, 0.1])
- desired = np.array([2, 3, 1])
+ desired = np.array([2, 3, 1], dtype=np.int64)
assert_array_equal(actual, desired)
def test_choice_noninteger(self):
@@ -569,11 +569,22 @@ class TestRandomDist(object):
desired = np.array(['c', 'd', 'c', 'd'])
assert_array_equal(actual, desired)
+ def test_choice_multidimensional_default_axis(self):
+ random.brng.seed(self.seed)
+ actual = random.choice([[0, 1], [2, 3], [4, 5], [6, 7]], 3)
+ desired = np.array([[4, 5], [6, 7], [4, 5]])
+ assert_array_equal(actual, desired)
+
+ def test_choice_multidimensional_custom_axis(self):
+ random.brng.seed(self.seed)
+ actual = random.choice([[0, 1], [2, 3], [4, 5], [6, 7]], 1, axis=1)
+ desired = np.array([[0], [2], [4], [6]])
+ assert_array_equal(actual, desired)
+
def test_choice_exceptions(self):
sample = random.choice
assert_raises(ValueError, sample, -1, 3)
assert_raises(ValueError, sample, 3., 3)
- assert_raises(ValueError, sample, [[1, 2], [3, 4]], 3)
assert_raises(ValueError, sample, [], 3)
assert_raises(ValueError, sample, [1, 2, 3, 4], 3,
p=[[0.25, 0.25], [0.25, 0.25]])
@@ -651,6 +662,17 @@ class TestRandomDist(object):
actual = random.choice(4, 2, p=p, replace=False)
assert actual.dtype == np.int64
+ def test_choice_large_sample(self):
+ import hashlib
+
+ choice_hash = '6395868be877d27518c832213c17977c'
+ random.brng.seed(self.seed)
+ actual = random.choice(10000, 5000, replace=False)
+ if sys.byteorder != 'little':
+ actual = actual.byteswap()
+ res = hashlib.md5(actual.view(np.int8)).hexdigest()
+ assert_(choice_hash == res)
+
def test_bytes(self):
random.brng.seed(self.seed)
actual = random.bytes(10)