diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2016-09-20 19:24:59 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-09-20 19:24:59 -0500 |
commit | 866bc024ab4f23ae96f0d1370c348eccbe543f53 (patch) | |
tree | f2506ce8fc900afecd9c58e6548cd70cbd6d8e87 | |
parent | 55ece5839d3e9327de7a23ed346a12ec9526899d (diff) | |
parent | 7fdfa6b2c0a5da5115c523c6869898873050e41c (diff) | |
download | numpy-866bc024ab4f23ae96f0d1370c348eccbe543f53.tar.gz |
Merge pull request #8071 from gfyoung/randint-tempita
MAINT: Add Tempita to randint helpers
-rw-r--r-- | .travis.yml | 1 | ||||
-rw-r--r-- | numpy/random/mtrand/mtrand.pyx | 1 | ||||
-rw-r--r-- | numpy/random/mtrand/randint_helpers.pxi.in | 77 | ||||
-rwxr-xr-x | tools/cythonize.py | 57 |
4 files changed, 124 insertions, 12 deletions
diff --git a/.travis.yml b/.travis.yml index 7d503b8c5..47a6bc171 100644 --- a/.travis.yml +++ b/.travis.yml @@ -52,6 +52,7 @@ matrix: apt: packages: - *common_packages + - cython3-dbg - python3-dbg - python3-dev - python3-nose diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx index 138f0e39a..eab8e59b3 100644 --- a/numpy/random/mtrand/mtrand.pyx +++ b/numpy/random/mtrand/mtrand.pyx @@ -22,6 +22,7 @@ # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. include "Python.pxi" +include "randint_helpers.pxi" include "numpy.pxd" include "cpython/pycapsule.pxd" diff --git a/numpy/random/mtrand/randint_helpers.pxi.in b/numpy/random/mtrand/randint_helpers.pxi.in new file mode 100644 index 000000000..4bd7cd356 --- /dev/null +++ b/numpy/random/mtrand/randint_helpers.pxi.in @@ -0,0 +1,77 @@ +""" +Template for each `dtype` helper function in `np.random.randint`. +""" + +{{py: + +dtypes = ( + ('bool', 'bool', 'bool_'), + ('int8', 'uint8', 'int8'), + ('int16', 'uint16', 'int16'), + ('int32', 'uint32', 'int32'), + ('int64', 'uint64', 'int64'), + ('uint8', 'uint8', 'uint8'), + ('uint16', 'uint16', 'uint16'), + ('uint32', 'uint32', 'uint32'), + ('uint64', 'uint64', 'uint64'), +) + +def get_dispatch(dtypes): + for npy_dt, npy_udt, np_dt in dtypes: + yield npy_dt, npy_udt, np_dt +}} + +{{for npy_dt, npy_udt, np_dt in get_dispatch(dtypes)}} + +def _rand_{{npy_dt}}(low, high, size, rngstate): + """ + _rand_{{npy_dt}}(low, high, size, rngstate) + + Return random np.{{np_dt}} integers between ``low`` and ``high``, inclusive. + + Return random integers from the "discrete uniform" distribution in the + closed interval [``low``, ``high``). On entry the arguments are presumed + to have been validated for size and order for the np.{{np_dt}} type. + + Parameters + ---------- + low : int + Lowest (signed) integer to be drawn from the distribution. + high : int + Highest (signed) integer to be drawn from the distribution. + size : int or tuple of ints + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + rngstate : encapsulated pointer to rk_state + The specific type depends on the python version. In Python 2 it is + a PyCObject, in Python 3 a PyCapsule object. + + Returns + ------- + out : python integer or ndarray of np.{{np_dt}} + `size`-shaped array of random integers from the appropriate + distribution, or a single such random int if `size` not provided. + + """ + cdef npy_{{npy_udt}} off, rng, buf + cdef npy_{{npy_udt}} *out + cdef ndarray array "arrayObject" + cdef npy_intp cnt + cdef rk_state *state = <rk_state *>PyCapsule_GetPointer(rngstate, NULL) + + rng = <npy_{{npy_udt}}>(high - low) + off = <npy_{{npy_udt}}>(<npy_{{npy_dt}}>low) + + if size is None: + rk_random_{{npy_udt}}(off, rng, 1, &buf, state) + return np.{{np_dt}}(<npy_{{npy_dt}}>buf) + else: + array = <ndarray>np.empty(size, np.{{np_dt}}) + cnt = PyArray_SIZE(array) + array_data = <npy_{{npy_udt}} *>PyArray_DATA(array) + with nogil: + rk_random_{{npy_udt}}(off, rng, cnt, array_data, state) + return array + +{{endfor}} diff --git a/tools/cythonize.py b/tools/cythonize.py index 2db0cbd52..1085f2a91 100755 --- a/tools/cythonize.py +++ b/tools/cythonize.py @@ -100,6 +100,24 @@ def process_tempita_pyx(fromfile, tofile): f.write(pyxcontent) process_pyx(pyxfile, tofile) + +def process_tempita_pxi(fromfile, tofile): + try: + try: + from Cython import Tempita as tempita + except ImportError: + import tempita + except ImportError: + raise Exception('Building %s requires Tempita: ' + 'pip install --user Tempita' % VENDOR) + assert fromfile.endswith('.pxi.in') + assert tofile.endswith('.pxi') + with open(fromfile, "r") as f: + tmpl = f.read() + pyxcontent = tempita.sub(tmpl) + with open(tofile, "w") as f: + f.write(pyxcontent) + rules = { # fromext : function '.pyx' : process_pyx, @@ -170,22 +188,37 @@ def process(path, fromfile, tofile, processor_function, hash_db): def find_process_files(root_dir): hash_db = load_hashes(HASH_FILE) for cur_dir, dirs, files in os.walk(root_dir): + # .pxi or .pxi.in files are most likely dependencies for + # .pyx files, so we need to process them first + files.sort(key=lambda name: (name.endswith('.pxi') or + name.endswith('.pxi.in')), + reverse=True) + for filename in files: in_file = os.path.join(cur_dir, filename + ".in") if filename.endswith('.pyx') and os.path.isfile(in_file): continue - for fromext, function in rules.items(): - if filename.endswith(fromext): - toext = ".c" - with open(os.path.join(cur_dir, filename), 'rb') as f: - data = f.read() - m = re.search(br"^\s*#\s*distutils:\s*language\s*=\s*c\+\+\s*$", data, re.I|re.M) - if m: - toext = ".cxx" - fromfile = filename - tofile = filename[:-len(fromext)] + toext - process(cur_dir, fromfile, tofile, function, hash_db) - save_hashes(hash_db, HASH_FILE) + elif filename.endswith('.pxi.in'): + toext = '.pxi' + fromext = '.pxi.in' + fromfile = filename + function = process_tempita_pxi + tofile = filename[:-len(fromext)] + toext + process(cur_dir, fromfile, tofile, function, hash_db) + save_hashes(hash_db, HASH_FILE) + else: + for fromext, function in rules.items(): + if filename.endswith(fromext): + toext = ".c" + with open(os.path.join(cur_dir, filename), 'rb') as f: + data = f.read() + m = re.search(br"^\s*#\s*distutils:\s*language\s*=\s*c\+\+\s*$", data, re.I|re.M) + if m: + toext = ".cxx" + fromfile = filename + tofile = filename[:-len(fromext)] + toext + process(cur_dir, fromfile, tofile, function, hash_db) + save_hashes(hash_db, HASH_FILE) def main(): try: |