diff options
author | Kexuan Sun <me@kianasun.com> | 2019-09-14 16:55:54 -0700 |
---|---|---|
committer | Kexuan Sun <me@kianasun.com> | 2019-09-14 16:55:54 -0700 |
commit | c99a16677bf72073af52ac0e028bd0f1487cc338 (patch) | |
tree | ec0b7078de9af6044c86ea4b57b9214deca1bb6a /numpy/random/tests | |
parent | 31ffdecf07d18ed4dbb66b171cb0f998d4b190fa (diff) | |
download | numpy-c99a16677bf72073af52ac0e028bd0f1487cc338.tar.gz |
ENH: Improve permutation and shuffle functions on a given axis
Diffstat (limited to 'numpy/random/tests')
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index 853d86fba..20bc10cd0 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -732,6 +732,20 @@ class TestRandomDist(object): desired = conv([4, 1, 9, 8, 0, 5, 3, 6, 2, 7]) assert_array_equal(actual, desired) + def test_shuffle_custom_axis(self): + random = Generator(MT19937(self.seed)) + actual = np.arange(16).reshape((4, 4)) + random.shuffle(actual, axis=1) + desired = np.array([[ 0, 3, 1, 2], + [ 4, 7, 5, 6], + [ 8, 11, 9, 10], + [12, 15, 13, 14]]) + assert_array_equal(actual, desired) + random = Generator(MT19937(self.seed)) + actual = np.arange(16).reshape((4, 4)) + random.shuffle(actual, axis=-1) + assert_array_equal(actual, desired) + def test_shuffle_masked(self): # gh-3263 a = np.ma.masked_values(np.reshape(range(20), (5, 4)) % 3 - 1, -1) @@ -746,6 +760,16 @@ class TestRandomDist(object): assert_equal( sorted(b.data[~b.mask]), sorted(b_orig.data[~b_orig.mask])) + def test_shuffle_exceptions(self): + random = Generator(MT19937(self.seed)) + arr = np.arange(10) + assert_raises(np.AxisError, random.shuffle, arr, 1) + arr = np.arange(9).reshape((3, 3)) + assert_raises(np.AxisError, random.shuffle, arr, 3) + assert_raises(TypeError, random.shuffle, arr, slice(1, 2, None)) + arr = [[1, 2, 3], [4, 5, 6]] + assert_raises(NotImplementedError, random.shuffle, arr, 1) + def test_permutation(self): random = Generator(MT19937(self.seed)) alist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0] @@ -771,6 +795,27 @@ class TestRandomDist(object): actual = random.permutation(integer_val) assert_array_equal(actual, desired) + def test_permutation_custom_axis(self): + a = np.arange(16).reshape((4, 4)) + desired = np.array([[ 0, 3, 1, 2], + [ 4, 7, 5, 6], + [ 8, 11, 9, 10], + [12, 15, 13, 14]]) + random = Generator(MT19937(self.seed)) + actual = random.permutation(a, axis=1) + assert_array_equal(actual, desired) + random = Generator(MT19937(self.seed)) + actual = random.permutation(a, axis=-1) + assert_array_equal(actual, desired) + + def test_permutation_exceptions(self): + random = Generator(MT19937(self.seed)) + arr = np.arange(10) + assert_raises(np.AxisError, random.permutation, arr, 1) + arr = np.arange(9).reshape((3, 3)) + assert_raises(np.AxisError, random.permutation, arr, 3) + assert_raises(TypeError, random.permutation, arr, slice(1, 2, None)) + def test_beta(self): random = Generator(MT19937(self.seed)) actual = random.beta(.1, .9, size=(3, 2)) |