summaryrefslogtreecommitdiff
path: root/numpy/random/tests
diff options
context:
space:
mode:
authorChris Jordan-Squire <cjordan1@uw.edu>2011-09-01 15:14:05 -0500
committerCharles Harris <charlesr.harris@gmail.com>2011-12-16 20:35:44 -0700
commit8b6c850cf3e6ec664705e9823014e59edc1a26e7 (patch)
treed9e0a95456fffbf978da84edd83ea11b13f69a22 /numpy/random/tests
parent3d0b348450addcc33bd30e9c0b3ea5b10106ab4d (diff)
downloadnumpy-8b6c850cf3e6ec664705e9823014e59edc1a26e7.tar.gz
ENH: New sample function, bugs in tests fixed
Diffstat (limited to 'numpy/random/tests')
-rw-r--r--numpy/random/tests/test_random.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index 2782cd91d..25277b6ba 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -116,6 +116,51 @@ class TestRandomDist(TestCase):
[ 0.4575674820298663 , 0.7781880808593471 ]])
np.testing.assert_array_almost_equal(actual, desired, decimal=15)
+ def test_sample_uniform_replace(self):
+ np.random.seed(self.seed)
+ actual = np.random.sample(4, 4)
+ desired = np.array([2, 3, 2, 3])
+ np.testing.assert_array_equal(actual, desired)
+
+ def test_sample_nonuniform_replace(self):
+ np.random.seed(self.seed)
+ actual = np.random.sample(4, 4, p=[0.4, 0.4, 0.1, 0.1])
+ desired = np.array([1, 0, 3, 0])
+ np.testing.assert_array_equal(actual, desired)
+
+ def test_sample_uniform_noreplace(self):
+ np.random.seed(self.seed)
+ actual = np.random.sample(4, 3, replace=False)
+ desired = np.array([0, 1, 3])
+ np.testing.assert_array_equal(actual, desired)
+
+ def test_sample_nonuniform_noreplace(self):
+ np.random.seed(self.seed)
+ actual = np.random.sample(4, 3, replace=False,
+ p=[0.1, 0.3, 0.5, 0.1])
+ desired = np.array([2, 1, 3])
+ np.testing.assert_array_equal(actual, desired)
+
+ def test_sample_noninteger(self):
+ np.random.seed(self.seed)
+ actual = np.random.sample(['a', 'b', 'c', 'd'], 4)
+ desired = np.array(['c', 'd', 'c', 'd'])
+ np.testing.assert_array_equal(actual, desired)
+
+ def test_sample_exceptions(self):
+ sample = np.random.sample
+ assert_raises(ValueError, sample, -1,3)
+ assert_raises(ValueError, sample, [[1,2],[3,4]], 3)
+ assert_raises(ValueError, sample, [], 3)
+ assert_raises(ValueError, sample, [1,2,3,4], 3,
+ p=[[0.25,0.25],[0.25,0.25]])
+ assert_raises(ValueError, sample, [1,2], 3, p=[0.4,0.4,0.2])
+ assert_raises(ValueError, sample, [1,2], 3, p=[1.1,-0.1])
+ assert_raises(ValueError, sample, [1,2], 3, p=[0.4,0.4])
+ assert_raises(ValueError, sample, [1,2,3], 4, replace=False)
+ assert_raises(ValueError, sample, [1,2,3], 2, replace=False,
+ p=[1,0,0])
+
def test_bytes(self):
np.random.seed(self.seed)
actual = np.random.bytes(10)