summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c37
-rw-r--r--numpy/core/tests/test_api.py22
2 files changed, 59 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index e1483f4e1..957a74fde 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -141,6 +141,12 @@ PyArray_AdaptFlexibleDType(PyObject *data_obj, PyArray_Descr *data_dtype,
{
PyArray_DatetimeMetaData *meta;
int flex_type_num;
+ PyArrayObject *arr = NULL, *ret;
+ PyArray_Descr *dtype = NULL;
+ int ndim = 0;
+ npy_intp dims[NPY_MAXDIMS];
+ PyObject *list = NULL;
+ int result;
if (*flex_dtype == NULL) {
if (!PyErr_Occurred()) {
@@ -220,6 +226,37 @@ PyArray_AdaptFlexibleDType(PyObject *data_obj, PyArray_Descr *data_dtype,
break;
case NPY_OBJECT:
size = 64;
+ /*
+ * If we're adapting a string dtype for an array of string
+ * objects, call GetArrayParamsFromObject to figure out
+ * maximum string size, and use that as new dtype size.
+ */
+ if ((flex_type_num == NPY_STRING ||
+ flex_type_num == NPY_UNICODE) &&
+ data_obj != NULL) {
+ /*
+ * Convert data array to list of objects since
+ * GetArrayParamsFromObject won't iterate through
+ * items in an array.
+ */
+ list = PyArray_ToList(data_obj);
+ if (list != NULL) {
+ result = PyArray_GetArrayParamsFromObject(
+ list,
+ flex_dtype,
+ 0, &dtype,
+ &ndim, dims, &arr, NULL);
+ if (result == 0 && dtype != NULL) {
+ if (flex_type_num == NPY_UNICODE) {
+ size = dtype->elsize / 4;
+ }
+ else {
+ size = dtype->elsize;
+ }
+ }
+ Py_DECREF(list);
+ }
+ }
break;
case NPY_STRING:
case NPY_VOID:
diff --git a/numpy/core/tests/test_api.py b/numpy/core/tests/test_api.py
index 1d4b93b0f..8ab48f2d1 100644
--- a/numpy/core/tests/test_api.py
+++ b/numpy/core/tests/test_api.py
@@ -6,6 +6,7 @@ import numpy as np
from numpy.testing import *
from numpy.testing.utils import WarningManager
import warnings
+from numpy.compat import sixu
# Switch between new behaviour when NPY_RELAXED_STRIDES_CHECKING is set.
NPY_RELAXED_STRIDES_CHECKING = np.ones((10,1), order='C').flags.f_contiguous
@@ -89,6 +90,27 @@ def test_array_astype():
assert_(not (a is b))
assert_(type(b) != np.matrix)
+ # Make sure converting from string object to fixed length string
+ # does not truncate.
+ a = np.array([b'a'*100], dtype='O')
+ b = a.astype('S')
+ assert_equal(a, b)
+ assert_equal(b.dtype, np.dtype('S100'))
+ a = np.array([sixu('a')*100], dtype='O')
+ b = a.astype('U')
+ assert_equal(a, b)
+ assert_equal(b.dtype, np.dtype('U100'))
+
+ # Same test as above but for strings shorter than 64 characters
+ a = np.array([b'a'*10], dtype='O')
+ b = a.astype('S')
+ assert_equal(a, b)
+ assert_equal(b.dtype, np.dtype('S10'))
+ a = np.array([sixu('a')*10], dtype='O')
+ b = a.astype('U')
+ assert_equal(a, b)
+ assert_equal(b.dtype, np.dtype('U10'))
+
def test_copyto_fromscalar():
a = np.arange(6, dtype='f4').reshape(2,3)