summaryrefslogtreecommitdiff
path: root/docs/examples/tutorial/parallelization/median.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'docs/examples/tutorial/parallelization/median.pyx')
-rw-r--r--docs/examples/tutorial/parallelization/median.pyx34
1 files changed, 34 insertions, 0 deletions
diff --git a/docs/examples/tutorial/parallelization/median.pyx b/docs/examples/tutorial/parallelization/median.pyx
new file mode 100644
index 000000000..242cb6091
--- /dev/null
+++ b/docs/examples/tutorial/parallelization/median.pyx
@@ -0,0 +1,34 @@
+# distutils: language = c++
+
+from cython.parallel cimport parallel, prange
+from libcpp.vector cimport vector
+from libcpp.algorithm cimport nth_element
+cimport cython
+from cython.operator cimport dereference
+
+import numpy as np
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+def median_along_axis0(const double[:,:] x):
+ cdef double[::1] out = np.empty(x.shape[1])
+ cdef Py_ssize_t i, j
+
+ cdef vector[double] *scratch
+ cdef vector[double].iterator median_it
+ with nogil, parallel():
+ # allocate scratch space per loop
+ scratch = new vector[double](x.shape[0])
+ try:
+ for i in prange(x.shape[1]):
+ # copy row into scratch space
+ for j in range(x.shape[0]):
+ dereference(scratch)[j] = x[j, i]
+ median_it = scratch.begin() + scratch.size()//2
+ nth_element(scratch.begin(), median_it, scratch.end())
+ # for the sake of a simple example, don't handle even lengths...
+ out[i] = dereference(median_it)
+ finally:
+ del scratch
+ return np.asarray(out)
+