summaryrefslogtreecommitdiff
path: root/numpy/random
diff options
context:
space:
mode:
authorStefan van der Walt <stefan@sun.ac.za>2007-11-23 15:20:41 +0000
committerStefan van der Walt <stefan@sun.ac.za>2007-11-23 15:20:41 +0000
commita9f446da7d3613743a2f191446519eeef2d8f8d7 (patch)
treefe79e00ea209a117c89203ed80cf44778cabcbaa /numpy/random
parent8b60ca4182b76904e49fa80ab568d5f77cbe6d54 (diff)
downloadnumpy-a9f446da7d3613743a2f191446519eeef2d8f8d7.tar.gz
Fix randint for negative interval.
Diffstat (limited to 'numpy/random')
-rw-r--r--numpy/random/mtrand/mtrand.c4
-rw-r--r--numpy/random/mtrand/mtrand.pyx4
-rw-r--r--numpy/random/tests/test_random.py6
3 files changed, 10 insertions, 4 deletions
diff --git a/numpy/random/mtrand/mtrand.c b/numpy/random/mtrand/mtrand.c
index cc6c48b02..92f8cd75e 100644
--- a/numpy/random/mtrand/mtrand.c
+++ b/numpy/random/mtrand/mtrand.c
@@ -1,4 +1,4 @@
-/* Generated by Pyrex 0.9.5.1a on Wed Aug 29 00:25:24 2007 */
+/* Generated by Pyrex 0.9.5.1a on Fri Nov 23 17:16:35 2007 */
#include "Python.h"
#include "structmember.h"
@@ -2438,7 +2438,7 @@ static PyObject *__pyx_f_6mtrand_11RandomState_randint(PyObject *__pyx_v_self, P
if (__pyx_1) {
/* "/home/stefan/work/scipy/numpy.patch/numpy/random/mtrand/mtrand.pyx":603 */
- __pyx_3 = PyLong_FromUnsignedLong((rk_interval(__pyx_v_diff,((struct __pyx_obj_6mtrand_RandomState *)__pyx_v_self)->internal_state) + __pyx_v_lo)); if (!__pyx_3) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 603; goto __pyx_L1;}
+ __pyx_3 = PyInt_FromLong((((long )rk_interval(__pyx_v_diff,((struct __pyx_obj_6mtrand_RandomState *)__pyx_v_self)->internal_state)) + __pyx_v_lo)); if (!__pyx_3) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 603; goto __pyx_L1;}
__pyx_r = __pyx_3;
__pyx_3 = 0;
goto __pyx_L0;
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx
index 48a31d3ce..f069f11b9 100644
--- a/numpy/random/mtrand/mtrand.pyx
+++ b/numpy/random/mtrand/mtrand.pyx
@@ -598,9 +598,9 @@ cdef class RandomState:
diff = hi - lo - 1
if diff < 0:
raise ValueError("low >= high")
-
+
if size is None:
- return rk_interval(diff, self.internal_state) + lo
+ return <long>rk_interval(diff, self.internal_state) + lo
else:
array = <ndarray>_sp.empty(size, int)
length = PyArray_SIZE(array)
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index 9b9e5d828..633f6840f 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -9,5 +9,11 @@ class TestMultinomial(NumpyTestCase):
def test_zero_probability(self):
random.multinomial(100, [0.2, 0.8, 0.0, 0.0, 0.0])
+ def test_int_negative_interval(self):
+ assert -5 <= random.randint(-5,-1) < -1
+ x = random.randint(-5,-1,5)
+ assert N.all(-5 <= x)
+ assert N.all(x < -1)
+
if __name__ == "__main__":
NumpyTest().run()