summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2011-03-04 11:16:12 -0800
committerMark Wiebe <mwwiebe@gmail.com>2011-03-04 11:16:12 -0800
commitcfff7508bc29e6bc0c44b2d42d7bb23e143d5bc3 (patch)
treef6cab29fe840dab231103b545a425fad1af47b8f
parent13212a5d7919f8668522a8251bea90b3a2b22894 (diff)
downloadnumpy-cfff7508bc29e6bc0c44b2d42d7bb23e143d5bc3.tar.gz
BUG: Fix CastToType to handle string->string casts (ticket #1748)
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c54
-rw-r--r--numpy/core/tests/test_regression.py10
2 files changed, 38 insertions, 26 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index b24869f89..3359a5573 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -23,54 +23,56 @@
* Cast an array using typecode structure.
* steals reference to at --- cannot be NULL
*
- * This function always makes a copy of mp, even if the dtype
+ * This function always makes a copy of arr, even if the dtype
* doesn't change.
*/
NPY_NO_EXPORT PyObject *
-PyArray_CastToType(PyArrayObject *mp, PyArray_Descr *at, int fortran)
+PyArray_CastToType(PyArrayObject *arr, PyArray_Descr *dtype, int fortran)
{
PyObject *out;
- int ret;
- PyArray_Descr *mpd;
+ PyArray_Descr *arr_dtype;
- mpd = mp->descr;
+ arr_dtype = PyArray_DESCR(arr);
- if (at->elsize == 0) {
- PyArray_DESCR_REPLACE(at);
- if (at == NULL) {
+ if (dtype->elsize == 0) {
+ PyArray_DESCR_REPLACE(dtype);
+ if (dtype == NULL) {
return NULL;
}
- if (mpd->type_num == PyArray_STRING &&
- at->type_num == PyArray_UNICODE) {
- at->elsize = mpd->elsize << 2;
+
+ if (arr_dtype->type_num == dtype->type_num) {
+ dtype->elsize = arr_dtype->elsize;
+ }
+ else if (arr_dtype->type_num == NPY_STRING &&
+ dtype->type_num == NPY_UNICODE) {
+ dtype->elsize = arr_dtype->elsize * 4;
}
- if (mpd->type_num == PyArray_UNICODE &&
- at->type_num == PyArray_STRING) {
- at->elsize = mpd->elsize >> 2;
+ else if (arr_dtype->type_num == NPY_UNICODE &&
+ dtype->type_num == NPY_STRING) {
+ dtype->elsize = arr_dtype->elsize / 4;
}
- if (at->type_num == PyArray_VOID) {
- at->elsize = mpd->elsize;
+ else if (dtype->type_num == NPY_VOID) {
+ dtype->elsize = arr_dtype->elsize;
}
}
- out = PyArray_NewFromDescr(Py_TYPE(mp), at,
- mp->nd,
- mp->dimensions,
+ out = PyArray_NewFromDescr(Py_TYPE(arr), dtype,
+ arr->nd,
+ arr->dimensions,
NULL, NULL,
fortran,
- (PyObject *)mp);
+ (PyObject *)arr);
if (out == NULL) {
return NULL;
}
- ret = PyArray_CopyInto((PyArrayObject *)out, mp);
- if (ret != -1) {
- return out;
- }
- Py_DECREF(out);
- return NULL;
+ if (PyArray_CopyInto((PyArrayObject *)out, arr) < 0) {
+ Py_DECREF(out);
+ return NULL;
+ }
+ return out;
}
/*NUMPY_API
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index 3fbbb71df..48a745bfe 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -1535,5 +1535,15 @@ class TestRegression(TestCase):
a[()] = np.array(4)
assert_equal(a, np.array(4))
+ def test_string_astype(self):
+ "Ticket #1748"
+ s1 = asbytes('black')
+ s2 = asbytes('white')
+ s3 = asbytes('other')
+ a = np.array([[s1],[s2],[s3]])
+ assert_equal(a.dtype, np.dtype('S5'))
+ b = a.astype('str')
+ assert_equal(b.dtype, np.dtype('S5'))
+
if __name__ == "__main__":
run_module_suite()