diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2019-11-18 16:13:06 -0800 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2019-11-21 17:06:02 -0800 |
commit | c5da77d47309b6bf11f13f3b2e760f9b21423427 (patch) | |
tree | 621cf8e0f25bb362788683dbbba2b89fb3366611 /numpy | |
parent | 9aeb7513019954718a3225fbab24bdbb98ca4fd5 (diff) | |
download | numpy-c5da77d47309b6bf11f13f3b2e760f9b21423427.tar.gz |
API: Use `ResultType` in `PyArray_ConvertToCommonType`
This slightly modifies the behaviour of `np.choose` (practically a
bug fix) and the public function itself. The function is not used within
e.g. SciPy, so the small performance hit of this implementation
is probably insignificant.
The change should help clean up dtypes a bit, since the whole "scalar
cast" logic is brittle and should be deprecated/removed, and this is
probably one of the few places actually using it.
The choose change means that:
```
np.choose([0], (1000, np.array([1], dtype=np.uint8)))
```
will actually return a value of 1000 (the dtype not being uint8).
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 105 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 9 |
2 files changed, 37 insertions, 77 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index 025c66013..11609092e 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -2115,15 +2115,19 @@ PyArray_ObjectType(PyObject *op, int minimum_type) /* Raises error when len(op) == 0 */ -/*NUMPY_API*/ +/*NUMPY_API + * + * This function is only used in one place within NumPy and should + * generally be avoided. It is provided mainly for backward compatibility. + * + * The user of the function has to free the returned array. + */ NPY_NO_EXPORT PyArrayObject ** PyArray_ConvertToCommonType(PyObject *op, int *retn) { - int i, n, allscalars = 0; + int i, n; + PyArray_Descr *common_descr = NULL; PyArrayObject **mps = NULL; - PyArray_Descr *intype = NULL, *stype = NULL; - PyArray_Descr *newtype = NULL; - NPY_SCALARKIND scalarkind = NPY_NOSCALAR, intypekind = NPY_NOSCALAR; *retn = n = PySequence_Length(op); if (n == 0) { @@ -2159,94 +2163,41 @@ PyArray_ConvertToCommonType(PyObject *op, int *retn) } for (i = 0; i < n; i++) { - PyObject *otmp = PySequence_GetItem(op, i); - if (otmp == NULL) { + /* Convert everything to an array, this could be optimized away */ + PyObject *tmp = PySequence_GetItem(op, i); + if (tmp == NULL) { goto fail; } - if (!PyArray_CheckAnyScalar(otmp)) { - newtype = PyArray_DescrFromObject(otmp, intype); - Py_DECREF(otmp); - Py_XDECREF(intype); - if (newtype == NULL) { - goto fail; - } - intype = newtype; - intypekind = PyArray_ScalarKind(intype->type_num, NULL); - } - else { - newtype = PyArray_DescrFromObject(otmp, stype); - Py_DECREF(otmp); - Py_XDECREF(stype); - if (newtype == NULL) { - goto fail; - } - stype = newtype; - scalarkind = PyArray_ScalarKind(newtype->type_num, NULL); - mps[i] = (PyArrayObject *)Py_None; - Py_INCREF(Py_None); - } - } - if (intype == NULL) { - /* all scalars */ - allscalars = 1; - intype = stype; - Py_INCREF(intype); - for (i = 0; i < n; i++) { - Py_XDECREF(mps[i]); - mps[i] = NULL; - } - } - else if ((stype != NULL) && (intypekind != scalarkind)) { - /* - * we need to upconvert to type that - * handles both intype and stype - * also don't forcecast the scalars. - */ - if (!PyArray_CanCoerceScalar(stype->type_num, - intype->type_num, - scalarkind)) { - newtype = PyArray_PromoteTypes(intype, stype); - Py_XDECREF(intype); - intype = newtype; - if (newtype == NULL) { - goto fail; - } - } - for (i = 0; i < n; i++) { - Py_XDECREF(mps[i]); - mps[i] = NULL; + + mps[i] = (PyArrayObject *)PyArray_FROM_O(tmp); + Py_DECREF(tmp); + if (mps[i] == NULL) { + goto fail; } } + common_descr = PyArray_ResultType(n, mps, 0, NULL); + if (common_descr == NULL) { + goto fail; + } - /* Make sure all arrays are actual array objects. */ + /* Make sure all arrays are contiguous and have the correct dtype. */ for (i = 0; i < n; i++) { int flags = NPY_ARRAY_CARRAY; - PyObject *otmp = PySequence_GetItem(op, i); + PyArrayObject *tmp = mps[i]; - if (otmp == NULL) { - goto fail; - } - if (!allscalars && ((PyObject *)(mps[i]) == Py_None)) { - /* forcecast scalars */ - flags |= NPY_ARRAY_FORCECAST; - Py_DECREF(Py_None); - } - Py_INCREF(intype); - mps[i] = (PyArrayObject*)PyArray_FromAny(otmp, intype, 0, 0, - flags, NULL); - Py_DECREF(otmp); + Py_INCREF(common_descr); + mps[i] = (PyArrayObject *)PyArray_FromArray(tmp, common_descr, flags); + Py_DECREF(tmp); if (mps[i] == NULL) { goto fail; } } - Py_DECREF(intype); - Py_XDECREF(stype); + Py_DECREF(common_descr); return mps; fail: - Py_XDECREF(intype); - Py_XDECREF(stype); + Py_XDECREF(common_descr); *retn = 0; for (i = 0; i < n; i++) { Py_XDECREF(mps[i]); diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 218106a63..6c496a028 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -6488,6 +6488,15 @@ class TestChoose(object): A = np.choose(self.ind, (self.x, self.y2)) assert_equal(A, [[2, 2, 3], [2, 2, 3]]) + @pytest.mark.parametrize("ops", + [(1000, np.array([1], dtype=np.uint8)), + (-1, np.array([1], dtype=np.uint8)), + (1., np.float32(3)), + (1., np.array([3], dtype=np.float32))],) + def test_output_dtype(self, ops): + expected_dt = np.result_type(*ops) + assert(np.choose([0], ops).dtype == expected_dt) + class TestRepeat(object): def setup(self): |