summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorKevin Sheppard <kevin.k.sheppard@gmail.com>2021-12-22 12:26:04 +0000
committerKevin Sheppard <kevin.k.sheppard@gmail.com>2021-12-22 13:28:12 +0000
commit84fb5f4f28fe115f9e354113c6c9a2b69d4142d3 (patch)
tree1bfb48392a0e1e10b1ec215519e9220d95f2aed7 /numpy
parent6d5bc233442c034d54e0c0b15fdd7ed27e36fcac (diff)
downloadnumpy-84fb5f4f28fe115f9e354113c6c9a2b69d4142d3.tar.gz
PERF: Optimize array check for bounded 0,1 values
Optimize frequent check for probabilities when they are doubles
Diffstat (limited to 'numpy')
-rw-r--r--numpy/random/_common.pyx23
1 files changed, 20 insertions, 3 deletions
diff --git a/numpy/random/_common.pyx b/numpy/random/_common.pyx
index 4c34c503c..d352598c8 100644
--- a/numpy/random/_common.pyx
+++ b/numpy/random/_common.pyx
@@ -5,6 +5,7 @@ from cpython cimport PyFloat_AsDouble
import sys
import numpy as np
cimport numpy as np
+cimport numpy.math as npmath
from libc.stdint cimport uintptr_t
@@ -358,6 +359,24 @@ cdef object float_fill_from_double(void *func, bitgen_t *state, object size, obj
out_array_data[i] = <float>random_func(state)
return out_array
+cdef int _check_array_cons_bounded_0_1(np.ndarray val, object name) except -1:
+ cdef double *val_data
+ cdef np.npy_intp i
+ cdef bint err = 0
+
+ if not np.PyArray_ISONESEGMENT(val) or np.PyArray_TYPE(val) != np.NPY_DOUBLE:
+ # slow path for non-contiguous arrays or any non-double dtypes
+ err = not np.all(np.greater_equal(val, 0)) or not np.all(np.less_equal(val, 1))
+ else:
+ val_data = <double *>np.PyArray_DATA(val)
+ for i in range(np.PyArray_SIZE(val)):
+ err = (not (val_data[i] >= 0)) or (not val_data[i] <= 1)
+ if err:
+ break
+ if err:
+ raise ValueError(f"{name} < 0, {name} > 1 or {name} contains NaNs")
+
+ return 0
cdef int check_array_constraint(np.ndarray val, object name, constraint_type cons) except -1:
if cons == CONS_NON_NEGATIVE:
@@ -369,9 +388,7 @@ cdef int check_array_constraint(np.ndarray val, object name, constraint_type con
elif np.any(np.less_equal(val, 0)):
raise ValueError(name + " <= 0")
elif cons == CONS_BOUNDED_0_1:
- if not np.all(np.greater_equal(val, 0)) or \
- not np.all(np.less_equal(val, 1)):
- raise ValueError("{0} < 0, {0} > 1 or {0} contains NaNs".format(name))
+ return _check_array_cons_bounded_0_1(val, name)
elif cons == CONS_BOUNDED_GT_0_1:
if not np.all(np.greater(val, 0)) or not np.all(np.less_equal(val, 1)):
raise ValueError("{0} <= 0, {0} > 1 or {0} contains NaNs".format(name))