diff options
author | Blake Griffith <blake.a.griffith@gmail.com> | 2013-08-25 11:44:53 -0500 |
---|---|---|
committer | Blake Griffith <blake.a.griffith@gmail.com> | 2013-09-05 20:09:44 -0500 |
commit | 21976ca31eda537824ad877d432634eaf103567b (patch) | |
tree | 9704d3f03874734d23ebcb9d2b93ec31aeb745ef | |
parent | 536cd36c1ba45c42ad306fda5f9c4d12ee0f5afd (diff) | |
download | numpy-21976ca31eda537824ad877d432634eaf103567b.tar.gz |
ENH: Add ufunc override functionality to ufuncs and dots.
-rw-r--r-- | numpy/core/blasdot/_dotblas.c | 28 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 22 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 48 |
3 files changed, 95 insertions, 3 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c index ae6b1b1e5..7a9f858e3 100644 --- a/numpy/core/blasdot/_dotblas.c +++ b/numpy/core/blasdot/_dotblas.c @@ -2,11 +2,14 @@ * This module provides a BLAS optimized\nmatrix multiply, * inner product and dot for numpy arrays */ -#define NPY_NO_DEPRECATED_API NPY_API_VERSION +#define NPY_NO_DEPRECATED_API NPY_API_VERSION #include "Python.h" -#include "npy_config.h" + #include "numpy/arrayobject.h" +#include "npy_config.h" +#include "npy_pycompat.h" +#include "private/ufunc_override.h" #ifndef CBLAS_HEADER #define CBLAS_HEADER "cblas.h" #endif @@ -215,8 +218,12 @@ _bad_strides(PyArrayObject *ap) static PyObject * dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwargs) { + static PyObject *cached_npy_dot = NULL; + PyObject *override = NULL; + PyObject *module; PyObject *op1, *op2; PyArrayObject *ap1 = NULL, *ap2 = NULL, *out = NULL, *ret = NULL; + int errval; int j, l, lda, ldb, ldc; int typenum, nd; npy_intp ap1stride = 0; @@ -232,6 +239,23 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa MatrixShape ap1shape, ap2shape; char* kwords[] = {"a", "b", "out", NULL }; + if (cached_npy_dot == NULL) { + module = PyImport_ImportModule("numpy.core._dotblas"); + cached_npy_dot = PyDict_GetItemString(PyModule_GetDict(module), "dot"); + + Py_INCREF(cached_npy_dot); + Py_DECREF(module); + } + + errval = PyUFunc_CheckOverride(cached_npy_dot, "__call__", args, kwargs, + &override, 2); + if (errval) { + return NULL; + } + else if (override) { + return override; + } + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO|O", kwords, &op1, &op2, &out)) { return NULL; diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 85dd8ab01..f4ceecde9 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -53,6 +53,7 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0; #include "ctors.h" #include "array_assign.h" #include "common.h" +#include "private/ufunc_override.h" /* Only here for API compatibility */ NPY_NO_EXPORT PyTypeObject PyBigArray_Type; @@ -2079,8 +2080,29 @@ array_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) static PyObject * array_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwds) { + int errval; + static PyObject *cached_npy_dot = NULL; + PyObject *override = NULL; PyObject *v, *a, *o = NULL; char* kwlist[] = {"a", "b", "out", NULL }; + PyObject *module; + + if (cached_npy_dot == NULL) { + module = PyImport_ImportModule("numpy.core.multiarray"); + cached_npy_dot = PyDict_GetItemString(PyModule_GetDict(module), "dot"); + + Py_INCREF(cached_npy_dot); + Py_DECREF(module); + } + + errval = PyUFunc_CheckOverride(cached_npy_dot, "__call__", args, kwds, + &override, 2); + if (errval) { + return NULL; + } + else if (override) { + return override; + } if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|O", kwlist, &a, &v, &o)) { return NULL; diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index a71777ce3..062bf163b 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -44,6 +44,7 @@ #include "reduction.h" #include "ufunc_object.h" +#include "ufunc_override.h" /********** PRINTF DEBUG TRACING **************/ #define NPY_UF_DBG_TRACING 0 @@ -707,7 +708,6 @@ fail: return -1; } - /********* GENERIC UFUNC USING ITERATOR *********/ /* @@ -4043,6 +4043,7 @@ ufunc_generic_call(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) PyObject *retobj[NPY_MAXARGS]; PyObject *wraparr[NPY_MAXARGS]; PyObject *res; + PyObject *override = NULL; int errval; /* @@ -4053,6 +4054,18 @@ ufunc_generic_call(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) mps[i] = NULL; } + errval = PyUFunc_CheckOverride(ufunc, "__call__", args, kwds, &override, + ufunc->nin); + if (errval) { + return NULL; + } + else if (override) { + for (i = 0; i < ufunc->nargs; i++) { + PyArray_XDECREF_ERR(mps[i]); + } + return override; + } + errval = PyUFunc_GenericFunction(ufunc, args, kwds, mps); if (errval < 0) { for (i = 0; i < ufunc->nargs; i++) { @@ -4834,18 +4847,51 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) static PyObject * ufunc_reduce(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) { + int errval; + PyObject *override = NULL; + + errval = PyUFunc_CheckOverride(ufunc, "reduce", args, kwds, &override, + ufunc->nin); + if (errval) { + return NULL; + } + else if (override) { + return override; + } return PyUFunc_GenericReduction(ufunc, args, kwds, UFUNC_REDUCE); } static PyObject * ufunc_accumulate(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) { + int errval; + PyObject *override = NULL; + + errval = PyUFunc_CheckOverride(ufunc, "accumulate", args, kwds, &override, + ufunc->nin); + if (errval) { + return NULL; + } + else if (override) { + return override; + } return PyUFunc_GenericReduction(ufunc, args, kwds, UFUNC_ACCUMULATE); } static PyObject * ufunc_reduceat(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) { + int errval; + PyObject *override = NULL; + + errval = PyUFunc_CheckOverride(ufunc, "reduceat", args, kwds, &override, + ufunc->nin); + if (errval) { + return NULL; + } + else if (override) { + return override; + } return PyUFunc_GenericReduction(ufunc, args, kwds, UFUNC_REDUCEAT); } |