summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/random/_generator.pyx4
-rw-r--r--numpy/random/mtrand.pyx4
-rw-r--r--numpy/random/tests/test_generator_mt19937.py6
-rw-r--r--numpy/random/tests/test_random.py6
4 files changed, 16 insertions, 4 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index fdc6c3be0..ed5137a46 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -3641,12 +3641,12 @@ cdef class Generator:
cdef int64_t ni
cdef np.broadcast it
- if np.ndim(pvals) != 1:
- raise ValueError("pvals must be 1d array")
d = len(pvals)
on = <np.ndarray>np.PyArray_FROM_OTF(n, np.NPY_INT64, np.NPY_ALIGNED)
parr = <np.ndarray>np.PyArray_FROM_OTF(
pvals, np.NPY_DOUBLE, np.NPY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS)
+ if np.PyArray_NDIM(parr) != 1:
+ raise ValueError("pvals must be 1d array")
pix = <double*>np.PyArray_DATA(parr)
check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1)
if kahan_sum(pix, d-1) > (1.0 + 1e-12):
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx
index 9eabfd848..6d36dda36 100644
--- a/numpy/random/mtrand.pyx
+++ b/numpy/random/mtrand.pyx
@@ -4198,11 +4198,11 @@ cdef class RandomState:
cdef long *mnix
cdef long ni
- if np.ndim(pvals) != 1:
- raise ValueError("pvals must be 1d array")
d = len(pvals)
parr = <np.ndarray>np.PyArray_FROM_OTF(
pvals, np.NPY_DOUBLE, np.NPY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS)
+ if np.PyArray_NDIM(parr) != 1:
+ raise ValueError("pvals must be 1d array")
pix = <double*>np.PyArray_DATA(parr)
check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1)
if kahan_sum(pix, d-1) > (1.0 + 1e-12):
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 6f4407373..2a30d39e2 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -115,6 +115,12 @@ class TestMultinomial:
contig = random.multinomial(100, pvals=np.ascontiguousarray(pvals))
assert_array_equal(non_contig, contig)
+ def test_multidimensional_pvals(self):
+ assert_raises(ValueError, random.multinomial, 10, [[0, 1], [1, 0]])
+ assert_raises(ValueError, random.multinomial, 10, [[0], [1]])
+ assert_raises(ValueError, random.multinomial, 10, [[[0], [1]], [[1], [0]]])
+ assert_raises(ValueError, random.multinomial, 10, np.array([[0, 1], [1, 0]]))
+
class TestMultivariateHypergeometric:
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index a9aa15083..935afbf71 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -91,6 +91,12 @@ class TestMultinomial:
assert_raises(TypeError, np.random.multinomial, 1, p,
float(1))
+ def test_multidimensional_pvals(self):
+ assert_raises(ValueError, np.random.multinomial, 10, [[0, 1], [1, 0]])
+ assert_raises(ValueError, np.random.multinomial, 10, [[0], [1]])
+ assert_raises(ValueError, np.random.multinomial, 10, [[[0], [1]], [[1], [0]]])
+ assert_raises(ValueError, np.random.multinomial, 10, np.array([[0, 1], [1, 0]]))
+
class TestSetState:
def setup(self):