diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2022-01-10 20:18:30 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-10 20:18:30 -0600 |
commit | b2a97ba5d257c4ea8312bb23e1c9a4b9d41336a9 (patch) | |
tree | 73238f35c8beca93c77b2327812cce45d251ec88 /numpy | |
parent | abd47fc824ca9cdd2c03063b5d294039a2f6f0d1 (diff) | |
parent | 84fb5f4f28fe115f9e354113c6c9a2b69d4142d3 (diff) | |
download | numpy-b2a97ba5d257c4ea8312bb23e1c9a4b9d41336a9.tar.gz |
Merge pull request #20643 from bashtage/experiment-array-cons-check
PERF: Optimize array check for bounded 0,1 values
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/random/_common.pyx | 22 |
1 files changed, 19 insertions, 3 deletions
diff --git a/numpy/random/_common.pyx b/numpy/random/_common.pyx index 86828e200..864150458 100644 --- a/numpy/random/_common.pyx +++ b/numpy/random/_common.pyx @@ -359,6 +359,24 @@ cdef object float_fill_from_double(void *func, bitgen_t *state, object size, obj out_array_data[i] = <float>random_func(state) return out_array +cdef int _check_array_cons_bounded_0_1(np.ndarray val, object name) except -1: + cdef double *val_data + cdef np.npy_intp i + cdef bint err = 0 + + if not np.PyArray_ISONESEGMENT(val) or np.PyArray_TYPE(val) != np.NPY_DOUBLE: + # slow path for non-contiguous arrays or any non-double dtypes + err = not np.all(np.greater_equal(val, 0)) or not np.all(np.less_equal(val, 1)) + else: + val_data = <double *>np.PyArray_DATA(val) + for i in range(np.PyArray_SIZE(val)): + err = (not (val_data[i] >= 0)) or (not val_data[i] <= 1) + if err: + break + if err: + raise ValueError(f"{name} < 0, {name} > 1 or {name} contains NaNs") + + return 0 cdef int check_array_constraint(np.ndarray val, object name, constraint_type cons) except -1: if cons == CONS_NON_NEGATIVE: @@ -370,9 +388,7 @@ cdef int check_array_constraint(np.ndarray val, object name, constraint_type con elif np.any(np.less_equal(val, 0)): raise ValueError(name + " <= 0") elif cons == CONS_BOUNDED_0_1: - if not np.all(np.greater_equal(val, 0)) or \ - not np.all(np.less_equal(val, 1)): - raise ValueError("{0} < 0, {0} > 1 or {0} contains NaNs".format(name)) + return _check_array_cons_bounded_0_1(val, name) elif cons == CONS_BOUNDED_GT_0_1: if not np.all(np.greater(val, 0)) or not np.all(np.less_equal(val, 1)): raise ValueError("{0} <= 0, {0} > 1 or {0} contains NaNs".format(name)) |