summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPan Jan <rumcajsgajos@gmail.com>2020-03-31 11:10:15 +0200
committerPan Jan <rumcajsgajos@gmail.com>2020-03-31 11:13:55 +0200
commit385a969d02e5e80383826d0a0b9f3c00a5e8d1f9 (patch)
treee8b3cd930d8c8e004746920b28e1a0a91516a6ec
parentf0a74b2d4e50a1b4f9d9189c6b2e31b920913a8b (diff)
downloadnumpy-385a969d02e5e80383826d0a0b9f3c00a5e8d1f9.tar.gz
BUG: add check if pvals is 1d array in numpy.random.*.multinomial
-rw-r--r--numpy/random/_generator.pyx2
-rw-r--r--numpy/random/mtrand.pyx2
2 files changed, 4 insertions, 0 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index 7766f8b8c..fdc6c3be0 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -3641,6 +3641,8 @@ cdef class Generator:
cdef int64_t ni
cdef np.broadcast it
+ if np.ndim(pvals) != 1:
+ raise ValueError("pvals must be 1d array")
d = len(pvals)
on = <np.ndarray>np.PyArray_FROM_OTF(n, np.NPY_INT64, np.NPY_ALIGNED)
parr = <np.ndarray>np.PyArray_FROM_OTF(
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx
index 6387814c8..9eabfd848 100644
--- a/numpy/random/mtrand.pyx
+++ b/numpy/random/mtrand.pyx
@@ -4198,6 +4198,8 @@ cdef class RandomState:
cdef long *mnix
cdef long ni
+ if np.ndim(pvals) != 1:
+ raise ValueError("pvals must be 1d array")
d = len(pvals)
parr = <np.ndarray>np.PyArray_FROM_OTF(
pvals, np.NPY_DOUBLE, np.NPY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS)