diff options
author | Kevin Sheppard <kevin.k.sheppard@gmail.com> | 2019-04-13 15:35:52 +0100 |
---|---|---|
committer | mattip <matti.picus@gmail.com> | 2019-05-20 18:45:27 +0300 |
commit | b2f9bea1aaa1dfe7c1b007902160ee434762351c (patch) | |
tree | 4c7b4b0ff678059a26dbb4835285c3d99ff170d7 /numpy/random/tests | |
parent | 0f931b344b2723a21921b927d6bf5548ddd26d02 (diff) | |
download | numpy-b2f9bea1aaa1dfe7c1b007902160ee434762351c.tar.gz |
ENH: Improvce choice without replacement
Improve performance in all cases
Large improvement with size is small
xref numpy/numpy#5299
xref numpy/numpy#2764
xref numpy/numpy#9855
xref numpy/numpy#7810
Diffstat (limited to 'numpy/random/tests')
-rw-r--r-- | numpy/random/tests/test_against_numpy.py | 1 | ||||
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 32 |
2 files changed, 28 insertions, 5 deletions
diff --git a/numpy/random/tests/test_against_numpy.py b/numpy/random/tests/test_against_numpy.py index 9846ba38f..3c83932e8 100644 --- a/numpy/random/tests/test_against_numpy.py +++ b/numpy/random/tests/test_against_numpy.py @@ -328,6 +328,7 @@ class TestAgainstNumPy(object): g(100, np.array(p), size=(7, 23))) self._is_state_common() + @pytest.mark.xfail(reason='Stream broken for performance') def test_choice(self): self._set_common_state() self._is_state_common() diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index d76291d1a..d591ca339 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -542,25 +542,25 @@ class TestRandomDist(object): def test_choice_uniform_replace(self): random.brng.seed(self.seed) actual = random.choice(4, 4) - desired = np.array([2, 3, 2, 3]) + desired = np.array([2, 3, 2, 3], dtype=np.int64) assert_array_equal(actual, desired) def test_choice_nonuniform_replace(self): random.brng.seed(self.seed) actual = random.choice(4, 4, p=[0.4, 0.4, 0.1, 0.1]) - desired = np.array([1, 1, 2, 2]) + desired = np.array([1, 1, 2, 2], dtype=np.int64) assert_array_equal(actual, desired) def test_choice_uniform_noreplace(self): random.brng.seed(self.seed) actual = random.choice(4, 3, replace=False) - desired = np.array([0, 1, 3]) + desired = np.array([0, 2, 3], dtype=np.int64) assert_array_equal(actual, desired) def test_choice_nonuniform_noreplace(self): random.brng.seed(self.seed) actual = random.choice(4, 3, replace=False, p=[0.1, 0.3, 0.5, 0.1]) - desired = np.array([2, 3, 1]) + desired = np.array([2, 3, 1], dtype=np.int64) assert_array_equal(actual, desired) def test_choice_noninteger(self): @@ -569,11 +569,22 @@ class TestRandomDist(object): desired = np.array(['c', 'd', 'c', 'd']) assert_array_equal(actual, desired) + def test_choice_multidimensional_default_axis(self): + random.brng.seed(self.seed) + actual = random.choice([[0, 1], [2, 3], [4, 5], [6, 7]], 3) + desired = np.array([[4, 5], [6, 7], [4, 5]]) + assert_array_equal(actual, desired) + + def test_choice_multidimensional_custom_axis(self): + random.brng.seed(self.seed) + actual = random.choice([[0, 1], [2, 3], [4, 5], [6, 7]], 1, axis=1) + desired = np.array([[0], [2], [4], [6]]) + assert_array_equal(actual, desired) + def test_choice_exceptions(self): sample = random.choice assert_raises(ValueError, sample, -1, 3) assert_raises(ValueError, sample, 3., 3) - assert_raises(ValueError, sample, [[1, 2], [3, 4]], 3) assert_raises(ValueError, sample, [], 3) assert_raises(ValueError, sample, [1, 2, 3, 4], 3, p=[[0.25, 0.25], [0.25, 0.25]]) @@ -651,6 +662,17 @@ class TestRandomDist(object): actual = random.choice(4, 2, p=p, replace=False) assert actual.dtype == np.int64 + def test_choice_large_sample(self): + import hashlib + + choice_hash = '6395868be877d27518c832213c17977c' + random.brng.seed(self.seed) + actual = random.choice(10000, 5000, replace=False) + if sys.byteorder != 'little': + actual = actual.byteswap() + res = hashlib.md5(actual.view(np.int8)).hexdigest() + assert_(choice_hash == res) + def test_bytes(self): random.brng.seed(self.seed) actual = random.bytes(10) |