summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/add_newdocs.py14
-rw-r--r--numpy/core/src/multiarray/methods.c45
-rw-r--r--numpy/core/src/umath/ufunc_object.c173
-rw-r--r--numpy/core/tests/test_umath.py32
-rw-r--r--numpy/doc/subclassing.py38
-rw-r--r--numpy/lib/shape_base.py18
-rw-r--r--numpy/linalg/linalg.py2
7 files changed, 297 insertions, 25 deletions
diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py
index e541e7cb0..d588cbba0 100644
--- a/numpy/add_newdocs.py
+++ b/numpy/add_newdocs.py
@@ -1645,8 +1645,14 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('__array__',
"""))
+add_newdoc('numpy.core.multiarray', 'ndarray', ('__array_prepare__',
+ """a.__array_prepare__(obj) -> Object of same type as ndarray object obj.
+
+ """))
+
+
add_newdoc('numpy.core.multiarray', 'ndarray', ('__array_wrap__',
- """a.__array_wrap__(obj) -> Object of same type as a from ndarray obj.
+ """a.__array_wrap__(obj) -> Object of same type as ndarray object a.
"""))
@@ -3943,7 +3949,7 @@ add_newdoc('numpy.lib.index_tricks', 'ogrid',
""")
-
+
##############################################################################
#
# Documentation for `generic` attributes and methods
@@ -3955,7 +3961,7 @@ add_newdoc('numpy.core.numerictypes', 'generic',
""")
# Attributes
-
+
add_newdoc('numpy.core.numerictypes', 'generic', ('T',
"""
"""))
@@ -4246,7 +4252,7 @@ add_newdoc('numpy.core.numerictypes', 'generic', ('var',
add_newdoc('numpy.core.numerictypes', 'generic', ('view',
"""
"""))
-
+
##############################################################################
#
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 8e9bf24e4..de99ca137 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -767,6 +767,49 @@ array_wraparray(PyArrayObject *self, PyObject *args)
return NULL;
}
arr = PyTuple_GET_ITEM(args, 0);
+ if (arr == NULL) {
+ return NULL;
+ }
+ if (!PyArray_Check(arr)) {
+ PyErr_SetString(PyExc_TypeError,
+ "can only be called with ndarray object");
+ return NULL;
+ }
+
+ if (self->ob_type != arr->ob_type){
+ Py_INCREF(PyArray_DESCR(arr));
+ ret = PyArray_NewFromDescr(self->ob_type,
+ PyArray_DESCR(arr),
+ PyArray_NDIM(arr),
+ PyArray_DIMS(arr),
+ PyArray_STRIDES(arr), PyArray_DATA(arr),
+ PyArray_FLAGS(arr), (PyObject *)self);
+ if (ret == NULL) {
+ return NULL;
+ }
+ Py_INCREF(arr);
+ PyArray_BASE(ret) = arr;
+ return ret;
+ } else {
+ /*The type was set in __array_prepare__*/
+ Py_INCREF(arr);
+ return arr;
+ }
+}
+
+
+static PyObject *
+array_preparearray(PyArrayObject *self, PyObject *args)
+{
+ PyObject *arr;
+ PyObject *ret;
+
+ if (PyTuple_Size(args) < 1) {
+ PyErr_SetString(PyExc_TypeError,
+ "only accepts 1 argument");
+ return NULL;
+ }
+ arr = PyTuple_GET_ITEM(args, 0);
if (!PyArray_Check(arr)) {
PyErr_SetString(PyExc_TypeError,
"can only be called with ndarray object");
@@ -2031,6 +2074,8 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
/* for subtypes */
{"__array__", (PyCFunction)array_getarray,
METH_VARARGS, NULL},
+ {"__array_prepare__", (PyCFunction)array_preparearray,
+ METH_VARARGS, NULL},
{"__array_wrap__", (PyCFunction)array_wraparray,
METH_VARARGS, NULL},
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 537f7d768..74ba7acbf 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -234,6 +234,122 @@ static char *_types_msg = "function not supported for these types, " \
"and can't coerce safely to supported types";
/*
+ * This function analyzes the input arguments
+ * and determines an appropriate __array_prepare__ function to call
+ * for the outputs.
+ *
+ * If an output argument is provided, then it is wrapped
+ * with its own __array_prepare__ not with the one determined by
+ * the input arguments.
+ *
+ * if the provided output argument is already an ndarray,
+ * the wrapping function is None (which means no wrapping will
+ * be done --- not even PyArray_Return).
+ *
+ * A NULL is placed in output_wrap for outputs that
+ * should just have PyArray_Return called.
+ */
+static void
+_find_array_prepare(PyObject *args, PyObject **output_wrap, int nin, int nout)
+{
+ Py_ssize_t nargs;
+ int i;
+ int np = 0;
+ PyObject *with_wrap[NPY_MAXARGS], *wraps[NPY_MAXARGS];
+ PyObject *obj, *wrap = NULL;
+
+ nargs = PyTuple_GET_SIZE(args);
+ for (i = 0; i < nin; i++) {
+ obj = PyTuple_GET_ITEM(args, i);
+ if (PyArray_CheckExact(obj) || PyArray_IsAnyScalar(obj)) {
+ continue;
+ }
+ wrap = PyObject_GetAttrString(obj, "__array_prepare__");
+ if (wrap) {
+ if (PyCallable_Check(wrap)) {
+ with_wrap[np] = obj;
+ wraps[np] = wrap;
+ ++np;
+ }
+ else {
+ Py_DECREF(wrap);
+ wrap = NULL;
+ }
+ }
+ else {
+ PyErr_Clear();
+ }
+ }
+ if (np > 0) {
+ /* If we have some wraps defined, find the one of highest priority */
+ wrap = wraps[0];
+ if (np > 1) {
+ double maxpriority = PyArray_GetPriority(with_wrap[0],
+ PyArray_SUBTYPE_PRIORITY);
+ for (i = 1; i < np; ++i) {
+ double priority = PyArray_GetPriority(with_wrap[i],
+ PyArray_SUBTYPE_PRIORITY);
+ if (priority > maxpriority) {
+ maxpriority = priority;
+ Py_DECREF(wrap);
+ wrap = wraps[i];
+ }
+ else {
+ Py_DECREF(wraps[i]);
+ }
+ }
+ }
+ }
+
+ /*
+ * Here wrap is the wrapping function determined from the
+ * input arrays (could be NULL).
+ *
+ * For all the output arrays decide what to do.
+ *
+ * 1) Use the wrap function determined from the input arrays
+ * This is the default if the output array is not
+ * passed in.
+ *
+ * 2) Use the __array_prepare__ method of the output object.
+ * This is special cased for
+ * exact ndarray so that no PyArray_Return is
+ * done in that case.
+ */
+ for (i = 0; i < nout; i++) {
+ int j = nin + i;
+ int incref = 1;
+ output_wrap[i] = wrap;
+ if (j < nargs) {
+ obj = PyTuple_GET_ITEM(args, j);
+ if (obj == Py_None) {
+ continue;
+ }
+ if (PyArray_CheckExact(obj)) {
+ output_wrap[i] = Py_None;
+ }
+ else {
+ PyObject *owrap = PyObject_GetAttrString(obj,
+ "__array_prepare__");
+ incref = 0;
+ if (!(owrap) || !(PyCallable_Check(owrap))) {
+ Py_XDECREF(owrap);
+ owrap = wrap;
+ incref = 1;
+ PyErr_Clear();
+ }
+ output_wrap[i] = owrap;
+ }
+ }
+ if (incref) {
+ Py_XINCREF(output_wrap[i]);
+ }
+ }
+ Py_XDECREF(wrap);
+ return;
+}
+
+/*
* Called for non-NULL user-defined functions.
* The object should be a CObject pointing to a linked-list of functions
* storing the function, data, and signature of all user-defined functions.
@@ -1054,6 +1170,7 @@ construct_arrays(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps,
npy_intp temp_dims[NPY_MAXDIMS];
npy_intp *out_dims;
int out_nd;
+ PyObject *wraparr[NPY_MAXARGS];
/* Check number of arguments */
nargs = PyTuple_Size(args);
@@ -1332,13 +1449,57 @@ construct_arrays(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps,
return -1;
}
- /* Recover mps[i]. */
- if (self->core_enabled) {
- PyArrayObject *ao = mps[i];
- mps[i] = (PyArrayObject *)mps[i]->base;
- Py_DECREF(ao);
- }
+ /* Recover mps[i]. */
+ if (self->core_enabled) {
+ PyArrayObject *ao = mps[i];
+ mps[i] = (PyArrayObject *)mps[i]->base;
+ Py_DECREF(ao);
+ }
+
+ }
+
+ /*
+ * Use __array_prepare__ on all outputs
+ * if present on one of the input arguments.
+ * If present for multiple inputs:
+ * use __array_prepare__ of input object with largest
+ * __array_priority__ (default = 0.0)
+ *
+ * Exception: we should not wrap outputs for items already
+ * passed in as output-arguments. These items should either
+ * be left unwrapped or wrapped by calling their own __array_prepare__
+ * routine.
+ *
+ * For each output argument, wrap will be either
+ * NULL --- call PyArray_Return() -- default if no output arguments given
+ * None --- array-object passed in don't call PyArray_Return
+ * method --- the __array_prepare__ method to call.
+ */
+ _find_array_prepare(args, wraparr, loop->ufunc->nin, loop->ufunc->nout);
+ /* wrap outputs */
+ for (i = 0; i < loop->ufunc->nout; i++) {
+ int j = loop->ufunc->nin+i;
+ PyObject *wrap;
+ wrap = wraparr[i];
+ if (wrap != NULL) {
+ if (wrap == Py_None) {
+ Py_DECREF(wrap);
+ continue;
+ }
+ PyObject *res = PyObject_CallFunction(wrap, "O(OOi)",
+ mps[j], loop->ufunc, args, i);
+ Py_DECREF(wrap);
+ if ((res == NULL) || (res == Py_None)) {
+ if (!PyErr_Occurred()){
+ PyErr_SetString(PyExc_TypeError,
+ "__array_prepare__ must return an ndarray or subclass thereof");
+ }
+ return -1;
+ }
+ Py_DECREF(mps[j]);
+ mps[j] = (PyArrayObject *)res;
+ }
}
/*
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index c82e5af7c..abea0a222 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -484,6 +484,38 @@ class TestSpecialMethods(TestCase):
a = A()
self.failUnlessRaises(RuntimeError, ncu.maximum, a, a)
+ def test_default_prepare(self):
+ class with_wrap(object):
+ __array_priority__ = 10
+ def __array__(self):
+ return np.zeros(1)
+ def __array_wrap__(self, arr, context):
+ return arr
+ a = with_wrap()
+ x = ncu.minimum(a, a)
+ assert_equal(x, np.zeros(1))
+ assert_equal(type(x), np.ndarray)
+
+ def test_prepare(self):
+ class with_prepare(np.ndarray):
+ __array_priority__ = 10
+ def __array_prepare__(self, arr, context):
+ # make sure we can return a new
+ return np.array(arr).view(type=with_prepare)
+ a = np.array(1).view(type=with_prepare)
+ x = np.add(a, a)
+ assert_equal(x, np.array(2))
+ assert_equal(type(x), with_prepare)
+
+ def test_failing_prepare(self):
+ class A(object):
+ def __array__(self):
+ return np.zeros(1)
+ def __array_prepare__(self, arr, context=None):
+ raise RuntimeError
+ a = A()
+ self.failUnlessRaises(RuntimeError, ncu.maximum, a, a)
+
def test_array_with_context(self):
class A(object):
def __array__(self, dtype=None, context=None):
diff --git a/numpy/doc/subclassing.py b/numpy/doc/subclassing.py
index a6666217b..5f658d922 100644
--- a/numpy/doc/subclassing.py
+++ b/numpy/doc/subclassing.py
@@ -115,7 +115,7 @@ A brief Python primer on ``__new__`` and ``__init__``
``__new__`` is a standard Python method, and, if present, is called
before ``__init__`` when we create a class instance. See the `python
__new__ documentation
-<http://docs.python.org/reference/datamodel.html#object.__new__>`_ for more detail.
+<http://docs.python.org/reference/datamodel.html#object.__new__>`_ for more detail.
For example, consider the following Python code:
@@ -229,7 +229,7 @@ where our object creation housekeeping usually goes.
``ndarray.__new__(MySubClass,...)``, or do view casting of an existing
array (see below)
* For view casting and new-from-template, the equivalent of
- ``ndarray.__new__(MySubClass,...`` is called, at the C level.
+ ``ndarray.__new__(MySubClass,...`` is called, at the C level.
The arguments that ``__array_finalize__`` recieves differ for the three
methods of instance creation above.
@@ -355,7 +355,7 @@ Using the object looks like this:
>>> type(obj)
<class 'InfoArray'>
>>> obj.info is None
- True
+ True
>>> obj = InfoArray(shape=(3,), info='information')
>>> obj.info
'information'
@@ -364,8 +364,8 @@ Using the object looks like this:
<class 'InfoArray'>
>>> v.info
'information'
- >>> arr = np.arange(10)
- >>> cast_arr = arr.view(InfoArray) # view casting
+ >>> arr = np.arange(10)
+ >>> cast_arr = arr.view(InfoArray) # view casting
>>> type(cast_arr)
<class 'InfoArray'>
>>> cast_arr.info is None
@@ -381,7 +381,7 @@ Slightly more realistic example - attribute added to existing array
-------------------------------------------------------------------
Here is a class that takes a standard ndarray that already exists, casts
-as our type, and adds an extra attribute.
+as our type, and adds an extra attribute.
.. testcode::
@@ -403,7 +403,7 @@ as our type, and adds an extra attribute.
if obj is None: return
self.info = getattr(obj, 'info', None)
-
+
So:
>>> arr = np.arange(5)
@@ -423,9 +423,9 @@ So:
``__array_wrap__`` for ufuncs
-----------------------------
-``__array_wrap__`` gets called by numpy ufuncs and other numpy
+``__array_wrap__`` gets called at the end of numpy ufuncs and other numpy
functions, to allow a subclass to set the type of the return value
-from - for example - ufuncs. Let's show how this works with an example.
+and update attributes and metadata. Let's show how this works with an example.
First we make the same subclass as above, but with a different name and
some print statements:
@@ -454,8 +454,8 @@ some print statements:
# then just call the parent
return np.ndarray.__array_wrap__(self, out_arr, context)
-We run a ufunc on an instance of our new array:
-
+We run a ufunc on an instance of our new array:
+
>>> obj = MySubClass(np.arange(5), info='spam')
In __array_finalize__:
self is MySubClass([0, 1, 2, 3, 4])
@@ -473,8 +473,9 @@ MySubClass([1, 3, 5, 7, 9])
>>> ret.info
'spam'
-Note that the ufunc (``np.add``) has called
-``MySubClass.__array_wrap__`` with arguments ``self`` as ``obj``, and
+Note that the ufunc (``np.add``) has called the ``__array_wrap__`` method of the
+input with the highest ``__array_priority__`` value, in this case
+``MySubClass.__array_wrap__``, with arguments ``self`` as ``obj``, and
``out_arr`` as the (ndarray) result of the addition. In turn, the
default ``__array_wrap__`` (``ndarray.__array_wrap__``) has cast the
result to class ``MySubClass``, and called ``__array_finalize__`` -
@@ -505,6 +506,17 @@ domain of the ufunc). ``__array_wrap__`` should return an instance of
its containing class. See the masked array subclass for an
implementation.
+In addition to ``__array_wrap__``, which is called on the way out of the
+ufunc, there is also an ``__array_prepare__`` method which is called on
+the way into the ufunc, after the output arrays are created but before any
+computation has been performed. The default implementation does nothing
+but pass through the array. ``__array_prepare__`` should not attempt to
+access the array data or resize the array, it is intended for setting the
+output array type, updating attributes and metadata, and performing any
+checks based on the input that may be desired before computation begins.
+Like ``__array_wrap__``, ``__array_prepare__`` must return an ndarray or
+subclass thereof or raise an error.
+
Extra gotchas - custom ``__del__`` methods and ndarray.base
-----------------------------------------------------------
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 69ef0be4f..a5bf4d0ea 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -892,6 +892,19 @@ def dsplit(ary,indices_or_sections):
raise ValueError, 'vsplit only works on arrays of 3 or more dimensions'
return split(ary,indices_or_sections,2)
+def get_array_prepare(*args):
+ """Find the wrapper for the array with the highest priority.
+
+ In case of ties, leftmost wins. If no wrapper is found, return None
+ """
+ wrappers = [(getattr(x, '__array_priority__', 0), -i,
+ x.__array_prepare__) for i, x in enumerate(args)
+ if hasattr(x, '__array_prepare__')]
+ wrappers.sort()
+ if wrappers:
+ return wrappers[-1][-1]
+ return None
+
def get_array_wrap(*args):
"""Find the wrapper for the array with the highest priority.
@@ -975,7 +988,6 @@ def kron(a,b):
True
"""
- wrapper = get_array_wrap(a, b)
b = asanyarray(b)
a = array(a,copy=False,subok=True,ndmin=b.ndim)
ndb, nda = b.ndim, a.ndim
@@ -998,6 +1010,10 @@ def kron(a,b):
axis = nd-1
for _ in xrange(nd):
result = concatenate(result, axis=axis)
+ wrapper = get_array_prepare(a, b)
+ if wrapper is not None:
+ result = wrapper(result)
+ wrapper = get_array_wrap(a, b)
if wrapper is not None:
result = wrapper(result)
return result
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index df888b754..5878b909f 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -30,7 +30,7 @@ class LinAlgError(Exception):
def _makearray(a):
new = asarray(a)
- wrap = getattr(a, "__array_wrap__", new.__array_wrap__)
+ wrap = getattr(a, "__array_prepare__", new.__array_wrap__)
return new, wrap
def isComplexType(t):