summaryrefslogtreecommitdiff
path: root/numpy/random/tests
diff options
context:
space:
mode:
authorKexuan Sun <me@kianasun.com>2019-09-14 16:55:54 -0700
committerKexuan Sun <me@kianasun.com>2019-09-14 16:55:54 -0700
commitc99a16677bf72073af52ac0e028bd0f1487cc338 (patch)
treeec0b7078de9af6044c86ea4b57b9214deca1bb6a /numpy/random/tests
parent31ffdecf07d18ed4dbb66b171cb0f998d4b190fa (diff)
downloadnumpy-c99a16677bf72073af52ac0e028bd0f1487cc338.tar.gz
ENH: Improve permutation and shuffle functions on a given axis
Diffstat (limited to 'numpy/random/tests')
-rw-r--r--numpy/random/tests/test_generator_mt19937.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 853d86fba..20bc10cd0 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -732,6 +732,20 @@ class TestRandomDist(object):
desired = conv([4, 1, 9, 8, 0, 5, 3, 6, 2, 7])
assert_array_equal(actual, desired)
+ def test_shuffle_custom_axis(self):
+ random = Generator(MT19937(self.seed))
+ actual = np.arange(16).reshape((4, 4))
+ random.shuffle(actual, axis=1)
+ desired = np.array([[ 0, 3, 1, 2],
+ [ 4, 7, 5, 6],
+ [ 8, 11, 9, 10],
+ [12, 15, 13, 14]])
+ assert_array_equal(actual, desired)
+ random = Generator(MT19937(self.seed))
+ actual = np.arange(16).reshape((4, 4))
+ random.shuffle(actual, axis=-1)
+ assert_array_equal(actual, desired)
+
def test_shuffle_masked(self):
# gh-3263
a = np.ma.masked_values(np.reshape(range(20), (5, 4)) % 3 - 1, -1)
@@ -746,6 +760,16 @@ class TestRandomDist(object):
assert_equal(
sorted(b.data[~b.mask]), sorted(b_orig.data[~b_orig.mask]))
+ def test_shuffle_exceptions(self):
+ random = Generator(MT19937(self.seed))
+ arr = np.arange(10)
+ assert_raises(np.AxisError, random.shuffle, arr, 1)
+ arr = np.arange(9).reshape((3, 3))
+ assert_raises(np.AxisError, random.shuffle, arr, 3)
+ assert_raises(TypeError, random.shuffle, arr, slice(1, 2, None))
+ arr = [[1, 2, 3], [4, 5, 6]]
+ assert_raises(NotImplementedError, random.shuffle, arr, 1)
+
def test_permutation(self):
random = Generator(MT19937(self.seed))
alist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
@@ -771,6 +795,27 @@ class TestRandomDist(object):
actual = random.permutation(integer_val)
assert_array_equal(actual, desired)
+ def test_permutation_custom_axis(self):
+ a = np.arange(16).reshape((4, 4))
+ desired = np.array([[ 0, 3, 1, 2],
+ [ 4, 7, 5, 6],
+ [ 8, 11, 9, 10],
+ [12, 15, 13, 14]])
+ random = Generator(MT19937(self.seed))
+ actual = random.permutation(a, axis=1)
+ assert_array_equal(actual, desired)
+ random = Generator(MT19937(self.seed))
+ actual = random.permutation(a, axis=-1)
+ assert_array_equal(actual, desired)
+
+ def test_permutation_exceptions(self):
+ random = Generator(MT19937(self.seed))
+ arr = np.arange(10)
+ assert_raises(np.AxisError, random.permutation, arr, 1)
+ arr = np.arange(9).reshape((3, 3))
+ assert_raises(np.AxisError, random.permutation, arr, 3)
+ assert_raises(TypeError, random.permutation, arr, slice(1, 2, None))
+
def test_beta(self):
random = Generator(MT19937(self.seed))
actual = random.beta(.1, .9, size=(3, 2))