diff options
author | Robert Kern <robert.kern@gmail.com> | 2008-08-24 23:22:11 +0000 |
---|---|---|
committer | Robert Kern <robert.kern@gmail.com> | 2008-08-24 23:22:11 +0000 |
commit | 6f33684d500a6be10c47ca201cfe404a863f4abe (patch) | |
tree | 9905b1924a70621c6841403c9e4bb02ddf054795 /numpy/random | |
parent | 124149c8426556707a8fde26490a30f402f3e5b2 (diff) | |
download | numpy-6f33684d500a6be10c47ca201cfe404a863f4abe.tar.gz |
BUG: Logarithmic series needs to exclude p==0 and p==1. When the conversion of the result to C longs gives a negative number (i.e. out of bounds), reject the sample and try again until we do get something in bounds.
Diffstat (limited to 'numpy/random')
-rw-r--r-- | numpy/random/mtrand/distributions.c | 31 | ||||
-rw-r--r-- | numpy/random/mtrand/mtrand.c | 21 | ||||
-rw-r--r-- | numpy/random/mtrand/mtrand.pyx | 16 |
3 files changed, 43 insertions, 25 deletions
diff --git a/numpy/random/mtrand/distributions.c b/numpy/random/mtrand/distributions.c index 364af3da0..8cd508f7a 100644 --- a/numpy/random/mtrand/distributions.c +++ b/numpy/random/mtrand/distributions.c @@ -848,14 +848,29 @@ double rk_triangular(rk_state *state, double left, double mode, double right) long rk_logseries(rk_state *state, double p) { double q, r, U, V; + long result; r = log(1.0 - p); - - V = rk_double(state); - if (V >= p) return 1; - U = rk_double(state); - q = 1.0 - exp(r*U); - if (V <= q*q) return (long)floor(1 + log(V)/log(q)); - if (V <= q) return 1; - return 2; + + while (1) { + V = rk_double(state); + if (V >= p) { + return 1; + } + U = rk_double(state); + q = 1.0 - exp(r*U); + if (V <= q*q) { + result = (long)floor(1 + log(V)/log(q)); + if (result < 1) { + continue; + } + else { + return result; + } + } + if (V <= q) { + return 1; + } + return 2; + } } diff --git a/numpy/random/mtrand/mtrand.c b/numpy/random/mtrand/mtrand.c index 69756e38e..96812f620 100644 --- a/numpy/random/mtrand/mtrand.c +++ b/numpy/random/mtrand/mtrand.c @@ -1,4 +1,4 @@ -/* Generated by Pyrex 0.9.6.4 on Fri Aug 22 22:54:35 2008 */ +/* Generated by Pyrex 0.9.6.4 on Sun Aug 24 16:14:30 2008 */ #define PY_SSIZE_T_CLEAN #include "Python.h" @@ -8171,15 +8171,17 @@ static PyObject *__pyx_f_6mtrand_11RandomState_hypergeometric(PyObject *__pyx_v_ return __pyx_r; } +static PyObject *__pyx_n_greater_equal; + static PyObject *__pyx_k162p; static PyObject *__pyx_k163p; static PyObject *__pyx_k164p; static PyObject *__pyx_k165p; -static char __pyx_k162[] = "p < 0.0"; -static char __pyx_k163[] = "p > 1.0"; -static char __pyx_k164[] = "p < 0.0"; -static char __pyx_k165[] = "p > 1.0"; +static char __pyx_k162[] = "p <= 0.0"; +static char __pyx_k163[] = "p >= 1.0"; +static char __pyx_k164[] = "p <= 0.0"; +static char __pyx_k165[] = "p >= 1.0"; static PyObject *__pyx_f_6mtrand_11RandomState_logseries(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ static char __pyx_doc_6mtrand_11RandomState_logseries[] = "\n logseries(p, size=None)\n\n Logarithmic series distribution.\n\n "; @@ -8210,7 +8212,7 @@ static PyObject *__pyx_f_6mtrand_11RandomState_logseries(PyObject *__pyx_v_self, if (__pyx_1) { /* "/Users/rkern/svn/numpy/numpy/random/mtrand/mtrand.pyx":2435 */ - __pyx_1 = (__pyx_v_fp < 0.0); + __pyx_1 = (__pyx_v_fp <= 0.0); if (__pyx_1) { __pyx_2 = __Pyx_GetName(__pyx_b, __pyx_n_ValueError); if (!__pyx_2) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2436; goto __pyx_L1;} __pyx_3 = PyTuple_New(1); if (!__pyx_3) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2436; goto __pyx_L1;} @@ -8227,7 +8229,7 @@ static PyObject *__pyx_f_6mtrand_11RandomState_logseries(PyObject *__pyx_v_self, __pyx_L3:; /* "/Users/rkern/svn/numpy/numpy/random/mtrand/mtrand.pyx":2437 */ - __pyx_1 = (__pyx_v_fp > 1.0); + __pyx_1 = (__pyx_v_fp >= 1.0); if (__pyx_1) { __pyx_2 = __Pyx_GetName(__pyx_b, __pyx_n_ValueError); if (!__pyx_2) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2438; goto __pyx_L1;} __pyx_3 = PyTuple_New(1); if (!__pyx_3) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2438; goto __pyx_L1;} @@ -8267,7 +8269,7 @@ static PyObject *__pyx_f_6mtrand_11RandomState_logseries(PyObject *__pyx_v_self, __pyx_2 = PyObject_GetAttr(__pyx_4, __pyx_n_any); if (!__pyx_2) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2444; goto __pyx_L1;} Py_DECREF(__pyx_4); __pyx_4 = 0; __pyx_3 = __Pyx_GetName(__pyx_m, __pyx_n_np); if (!__pyx_3) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2444; goto __pyx_L1;} - __pyx_4 = PyObject_GetAttr(__pyx_3, __pyx_n_less); if (!__pyx_4) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2444; goto __pyx_L1;} + __pyx_4 = PyObject_GetAttr(__pyx_3, __pyx_n_less_equal); if (!__pyx_4) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2444; goto __pyx_L1;} Py_DECREF(__pyx_3); __pyx_3 = 0; __pyx_3 = PyFloat_FromDouble(0.0); if (!__pyx_3) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2444; goto __pyx_L1;} __pyx_5 = PyTuple_New(2); if (!__pyx_5) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2444; goto __pyx_L1;} @@ -8306,7 +8308,7 @@ static PyObject *__pyx_f_6mtrand_11RandomState_logseries(PyObject *__pyx_v_self, __pyx_3 = PyObject_GetAttr(__pyx_5, __pyx_n_any); if (!__pyx_3) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2446; goto __pyx_L1;} Py_DECREF(__pyx_5); __pyx_5 = 0; __pyx_2 = __Pyx_GetName(__pyx_m, __pyx_n_np); if (!__pyx_2) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2446; goto __pyx_L1;} - __pyx_4 = PyObject_GetAttr(__pyx_2, __pyx_n_greater); if (!__pyx_4) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2446; goto __pyx_L1;} + __pyx_4 = PyObject_GetAttr(__pyx_2, __pyx_n_greater_equal); if (!__pyx_4) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2446; goto __pyx_L1;} Py_DECREF(__pyx_2); __pyx_2 = 0; __pyx_5 = PyFloat_FromDouble(1.0); if (!__pyx_5) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2446; goto __pyx_L1;} __pyx_2 = PyTuple_New(2); if (!__pyx_2) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 2446; goto __pyx_L1;} @@ -9461,6 +9463,7 @@ static __Pyx_InternTabEntry __pyx_intern_tab[] = { {&__pyx_n_geometric, "geometric"}, {&__pyx_n_get_state, "get_state"}, {&__pyx_n_greater, "greater"}, + {&__pyx_n_greater_equal, "greater_equal"}, {&__pyx_n_gumbel, "gumbel"}, {&__pyx_n_hypergeometric, "hypergeometric"}, {&__pyx_n_int, "int"}, diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx index 5dda95ad8..1dd8b9f05 100644 --- a/numpy/random/mtrand/mtrand.pyx +++ b/numpy/random/mtrand/mtrand.pyx @@ -2432,19 +2432,19 @@ cdef class RandomState: fp = PyFloat_AsDouble(p) if not PyErr_Occurred(): - if fp < 0.0: - raise ValueError("p < 0.0") - if fp > 1.0: - raise ValueError("p > 1.0") + if fp <= 0.0: + raise ValueError("p <= 0.0") + if fp >= 1.0: + raise ValueError("p >= 1.0") return discd_array_sc(self.internal_state, rk_logseries, size, fp) PyErr_Clear() op = <ndarray>PyArray_FROM_OTF(p, NPY_DOUBLE, NPY_ALIGNED) - if np.any(np.less(op, 0.0)): - raise ValueError("p < 0.0") - if np.any(np.greater(op, 1.0)): - raise ValueError("p > 1.0") + if np.any(np.less_equal(op, 0.0)): + raise ValueError("p <= 0.0") + if np.any(np.greater_equal(op, 1.0)): + raise ValueError("p >= 1.0") return discd_array(self.internal_state, rk_logseries, size, op) # Multivariate distributions: |