diff options
author | Mark Wiebe <mwwiebe@gmail.com> | 2010-12-17 17:27:23 -0800 |
---|---|---|
committer | Mark Wiebe <mwwiebe@gmail.com> | 2011-01-09 01:54:58 -0800 |
commit | 6db2223b7c8e6ff0ba338c96a0ac382430930472 (patch) | |
tree | ee2dd167fa8275c9e2e22f2e158977a78dce12c2 | |
parent | b07ad2bb98cdda94a25b7c9941efd37c9efa82e4 (diff) | |
download | numpy-6db2223b7c8e6ff0ba338c96a0ac382430930472.tar.gz |
BUG: iter: Fixed finding of common dtype with outputs factored in as well
-rw-r--r-- | numpy/core/src/multiarray/new_iterator.c.src | 41 | ||||
-rw-r--r-- | numpy/core/src/multiarray/new_iterator_pywrap.c | 6 | ||||
-rw-r--r-- | numpy/core/tests/test_new_iterator.py | 18 |
3 files changed, 48 insertions, 17 deletions
diff --git a/numpy/core/src/multiarray/new_iterator.c.src b/numpy/core/src/multiarray/new_iterator.c.src index 3dea7d574..c4b80b7d7 100644 --- a/numpy/core/src/multiarray/new_iterator.c.src +++ b/numpy/core/src/multiarray/new_iterator.c.src @@ -191,8 +191,9 @@ static void npyiter_shrink_ndim(NpyIter *iter, npy_intp new_ndim); static PyArray_Descr * -npyiter_get_output_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim, - char *op_itflags, PyArray_Descr **op_dtype); +npyiter_get_common_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim, + char *op_itflags, PyArray_Descr **op_dtype, + int only_inputs); static int npyiter_promote_types(int type1, int type2); @@ -519,9 +520,11 @@ NpyIter_MultiNew(npy_intp niter, PyArrayObject **op_in, npy_uint32 flags, */ if (any_missing_dtypes || (flags&NPY_ITER_COMMON_DATA_TYPE)) { PyArray_Descr *dtype; + int only_inputs = !(flags&NPY_ITER_COMMON_DATA_TYPE); - dtype = npyiter_get_output_type(niter, op, op_ndim, - op_itflags, op_dtype); + dtype = npyiter_get_common_type(niter, op, op_ndim, + op_itflags, op_dtype, + only_inputs); if (dtype == NULL) { NpyIter_Deallocate(iter); return NULL; @@ -1783,20 +1786,20 @@ npyiter_check_casting(npy_intp niter, PyArrayObject **op, if ((op_itflags[iiter]&NPY_OP_ITFLAG_READ) && !npyiter_can_cast(op_flags[iiter], PyArray_DESCR(op[iiter]), op_dtype[iiter])) { - PyErr_SetString(PyExc_TypeError, - "Iterator input dtype could not be cast " + PyErr_Format(PyExc_TypeError, + "Iterator operand %d dtype could not be cast " "to the requested dtype, according to " - "the casting flags given"); + "the casting flags given", (int)iiter); return 0; } /* Check write (temp -> op) casting */ if ((op_itflags[iiter]&NPY_OP_ITFLAG_WRITE) && !npyiter_can_cast(op_flags[iiter], op_dtype[iiter], PyArray_DESCR(op[iiter]))) { - PyErr_SetString(PyExc_TypeError, + PyErr_Format(PyExc_TypeError, "Iterator requested dtype could not be cast " - "to the input dtype, according to " - "the casting flags given"); + "to the operand %d dtype, according to " + "the casting flags given", (int)iiter); return 0; } @@ -2734,9 +2737,15 @@ npyiter_new_temp_array(NpyIter *iter, PyTypeObject *subtype, return ret; } +/* + * Calculates a dtype that all the types can be promoted to, using the + * ufunc rules. If only_inputs is 1, it leaves any operands that + * are not read from out of the calculation. + */ static PyArray_Descr * -npyiter_get_output_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim, - char *op_itflags, PyArray_Descr **op_dtype) +npyiter_get_common_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim, + char *op_itflags, PyArray_Descr **op_dtype, + int only_inputs) { PyArray_Descr *scalar_dtype = NULL, *array_dtype = NULL; int scalar_kind = NPY_NOSCALAR; @@ -2744,8 +2753,8 @@ npyiter_get_output_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim, /* First promote all the scalar dtypes */ for (iiter = 0; iiter < niter; ++iiter) { - if (op_ndim[iiter] == 0 && - (op_itflags[iiter]&NPY_OP_ITFLAG_READ)) { + if (op_ndim[iiter] == 0 && op_dtype[iiter] != NULL && + (!only_inputs || (op_itflags[iiter]&NPY_OP_ITFLAG_READ))) { if (scalar_dtype == NULL) { scalar_dtype = op_dtype[iiter]; Py_INCREF(scalar_dtype); @@ -2797,8 +2806,8 @@ npyiter_get_output_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim, /* Then promote all the array dtypes */ for (iiter = 0; iiter < niter; ++iiter) { - if (op_ndim[iiter] > 0 && - (op_itflags[iiter]&NPY_OP_ITFLAG_READ)) { + if (op_ndim[iiter] > 0 && op_dtype[iiter] != NULL && + (!only_inputs || (op_itflags[iiter]&NPY_OP_ITFLAG_READ))) { if (array_dtype == NULL) { array_dtype = op_dtype[iiter]; Py_INCREF(array_dtype); diff --git a/numpy/core/src/multiarray/new_iterator_pywrap.c b/numpy/core/src/multiarray/new_iterator_pywrap.c index 23f9070ce..ea0ebfd0f 100644 --- a/numpy/core/src/multiarray/new_iterator_pywrap.c +++ b/numpy/core/src/multiarray/new_iterator_pywrap.c @@ -211,8 +211,12 @@ npyiter_init(NewNpyArrayIterObject *self, PyObject *args, PyObject *kwds) /* op_flags */ if (op_flags_in == NULL) { for (iiter = 0; iiter < niter; ++iiter) { + /* + * By default, make writeable arrays readwrite, and anything + * else readonly. + */ if (op[iiter] != NULL && PyArray_Check(op[iiter]) && - PyArray_CHKFLAGS(op[iiter], NPY_WRITEABLE)) { + PyArray_CHKFLAGS(op[iiter], NPY_WRITEABLE)) { op_flags[iiter] = NPY_ITER_READWRITE; } else { diff --git a/numpy/core/tests/test_new_iterator.py b/numpy/core/tests/test_new_iterator.py index 1ad7d95e7..f82127112 100644 --- a/numpy/core/tests/test_new_iterator.py +++ b/numpy/core/tests/test_new_iterator.py @@ -832,6 +832,24 @@ def test_iter_common_data_type(): assert_equal(i.dtypes[3], np.dtype('c16')); assert_equal(i.value, (3,-12,2j,9)) + # When allocating outputs, other outputs aren't factored in + i = newiter([array([3],dtype='i4'),None,array([2j],dtype='c16')], [], + [['readonly','copy','safe_casts'], + ['writeonly','allocate'], + ['writeonly']]) + assert_equal(i.dtypes[0], np.dtype('i4')); + assert_equal(i.dtypes[1], np.dtype('i4')); + assert_equal(i.dtypes[2], np.dtype('c16')); + # But, if common data types are requested, they are + i = newiter([array([3],dtype='i4'),None,array([2j],dtype='c16')], + ['common_data_type'], + [['readonly','copy','safe_casts'], + ['writeonly','allocate'], + ['writeonly']]) + assert_equal(i.dtypes[0], np.dtype('c16')); + assert_equal(i.dtypes[1], np.dtype('c16')); + assert_equal(i.dtypes[2], np.dtype('c16')); + def test_iter_op_axes(): # Check that custom axes work |