diff options
-rw-r--r-- | numpy/random/_generator.pyx | 4 | ||||
-rw-r--r-- | numpy/random/mtrand.pyx | 4 | ||||
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 6 | ||||
-rw-r--r-- | numpy/random/tests/test_random.py | 6 |
4 files changed, 16 insertions, 4 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx index fdc6c3be0..ed5137a46 100644 --- a/numpy/random/_generator.pyx +++ b/numpy/random/_generator.pyx @@ -3641,12 +3641,12 @@ cdef class Generator: cdef int64_t ni cdef np.broadcast it - if np.ndim(pvals) != 1: - raise ValueError("pvals must be 1d array") d = len(pvals) on = <np.ndarray>np.PyArray_FROM_OTF(n, np.NPY_INT64, np.NPY_ALIGNED) parr = <np.ndarray>np.PyArray_FROM_OTF( pvals, np.NPY_DOUBLE, np.NPY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS) + if np.PyArray_NDIM(parr) != 1: + raise ValueError("pvals must be 1d array") pix = <double*>np.PyArray_DATA(parr) check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1) if kahan_sum(pix, d-1) > (1.0 + 1e-12): diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx index 9eabfd848..6d36dda36 100644 --- a/numpy/random/mtrand.pyx +++ b/numpy/random/mtrand.pyx @@ -4198,11 +4198,11 @@ cdef class RandomState: cdef long *mnix cdef long ni - if np.ndim(pvals) != 1: - raise ValueError("pvals must be 1d array") d = len(pvals) parr = <np.ndarray>np.PyArray_FROM_OTF( pvals, np.NPY_DOUBLE, np.NPY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS) + if np.PyArray_NDIM(parr) != 1: + raise ValueError("pvals must be 1d array") pix = <double*>np.PyArray_DATA(parr) check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1) if kahan_sum(pix, d-1) > (1.0 + 1e-12): diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index 6f4407373..2a30d39e2 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -115,6 +115,12 @@ class TestMultinomial: contig = random.multinomial(100, pvals=np.ascontiguousarray(pvals)) assert_array_equal(non_contig, contig) + def test_multidimensional_pvals(self): + assert_raises(ValueError, random.multinomial, 10, [[0, 1], [1, 0]]) + assert_raises(ValueError, random.multinomial, 10, [[0], [1]]) + assert_raises(ValueError, random.multinomial, 10, [[[0], [1]], [[1], [0]]]) + assert_raises(ValueError, random.multinomial, 10, np.array([[0, 1], [1, 0]])) + class TestMultivariateHypergeometric: diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py index a9aa15083..935afbf71 100644 --- a/numpy/random/tests/test_random.py +++ b/numpy/random/tests/test_random.py @@ -91,6 +91,12 @@ class TestMultinomial: assert_raises(TypeError, np.random.multinomial, 1, p, float(1)) + def test_multidimensional_pvals(self): + assert_raises(ValueError, np.random.multinomial, 10, [[0, 1], [1, 0]]) + assert_raises(ValueError, np.random.multinomial, 10, [[0], [1]]) + assert_raises(ValueError, np.random.multinomial, 10, [[[0], [1]], [[1], [0]]]) + assert_raises(ValueError, np.random.multinomial, 10, np.array([[0, 1], [1, 0]])) + class TestSetState: def setup(self): |