diff options
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 54 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 10 |
2 files changed, 38 insertions, 26 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index b24869f89..3359a5573 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -23,54 +23,56 @@ * Cast an array using typecode structure. * steals reference to at --- cannot be NULL * - * This function always makes a copy of mp, even if the dtype + * This function always makes a copy of arr, even if the dtype * doesn't change. */ NPY_NO_EXPORT PyObject * -PyArray_CastToType(PyArrayObject *mp, PyArray_Descr *at, int fortran) +PyArray_CastToType(PyArrayObject *arr, PyArray_Descr *dtype, int fortran) { PyObject *out; - int ret; - PyArray_Descr *mpd; + PyArray_Descr *arr_dtype; - mpd = mp->descr; + arr_dtype = PyArray_DESCR(arr); - if (at->elsize == 0) { - PyArray_DESCR_REPLACE(at); - if (at == NULL) { + if (dtype->elsize == 0) { + PyArray_DESCR_REPLACE(dtype); + if (dtype == NULL) { return NULL; } - if (mpd->type_num == PyArray_STRING && - at->type_num == PyArray_UNICODE) { - at->elsize = mpd->elsize << 2; + + if (arr_dtype->type_num == dtype->type_num) { + dtype->elsize = arr_dtype->elsize; + } + else if (arr_dtype->type_num == NPY_STRING && + dtype->type_num == NPY_UNICODE) { + dtype->elsize = arr_dtype->elsize * 4; } - if (mpd->type_num == PyArray_UNICODE && - at->type_num == PyArray_STRING) { - at->elsize = mpd->elsize >> 2; + else if (arr_dtype->type_num == NPY_UNICODE && + dtype->type_num == NPY_STRING) { + dtype->elsize = arr_dtype->elsize / 4; } - if (at->type_num == PyArray_VOID) { - at->elsize = mpd->elsize; + else if (dtype->type_num == NPY_VOID) { + dtype->elsize = arr_dtype->elsize; } } - out = PyArray_NewFromDescr(Py_TYPE(mp), at, - mp->nd, - mp->dimensions, + out = PyArray_NewFromDescr(Py_TYPE(arr), dtype, + arr->nd, + arr->dimensions, NULL, NULL, fortran, - (PyObject *)mp); + (PyObject *)arr); if (out == NULL) { return NULL; } - ret = PyArray_CopyInto((PyArrayObject *)out, mp); - if (ret != -1) { - return out; - } - Py_DECREF(out); - return NULL; + if (PyArray_CopyInto((PyArrayObject *)out, arr) < 0) { + Py_DECREF(out); + return NULL; + } + return out; } /*NUMPY_API diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index 3fbbb71df..48a745bfe 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -1535,5 +1535,15 @@ class TestRegression(TestCase): a[()] = np.array(4) assert_equal(a, np.array(4)) + def test_string_astype(self): + "Ticket #1748" + s1 = asbytes('black') + s2 = asbytes('white') + s3 = asbytes('other') + a = np.array([[s1],[s2],[s3]]) + assert_equal(a.dtype, np.dtype('S5')) + b = a.astype('str') + assert_equal(b.dtype, np.dtype('S5')) + if __name__ == "__main__": run_module_suite() |