summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/add_newdocs.py31
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c297
-rw-r--r--numpy/core/src/multiarray/ctors.c198
-rw-r--r--numpy/core/src/multiarray/methods.c27
-rw-r--r--numpy/core/tests/test_multiarray.py17
-rw-r--r--numpy/matrixlib/tests/test_defmatrix.py2
6 files changed, 227 insertions, 345 deletions
diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py
index 2faf3ba16..5e6a07bb0 100644
--- a/numpy/add_newdocs.py
+++ b/numpy/add_newdocs.py
@@ -2642,6 +2642,37 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('itemset',
"""))
+add_newdoc('numpy.core.multiarray', 'ndarray', ('setasflat',
+ """
+ a.setasflat(arr)
+
+ Equivalent to a.flat = arr.flat, but is generally more efficient.
+ This function does not check for overlap, so if ``arr`` and ``a``
+ are viewing the same data with different strides, the results will
+ be unpredictable.
+
+ Parameters
+ ----------
+ arr : array_like
+ The array to copy into a.
+
+ Examples
+ --------
+ >>> a = np.arange(2*4).reshape(2,4)[:,:-1]; a
+ array([[0, 1, 2],
+ [4, 5, 6]])
+ >>> b = np.arange(3*3, dtype='f4').reshape(3,3).T[::-1,:-1]; b
+ array([[ 2., 5.],
+ [ 1., 4.],
+ [ 0., 3.]], dtype=float32)
+ >>> a.setasflat(b)
+ >>> a
+ array([[2, 5, 1],
+ [4, 0, 3]])
+
+ """))
+
+
add_newdoc('numpy.core.multiarray', 'ndarray', ('max',
"""
a.max(axis=None, out=None)
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index 52053d063..f03821859 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -141,161 +141,6 @@ PyArray_GetCastFunc(PyArray_Descr *descr, int type_num)
}
/*
- * Reference counts:
- * copyswapn is used which increases and decreases reference counts for OBJECT arrays.
- * All that needs to happen is for any reference counts in the buffers to be
- * decreased when completely finished with the buffers.
- *
- * buffers[0] is the destination
- * buffers[1] is the source
- */
-static void
-_strided_buffered_cast(char *dptr, intp dstride, int delsize, int dswap,
- PyArray_CopySwapNFunc *dcopyfunc,
- char *sptr, intp sstride, int selsize, int sswap,
- PyArray_CopySwapNFunc *scopyfunc,
- intp N, char **buffers, int bufsize,
- PyArray_VectorUnaryFunc *castfunc,
- PyArrayObject *dest, PyArrayObject *src)
-{
- int i;
- if (N <= bufsize) {
- /*
- * 1. copy input to buffer and swap
- * 2. cast input to output
- * 3. swap output if necessary and copy from output buffer
- */
- scopyfunc(buffers[1], selsize, sptr, sstride, N, sswap, src);
- castfunc(buffers[1], buffers[0], N, src, dest);
- dcopyfunc(dptr, dstride, buffers[0], delsize, N, dswap, dest);
- return;
- }
-
- /* otherwise we need to divide up into bufsize pieces */
- i = 0;
- while (N > 0) {
- int newN = MIN(N, bufsize);
-
- _strided_buffered_cast(dptr+i*dstride, dstride, delsize,
- dswap, dcopyfunc,
- sptr+i*sstride, sstride, selsize,
- sswap, scopyfunc,
- newN, buffers, bufsize, castfunc, dest, src);
- i += newN;
- N -= bufsize;
- }
- return;
-}
-
-static int
-_broadcast_cast(PyArrayObject *out, PyArrayObject *in,
- PyArray_VectorUnaryFunc *castfunc, int iswap, int oswap)
-{
- int delsize, selsize, maxaxis, i, N;
- PyArrayMultiIterObject *multi;
- intp maxdim, ostrides, istrides;
- char *buffers[2];
- PyArray_CopySwapNFunc *ocopyfunc, *icopyfunc;
- char *obptr;
- NPY_BEGIN_THREADS_DEF;
-
- delsize = PyArray_ITEMSIZE(out);
- selsize = PyArray_ITEMSIZE(in);
- multi = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, out, in);
- if (multi == NULL) {
- return -1;
- }
-
- if (multi->size != PyArray_SIZE(out)) {
- PyErr_SetString(PyExc_ValueError,
- "array dimensions are not "\
- "compatible for copy");
- Py_DECREF(multi);
- return -1;
- }
-
- icopyfunc = in->descr->f->copyswapn;
- ocopyfunc = out->descr->f->copyswapn;
- maxaxis = PyArray_RemoveSmallest(multi);
- if (maxaxis < 0) {
- /* cast 1 0-d array to another */
- N = 1;
- maxdim = 1;
- ostrides = delsize;
- istrides = selsize;
- }
- else {
- maxdim = multi->dimensions[maxaxis];
- N = (int) (MIN(maxdim, PyArray_BUFSIZE));
- ostrides = multi->iters[0]->strides[maxaxis];
- istrides = multi->iters[1]->strides[maxaxis];
-
- }
- buffers[0] = _pya_malloc(N*delsize);
- if (buffers[0] == NULL) {
- PyErr_NoMemory();
- return -1;
- }
- buffers[1] = _pya_malloc(N*selsize);
- if (buffers[1] == NULL) {
- _pya_free(buffers[0]);
- PyErr_NoMemory();
- return -1;
- }
- if (PyDataType_FLAGCHK(out->descr, NPY_NEEDS_INIT)) {
- memset(buffers[0], 0, N*delsize);
- }
- if (PyDataType_FLAGCHK(in->descr, NPY_NEEDS_INIT)) {
- memset(buffers[1], 0, N*selsize);
- }
-
-#if NPY_ALLOW_THREADS
- if (PyArray_ISNUMBER(in) && PyArray_ISNUMBER(out)) {
- NPY_BEGIN_THREADS;
- }
-#endif
-
- while (multi->index < multi->size) {
- _strided_buffered_cast(multi->iters[0]->dataptr,
- ostrides,
- delsize, oswap, ocopyfunc,
- multi->iters[1]->dataptr,
- istrides,
- selsize, iswap, icopyfunc,
- maxdim, buffers, N,
- castfunc, out, in);
- PyArray_MultiIter_NEXT(multi);
- }
-#if NPY_ALLOW_THREADS
- if (PyArray_ISNUMBER(in) && PyArray_ISNUMBER(out)) {
- NPY_END_THREADS;
- }
-#endif
- Py_DECREF(multi);
- if (PyDataType_REFCHK(in->descr)) {
- obptr = buffers[1];
- for (i = 0; i < N; i++, obptr+=selsize) {
- PyArray_Item_XDECREF(obptr, in->descr);
- }
- }
- if (PyDataType_REFCHK(out->descr)) {
- obptr = buffers[0];
- for (i = 0; i < N; i++, obptr+=delsize) {
- PyArray_Item_XDECREF(obptr, out->descr);
- }
- }
- _pya_free(buffers[0]);
- _pya_free(buffers[1]);
- if (PyErr_Occurred()) {
- return -1;
- }
-
- return 0;
-}
-
-
-
-/*
* Must be broadcastable.
* This code is very similar to PyArray_CopyInto/PyArray_MoveInto
* except casting is done --- PyArray_BUFSIZE is used
@@ -312,110 +157,6 @@ PyArray_CastTo(PyArrayObject *out, PyArrayObject *mp)
return PyArray_CopyInto(out, mp);
}
-
-static int
-_bufferedcast(PyArrayObject *out, PyArrayObject *in,
- PyArray_VectorUnaryFunc *castfunc)
-{
- char *inbuffer, *bptr, *optr;
- char *outbuffer=NULL;
- PyArrayIterObject *it_in = NULL, *it_out = NULL;
- intp i, index;
- intp ncopies = PyArray_SIZE(out) / PyArray_SIZE(in);
- int elsize=in->descr->elsize;
- int nels = PyArray_BUFSIZE;
- int el;
- int inswap, outswap = 0;
- int obuf=!PyArray_ISCARRAY(out);
- int oelsize = out->descr->elsize;
- PyArray_CopySwapFunc *in_csn;
- PyArray_CopySwapFunc *out_csn;
- int retval = -1;
-
- in_csn = in->descr->f->copyswap;
- out_csn = out->descr->f->copyswap;
-
- /*
- * If the input or output is STRING, UNICODE, or VOID
- * then getitem and setitem are used for the cast
- * and byteswapping is handled by those methods
- */
-
- inswap = !(PyArray_ISFLEXIBLE(in) || PyArray_ISNOTSWAPPED(in));
-
- inbuffer = PyDataMem_NEW(PyArray_BUFSIZE*elsize);
- if (inbuffer == NULL) {
- return -1;
- }
- if (PyArray_ISOBJECT(in)) {
- memset(inbuffer, 0, PyArray_BUFSIZE*elsize);
- }
- it_in = (PyArrayIterObject *)PyArray_IterNew((PyObject *)in);
- if (it_in == NULL) {
- goto exit;
- }
- if (obuf) {
- outswap = !(PyArray_ISFLEXIBLE(out) ||
- PyArray_ISNOTSWAPPED(out));
- outbuffer = PyDataMem_NEW(PyArray_BUFSIZE*oelsize);
- if (outbuffer == NULL) {
- goto exit;
- }
- if (PyArray_ISOBJECT(out)) {
- memset(outbuffer, 0, PyArray_BUFSIZE*oelsize);
- }
- it_out = (PyArrayIterObject *)PyArray_IterNew((PyObject *)out);
- if (it_out == NULL) {
- goto exit;
- }
- nels = MIN(nels, PyArray_BUFSIZE);
- }
-
- optr = (obuf) ? outbuffer: out->data;
- bptr = inbuffer;
- el = 0;
- while (ncopies--) {
- index = it_in->size;
- PyArray_ITER_RESET(it_in);
- while (index--) {
- in_csn(bptr, it_in->dataptr, inswap, in);
- bptr += elsize;
- PyArray_ITER_NEXT(it_in);
- el += 1;
- if ((el == nels) || (index == 0)) {
- /* buffer filled, do cast */
- castfunc(inbuffer, optr, el, in, out);
- if (obuf) {
- /* Copy from outbuffer to array */
- for (i = 0; i < el; i++) {
- out_csn(it_out->dataptr,
- optr, outswap,
- out);
- optr += oelsize;
- PyArray_ITER_NEXT(it_out);
- }
- optr = outbuffer;
- }
- else {
- optr += out->descr->elsize * nels;
- }
- el = 0;
- bptr = inbuffer;
- }
- }
- }
- retval = 0;
-
- exit:
- Py_XDECREF(it_in);
- PyDataMem_FREE(inbuffer);
- PyDataMem_FREE(outbuffer);
- if (obuf) {
- Py_XDECREF(it_out);
- }
- return retval;
-}
-
/*NUMPY_API
* Cast to an already created array. Arrays don't have to be "broadcastable"
* Only requirement is they have the same number of elements.
@@ -423,42 +164,8 @@ _bufferedcast(PyArrayObject *out, PyArrayObject *in,
NPY_NO_EXPORT int
PyArray_CastAnyTo(PyArrayObject *out, PyArrayObject *mp)
{
- int simple;
- PyArray_VectorUnaryFunc *castfunc = NULL;
- npy_intp mpsize = PyArray_SIZE(mp);
-
- if (mpsize == 0) {
- return 0;
- }
- if (!PyArray_ISWRITEABLE(out)) {
- PyErr_SetString(PyExc_ValueError, "output array is not writeable");
- return -1;
- }
-
- if (!(mpsize == PyArray_SIZE(out))) {
- PyErr_SetString(PyExc_ValueError,
- "arrays must have the same number of"
- " elements for the cast.");
- return -1;
- }
-
- castfunc = PyArray_GetCastFunc(mp->descr, out->descr->type_num);
- if (castfunc == NULL) {
- return -1;
- }
- simple = ((PyArray_ISCARRAY_RO(mp) && PyArray_ISCARRAY(out)) ||
- (PyArray_ISFARRAY_RO(mp) && PyArray_ISFARRAY(out)));
- if (simple) {
- castfunc(mp->data, out->data, mpsize, mp, out);
- return 0;
- }
- if (PyArray_SAMESHAPE(out, mp)) {
- int iswap, oswap;
- iswap = PyArray_ISBYTESWAPPED(mp) && !PyArray_ISFLEXIBLE(mp);
- oswap = PyArray_ISBYTESWAPPED(out) && !PyArray_ISFLEXIBLE(out);
- return _broadcast_cast(out, mp, castfunc, iswap, oswap);
- }
- return _bufferedcast(out, mp, castfunc);
+ /* CopyAnyInto handles the casting now */
+ return PyArray_CopyAnyInto(out, mp);
}
/*NUMPY_API
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c
index 8eb3dec37..7dcfd7666 100644
--- a/numpy/core/src/multiarray/ctors.c
+++ b/numpy/core/src/multiarray/ctors.c
@@ -2497,79 +2497,174 @@ PyArray_EnsureAnyArray(PyObject *op)
* Copy an Array into another array -- memory must not overlap
* Does not require src and dest to have "broadcastable" shapes
* (only the same number of elements).
+ *
+ * TODO: For NumPy 2.0, this could accept an order parameter which
+ * only allows NPY_CORDER and NPY_FORDER. Could also rename
+ * this to CopyAsFlat to make the name more intuitive.
+ *
+ * Returns 0 on success, -1 on error.
*/
NPY_NO_EXPORT int
-PyArray_CopyAnyInto(PyArrayObject *dest, PyArrayObject *src)
+PyArray_CopyAnyInto(PyArrayObject *dst, PyArrayObject *src)
{
- int elsize, simple;
- PyArrayIterObject *idest, *isrc;
- void (*myfunc)(char *, npy_intp, char *, npy_intp, npy_intp, int);
+ PyArray_StridedTransferFn *stransfer = NULL;
+ void *transferdata = NULL;
+ NpyIter *dst_iter, *src_iter;
NPY_BEGIN_THREADS_DEF;
- if (!PyArray_EquivArrTypes(dest, src)) {
- return PyArray_CastAnyTo(dest, src);
- }
- if (!PyArray_ISWRITEABLE(dest)) {
+ NpyIter_IterNext_Fn dst_iternext, src_iternext;
+ char **dst_dataptr, **src_dataptr;
+ npy_intp dst_stride, src_stride;
+ npy_intp *dst_countptr, *src_countptr;
+
+ char *dst_data, *src_data;
+ npy_intp dst_count, src_count, count;
+ npy_intp src_itemsize;
+ int needs_api;
+
+ if (!PyArray_ISWRITEABLE(dst)) {
PyErr_SetString(PyExc_RuntimeError,
"cannot write to array");
return -1;
}
- if (PyArray_SIZE(dest) != PyArray_SIZE(src)) {
+
+ /* If the shapes match, use the more efficient CopyInto */
+ if (PyArray_NDIM(dst) == PyArray_NDIM(src) &&
+ PyArray_CompareLists(PyArray_DIMS(dst), PyArray_DIMS(src),
+ PyArray_NDIM(dst))) {
+ return PyArray_CopyInto(dst, src);
+ }
+
+ if (PyArray_SIZE(dst) != PyArray_SIZE(src)) {
PyErr_SetString(PyExc_ValueError,
"arrays must have the same number of elements"
" for copy");
return -1;
}
- simple = ((PyArray_ISCARRAY_RO(src) && PyArray_ISCARRAY(dest)) ||
- (PyArray_ISFARRAY_RO(src) && PyArray_ISFARRAY(dest)));
- if (simple) {
- /* Refcount note: src and dest have the same size */
- PyArray_INCREF(src);
- PyArray_XDECREF(dest);
+ /*
+ * This copy is based on matching C-order traversals of src and dst.
+ * By using two iterators, we can find maximal sub-chunks that
+ * can be processed at once.
+ */
+ dst_iter = NpyIter_New(dst, NPY_ITER_WRITEONLY|
+ NPY_ITER_NO_INNER_ITERATION|
+ NPY_ITER_REFS_OK,
+ NPY_CORDER,
+ NPY_NO_CASTING,
+ NULL, 0, NULL, 0);
+ if (dst_iter == NULL) {
+ return -1;
+ }
+ src_iter = NpyIter_New(src, NPY_ITER_READONLY|
+ NPY_ITER_NO_INNER_ITERATION|
+ NPY_ITER_REFS_OK,
+ NPY_CORDER,
+ NPY_NO_CASTING,
+ NULL, 0, NULL, 0);
+ if (src_iter == NULL) {
+ NpyIter_Deallocate(dst_iter);
+ return -1;
+ }
+
+ /* Get all the values needed for the inner loop */
+ dst_iternext = NpyIter_GetIterNext(dst_iter, NULL);
+ dst_dataptr = NpyIter_GetDataPtrArray(dst_iter);
+ /* Since buffering is disabled, we can cache the stride */
+ dst_stride = *NpyIter_GetInnerStrideArray(dst_iter);
+ dst_countptr = NpyIter_GetInnerLoopSizePtr(dst_iter);
+
+ src_iternext = NpyIter_GetIterNext(src_iter, NULL);
+ src_dataptr = NpyIter_GetDataPtrArray(src_iter);
+ /* Since buffering is disabled, we can cache the stride */
+ src_stride = *NpyIter_GetInnerStrideArray(src_iter);
+ src_countptr = NpyIter_GetInnerLoopSizePtr(src_iter);
+
+ if (dst_iternext == NULL || src_iternext == NULL) {
+ NpyIter_Deallocate(dst_iter);
+ NpyIter_Deallocate(src_iter);
+ return -1;
+ }
+
+ src_itemsize = PyArray_DESCR(src)->elsize;
+
+ needs_api = NpyIter_IterationNeedsAPI(dst_iter) ||
+ NpyIter_IterationNeedsAPI(src_iter);
+
+ /*
+ * Because buffering is disabled in the iterator, the inner loop
+ * strides will be the same throughout the iteration loop. Thus,
+ * we can pass them to this function to take advantage of
+ * contiguous strides, etc.
+ */
+ if (PyArray_GetDTypeTransferFunction(
+ PyArray_ISALIGNED(src) && PyArray_ISALIGNED(dst),
+ src_stride, dst_stride,
+ PyArray_DESCR(src), PyArray_DESCR(dst),
+ 0,
+ &stransfer, &transferdata,
+ &needs_api) != NPY_SUCCEED) {
+ NpyIter_Deallocate(dst_iter);
+ NpyIter_Deallocate(src_iter);
+ return -1;
+ }
+
+
+ if (!needs_api) {
NPY_BEGIN_THREADS;
- memcpy(dest->data, src->data, PyArray_NBYTES(dest));
- NPY_END_THREADS;
- return 0;
}
- if (PyArray_SAMESHAPE(dest, src)) {
- int swap;
+ dst_count = *dst_countptr;
+ src_count = *src_countptr;
+ dst_data = *dst_dataptr;
+ src_data = *src_dataptr;
+ /*
+ * The tests did not trigger this code, so added a new function
+ * ndarray.setasflat to the Python exposure in order to test it.
+ */
+ for(;;) {
+ /* Transfer the biggest amount that fits both */
+ count = (src_count < dst_count) ? src_count : dst_count;
+ stransfer(dst_data, dst_stride,
+ src_data, src_stride,
+ count, src_itemsize, transferdata);
+
+ /* If we exhausted the dst block, refresh it */
+ if (dst_count == count) {
+ if (!dst_iternext(dst_iter)) {
+ break;
+ }
+ dst_count = *dst_countptr;
+ dst_data = *dst_dataptr;
+ }
+ else {
+ dst_count -= count;
+ dst_data += count*dst_stride;
+ }
- if (PyArray_SAFEALIGNEDCOPY(dest) && PyArray_SAFEALIGNEDCOPY(src)) {
- myfunc = _strided_byte_copy;
+ /* If we exhausted the src block, refresh it */
+ if (src_count == count) {
+ if (!src_iternext(src_iter)) {
+ break;
+ }
+ src_count = *src_countptr;
+ src_data = *src_dataptr;
}
else {
- myfunc = _unaligned_strided_byte_copy;
+ src_count -= count;
+ src_data += count*src_stride;
}
- swap = PyArray_ISNOTSWAPPED(dest) != PyArray_ISNOTSWAPPED(src);
- return _copy_from_same_shape(dest, src, myfunc, swap);
}
- /* Otherwise we have to do an iterator-based copy */
- idest = (PyArrayIterObject *)PyArray_IterNew((PyObject *)dest);
- if (idest == NULL) {
- return -1;
- }
- isrc = (PyArrayIterObject *)PyArray_IterNew((PyObject *)src);
- if (isrc == NULL) {
- Py_DECREF(idest);
- return -1;
- }
- elsize = dest->descr->elsize;
- /* Refcount note: src and dest have the same size */
- PyArray_INCREF(src);
- PyArray_XDECREF(dest);
- NPY_BEGIN_THREADS;
- while(idest->index < idest->size) {
- memcpy(idest->dataptr, isrc->dataptr, elsize);
- PyArray_ITER_NEXT(idest);
- PyArray_ITER_NEXT(isrc);
+ if (!needs_api) {
+ NPY_END_THREADS;
}
- NPY_END_THREADS;
- Py_DECREF(idest);
- Py_DECREF(isrc);
- return 0;
+
+ PyArray_FreeStridedTransferData(transferdata);
+ NpyIter_Deallocate(dst_iter);
+ NpyIter_Deallocate(src_iter);
+
+ return PyErr_Occurred() ? -1 : 0;
}
/*NUMPY_API
@@ -2648,6 +2743,7 @@ PyArray_CopyInto(PyArrayObject *dst, PyArrayObject *src)
npy_intp *stride;
npy_intp *countptr;
npy_intp src_itemsize;
+ int needs_api;
op[0] = dst;
op[1] = src;
@@ -2666,12 +2762,16 @@ PyArray_CopyInto(PyArrayObject *dst, PyArrayObject *src)
}
iternext = NpyIter_GetIterNext(iter, NULL);
+ if (iternext == NULL) {
+ NpyIter_Deallocate(iter);
+ return -1;
+ }
dataptr = NpyIter_GetDataPtrArray(iter);
stride = NpyIter_GetInnerStrideArray(iter);
countptr = NpyIter_GetInnerLoopSizePtr(iter);
src_itemsize = PyArray_DESCR(src)->elsize;
- int needs_api = NpyIter_IterationNeedsAPI(iter);
+ needs_api = NpyIter_IterationNeedsAPI(iter);
/*
* Because buffering is disabled in the iterator, the inner loop
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 6ccac9f6e..d3b4dfe6e 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -744,6 +744,30 @@ array_setscalar(PyArrayObject *self, PyObject *args) {
return Py_None;
}
+/* Sets the array values from another array as if they were flat */
+static PyObject *
+array_setasflat(PyArrayObject *self, PyObject *args)
+{
+ PyObject *arr_in;
+ PyArrayObject *arr;
+
+ if (!PyArg_ParseTuple(args, "O", &arr_in)) {
+ return NULL;
+ }
+
+ arr = PyArray_FromAny(arr_in, NULL, 0, 0, 0, NULL);
+ if (arr == NULL) {
+ return NULL;
+ }
+
+ if (PyArray_CopyAnyInto(self, arr) != 0) {
+ Py_DECREF(arr);
+ return NULL;
+ }
+
+ Py_DECREF(arr);
+ Py_RETURN_NONE;
+}
static PyObject *
array_cast(PyArrayObject *self, PyObject *args)
@@ -2231,6 +2255,9 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
{"itemset",
(PyCFunction) array_setscalar,
METH_VARARGS, NULL},
+ {"setasflat",
+ (PyCFunction) array_setasflat,
+ METH_VARARGS, NULL},
{"max",
(PyCFunction)array_max,
METH_VARARGS | METH_KEYWORDS, NULL},
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index b3bf209a1..d30ac13e8 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -632,6 +632,23 @@ class TestMethods(TestCase):
assert_equal(a.ravel(), a.reshape(-1))
assert_equal(a.ravel(order='A'), a.reshape(-1, order='A'))
+ def test_setasflat(self):
+ # In this case, setasflat can treat a as a flat array,
+ # and must treat b in chunks of 3
+ a = np.arange(3*3*4).reshape(3,3,4)
+ b = np.arange(3*4*3, dtype='f4').reshape(3,4,3).T
+
+ assert_(not np.all(a.ravel() == b.ravel()))
+ a.setasflat(b)
+ assert_equal(a.ravel(), b.ravel())
+
+ # A case where the strides of neither a nor b can be collapsed
+ a = np.arange(3*2*4).reshape(3,2,4)[:,:,:-1]
+ b = np.arange(3*3*3, dtype='f4').reshape(3,3,3).T[:,:,:-1]
+
+ assert_(not np.all(a.ravel() == b.ravel()))
+ a.setasflat(b)
+ assert_equal(a.ravel(), b.ravel())
class TestSubscripting(TestCase):
def test_test_zero_rank(self):
diff --git a/numpy/matrixlib/tests/test_defmatrix.py b/numpy/matrixlib/tests/test_defmatrix.py
index 65d79df0b..ccb68f0e7 100644
--- a/numpy/matrixlib/tests/test_defmatrix.py
+++ b/numpy/matrixlib/tests/test_defmatrix.py
@@ -263,7 +263,7 @@ class TestMatrixReturn(TestCase):
'searchsorted', 'setflags', 'setfield', 'sort', 'take',
'tofile', 'tolist', 'tostring', 'all', 'any', 'sum',
'argmax', 'argmin', 'min', 'max', 'mean', 'var', 'ptp',
- 'prod', 'std', 'ctypes', 'itemset'
+ 'prod', 'std', 'ctypes', 'itemset', 'setasflat'
]
for attrib in dir(a):
if attrib.startswith('_') or attrib in excluded_methods: