diff options
author | Darren Dale <dsdale24@gmail.com> | 2009-08-23 16:30:28 +0000 |
---|---|---|
committer | Darren Dale <dsdale24@gmail.com> | 2009-08-23 16:30:28 +0000 |
commit | 9e053da77d773fb22ee83219ad5595af6c73c953 (patch) | |
tree | 955f7f9eae0d972c8d12019b5b873811de136505 | |
parent | 856a9363bf28da036c6102fc77ea7fcdba5e777a (diff) | |
download | numpy-9e053da77d773fb22ee83219ad5595af6c73c953.tar.gz |
add support for __array_prepare__
-rw-r--r-- | doc/source/reference/arrays.classes.rst | 31 | ||||
-rw-r--r-- | doc/source/reference/ufuncs.rst | 21 | ||||
-rw-r--r-- | numpy/add_newdocs.py | 14 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 45 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 173 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 32 | ||||
-rw-r--r-- | numpy/doc/subclassing.py | 38 | ||||
-rw-r--r-- | numpy/lib/shape_base.py | 18 | ||||
-rw-r--r-- | numpy/linalg/linalg.py | 2 |
9 files changed, 333 insertions, 41 deletions
diff --git a/doc/source/reference/arrays.classes.rst b/doc/source/reference/arrays.classes.rst index 671c95bbf..f9abfbafa 100644 --- a/doc/source/reference/arrays.classes.rst +++ b/doc/source/reference/arrays.classes.rst @@ -47,14 +47,29 @@ customize: update meta-information from the "parent." Subclasses inherit a default implementation of this method that does nothing. -.. function:: __array_wrap__(array) - - This method should return an instance of the subclass from the - :class:`ndarray` object passed in. For example, this is called - after every :ref:`ufunc <ufuncs.output-type>` for the object with - the highest array priority. The ufunc-computed array object is - passed in and whatever is returned is passed to the - user. Subclasses inherit a default implementation of this method. +.. function:: __array_prepare__(array, context=None) + + At the beginning of every :ref:`ufunc <ufuncs.output-type>`, this + method is called on the input object with the highest array + priority, or the output object if one was specified. The output + array is passed in and whatever is returned is passed to the ufunc. + Subclasses inherit a default implementation of this method which + simply returns the output array unmodified. Subclasses may opt to + use this method to transform the output array into an instance of + the subclass and update metadata before returning the array to the + ufunc for computation. + +.. function:: __array_wrap__(array, context=None) + + At the end of every :ref:`ufunc <ufuncs.output-type>`, this method + is called on the input object with the highest array priority, or + the output object if one was specified. The ufunc-computed array + is passed in and whatever is returned is passed to the user. + Subclasses inherit a default implementation of this method, which + transforms the array into a new instance of the object's class. Subclasses + may opt to use this method to transform the output array into an + instance of the subclass and update metadata before returning the + array to the user. .. data:: __array_priority__ diff --git a/doc/source/reference/ufuncs.rst b/doc/source/reference/ufuncs.rst index 09b13dc89..8096e1497 100644 --- a/doc/source/reference/ufuncs.rst +++ b/doc/source/reference/ufuncs.rst @@ -102,19 +102,24 @@ Output type determination The output of the ufunc (and its methods) is not necessarily an :class:`ndarray`, if all input arguments are not :class:`ndarrays <ndarray>`. -All output arrays will be passed to the :obj:`__array_wrap__` -method of the input (besides :class:`ndarrays <ndarray>`, and scalars) -that defines it **and** has the highest :obj:`__array_priority__` of -any other input to the universal function. The default -:obj:`__array_priority__` of the ndarray is 0.0, and the default -:obj:`__array_priority__` of a subtype is 1.0. Matrices have -:obj:`__array_priority__` equal to 10.0. +All output arrays will be passed to the :obj:`__array_prepare__` and +:obj:`__array_wrap__` methods of the input (besides +:class:`ndarrays <ndarray>`, and scalars) that defines it **and** has +the highest :obj:`__array_priority__` of any other input to the +universal function. The default :obj:`__array_priority__` of the +ndarray is 0.0, and the default :obj:`__array_priority__` of a subtype +is 1.0. Matrices have :obj:`__array_priority__` equal to 10.0. The ufuncs can also all take output arguments. The output will be cast if necessary to the provided output array. If a class with an :obj:`__array__` method is used for the output, results will be written to the object returned by :obj:`__array__`. Then, if the class -also has an :obj:`__array_wrap__` method, the returned +also has an :obj:`__array_prepare__` method, it is called so metadata +may be determined based on the context of the ufunc (the context +consisting of the ufunc itself, the arguments passed to the ufunc, and +the ufunc domain.) The array object returned by +:obj:`__array_prepare__` is passed to the ufunc for computation. +Finally, if the class also has an :obj:`__array_wrap__` method, the returned :class:`ndarray` result will be passed to that method just before passing control back to the caller. diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index e541e7cb0..d588cbba0 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -1645,8 +1645,14 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('__array__', """)) +add_newdoc('numpy.core.multiarray', 'ndarray', ('__array_prepare__', + """a.__array_prepare__(obj) -> Object of same type as ndarray object obj. + + """)) + + add_newdoc('numpy.core.multiarray', 'ndarray', ('__array_wrap__', - """a.__array_wrap__(obj) -> Object of same type as a from ndarray obj. + """a.__array_wrap__(obj) -> Object of same type as ndarray object a. """)) @@ -3943,7 +3949,7 @@ add_newdoc('numpy.lib.index_tricks', 'ogrid', """) - + ############################################################################## # # Documentation for `generic` attributes and methods @@ -3955,7 +3961,7 @@ add_newdoc('numpy.core.numerictypes', 'generic', """) # Attributes - + add_newdoc('numpy.core.numerictypes', 'generic', ('T', """ """)) @@ -4246,7 +4252,7 @@ add_newdoc('numpy.core.numerictypes', 'generic', ('var', add_newdoc('numpy.core.numerictypes', 'generic', ('view', """ """)) - + ############################################################################## # diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 8e9bf24e4..de99ca137 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -767,6 +767,49 @@ array_wraparray(PyArrayObject *self, PyObject *args) return NULL; } arr = PyTuple_GET_ITEM(args, 0); + if (arr == NULL) { + return NULL; + } + if (!PyArray_Check(arr)) { + PyErr_SetString(PyExc_TypeError, + "can only be called with ndarray object"); + return NULL; + } + + if (self->ob_type != arr->ob_type){ + Py_INCREF(PyArray_DESCR(arr)); + ret = PyArray_NewFromDescr(self->ob_type, + PyArray_DESCR(arr), + PyArray_NDIM(arr), + PyArray_DIMS(arr), + PyArray_STRIDES(arr), PyArray_DATA(arr), + PyArray_FLAGS(arr), (PyObject *)self); + if (ret == NULL) { + return NULL; + } + Py_INCREF(arr); + PyArray_BASE(ret) = arr; + return ret; + } else { + /*The type was set in __array_prepare__*/ + Py_INCREF(arr); + return arr; + } +} + + +static PyObject * +array_preparearray(PyArrayObject *self, PyObject *args) +{ + PyObject *arr; + PyObject *ret; + + if (PyTuple_Size(args) < 1) { + PyErr_SetString(PyExc_TypeError, + "only accepts 1 argument"); + return NULL; + } + arr = PyTuple_GET_ITEM(args, 0); if (!PyArray_Check(arr)) { PyErr_SetString(PyExc_TypeError, "can only be called with ndarray object"); @@ -2031,6 +2074,8 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = { /* for subtypes */ {"__array__", (PyCFunction)array_getarray, METH_VARARGS, NULL}, + {"__array_prepare__", (PyCFunction)array_preparearray, + METH_VARARGS, NULL}, {"__array_wrap__", (PyCFunction)array_wraparray, METH_VARARGS, NULL}, diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 537f7d768..74ba7acbf 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -234,6 +234,122 @@ static char *_types_msg = "function not supported for these types, " \ "and can't coerce safely to supported types"; /* + * This function analyzes the input arguments + * and determines an appropriate __array_prepare__ function to call + * for the outputs. + * + * If an output argument is provided, then it is wrapped + * with its own __array_prepare__ not with the one determined by + * the input arguments. + * + * if the provided output argument is already an ndarray, + * the wrapping function is None (which means no wrapping will + * be done --- not even PyArray_Return). + * + * A NULL is placed in output_wrap for outputs that + * should just have PyArray_Return called. + */ +static void +_find_array_prepare(PyObject *args, PyObject **output_wrap, int nin, int nout) +{ + Py_ssize_t nargs; + int i; + int np = 0; + PyObject *with_wrap[NPY_MAXARGS], *wraps[NPY_MAXARGS]; + PyObject *obj, *wrap = NULL; + + nargs = PyTuple_GET_SIZE(args); + for (i = 0; i < nin; i++) { + obj = PyTuple_GET_ITEM(args, i); + if (PyArray_CheckExact(obj) || PyArray_IsAnyScalar(obj)) { + continue; + } + wrap = PyObject_GetAttrString(obj, "__array_prepare__"); + if (wrap) { + if (PyCallable_Check(wrap)) { + with_wrap[np] = obj; + wraps[np] = wrap; + ++np; + } + else { + Py_DECREF(wrap); + wrap = NULL; + } + } + else { + PyErr_Clear(); + } + } + if (np > 0) { + /* If we have some wraps defined, find the one of highest priority */ + wrap = wraps[0]; + if (np > 1) { + double maxpriority = PyArray_GetPriority(with_wrap[0], + PyArray_SUBTYPE_PRIORITY); + for (i = 1; i < np; ++i) { + double priority = PyArray_GetPriority(with_wrap[i], + PyArray_SUBTYPE_PRIORITY); + if (priority > maxpriority) { + maxpriority = priority; + Py_DECREF(wrap); + wrap = wraps[i]; + } + else { + Py_DECREF(wraps[i]); + } + } + } + } + + /* + * Here wrap is the wrapping function determined from the + * input arrays (could be NULL). + * + * For all the output arrays decide what to do. + * + * 1) Use the wrap function determined from the input arrays + * This is the default if the output array is not + * passed in. + * + * 2) Use the __array_prepare__ method of the output object. + * This is special cased for + * exact ndarray so that no PyArray_Return is + * done in that case. + */ + for (i = 0; i < nout; i++) { + int j = nin + i; + int incref = 1; + output_wrap[i] = wrap; + if (j < nargs) { + obj = PyTuple_GET_ITEM(args, j); + if (obj == Py_None) { + continue; + } + if (PyArray_CheckExact(obj)) { + output_wrap[i] = Py_None; + } + else { + PyObject *owrap = PyObject_GetAttrString(obj, + "__array_prepare__"); + incref = 0; + if (!(owrap) || !(PyCallable_Check(owrap))) { + Py_XDECREF(owrap); + owrap = wrap; + incref = 1; + PyErr_Clear(); + } + output_wrap[i] = owrap; + } + } + if (incref) { + Py_XINCREF(output_wrap[i]); + } + } + Py_XDECREF(wrap); + return; +} + +/* * Called for non-NULL user-defined functions. * The object should be a CObject pointing to a linked-list of functions * storing the function, data, and signature of all user-defined functions. @@ -1054,6 +1170,7 @@ construct_arrays(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps, npy_intp temp_dims[NPY_MAXDIMS]; npy_intp *out_dims; int out_nd; + PyObject *wraparr[NPY_MAXARGS]; /* Check number of arguments */ nargs = PyTuple_Size(args); @@ -1332,13 +1449,57 @@ construct_arrays(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps, return -1; } - /* Recover mps[i]. */ - if (self->core_enabled) { - PyArrayObject *ao = mps[i]; - mps[i] = (PyArrayObject *)mps[i]->base; - Py_DECREF(ao); - } + /* Recover mps[i]. */ + if (self->core_enabled) { + PyArrayObject *ao = mps[i]; + mps[i] = (PyArrayObject *)mps[i]->base; + Py_DECREF(ao); + } + + } + + /* + * Use __array_prepare__ on all outputs + * if present on one of the input arguments. + * If present for multiple inputs: + * use __array_prepare__ of input object with largest + * __array_priority__ (default = 0.0) + * + * Exception: we should not wrap outputs for items already + * passed in as output-arguments. These items should either + * be left unwrapped or wrapped by calling their own __array_prepare__ + * routine. + * + * For each output argument, wrap will be either + * NULL --- call PyArray_Return() -- default if no output arguments given + * None --- array-object passed in don't call PyArray_Return + * method --- the __array_prepare__ method to call. + */ + _find_array_prepare(args, wraparr, loop->ufunc->nin, loop->ufunc->nout); + /* wrap outputs */ + for (i = 0; i < loop->ufunc->nout; i++) { + int j = loop->ufunc->nin+i; + PyObject *wrap; + wrap = wraparr[i]; + if (wrap != NULL) { + if (wrap == Py_None) { + Py_DECREF(wrap); + continue; + } + PyObject *res = PyObject_CallFunction(wrap, "O(OOi)", + mps[j], loop->ufunc, args, i); + Py_DECREF(wrap); + if ((res == NULL) || (res == Py_None)) { + if (!PyErr_Occurred()){ + PyErr_SetString(PyExc_TypeError, + "__array_prepare__ must return an ndarray or subclass thereof"); + } + return -1; + } + Py_DECREF(mps[j]); + mps[j] = (PyArrayObject *)res; + } } /* diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index c82e5af7c..abea0a222 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -484,6 +484,38 @@ class TestSpecialMethods(TestCase): a = A() self.failUnlessRaises(RuntimeError, ncu.maximum, a, a) + def test_default_prepare(self): + class with_wrap(object): + __array_priority__ = 10 + def __array__(self): + return np.zeros(1) + def __array_wrap__(self, arr, context): + return arr + a = with_wrap() + x = ncu.minimum(a, a) + assert_equal(x, np.zeros(1)) + assert_equal(type(x), np.ndarray) + + def test_prepare(self): + class with_prepare(np.ndarray): + __array_priority__ = 10 + def __array_prepare__(self, arr, context): + # make sure we can return a new + return np.array(arr).view(type=with_prepare) + a = np.array(1).view(type=with_prepare) + x = np.add(a, a) + assert_equal(x, np.array(2)) + assert_equal(type(x), with_prepare) + + def test_failing_prepare(self): + class A(object): + def __array__(self): + return np.zeros(1) + def __array_prepare__(self, arr, context=None): + raise RuntimeError + a = A() + self.failUnlessRaises(RuntimeError, ncu.maximum, a, a) + def test_array_with_context(self): class A(object): def __array__(self, dtype=None, context=None): diff --git a/numpy/doc/subclassing.py b/numpy/doc/subclassing.py index a6666217b..5f658d922 100644 --- a/numpy/doc/subclassing.py +++ b/numpy/doc/subclassing.py @@ -115,7 +115,7 @@ A brief Python primer on ``__new__`` and ``__init__`` ``__new__`` is a standard Python method, and, if present, is called before ``__init__`` when we create a class instance. See the `python __new__ documentation -<http://docs.python.org/reference/datamodel.html#object.__new__>`_ for more detail. +<http://docs.python.org/reference/datamodel.html#object.__new__>`_ for more detail. For example, consider the following Python code: @@ -229,7 +229,7 @@ where our object creation housekeeping usually goes. ``ndarray.__new__(MySubClass,...)``, or do view casting of an existing array (see below) * For view casting and new-from-template, the equivalent of - ``ndarray.__new__(MySubClass,...`` is called, at the C level. + ``ndarray.__new__(MySubClass,...`` is called, at the C level. The arguments that ``__array_finalize__`` recieves differ for the three methods of instance creation above. @@ -355,7 +355,7 @@ Using the object looks like this: >>> type(obj) <class 'InfoArray'> >>> obj.info is None - True + True >>> obj = InfoArray(shape=(3,), info='information') >>> obj.info 'information' @@ -364,8 +364,8 @@ Using the object looks like this: <class 'InfoArray'> >>> v.info 'information' - >>> arr = np.arange(10) - >>> cast_arr = arr.view(InfoArray) # view casting + >>> arr = np.arange(10) + >>> cast_arr = arr.view(InfoArray) # view casting >>> type(cast_arr) <class 'InfoArray'> >>> cast_arr.info is None @@ -381,7 +381,7 @@ Slightly more realistic example - attribute added to existing array ------------------------------------------------------------------- Here is a class that takes a standard ndarray that already exists, casts -as our type, and adds an extra attribute. +as our type, and adds an extra attribute. .. testcode:: @@ -403,7 +403,7 @@ as our type, and adds an extra attribute. if obj is None: return self.info = getattr(obj, 'info', None) - + So: >>> arr = np.arange(5) @@ -423,9 +423,9 @@ So: ``__array_wrap__`` for ufuncs ----------------------------- -``__array_wrap__`` gets called by numpy ufuncs and other numpy +``__array_wrap__`` gets called at the end of numpy ufuncs and other numpy functions, to allow a subclass to set the type of the return value -from - for example - ufuncs. Let's show how this works with an example. +and update attributes and metadata. Let's show how this works with an example. First we make the same subclass as above, but with a different name and some print statements: @@ -454,8 +454,8 @@ some print statements: # then just call the parent return np.ndarray.__array_wrap__(self, out_arr, context) -We run a ufunc on an instance of our new array: - +We run a ufunc on an instance of our new array: + >>> obj = MySubClass(np.arange(5), info='spam') In __array_finalize__: self is MySubClass([0, 1, 2, 3, 4]) @@ -473,8 +473,9 @@ MySubClass([1, 3, 5, 7, 9]) >>> ret.info 'spam' -Note that the ufunc (``np.add``) has called -``MySubClass.__array_wrap__`` with arguments ``self`` as ``obj``, and +Note that the ufunc (``np.add``) has called the ``__array_wrap__`` method of the +input with the highest ``__array_priority__`` value, in this case +``MySubClass.__array_wrap__``, with arguments ``self`` as ``obj``, and ``out_arr`` as the (ndarray) result of the addition. In turn, the default ``__array_wrap__`` (``ndarray.__array_wrap__``) has cast the result to class ``MySubClass``, and called ``__array_finalize__`` - @@ -505,6 +506,17 @@ domain of the ufunc). ``__array_wrap__`` should return an instance of its containing class. See the masked array subclass for an implementation. +In addition to ``__array_wrap__``, which is called on the way out of the +ufunc, there is also an ``__array_prepare__`` method which is called on +the way into the ufunc, after the output arrays are created but before any +computation has been performed. The default implementation does nothing +but pass through the array. ``__array_prepare__`` should not attempt to +access the array data or resize the array, it is intended for setting the +output array type, updating attributes and metadata, and performing any +checks based on the input that may be desired before computation begins. +Like ``__array_wrap__``, ``__array_prepare__`` must return an ndarray or +subclass thereof or raise an error. + Extra gotchas - custom ``__del__`` methods and ndarray.base ----------------------------------------------------------- diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 69ef0be4f..a5bf4d0ea 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -892,6 +892,19 @@ def dsplit(ary,indices_or_sections): raise ValueError, 'vsplit only works on arrays of 3 or more dimensions' return split(ary,indices_or_sections,2) +def get_array_prepare(*args): + """Find the wrapper for the array with the highest priority. + + In case of ties, leftmost wins. If no wrapper is found, return None + """ + wrappers = [(getattr(x, '__array_priority__', 0), -i, + x.__array_prepare__) for i, x in enumerate(args) + if hasattr(x, '__array_prepare__')] + wrappers.sort() + if wrappers: + return wrappers[-1][-1] + return None + def get_array_wrap(*args): """Find the wrapper for the array with the highest priority. @@ -975,7 +988,6 @@ def kron(a,b): True """ - wrapper = get_array_wrap(a, b) b = asanyarray(b) a = array(a,copy=False,subok=True,ndmin=b.ndim) ndb, nda = b.ndim, a.ndim @@ -998,6 +1010,10 @@ def kron(a,b): axis = nd-1 for _ in xrange(nd): result = concatenate(result, axis=axis) + wrapper = get_array_prepare(a, b) + if wrapper is not None: + result = wrapper(result) + wrapper = get_array_wrap(a, b) if wrapper is not None: result = wrapper(result) return result diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index df888b754..5878b909f 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -30,7 +30,7 @@ class LinAlgError(Exception): def _makearray(a): new = asarray(a) - wrap = getattr(a, "__array_wrap__", new.__array_wrap__) + wrap = getattr(a, "__array_prepare__", new.__array_wrap__) return new, wrap def isComplexType(t): |