summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-09-20 19:24:59 -0500
committerGitHub <noreply@github.com>2016-09-20 19:24:59 -0500
commit866bc024ab4f23ae96f0d1370c348eccbe543f53 (patch)
treef2506ce8fc900afecd9c58e6548cd70cbd6d8e87
parent55ece5839d3e9327de7a23ed346a12ec9526899d (diff)
parent7fdfa6b2c0a5da5115c523c6869898873050e41c (diff)
downloadnumpy-866bc024ab4f23ae96f0d1370c348eccbe543f53.tar.gz
Merge pull request #8071 from gfyoung/randint-tempita
MAINT: Add Tempita to randint helpers
-rw-r--r--.travis.yml1
-rw-r--r--numpy/random/mtrand/mtrand.pyx1
-rw-r--r--numpy/random/mtrand/randint_helpers.pxi.in77
-rwxr-xr-xtools/cythonize.py57
4 files changed, 124 insertions, 12 deletions
diff --git a/.travis.yml b/.travis.yml
index 7d503b8c5..47a6bc171 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -52,6 +52,7 @@ matrix:
apt:
packages:
- *common_packages
+ - cython3-dbg
- python3-dbg
- python3-dev
- python3-nose
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx
index 138f0e39a..eab8e59b3 100644
--- a/numpy/random/mtrand/mtrand.pyx
+++ b/numpy/random/mtrand/mtrand.pyx
@@ -22,6 +22,7 @@
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
include "Python.pxi"
+include "randint_helpers.pxi"
include "numpy.pxd"
include "cpython/pycapsule.pxd"
diff --git a/numpy/random/mtrand/randint_helpers.pxi.in b/numpy/random/mtrand/randint_helpers.pxi.in
new file mode 100644
index 000000000..4bd7cd356
--- /dev/null
+++ b/numpy/random/mtrand/randint_helpers.pxi.in
@@ -0,0 +1,77 @@
+"""
+Template for each `dtype` helper function in `np.random.randint`.
+"""
+
+{{py:
+
+dtypes = (
+ ('bool', 'bool', 'bool_'),
+ ('int8', 'uint8', 'int8'),
+ ('int16', 'uint16', 'int16'),
+ ('int32', 'uint32', 'int32'),
+ ('int64', 'uint64', 'int64'),
+ ('uint8', 'uint8', 'uint8'),
+ ('uint16', 'uint16', 'uint16'),
+ ('uint32', 'uint32', 'uint32'),
+ ('uint64', 'uint64', 'uint64'),
+)
+
+def get_dispatch(dtypes):
+ for npy_dt, npy_udt, np_dt in dtypes:
+ yield npy_dt, npy_udt, np_dt
+}}
+
+{{for npy_dt, npy_udt, np_dt in get_dispatch(dtypes)}}
+
+def _rand_{{npy_dt}}(low, high, size, rngstate):
+ """
+ _rand_{{npy_dt}}(low, high, size, rngstate)
+
+ Return random np.{{np_dt}} integers between ``low`` and ``high``, inclusive.
+
+ Return random integers from the "discrete uniform" distribution in the
+ closed interval [``low``, ``high``). On entry the arguments are presumed
+ to have been validated for size and order for the np.{{np_dt}} type.
+
+ Parameters
+ ----------
+ low : int
+ Lowest (signed) integer to be drawn from the distribution.
+ high : int
+ Highest (signed) integer to be drawn from the distribution.
+ size : int or tuple of ints
+ Output shape. If the given shape is, e.g., ``(m, n, k)``, then
+ ``m * n * k`` samples are drawn. Default is None, in which case a
+ single value is returned.
+ rngstate : encapsulated pointer to rk_state
+ The specific type depends on the python version. In Python 2 it is
+ a PyCObject, in Python 3 a PyCapsule object.
+
+ Returns
+ -------
+ out : python integer or ndarray of np.{{np_dt}}
+ `size`-shaped array of random integers from the appropriate
+ distribution, or a single such random int if `size` not provided.
+
+ """
+ cdef npy_{{npy_udt}} off, rng, buf
+ cdef npy_{{npy_udt}} *out
+ cdef ndarray array "arrayObject"
+ cdef npy_intp cnt
+ cdef rk_state *state = <rk_state *>PyCapsule_GetPointer(rngstate, NULL)
+
+ rng = <npy_{{npy_udt}}>(high - low)
+ off = <npy_{{npy_udt}}>(<npy_{{npy_dt}}>low)
+
+ if size is None:
+ rk_random_{{npy_udt}}(off, rng, 1, &buf, state)
+ return np.{{np_dt}}(<npy_{{npy_dt}}>buf)
+ else:
+ array = <ndarray>np.empty(size, np.{{np_dt}})
+ cnt = PyArray_SIZE(array)
+ array_data = <npy_{{npy_udt}} *>PyArray_DATA(array)
+ with nogil:
+ rk_random_{{npy_udt}}(off, rng, cnt, array_data, state)
+ return array
+
+{{endfor}}
diff --git a/tools/cythonize.py b/tools/cythonize.py
index 2db0cbd52..1085f2a91 100755
--- a/tools/cythonize.py
+++ b/tools/cythonize.py
@@ -100,6 +100,24 @@ def process_tempita_pyx(fromfile, tofile):
f.write(pyxcontent)
process_pyx(pyxfile, tofile)
+
+def process_tempita_pxi(fromfile, tofile):
+ try:
+ try:
+ from Cython import Tempita as tempita
+ except ImportError:
+ import tempita
+ except ImportError:
+ raise Exception('Building %s requires Tempita: '
+ 'pip install --user Tempita' % VENDOR)
+ assert fromfile.endswith('.pxi.in')
+ assert tofile.endswith('.pxi')
+ with open(fromfile, "r") as f:
+ tmpl = f.read()
+ pyxcontent = tempita.sub(tmpl)
+ with open(tofile, "w") as f:
+ f.write(pyxcontent)
+
rules = {
# fromext : function
'.pyx' : process_pyx,
@@ -170,22 +188,37 @@ def process(path, fromfile, tofile, processor_function, hash_db):
def find_process_files(root_dir):
hash_db = load_hashes(HASH_FILE)
for cur_dir, dirs, files in os.walk(root_dir):
+ # .pxi or .pxi.in files are most likely dependencies for
+ # .pyx files, so we need to process them first
+ files.sort(key=lambda name: (name.endswith('.pxi') or
+ name.endswith('.pxi.in')),
+ reverse=True)
+
for filename in files:
in_file = os.path.join(cur_dir, filename + ".in")
if filename.endswith('.pyx') and os.path.isfile(in_file):
continue
- for fromext, function in rules.items():
- if filename.endswith(fromext):
- toext = ".c"
- with open(os.path.join(cur_dir, filename), 'rb') as f:
- data = f.read()
- m = re.search(br"^\s*#\s*distutils:\s*language\s*=\s*c\+\+\s*$", data, re.I|re.M)
- if m:
- toext = ".cxx"
- fromfile = filename
- tofile = filename[:-len(fromext)] + toext
- process(cur_dir, fromfile, tofile, function, hash_db)
- save_hashes(hash_db, HASH_FILE)
+ elif filename.endswith('.pxi.in'):
+ toext = '.pxi'
+ fromext = '.pxi.in'
+ fromfile = filename
+ function = process_tempita_pxi
+ tofile = filename[:-len(fromext)] + toext
+ process(cur_dir, fromfile, tofile, function, hash_db)
+ save_hashes(hash_db, HASH_FILE)
+ else:
+ for fromext, function in rules.items():
+ if filename.endswith(fromext):
+ toext = ".c"
+ with open(os.path.join(cur_dir, filename), 'rb') as f:
+ data = f.read()
+ m = re.search(br"^\s*#\s*distutils:\s*language\s*=\s*c\+\+\s*$", data, re.I|re.M)
+ if m:
+ toext = ".cxx"
+ fromfile = filename
+ tofile = filename[:-len(fromext)] + toext
+ process(cur_dir, fromfile, tofile, function, hash_db)
+ save_hashes(hash_db, HASH_FILE)
def main():
try: