diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2014-09-25 17:38:38 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2014-09-25 17:38:38 -0600 |
commit | 002b0de24d7c88cb97d16b9e8df6948fe4f07d5d (patch) | |
tree | 4b9c13e5255eea1c19e7e0211f983ef1610f9445 /numpy | |
parent | caeb88871516933b671e547bcaabde173a220664 (diff) | |
parent | c7fae327a79b95671833e641612ca87e1d6a6b48 (diff) | |
download | numpy-002b0de24d7c88cb97d16b9e8df6948fe4f07d5d.tar.gz |
Merge pull request #5122 from juliantaylor/small-dot
ENH: add small kernel correlate function
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/bscript | 2 | ||||
-rw-r--r-- | numpy/core/setup.py | 3 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 17 | ||||
-rw-r--r-- | numpy/core/src/multiarray/templ_common.h.src | 71 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 17 |
5 files changed, 105 insertions, 5 deletions
diff --git a/numpy/core/bscript b/numpy/core/bscript index da9ff2799..af947b305 100644 --- a/numpy/core/bscript +++ b/numpy/core/bscript @@ -440,6 +440,7 @@ def pre_build(context): "src/multiarray/arraytypes.c.src", "src/multiarray/nditer_templ.c.src", "src/multiarray/lowlevel_strided_loops.c.src", + "src/multiarray/templ_common.h.src", "src/multiarray/einsum.c.src"] bld(target="multiarray_templates", source=multiarray_templates) if ENABLE_SEPARATE_COMPILATION: @@ -453,6 +454,7 @@ def pre_build(context): pjoin('src', 'multiarray', 'buffer.c'), pjoin('src', 'multiarray', 'calculation.c'), pjoin('src', 'multiarray', 'common.c'), + pjoin('src', 'multiarray', 'templ_common.h.src'), pjoin('src', 'multiarray', 'conversion_utils.c'), pjoin('src', 'multiarray', 'convert.c'), pjoin('src', 'multiarray', 'convert_datatype.c'), diff --git a/numpy/core/setup.py b/numpy/core/setup.py index 2da2e837a..3ad912f40 100644 --- a/numpy/core/setup.py +++ b/numpy/core/setup.py @@ -751,6 +751,7 @@ def configuration(parent_package='',top_path=None): join('src', 'multiarray', 'buffer.h'), join('src', 'multiarray', 'calculation.h'), join('src', 'multiarray', 'common.h'), + join('src', 'multiarray', 'templ_common.h.src'), join('src', 'multiarray', 'convert_datatype.h'), join('src', 'multiarray', 'convert.h'), join('src', 'multiarray', 'conversion_utils.h'), @@ -826,6 +827,7 @@ def configuration(parent_package='',top_path=None): join('src', 'multiarray', 'mapping.c'), join('src', 'multiarray', 'methods.c'), join('src', 'multiarray', 'multiarraymodule.c'), + join('src', 'multiarray', 'templ_common.h.src'), join('src', 'multiarray', 'nditer_templ.c.src'), join('src', 'multiarray', 'nditer_api.c'), join('src', 'multiarray', 'nditer_constr.c'), @@ -854,6 +856,7 @@ def configuration(parent_package='',top_path=None): multiarray_deps.extend(multiarray_src) multiarray_src = [join('src', 'multiarray', 'multiarraymodule_onefile.c')] multiarray_src.append(generate_multiarray_templated_sources) + multiarray_src.append(join('src', 'multiarray', 'templ_common.h.src')) config.add_extension('multiarray', diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 55ae16515..dff901e89 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -59,6 +59,7 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0; #include "multiarraymodule.h" #include "cblasfuncs.h" #include "vdot.h" +#include "templ_common.h" /* Only here for API compatibility */ NPY_NO_EXPORT PyTypeObject PyBigArray_Type; @@ -1221,10 +1222,18 @@ _pyarray_correlate(PyArrayObject *ap1, PyArrayObject *ap2, int typenum, ip2 -= is2; op += os; } - for (i = 0; i < (n1 - n2 + 1); i++) { - dot(ip1, is1, ip2, is2, op, n, ret); - ip1 += is1; - op += os; + if (small_correlate(ip1, is1, n1 - n2 + 1, PyArray_TYPE(ap1), + ip2, is2, n, PyArray_TYPE(ap2), + op, os)) { + ip1 += is1 * (n1 - n2 + 1); + op += os * (n1 - n2 + 1); + } + else { + for (i = 0; i < (n1 - n2 + 1); i++) { + dot(ip1, is1, ip2, is2, op, n, ret); + ip1 += is1; + op += os; + } } for (i = 0; i < n_right; i++) { n--; diff --git a/numpy/core/src/multiarray/templ_common.h.src b/numpy/core/src/multiarray/templ_common.h.src new file mode 100644 index 000000000..d16afcef8 --- /dev/null +++ b/numpy/core/src/multiarray/templ_common.h.src @@ -0,0 +1,71 @@ +#ifndef __NPY_TYPED_COMMON_INC +#define __NPY_TYPED_COMMON_INC + +/* utility functions that profit from templates */ + +#include "lowlevel_strided_loops.h" +#include "numpy/npy_common.h" +#include "numpy/ndarraytypes.h" + + +/* + * Compute correlation of data with with small kernels + * Calling a BLAS dot product for the inner loop of the correlation is overkill + * for small kernels. It is faster to compute it directly. + * Intended to be used by _pyarray_correlate so no input verifications is done + * especially it does not handle the boundaries, they should be handled by the + * caller. + * Returns 0 if kernel is considered too large or types are not supported, then + * the regular array dot should be used to process the data. + * + * d_, dstride, nd, dtype: data pointer, its stride in bytes, number of + * elements and type of data + * k_, kstride, nk, ktype: kernel pointer, its stride in bytes, number of + * elements and type of data + * out_, ostride: output data pointer and its stride in bytes + */ +static int +small_correlate(const char * d_, npy_intp dstride, + npy_intp nd, enum NPY_TYPES dtype, + const char * k_, npy_intp kstride, + npy_intp nk, enum NPY_TYPES ktype, + char * out_, npy_intp ostride) +{ + /* only handle small kernels and uniform types */ + if (nk > 11 || dtype != ktype) { + return 0; + } + + switch (dtype) { +/**begin repeat + * Float types + * #type = npy_float, npy_double# + * #TYPE = NPY_FLOAT, NPY_DOUBLE# + */ + case @TYPE@: + { + npy_intp i; + const @type@ * d = (@type@*)d_; + const @type@ * k = (@type@*)k_; + @type@ * out = (@type@*)out_; + dstride /= sizeof(@type@); + kstride /= sizeof(@type@); + ostride /= sizeof(@type@); + for (i = 0; i < nd; i++) { + npy_intp j; + @type@ s = d[i * dstride] * k[0 * kstride]; + for (j = 1; j < nk; j++) { + s += d[(i + j) * dstride] * k[j * kstride]; + } + out[i * ostride] = s; + } + return 1; + } +/**end repeat**/ + default: + return 0; + } +} + + +#endif diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index be82d449f..46e864495 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -1902,16 +1902,30 @@ class TestLikeFuncs(TestCase): class _TestCorrelate(TestCase): def _setup(self, dt): self.x = np.array([1, 2, 3, 4, 5], dtype=dt) + self.xs = np.arange(1, 20)[::3] self.y = np.array([-1, -2, -3], dtype=dt) self.z1 = np.array([ -3., -8., -14., -20., -26., -14., -5.], dtype=dt) - self.z2 = np.array([ -5., -14., -26., -20., -14., -8., -3.], dtype=dt) + self.z1_4 = np.array([-2., -5., -8., -11., -14., -5.], dtype=dt) + self.z1r = np.array([-15., -22., -22., -16., -10., -4., -1.], dtype=dt) + self.z2 = np.array([-5., -14., -26., -20., -14., -8., -3.], dtype=dt) + self.z2r = np.array([-1., -4., -10., -16., -22., -22., -15.], dtype=dt) + self.zs = np.array([-3., -14., -30., -48., -66., -84., + -102., -54., -19.], dtype=dt) def test_float(self): self._setup(np.float) z = np.correlate(self.x, self.y, 'full', old_behavior=self.old_behavior) assert_array_almost_equal(z, self.z1) + z = np.correlate(self.x, self.y[:-1], 'full', old_behavior=self.old_behavior) + assert_array_almost_equal(z, self.z1_4) z = np.correlate(self.y, self.x, 'full', old_behavior=self.old_behavior) assert_array_almost_equal(z, self.z2) + z = np.correlate(self.x[::-1], self.y, 'full', old_behavior=self.old_behavior) + assert_array_almost_equal(z, self.z1r) + z = np.correlate(self.y, self.x[::-1], 'full', old_behavior=self.old_behavior) + assert_array_almost_equal(z, self.z2r) + z = np.correlate(self.xs, self.y, 'full', old_behavior=self.old_behavior) + assert_array_almost_equal(z, self.zs) def test_object(self): self._setup(Decimal) @@ -1935,6 +1949,7 @@ class TestCorrelate(_TestCorrelate): # as well _TestCorrelate._setup(self, dt) self.z2 = self.z1 + self.z2r = self.z1r @dec.deprecated() def test_complex(self): |