summaryrefslogtreecommitdiff
path: root/numpy/random/tests
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2019-09-16 09:35:38 +0300
committerGitHub <noreply@github.com>2019-09-16 09:35:38 +0300
commitf8694923d0a4b241369726bfda8bbea29099a848 (patch)
tree08442be17c4303d109abf7a6e1edb8e7195a2298 /numpy/random/tests
parent33d5f9c23da18ed3b5ed5b8efc5c26297ed2d7ea (diff)
parentc99a16677bf72073af52ac0e028bd0f1487cc338 (diff)
downloadnumpy-f8694923d0a4b241369726bfda8bbea29099a848.tar.gz
Merge pull request #13829 from kianasun/add-permutation-axis
ENH: Add axis argument to random.permutation and random.shuffle
Diffstat (limited to 'numpy/random/tests')
-rw-r--r--numpy/random/tests/test_generator_mt19937.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 853d86fba..20bc10cd0 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -732,6 +732,20 @@ class TestRandomDist(object):
desired = conv([4, 1, 9, 8, 0, 5, 3, 6, 2, 7])
assert_array_equal(actual, desired)
+ def test_shuffle_custom_axis(self):
+ random = Generator(MT19937(self.seed))
+ actual = np.arange(16).reshape((4, 4))
+ random.shuffle(actual, axis=1)
+ desired = np.array([[ 0, 3, 1, 2],
+ [ 4, 7, 5, 6],
+ [ 8, 11, 9, 10],
+ [12, 15, 13, 14]])
+ assert_array_equal(actual, desired)
+ random = Generator(MT19937(self.seed))
+ actual = np.arange(16).reshape((4, 4))
+ random.shuffle(actual, axis=-1)
+ assert_array_equal(actual, desired)
+
def test_shuffle_masked(self):
# gh-3263
a = np.ma.masked_values(np.reshape(range(20), (5, 4)) % 3 - 1, -1)
@@ -746,6 +760,16 @@ class TestRandomDist(object):
assert_equal(
sorted(b.data[~b.mask]), sorted(b_orig.data[~b_orig.mask]))
+ def test_shuffle_exceptions(self):
+ random = Generator(MT19937(self.seed))
+ arr = np.arange(10)
+ assert_raises(np.AxisError, random.shuffle, arr, 1)
+ arr = np.arange(9).reshape((3, 3))
+ assert_raises(np.AxisError, random.shuffle, arr, 3)
+ assert_raises(TypeError, random.shuffle, arr, slice(1, 2, None))
+ arr = [[1, 2, 3], [4, 5, 6]]
+ assert_raises(NotImplementedError, random.shuffle, arr, 1)
+
def test_permutation(self):
random = Generator(MT19937(self.seed))
alist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
@@ -771,6 +795,27 @@ class TestRandomDist(object):
actual = random.permutation(integer_val)
assert_array_equal(actual, desired)
+ def test_permutation_custom_axis(self):
+ a = np.arange(16).reshape((4, 4))
+ desired = np.array([[ 0, 3, 1, 2],
+ [ 4, 7, 5, 6],
+ [ 8, 11, 9, 10],
+ [12, 15, 13, 14]])
+ random = Generator(MT19937(self.seed))
+ actual = random.permutation(a, axis=1)
+ assert_array_equal(actual, desired)
+ random = Generator(MT19937(self.seed))
+ actual = random.permutation(a, axis=-1)
+ assert_array_equal(actual, desired)
+
+ def test_permutation_exceptions(self):
+ random = Generator(MT19937(self.seed))
+ arr = np.arange(10)
+ assert_raises(np.AxisError, random.permutation, arr, 1)
+ arr = np.arange(9).reshape((3, 3))
+ assert_raises(np.AxisError, random.permutation, arr, 3)
+ assert_raises(TypeError, random.permutation, arr, slice(1, 2, None))
+
def test_beta(self):
random = Generator(MT19937(self.seed))
actual = random.beta(.1, .9, size=(3, 2))