summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2011-08-07 14:08:34 -0700
committerCharles Harris <charlesr.harris@gmail.com>2011-08-27 07:26:53 -0600
commitc1ebc154bf4f9d41f28543483a902adfe347f9a3 (patch)
tree3cd7d0f72bb129500165f394924d9e5be2de161c /numpy
parent20cdef13b86606f2ad1090cc8cf41f6921072d21 (diff)
downloadnumpy-c1ebc154bf4f9d41f28543483a902adfe347f9a3.tar.gz
ENH: core: Rewrite PyArray_FillWithScalar to use array_assign_scalar
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/arrayobject.c1
-rw-r--r--numpy/core/src/multiarray/convert.c156
-rw-r--r--numpy/core/src/multiarray/na_mask.c3
-rw-r--r--numpy/core/src/multiarray/na_mask.h3
4 files changed, 116 insertions, 47 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index 8f02b50e8..e8a8f1d27 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -197,6 +197,7 @@ PyArray_CopyObject(PyArrayObject *dest, PyObject *src_object)
}
value = scalar_value(src_object, dtype);
if (value == NULL) {
+ Py_DECREF(dtype);
Py_DECREF(src_object);
return -1;
}
diff --git a/numpy/core/src/multiarray/convert.c b/numpy/core/src/multiarray/convert.c
index d5fd2a680..1dbb387b2 100644
--- a/numpy/core/src/multiarray/convert.c
+++ b/numpy/core/src/multiarray/convert.c
@@ -16,6 +16,7 @@
#include "mapping.h"
#include "lowlevel_strided_loops.h"
#include "array_assign.h"
+#include "scalartypes.h"
#include "convert.h"
@@ -309,64 +310,128 @@ PyArray_ToString(PyArrayObject *self, NPY_ORDER order)
NPY_NO_EXPORT int
PyArray_FillWithScalar(PyArrayObject *arr, PyObject *obj)
{
- PyArrayObject *newarr;
- int itemsize, swap;
- void *fromptr;
- PyArray_Descr *descr;
- intp size;
- PyArray_CopySwapFunc *copyswap;
-
- itemsize = PyArray_DESCR(arr)->elsize;
- if (PyArray_ISOBJECT(arr)) {
- fromptr = &obj;
- swap = 0;
- newarr = NULL;
+ PyArray_Descr *dtype = NULL;
+ npy_longlong value_buffer[4];
+ char *value = NULL;
+ int retcode = 0;
+
+ /*
+ * If 'arr' is an object array, copy the object as is unless
+ * 'obj' is a zero-dimensional array, in which case we copy
+ * the element in that array instead.
+ */
+ if (PyArray_DESCR(arr)->type_num == NPY_OBJECT &&
+ !(PyArray_Check(obj) &&
+ PyArray_NDIM((PyArrayObject *)obj) == 0)) {
+ value = (char *)&obj;
+
+ dtype = PyArray_DescrFromType(NPY_OBJECT);
+ if (dtype == NULL) {
+ return -1;
+ }
}
- else {
- descr = PyArray_DESCR(arr);
- Py_INCREF(descr);
- newarr = (PyArrayObject *)PyArray_FromAny(obj, descr,
- 0,0, NPY_ARRAY_ALIGNED, NULL);
- if (newarr == NULL) {
+ /* Use array_assign_scalar if 'obj' is a numpy scalar object */
+ else if (PyArray_IsScalar(obj, Generic)) {
+ dtype = PyArray_DescrFromScalar(obj);
+ if (dtype == NULL) {
+ return -1;
+ }
+ value = scalar_value(obj, dtype);
+ if (value == NULL) {
+ Py_DECREF(dtype);
return -1;
}
- fromptr = PyArray_DATA(newarr);
- swap = (PyArray_ISNOTSWAPPED(arr) != PyArray_ISNOTSWAPPED(newarr));
}
- size=PyArray_SIZE(arr);
- copyswap = PyArray_DESCR(arr)->f->copyswap;
- if (PyArray_ISONESEGMENT(arr)) {
- char *toptr=PyArray_DATA(arr);
- PyArray_FillWithScalarFunc* fillwithscalar =
- PyArray_DESCR(arr)->f->fillwithscalar;
- if (fillwithscalar && PyArray_ISALIGNED(arr)) {
- copyswap(fromptr, NULL, swap, newarr);
- fillwithscalar(toptr, size, fromptr, arr);
+ /* Python boolean */
+ else if (PyBool_Check(obj)) {
+ value = (char *)value_buffer;
+ *value = (obj == Py_True);
+
+ dtype = PyArray_DescrFromType(NPY_BOOL);
+ if (dtype == NULL) {
+ return -1;
}
- else {
- while (size--) {
- copyswap(toptr, fromptr, swap, arr);
- toptr += itemsize;
- }
+ }
+ /* Python integer */
+ else if (PyLong_Check(obj) || PyInt_Check(obj)) {
+ npy_longlong v = PyLong_AsLongLong(obj);
+ if (v == -1 && PyErr_Occurred()) {
+ return -1;
+ }
+ value = (char *)value_buffer;
+ *(npy_longlong *)value = v;
+
+ dtype = PyArray_DescrFromType(NPY_LONGLONG);
+ if (dtype == NULL) {
+ return -1;
}
}
+ /* Python float */
+ else if (PyFloat_Check(obj)) {
+ npy_double v = PyFloat_AsDouble(obj);
+ if (v == -1 && PyErr_Occurred()) {
+ return -1;
+ }
+ value = (char *)value_buffer;
+ *(npy_double *)value = v;
+
+ dtype = PyArray_DescrFromType(NPY_DOUBLE);
+ if (dtype == NULL) {
+ return -1;
+ }
+ }
+ /* Python complex */
+ else if (PyComplex_Check(obj)) {
+ npy_double re, im;
+
+ re = PyComplex_RealAsDouble(obj);
+ if (re == -1 && PyErr_Occurred()) {
+ return -1;
+ }
+ im = PyComplex_ImagAsDouble(obj);
+ if (im == -1 && PyErr_Occurred()) {
+ return -1;
+ }
+ value = (char *)value_buffer;
+ ((npy_double *)value)[0] = re;
+ ((npy_double *)value)[1] = im;
+
+ dtype = PyArray_DescrFromType(NPY_CDOUBLE);
+ if (dtype == NULL) {
+ return -1;
+ }
+ }
+
+ /* Use the value pointer we got if possible */
+ if (value != NULL) {
+ /* TODO: switch to SAME_KIND casting */
+ retcode = array_assign_scalar(arr, dtype, value,
+ NULL, NPY_UNSAFE_CASTING, 0, NULL);
+ Py_DECREF(dtype);
+ return retcode;
+ }
+ /* Otherwise convert to an array to do the assignment */
else {
- PyArrayIterObject *iter;
+ PyArrayObject *src_arr;
- iter = (PyArrayIterObject *)\
- PyArray_IterNew((PyObject *)arr);
- if (iter == NULL) {
- Py_XDECREF(newarr);
+ src_arr = (PyArrayObject *)PyArray_FromAny(obj, NULL, 0, 0,
+ NPY_ARRAY_ALLOWNA, NULL);
+ if (src_arr == NULL) {
return -1;
}
- while (size--) {
- copyswap(iter->dataptr, fromptr, swap, arr);
- PyArray_ITER_NEXT(iter);
+
+ if (PyArray_NDIM(src_arr) != 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "Input object to FillWithScalar is not a scalar");
+ Py_DECREF(src_arr);
+ return -1;
}
- Py_DECREF(iter);
+
+ retcode = PyArray_CopyInto(arr, src_arr);
+
+ Py_DECREF(src_arr);
+ return retcode;
}
- Py_XDECREF(newarr);
- return 0;
}
/*NUMPY_API
@@ -406,6 +471,7 @@ PyArray_AssignZero(PyArrayObject *dst,
Py_DECREF(bool_dtype);
return retcode;
}
+
/*NUMPY_API
*
* Fills an array with ones.
diff --git a/numpy/core/src/multiarray/na_mask.c b/numpy/core/src/multiarray/na_mask.c
index 66dbb4c73..f9ffcba3a 100644
--- a/numpy/core/src/multiarray/na_mask.c
+++ b/numpy/core/src/multiarray/na_mask.c
@@ -615,7 +615,8 @@ _strided_bool_mask_inversion(char *dst, npy_intp dst_stride,
NPY_NO_EXPORT int
PyArray_GetMaskInversionFunction(
- npy_intp mask_stride, PyArray_Descr *mask_dtype,
+ npy_intp dst_mask_stride, npy_intp src_mask_stride,
+ PyArray_Descr *mask_dtype,
PyArray_StridedUnaryOp **out_unop, NpyAuxData **out_opdata)
{
/* Will use the opdata with the field version */
diff --git a/numpy/core/src/multiarray/na_mask.h b/numpy/core/src/multiarray/na_mask.h
index 39a26c685..fc6c40730 100644
--- a/numpy/core/src/multiarray/na_mask.h
+++ b/numpy/core/src/multiarray/na_mask.h
@@ -22,7 +22,8 @@ PyArray_IsNA(PyObject *obj);
* Gets a strided unary operation which inverts mask values.
*/
NPY_NO_EXPORT int
-PyArray_GetMaskInversionFunction(npy_intp mask_stride,
+PyArray_GetMaskInversionFunction(npy_intp dst_mask_stride,
+ npy_intp src_mask_stride,
PyArray_Descr *mask_dtype,
PyArray_StridedUnaryOp **out_unop,
NpyAuxData **out_opdata);