diff options
| author | Jay Bourque <jay.bourque@continuum.io> | 2013-06-29 15:14:11 -0500 |
|---|---|---|
| committer | Jay Bourque <jay.bourque@continuum.io> | 2013-07-08 13:43:11 -0500 |
| commit | edd989eccb6283bc4cf3605e62ad6733a95316dc (patch) | |
| tree | a4bb2393fd791cf8fcc1aafba604be20edbf4fb9 | |
| parent | d4e70cf8091d64d137c12b1cb4d8ae726804ebad (diff) | |
| download | numpy-edd989eccb6283bc4cf3605e62ad6733a95316dc.tar.gz | |
BUG: Fix creation of string arrays from object types
| -rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 50 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/ctors.c | 24 | ||||
| -rw-r--r-- | numpy/core/tests/test_api.py | 22 |
3 files changed, 79 insertions, 17 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index 8cdd65d7b..f701e49a1 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -226,24 +226,36 @@ PyArray_AdaptFlexibleDType(PyObject *data_obj, PyArray_Descr *data_dtype, break; case NPY_OBJECT: size = 64; - /* - * If we're adapting a string dtype for an array of string - * objects, call GetArrayParamsFromObject to figure out - * maximum string size, and use that as new dtype size. - */ if ((flex_type_num == NPY_STRING || flex_type_num == NPY_UNICODE) && data_obj != NULL) { - /* - * Convert data array to list of objects since - * GetArrayParamsFromObject won't iterate through - * items in an array. - */ - list = PyArray_ToList(data_obj); - if (list != NULL) { + if (PyArray_CheckScalar(data_obj)) { + PyObject *scalar = PyArray_ToList(data_obj); + if (scalar != NULL) { + PyObject *s = PyObject_Str(scalar); + if (s == NULL) { + Py_DECREF(scalar); + Py_DECREF(*flex_dtype); + *flex_dtype = NULL; + return; + } + else { + size = PyObject_Length(s); + Py_DECREF(s); + } + Py_DECREF(scalar); + } + } + else if (PyArray_Check(data_obj)) { + /* + * Convert data array to list of objects since + * GetArrayParamsFromObject won't iterator over + * array. + */ + list = PyArray_ToList(data_obj); result = PyArray_GetArrayParamsFromObject( list, - flex_dtype, + *flex_dtype, 0, &dtype, &ndim, dims, &arr, NULL); if (result == 0 && dtype != NULL) { @@ -256,6 +268,18 @@ PyArray_AdaptFlexibleDType(PyObject *data_obj, PyArray_Descr *data_dtype, } Py_DECREF(list); } + else if (PyArray_IsPythonScalar(data_obj)) { + PyObject *s = PyObject_Str(data_obj); + if (s == NULL) { + Py_DECREF(*flex_dtype); + *flex_dtype = NULL; + return; + } + else { + size = PyObject_Length(s); + Py_DECREF(s); + } + } } break; case NPY_STRING: diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c index 872f4e284..4c27c4cec 100644 --- a/numpy/core/src/multiarray/ctors.c +++ b/numpy/core/src/multiarray/ctors.c @@ -521,7 +521,7 @@ PyArray_AssignFromSequence(PyArrayObject *self, PyObject *v) */ static int -discover_itemsize(PyObject *s, int nd, int *itemsize) +discover_itemsize(PyObject *s, int nd, int *itemsize, int size_as_string) { int n, r, i; @@ -539,7 +539,19 @@ discover_itemsize(PyObject *s, int nd, int *itemsize) PyUnicode_Check(s)) { /* If an object has no length, leave it be */ - n = PyObject_Length(s); + if (size_as_string && s != NULL && !PyString_Check(s)) { + PyObject *s_string = PyObject_Str(s); + if (s_string) { + n = PyObject_Length(s_string); + Py_DECREF(s_string); + } + else { + n = -1; + } + } + else { + n = PyObject_Length(s); + } if (n == -1) { PyErr_Clear(); } @@ -557,7 +569,7 @@ discover_itemsize(PyObject *s, int nd, int *itemsize) return -1; } - r = discover_itemsize(e,nd-1,itemsize); + r = discover_itemsize(e,nd-1,itemsize,size_as_string); Py_DECREF(e); if (r == -1) { return -1; @@ -1528,7 +1540,11 @@ PyArray_GetArrayParamsFromObject(PyObject *op, if ((*out_dtype)->elsize == 0 && PyTypeNum_ISEXTENDED((*out_dtype)->type_num)) { int itemsize = 0; - if (discover_itemsize(op, *out_ndim, &itemsize) < 0) { + int size_as_string = 0; + if ((*out_dtype)->type_num == NPY_STRING || (*out_dtype)->type_num == NPY_UNICODE) { + size_as_string = 1; + } + if (discover_itemsize(op, *out_ndim, &itemsize, size_as_string) < 0) { Py_DECREF(*out_dtype); if (PyErr_Occurred() && PyErr_GivenExceptionMatches(PyErr_Occurred(), diff --git a/numpy/core/tests/test_api.py b/numpy/core/tests/test_api.py index c681b57ab..8036210d9 100644 --- a/numpy/core/tests/test_api.py +++ b/numpy/core/tests/test_api.py @@ -259,6 +259,28 @@ def test_array_astype(): assert_equal(a, b) assert_equal(b.dtype, np.dtype('U10')) + a = np.array(123456789012345678901234567890, dtype='O').astype('S') + assert_array_equal(a, np.array(b'123456789012345678901234567890', + dtype='S30')) + a = np.array(123456789012345678901234567890, dtype='O').astype('U') + assert_array_equal(a, np.array(sixu('123456789012345678901234567890'), + dtype='U30')) + + a = np.array([123456789012345678901234567890], dtype='O').astype('S') + assert_array_equal(a, np.array(b'123456789012345678901234567890', + dtype='S30')) + a = np.array([123456789012345678901234567890], dtype='O').astype('U') + assert_array_equal(a, np.array(sixu('123456789012345678901234567890'), + dtype='U30')) + + a = np.array(123456789012345678901234567890, dtype='S') + assert_array_equal(a, np.array(b'123456789012345678901234567890', + dtype='S30')) + a = np.array(123456789012345678901234567890, dtype='U') + assert_array_equal(a, np.array(sixu('123456789012345678901234567890'), + dtype='U30')) + + def test_copyto_fromscalar(): a = np.arange(6, dtype='f4').reshape(2,3) |
