summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJay Bourque <jay.bourque@continuum.io>2013-06-29 15:14:11 -0500
committerJay Bourque <jay.bourque@continuum.io>2013-07-08 13:43:11 -0500
commitedd989eccb6283bc4cf3605e62ad6733a95316dc (patch)
treea4bb2393fd791cf8fcc1aafba604be20edbf4fb9
parentd4e70cf8091d64d137c12b1cb4d8ae726804ebad (diff)
downloadnumpy-edd989eccb6283bc4cf3605e62ad6733a95316dc.tar.gz
BUG: Fix creation of string arrays from object types
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c50
-rw-r--r--numpy/core/src/multiarray/ctors.c24
-rw-r--r--numpy/core/tests/test_api.py22
3 files changed, 79 insertions, 17 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index 8cdd65d7b..f701e49a1 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -226,24 +226,36 @@ PyArray_AdaptFlexibleDType(PyObject *data_obj, PyArray_Descr *data_dtype,
break;
case NPY_OBJECT:
size = 64;
- /*
- * If we're adapting a string dtype for an array of string
- * objects, call GetArrayParamsFromObject to figure out
- * maximum string size, and use that as new dtype size.
- */
if ((flex_type_num == NPY_STRING ||
flex_type_num == NPY_UNICODE) &&
data_obj != NULL) {
- /*
- * Convert data array to list of objects since
- * GetArrayParamsFromObject won't iterate through
- * items in an array.
- */
- list = PyArray_ToList(data_obj);
- if (list != NULL) {
+ if (PyArray_CheckScalar(data_obj)) {
+ PyObject *scalar = PyArray_ToList(data_obj);
+ if (scalar != NULL) {
+ PyObject *s = PyObject_Str(scalar);
+ if (s == NULL) {
+ Py_DECREF(scalar);
+ Py_DECREF(*flex_dtype);
+ *flex_dtype = NULL;
+ return;
+ }
+ else {
+ size = PyObject_Length(s);
+ Py_DECREF(s);
+ }
+ Py_DECREF(scalar);
+ }
+ }
+ else if (PyArray_Check(data_obj)) {
+ /*
+ * Convert data array to list of objects since
+ * GetArrayParamsFromObject won't iterator over
+ * array.
+ */
+ list = PyArray_ToList(data_obj);
result = PyArray_GetArrayParamsFromObject(
list,
- flex_dtype,
+ *flex_dtype,
0, &dtype,
&ndim, dims, &arr, NULL);
if (result == 0 && dtype != NULL) {
@@ -256,6 +268,18 @@ PyArray_AdaptFlexibleDType(PyObject *data_obj, PyArray_Descr *data_dtype,
}
Py_DECREF(list);
}
+ else if (PyArray_IsPythonScalar(data_obj)) {
+ PyObject *s = PyObject_Str(data_obj);
+ if (s == NULL) {
+ Py_DECREF(*flex_dtype);
+ *flex_dtype = NULL;
+ return;
+ }
+ else {
+ size = PyObject_Length(s);
+ Py_DECREF(s);
+ }
+ }
}
break;
case NPY_STRING:
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c
index 872f4e284..4c27c4cec 100644
--- a/numpy/core/src/multiarray/ctors.c
+++ b/numpy/core/src/multiarray/ctors.c
@@ -521,7 +521,7 @@ PyArray_AssignFromSequence(PyArrayObject *self, PyObject *v)
*/
static int
-discover_itemsize(PyObject *s, int nd, int *itemsize)
+discover_itemsize(PyObject *s, int nd, int *itemsize, int size_as_string)
{
int n, r, i;
@@ -539,7 +539,19 @@ discover_itemsize(PyObject *s, int nd, int *itemsize)
PyUnicode_Check(s)) {
/* If an object has no length, leave it be */
- n = PyObject_Length(s);
+ if (size_as_string && s != NULL && !PyString_Check(s)) {
+ PyObject *s_string = PyObject_Str(s);
+ if (s_string) {
+ n = PyObject_Length(s_string);
+ Py_DECREF(s_string);
+ }
+ else {
+ n = -1;
+ }
+ }
+ else {
+ n = PyObject_Length(s);
+ }
if (n == -1) {
PyErr_Clear();
}
@@ -557,7 +569,7 @@ discover_itemsize(PyObject *s, int nd, int *itemsize)
return -1;
}
- r = discover_itemsize(e,nd-1,itemsize);
+ r = discover_itemsize(e,nd-1,itemsize,size_as_string);
Py_DECREF(e);
if (r == -1) {
return -1;
@@ -1528,7 +1540,11 @@ PyArray_GetArrayParamsFromObject(PyObject *op,
if ((*out_dtype)->elsize == 0 &&
PyTypeNum_ISEXTENDED((*out_dtype)->type_num)) {
int itemsize = 0;
- if (discover_itemsize(op, *out_ndim, &itemsize) < 0) {
+ int size_as_string = 0;
+ if ((*out_dtype)->type_num == NPY_STRING || (*out_dtype)->type_num == NPY_UNICODE) {
+ size_as_string = 1;
+ }
+ if (discover_itemsize(op, *out_ndim, &itemsize, size_as_string) < 0) {
Py_DECREF(*out_dtype);
if (PyErr_Occurred() &&
PyErr_GivenExceptionMatches(PyErr_Occurred(),
diff --git a/numpy/core/tests/test_api.py b/numpy/core/tests/test_api.py
index c681b57ab..8036210d9 100644
--- a/numpy/core/tests/test_api.py
+++ b/numpy/core/tests/test_api.py
@@ -259,6 +259,28 @@ def test_array_astype():
assert_equal(a, b)
assert_equal(b.dtype, np.dtype('U10'))
+ a = np.array(123456789012345678901234567890, dtype='O').astype('S')
+ assert_array_equal(a, np.array(b'123456789012345678901234567890',
+ dtype='S30'))
+ a = np.array(123456789012345678901234567890, dtype='O').astype('U')
+ assert_array_equal(a, np.array(sixu('123456789012345678901234567890'),
+ dtype='U30'))
+
+ a = np.array([123456789012345678901234567890], dtype='O').astype('S')
+ assert_array_equal(a, np.array(b'123456789012345678901234567890',
+ dtype='S30'))
+ a = np.array([123456789012345678901234567890], dtype='O').astype('U')
+ assert_array_equal(a, np.array(sixu('123456789012345678901234567890'),
+ dtype='U30'))
+
+ a = np.array(123456789012345678901234567890, dtype='S')
+ assert_array_equal(a, np.array(b'123456789012345678901234567890',
+ dtype='S30'))
+ a = np.array(123456789012345678901234567890, dtype='U')
+ assert_array_equal(a, np.array(sixu('123456789012345678901234567890'),
+ dtype='U30'))
+
+
def test_copyto_fromscalar():
a = np.arange(6, dtype='f4').reshape(2,3)