summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2010-12-17 17:27:23 -0800
committerMark Wiebe <mwwiebe@gmail.com>2011-01-09 01:54:58 -0800
commit6db2223b7c8e6ff0ba338c96a0ac382430930472 (patch)
treeee2dd167fa8275c9e2e22f2e158977a78dce12c2
parentb07ad2bb98cdda94a25b7c9941efd37c9efa82e4 (diff)
downloadnumpy-6db2223b7c8e6ff0ba338c96a0ac382430930472.tar.gz
BUG: iter: Fixed finding of common dtype with outputs factored in as well
-rw-r--r--numpy/core/src/multiarray/new_iterator.c.src41
-rw-r--r--numpy/core/src/multiarray/new_iterator_pywrap.c6
-rw-r--r--numpy/core/tests/test_new_iterator.py18
3 files changed, 48 insertions, 17 deletions
diff --git a/numpy/core/src/multiarray/new_iterator.c.src b/numpy/core/src/multiarray/new_iterator.c.src
index 3dea7d574..c4b80b7d7 100644
--- a/numpy/core/src/multiarray/new_iterator.c.src
+++ b/numpy/core/src/multiarray/new_iterator.c.src
@@ -191,8 +191,9 @@ static void
npyiter_shrink_ndim(NpyIter *iter, npy_intp new_ndim);
static PyArray_Descr *
-npyiter_get_output_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim,
- char *op_itflags, PyArray_Descr **op_dtype);
+npyiter_get_common_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim,
+ char *op_itflags, PyArray_Descr **op_dtype,
+ int only_inputs);
static int
npyiter_promote_types(int type1, int type2);
@@ -519,9 +520,11 @@ NpyIter_MultiNew(npy_intp niter, PyArrayObject **op_in, npy_uint32 flags,
*/
if (any_missing_dtypes || (flags&NPY_ITER_COMMON_DATA_TYPE)) {
PyArray_Descr *dtype;
+ int only_inputs = !(flags&NPY_ITER_COMMON_DATA_TYPE);
- dtype = npyiter_get_output_type(niter, op, op_ndim,
- op_itflags, op_dtype);
+ dtype = npyiter_get_common_type(niter, op, op_ndim,
+ op_itflags, op_dtype,
+ only_inputs);
if (dtype == NULL) {
NpyIter_Deallocate(iter);
return NULL;
@@ -1783,20 +1786,20 @@ npyiter_check_casting(npy_intp niter, PyArrayObject **op,
if ((op_itflags[iiter]&NPY_OP_ITFLAG_READ) &&
!npyiter_can_cast(op_flags[iiter],
PyArray_DESCR(op[iiter]), op_dtype[iiter])) {
- PyErr_SetString(PyExc_TypeError,
- "Iterator input dtype could not be cast "
+ PyErr_Format(PyExc_TypeError,
+ "Iterator operand %d dtype could not be cast "
"to the requested dtype, according to "
- "the casting flags given");
+ "the casting flags given", (int)iiter);
return 0;
}
/* Check write (temp -> op) casting */
if ((op_itflags[iiter]&NPY_OP_ITFLAG_WRITE) &&
!npyiter_can_cast(op_flags[iiter],
op_dtype[iiter], PyArray_DESCR(op[iiter]))) {
- PyErr_SetString(PyExc_TypeError,
+ PyErr_Format(PyExc_TypeError,
"Iterator requested dtype could not be cast "
- "to the input dtype, according to "
- "the casting flags given");
+ "to the operand %d dtype, according to "
+ "the casting flags given", (int)iiter);
return 0;
}
@@ -2734,9 +2737,15 @@ npyiter_new_temp_array(NpyIter *iter, PyTypeObject *subtype,
return ret;
}
+/*
+ * Calculates a dtype that all the types can be promoted to, using the
+ * ufunc rules. If only_inputs is 1, it leaves any operands that
+ * are not read from out of the calculation.
+ */
static PyArray_Descr *
-npyiter_get_output_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim,
- char *op_itflags, PyArray_Descr **op_dtype)
+npyiter_get_common_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim,
+ char *op_itflags, PyArray_Descr **op_dtype,
+ int only_inputs)
{
PyArray_Descr *scalar_dtype = NULL, *array_dtype = NULL;
int scalar_kind = NPY_NOSCALAR;
@@ -2744,8 +2753,8 @@ npyiter_get_output_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim,
/* First promote all the scalar dtypes */
for (iiter = 0; iiter < niter; ++iiter) {
- if (op_ndim[iiter] == 0 &&
- (op_itflags[iiter]&NPY_OP_ITFLAG_READ)) {
+ if (op_ndim[iiter] == 0 && op_dtype[iiter] != NULL &&
+ (!only_inputs || (op_itflags[iiter]&NPY_OP_ITFLAG_READ))) {
if (scalar_dtype == NULL) {
scalar_dtype = op_dtype[iiter];
Py_INCREF(scalar_dtype);
@@ -2797,8 +2806,8 @@ npyiter_get_output_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim,
/* Then promote all the array dtypes */
for (iiter = 0; iiter < niter; ++iiter) {
- if (op_ndim[iiter] > 0 &&
- (op_itflags[iiter]&NPY_OP_ITFLAG_READ)) {
+ if (op_ndim[iiter] > 0 && op_dtype[iiter] != NULL &&
+ (!only_inputs || (op_itflags[iiter]&NPY_OP_ITFLAG_READ))) {
if (array_dtype == NULL) {
array_dtype = op_dtype[iiter];
Py_INCREF(array_dtype);
diff --git a/numpy/core/src/multiarray/new_iterator_pywrap.c b/numpy/core/src/multiarray/new_iterator_pywrap.c
index 23f9070ce..ea0ebfd0f 100644
--- a/numpy/core/src/multiarray/new_iterator_pywrap.c
+++ b/numpy/core/src/multiarray/new_iterator_pywrap.c
@@ -211,8 +211,12 @@ npyiter_init(NewNpyArrayIterObject *self, PyObject *args, PyObject *kwds)
/* op_flags */
if (op_flags_in == NULL) {
for (iiter = 0; iiter < niter; ++iiter) {
+ /*
+ * By default, make writeable arrays readwrite, and anything
+ * else readonly.
+ */
if (op[iiter] != NULL && PyArray_Check(op[iiter]) &&
- PyArray_CHKFLAGS(op[iiter], NPY_WRITEABLE)) {
+ PyArray_CHKFLAGS(op[iiter], NPY_WRITEABLE)) {
op_flags[iiter] = NPY_ITER_READWRITE;
}
else {
diff --git a/numpy/core/tests/test_new_iterator.py b/numpy/core/tests/test_new_iterator.py
index 1ad7d95e7..f82127112 100644
--- a/numpy/core/tests/test_new_iterator.py
+++ b/numpy/core/tests/test_new_iterator.py
@@ -832,6 +832,24 @@ def test_iter_common_data_type():
assert_equal(i.dtypes[3], np.dtype('c16'));
assert_equal(i.value, (3,-12,2j,9))
+ # When allocating outputs, other outputs aren't factored in
+ i = newiter([array([3],dtype='i4'),None,array([2j],dtype='c16')], [],
+ [['readonly','copy','safe_casts'],
+ ['writeonly','allocate'],
+ ['writeonly']])
+ assert_equal(i.dtypes[0], np.dtype('i4'));
+ assert_equal(i.dtypes[1], np.dtype('i4'));
+ assert_equal(i.dtypes[2], np.dtype('c16'));
+ # But, if common data types are requested, they are
+ i = newiter([array([3],dtype='i4'),None,array([2j],dtype='c16')],
+ ['common_data_type'],
+ [['readonly','copy','safe_casts'],
+ ['writeonly','allocate'],
+ ['writeonly']])
+ assert_equal(i.dtypes[0], np.dtype('c16'));
+ assert_equal(i.dtypes[1], np.dtype('c16'));
+ assert_equal(i.dtypes[2], np.dtype('c16'));
+
def test_iter_op_axes():
# Check that custom axes work