diff options
-rw-r--r-- | numpy/core/src/multiarray/ctors.c | 51 | ||||
-rw-r--r-- | numpy/core/tests/test_api.py | 2 |
2 files changed, 45 insertions, 8 deletions
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c index 9388c41a7..6100fe7ee 100644 --- a/numpy/core/src/multiarray/ctors.c +++ b/numpy/core/src/multiarray/ctors.c @@ -2799,7 +2799,9 @@ PyArray_CopyInto(PyArrayObject *dst, PyArrayObject *src) else { PyArrayObject *op[2]; npy_uint32 op_flags[2]; + PyArray_Descr *op_dtypes_values[2], **op_dtypes = NULL; NpyIter *iter; + npy_intp src_size; NpyIter_IterNextFunc *iternext; char **dataptr; @@ -2811,7 +2813,7 @@ PyArray_CopyInto(PyArrayObject *dst, PyArrayObject *src) op[0] = dst; op[1] = src; /* - * TODO: In NumPy 2.0, renable NPY_ITER_NO_BROADCAST. This + * TODO: In NumPy 2.0, reenable NPY_ITER_NO_BROADCAST. This * was removed during NumPy 1.6 testing for compatibility * with NumPy 1.5, as per Travis's -10 veto power. */ @@ -2819,14 +2821,30 @@ PyArray_CopyInto(PyArrayObject *dst, PyArrayObject *src) op_flags[0] = NPY_ITER_WRITEONLY; op_flags[1] = NPY_ITER_READONLY; + /* + * If 'src' is being broadcast to 'dst', and it is smaller + * than the default NumPy buffer size, allow the iterator to + * make a copy of 'src' with the 'dst' dtype if necessary. + * + * This is a performance operation, to allow fewer casts followed + * by more plain copies. + */ + src_size = PyArray_SIZE(src); + if (src_size <= NPY_BUFSIZE && src_size < PyArray_SIZE(dst)) { + op_flags[1] |= NPY_ITER_COPY; + op_dtypes = op_dtypes_values; + op_dtypes_values[0] = NULL; + op_dtypes_values[1] = PyArray_DESCR(dst); + } + iter = NpyIter_MultiNew(2, op, NPY_ITER_EXTERNAL_LOOP| NPY_ITER_REFS_OK| NPY_ITER_ZEROSIZE_OK, NPY_KEEPORDER, - NPY_NO_CASTING, + NPY_UNSAFE_CASTING, op_flags, - NULL); + op_dtypes); if (iter == NULL) { return -1; } @@ -2852,7 +2870,7 @@ PyArray_CopyInto(PyArrayObject *dst, PyArrayObject *src) if (PyArray_GetDTypeTransferFunction( PyArray_ISALIGNED(src) && PyArray_ISALIGNED(dst), stride[1], stride[0], - PyArray_DESCR(src), PyArray_DESCR(dst), + NpyIter_GetDescrArray(iter)[1], PyArray_DESCR(dst), 0, &stransfer, &transferdata, &needs_api) != NPY_SUCCEED) { @@ -2983,7 +3001,9 @@ PyArray_MaskedCopyInto(PyArrayObject *dst, PyArrayObject *src, else { PyArrayObject *op[3]; npy_uint32 op_flags[3]; + PyArray_Descr *op_dtypes_values[3], **op_dtypes = NULL; NpyIter *iter; + npy_intp src_size; NpyIter_IterNextFunc *iternext; char **dataptr; @@ -3005,14 +3025,31 @@ PyArray_MaskedCopyInto(PyArrayObject *dst, PyArrayObject *src, op_flags[1] = NPY_ITER_READONLY; op_flags[2] = NPY_ITER_READONLY; + /* + * If 'src' is being broadcast to 'dst', and it is smaller + * than the default NumPy buffer size, allow the iterator to + * make a copy of 'src' with the 'dst' dtype if necessary. + * + * This is a performance operation, to allow fewer casts followed + * by more plain copies. + */ + src_size = PyArray_SIZE(src); + if (src_size <= NPY_BUFSIZE && src_size < PyArray_SIZE(dst)) { + op_flags[1] |= NPY_ITER_COPY; + op_dtypes = op_dtypes_values; + op_dtypes_values[0] = NULL; + op_dtypes_values[1] = PyArray_DESCR(dst); + op_dtypes_values[2] = NULL; + } + iter = NpyIter_MultiNew(3, op, NPY_ITER_EXTERNAL_LOOP| NPY_ITER_REFS_OK| NPY_ITER_ZEROSIZE_OK, NPY_KEEPORDER, - NPY_NO_CASTING, + NPY_UNSAFE_CASTING, op_flags, - NULL); + op_dtypes); if (iter == NULL) { return -1; } @@ -3038,7 +3075,7 @@ PyArray_MaskedCopyInto(PyArrayObject *dst, PyArrayObject *src, if (PyArray_GetMaskedDTypeTransferFunction( PyArray_ISALIGNED(src) && PyArray_ISALIGNED(dst), stride[1], stride[0], stride[2], - PyArray_DESCR(src), + NpyIter_GetDescrArray(iter)[1], PyArray_DESCR(dst), PyArray_DESCR(mask), 0, diff --git a/numpy/core/tests/test_api.py b/numpy/core/tests/test_api.py index 7ebcb932b..d2d8241f2 100644 --- a/numpy/core/tests/test_api.py +++ b/numpy/core/tests/test_api.py @@ -110,7 +110,7 @@ def test_copyto(): assert_raises(TypeError, np.copyto, a, 3.5, where=[True,False,True]) # Lists of integer 0's and 1's is ok too - np.copyto(a, 4, where=[[0,1,1], [1,0,0]]) + np.copyto(a, 4.0, casting='unsafe', where=[[0,1,1], [1,0,0]]) assert_equal(a, [[3,4,4], [4,1,3]]) # Overlapping copy with mask should work |