summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2020-07-14 22:41:48 +0300
committerGitHub <noreply@github.com>2020-07-14 22:41:48 +0300
commitf8ba31164868eae4c1b4ce78d7768dee1d58b443 (patch)
tree2eae8341c782e8545ec95e9612b50f4b922dc1ac
parent0f7e8316be0523c0e08f746aaf3ad53aca0c01a8 (diff)
parent89ecc557fc9205e06c9816ce2122f42e3adced02 (diff)
downloadnumpy-f8ba31164868eae4c1b4ce78d7768dee1d58b443.tar.gz
Merge pull request #16868 from bashtage/check-output-size-omial
BUG: Validate output size in bin- and multinomial
-rw-r--r--numpy/random/_common.pxd2
-rw-r--r--numpy/random/_generator.pyx5
-rw-r--r--numpy/random/mtrand.pyx2
-rw-r--r--numpy/random/tests/test_generator_mt19937.py10
-rw-r--r--numpy/random/tests/test_randomstate.py10
5 files changed, 28 insertions, 1 deletions
diff --git a/numpy/random/_common.pxd b/numpy/random/_common.pxd
index 588f613ae..4f404b7a1 100644
--- a/numpy/random/_common.pxd
+++ b/numpy/random/_common.pxd
@@ -77,6 +77,8 @@ cdef object wrap_int(object val, object bits)
cdef np.ndarray int_to_array(object value, object name, object bits, object uint_size)
+cdef validate_output_shape(iter_shape, np.ndarray output)
+
cdef object cont(void *func, void *state, object size, object lock, int narg,
object a, object a_name, constraint_type a_constraint,
object b, object b_name, constraint_type b_constraint,
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index cc2852da7..66847043b 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -25,6 +25,7 @@ from ._common cimport (POISSON_LAM_MAX, CONS_POSITIVE, CONS_NONE,
CONS_GT_1, CONS_POSITIVE_NOT_NAN, CONS_POISSON,
double_fill, cont, kahan_sum, cont_broadcast_3, float_fill, cont_f,
check_array_constraint, check_constraint, disc, discrete_broadcast_iii,
+ validate_output_shape
)
np.import_array()
@@ -2809,6 +2810,7 @@ cdef class Generator:
cnt = np.PyArray_SIZE(randoms)
it = np.PyArray_MultiIterNew3(randoms, p_arr, n_arr)
+ validate_output_shape(it.shape, randoms)
with self.lock, nogil:
for i in range(cnt):
_dp = (<double*>np.PyArray_MultiIter_DATA(it, 1))[0]
@@ -3606,7 +3608,7 @@ cdef class Generator:
Now, do one experiment throwing the dice 10 time, and 10 times again,
and another throwing the dice 20 times, and 20 times again:
- >>> rng.multinomial([[10], [20]], [1/6.]*6, size=2)
+ >>> rng.multinomial([[10], [20]], [1/6.]*6, size=(2, 2))
array([[[2, 4, 0, 1, 2, 1],
[1, 3, 0, 3, 1, 2]],
[[1, 4, 4, 4, 4, 3],
@@ -3661,6 +3663,7 @@ cdef class Generator:
temp = np.empty(size, dtype=np.int8)
temp_arr = <np.ndarray>temp
it = np.PyArray_MultiIterNew2(on, temp_arr)
+ validate_output_shape(it.shape, temp_arr)
shape = it.shape + (d,)
multin = np.zeros(shape, dtype=np.int64)
mnarr = <np.ndarray>multin
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx
index 8820a6e09..df305e689 100644
--- a/numpy/random/mtrand.pyx
+++ b/numpy/random/mtrand.pyx
@@ -22,6 +22,7 @@ from ._common cimport (POISSON_LAM_MAX, CONS_POSITIVE, CONS_NONE,
CONS_GT_1, LEGACY_CONS_POISSON,
double_fill, cont, kahan_sum, cont_broadcast_3,
check_array_constraint, check_constraint, disc, discrete_broadcast_iii,
+ validate_output_shape
)
cdef extern from "numpy/random/distributions.h":
@@ -3374,6 +3375,7 @@ cdef class RandomState:
cnt = np.PyArray_SIZE(randoms)
it = np.PyArray_MultiIterNew3(randoms, p_arr, n_arr)
+ validate_output_shape(it.shape, randoms)
with self.lock, nogil:
for i in range(cnt):
_dp = (<double*>np.PyArray_MultiIter_DATA(it, 1))[0]
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 332b63198..bb6d25ef1 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -2423,6 +2423,16 @@ def test_broadcast_size_error():
with pytest.raises(ValueError):
random.standard_gamma(shape, out=out)
+ # 2 arg
+ with pytest.raises(ValueError):
+ random.binomial(1, [0.3, 0.7], size=(2, 1))
+ with pytest.raises(ValueError):
+ random.binomial([1, 2], 0.3, size=(2, 1))
+ with pytest.raises(ValueError):
+ random.binomial([1, 2], [0.3, 0.7], size=(2, 1))
+ with pytest.raises(ValueError):
+ random.multinomial([2, 2], [.3, .7], size=(2, 1))
+
# 3 arg
a = random.chisquare(5, size=3)
b = random.chisquare(5, size=(4, 3))
diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py
index edd7811bf..23dbbed6a 100644
--- a/numpy/random/tests/test_randomstate.py
+++ b/numpy/random/tests/test_randomstate.py
@@ -1989,3 +1989,13 @@ def test_integer_repeat(int_func):
val = val.byteswap()
res = hashlib.md5(val.view(np.int8)).hexdigest()
assert_(res == md5)
+
+
+def test_broadcast_size_error():
+ # GH-16833
+ with pytest.raises(ValueError):
+ random.binomial(1, [0.3, 0.7], size=(2, 1))
+ with pytest.raises(ValueError):
+ random.binomial([1, 2], 0.3, size=(2, 1))
+ with pytest.raises(ValueError):
+ random.binomial([1, 2], [0.3, 0.7], size=(2, 1))