summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBlake Griffith <blake.a.griffith@gmail.com>2013-08-25 11:44:53 -0500
committerBlake Griffith <blake.a.griffith@gmail.com>2013-09-05 20:09:44 -0500
commit21976ca31eda537824ad877d432634eaf103567b (patch)
tree9704d3f03874734d23ebcb9d2b93ec31aeb745ef
parent536cd36c1ba45c42ad306fda5f9c4d12ee0f5afd (diff)
downloadnumpy-21976ca31eda537824ad877d432634eaf103567b.tar.gz
ENH: Add ufunc override functionality to ufuncs and dots.
-rw-r--r--numpy/core/blasdot/_dotblas.c28
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c22
-rw-r--r--numpy/core/src/umath/ufunc_object.c48
3 files changed, 95 insertions, 3 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c
index ae6b1b1e5..7a9f858e3 100644
--- a/numpy/core/blasdot/_dotblas.c
+++ b/numpy/core/blasdot/_dotblas.c
@@ -2,11 +2,14 @@
* This module provides a BLAS optimized\nmatrix multiply,
* inner product and dot for numpy arrays
*/
-#define NPY_NO_DEPRECATED_API NPY_API_VERSION
+#define NPY_NO_DEPRECATED_API NPY_API_VERSION
#include "Python.h"
-#include "npy_config.h"
+
#include "numpy/arrayobject.h"
+#include "npy_config.h"
+#include "npy_pycompat.h"
+#include "private/ufunc_override.h"
#ifndef CBLAS_HEADER
#define CBLAS_HEADER "cblas.h"
#endif
@@ -215,8 +218,12 @@ _bad_strides(PyArrayObject *ap)
static PyObject *
dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwargs)
{
+ static PyObject *cached_npy_dot = NULL;
+ PyObject *override = NULL;
+ PyObject *module;
PyObject *op1, *op2;
PyArrayObject *ap1 = NULL, *ap2 = NULL, *out = NULL, *ret = NULL;
+ int errval;
int j, l, lda, ldb, ldc;
int typenum, nd;
npy_intp ap1stride = 0;
@@ -232,6 +239,23 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
MatrixShape ap1shape, ap2shape;
char* kwords[] = {"a", "b", "out", NULL };
+ if (cached_npy_dot == NULL) {
+ module = PyImport_ImportModule("numpy.core._dotblas");
+ cached_npy_dot = PyDict_GetItemString(PyModule_GetDict(module), "dot");
+
+ Py_INCREF(cached_npy_dot);
+ Py_DECREF(module);
+ }
+
+ errval = PyUFunc_CheckOverride(cached_npy_dot, "__call__", args, kwargs,
+ &override, 2);
+ if (errval) {
+ return NULL;
+ }
+ else if (override) {
+ return override;
+ }
+
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO|O", kwords,
&op1, &op2, &out)) {
return NULL;
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 85dd8ab01..f4ceecde9 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -53,6 +53,7 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
#include "ctors.h"
#include "array_assign.h"
#include "common.h"
+#include "private/ufunc_override.h"
/* Only here for API compatibility */
NPY_NO_EXPORT PyTypeObject PyBigArray_Type;
@@ -2079,8 +2080,29 @@ array_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
static PyObject *
array_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwds)
{
+ int errval;
+ static PyObject *cached_npy_dot = NULL;
+ PyObject *override = NULL;
PyObject *v, *a, *o = NULL;
char* kwlist[] = {"a", "b", "out", NULL };
+ PyObject *module;
+
+ if (cached_npy_dot == NULL) {
+ module = PyImport_ImportModule("numpy.core.multiarray");
+ cached_npy_dot = PyDict_GetItemString(PyModule_GetDict(module), "dot");
+
+ Py_INCREF(cached_npy_dot);
+ Py_DECREF(module);
+ }
+
+ errval = PyUFunc_CheckOverride(cached_npy_dot, "__call__", args, kwds,
+ &override, 2);
+ if (errval) {
+ return NULL;
+ }
+ else if (override) {
+ return override;
+ }
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|O", kwlist, &a, &v, &o)) {
return NULL;
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index a71777ce3..062bf163b 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -44,6 +44,7 @@
#include "reduction.h"
#include "ufunc_object.h"
+#include "ufunc_override.h"
/********** PRINTF DEBUG TRACING **************/
#define NPY_UF_DBG_TRACING 0
@@ -707,7 +708,6 @@ fail:
return -1;
}
-
/********* GENERIC UFUNC USING ITERATOR *********/
/*
@@ -4043,6 +4043,7 @@ ufunc_generic_call(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
PyObject *retobj[NPY_MAXARGS];
PyObject *wraparr[NPY_MAXARGS];
PyObject *res;
+ PyObject *override = NULL;
int errval;
/*
@@ -4053,6 +4054,18 @@ ufunc_generic_call(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
mps[i] = NULL;
}
+ errval = PyUFunc_CheckOverride(ufunc, "__call__", args, kwds, &override,
+ ufunc->nin);
+ if (errval) {
+ return NULL;
+ }
+ else if (override) {
+ for (i = 0; i < ufunc->nargs; i++) {
+ PyArray_XDECREF_ERR(mps[i]);
+ }
+ return override;
+ }
+
errval = PyUFunc_GenericFunction(ufunc, args, kwds, mps);
if (errval < 0) {
for (i = 0; i < ufunc->nargs; i++) {
@@ -4834,18 +4847,51 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
static PyObject *
ufunc_reduce(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
{
+ int errval;
+ PyObject *override = NULL;
+
+ errval = PyUFunc_CheckOverride(ufunc, "reduce", args, kwds, &override,
+ ufunc->nin);
+ if (errval) {
+ return NULL;
+ }
+ else if (override) {
+ return override;
+ }
return PyUFunc_GenericReduction(ufunc, args, kwds, UFUNC_REDUCE);
}
static PyObject *
ufunc_accumulate(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
{
+ int errval;
+ PyObject *override = NULL;
+
+ errval = PyUFunc_CheckOverride(ufunc, "accumulate", args, kwds, &override,
+ ufunc->nin);
+ if (errval) {
+ return NULL;
+ }
+ else if (override) {
+ return override;
+ }
return PyUFunc_GenericReduction(ufunc, args, kwds, UFUNC_ACCUMULATE);
}
static PyObject *
ufunc_reduceat(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
{
+ int errval;
+ PyObject *override = NULL;
+
+ errval = PyUFunc_CheckOverride(ufunc, "reduceat", args, kwds, &override,
+ ufunc->nin);
+ if (errval) {
+ return NULL;
+ }
+ else if (override) {
+ return override;
+ }
return PyUFunc_GenericReduction(ufunc, args, kwds, UFUNC_REDUCEAT);
}