diff options
-rw-r--r-- | numpy/core/src/arrayobject.c | 79 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 35 | ||||
-rw-r--r-- | numpy/core/tests/test_unicode.py | 8 |
3 files changed, 91 insertions, 31 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 3a1e0c836..e118c6a22 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -932,6 +932,16 @@ PyArray_Scalar(void *data, PyArray_Descr *descr, PyObject *base) type = descr->typeobj; copyswap = descr->f->copyswap; swap = !PyArray_ISNBO(descr->byteorder); + if PyTypeNum_ISSTRING(type_num) { /* Eliminate NULL bytes */ + char *dptr = data; + dptr += itemsize-1; + while(*dptr-- == 0 && itemsize) itemsize--; + if (type_num == PyArray_UNICODE && itemsize) { + /* make sure itemsize is a multiple of 4 */ + /* so round up to nearest multiple */ + itemsize = (((itemsize-1) >> 2) + 1) << 2; + } + } if (type->tp_itemsize != 0) /* String type */ obj = type->tp_alloc(type, itemsize); else @@ -3491,8 +3501,10 @@ array_str(PyArrayObject *self) return s; } + + /*OBJECT_API -*/ + */ static int PyArray_CompareUCS4(PyArray_UCS4 *s1, PyArray_UCS4 *s2, register size_t len) { @@ -3507,25 +3519,70 @@ PyArray_CompareUCS4(PyArray_UCS4 *s1, PyArray_UCS4 *s2, register size_t len) return 0; } +static int +_myunincmp(PyArray_UCS4 *s1, PyArray_UCS4 *s2, int len1, int len2) +{ + PyArray_UCS4 *sptr; + int val; + + val = PyArray_CompareUCS4(s1, s2, MIN(len1, len2)); + if ((val != 0) || (len1 == len2)) return val; + if (len2 > len1) {sptr = s2+len1; val = -1;} + else {sptr = s1+len2; val = 1;} + if (*sptr != 0) return val; + + if (len2 > len1) {sptr = s2; val = -1;} + else {sptr = s1; val = 1;} + if (*sptr != 0) return val; + return 0; +} + + + +/* Compare s1 and s2 which are not necessarily NULL-terminated. + s1 is of length len1 + s2 is of length len2 + If they are NULL terminated, then stop comparison. +*/ +static int +_mystrncmp(char *s1, char *s2, int len1, int len2) +{ + char *sptr; + int val; + + val = strncmp(s1, s2, MIN(len1, len2)); + if ((val != 0) || (len1 == len2)) return val; + if (len2 > len1) {sptr = s2+len1; val = -1;} + else {sptr = s1+len2; val = 1;} + if (*sptr != 0) return val; + return 0; +} static int _compare_strings(PyObject *result, PyArrayMultiIterObject *multi, - int cmp_op, int len, void *func) + int cmp_op, void *func) { PyArrayIterObject *iself, *iother; Bool *dptr; intp size; int val; - int (*cmpfunc)(void *, void *, size_t); + int N1, N2; + int (*cmpfunc)(void *, void *, int, int); cmpfunc = func; dptr = (Bool *)PyArray_DATA(result); iself = multi->iters[0]; iother = multi->iters[1]; size = multi->size; + N1 = iself->ao->descr->elsize; + N2 = iother->ao->descr->elsize; + if ((void *)cmpfunc == (void *)_myunincmp) { + N1 >>= 2; + N2 >>= 2; + } while(size--) { val = cmpfunc((void *)iself->dataptr, (void *)iother->dataptr, - len); + N1, N2); switch (cmp_op) { case Py_EQ: *dptr = (val == 0); @@ -3562,7 +3619,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op) { PyObject *result; PyArrayMultiIterObject *mit; - int N, val; + int val; /* Cast arrays to a common type */ if (self->descr->type != other->descr->type) { @@ -3603,12 +3660,6 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op) Py_DECREF(other); if (mit == NULL) return NULL; - /* Choose largest size for result array */ - if (self->descr->elsize > other->descr->elsize) - N = self->descr->elsize; - else - N = other->descr->elsize; - result = PyArray_NewFromDescr(&PyArray_Type, PyArray_DescrFromType(PyArray_BOOL), mit->nd, @@ -3618,12 +3669,10 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op) if (result == NULL) goto finish; if (self->descr->type_num == PyArray_STRING) { - val = _compare_strings(result, mit, cmp_op, N, strncmp); + val = _compare_strings(result, mit, cmp_op, _mystrncmp); } else { - val = _compare_strings(result, mit, cmp_op, - N/sizeof(PyArray_UCS4), - PyArray_CompareUCS4); + val = _compare_strings(result, mit, cmp_op, _myunincmp); } if (val < 0) {Py_DECREF(result); result = NULL;} diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 877865546..49c34e8c6 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -261,22 +261,33 @@ class test_string_compare(ScipyTestCase): def check_string(self): g1 = array(["This","is","example"]) g2 = array(["This","was","example"]) - assert_array_equal(g1 == g2, [True, False, True]) - assert_array_equal(g1 != g2, [False, True, False]) - assert_array_equal(g1 <= g2, [True, True, True]) - assert_array_equal(g1 >= g2, [True, False, True]) - assert_array_equal(g1 < g2, [False, True, False]) - assert_array_equal(g1 > g2, [False, False, False]) + assert_array_equal(g1 == g2, [g1[i] == g2[i] for i in [0,1,2]]) + assert_array_equal(g1 != g2, [g1[i] != g2[i] for i in [0,1,2]]) + assert_array_equal(g1 <= g2, [g1[i] <= g2[i] for i in [0,1,2]]) + assert_array_equal(g1 >= g2, [g1[i] >= g2[i] for i in [0,1,2]]) + assert_array_equal(g1 < g2, [g1[i] < g2[i] for i in [0,1,2]]) + assert_array_equal(g1 > g2, [g1[i] > g2[i] for i in [0,1,2]]) + + def check_mixed(self): + g1 = array(["spam","spa","spammer","and eggs"]) + g2 = "spam" + assert_array_equal(g1 == g2, [x == g2 for x in g1]) + assert_array_equal(g1 != g2, [x != g2 for x in g1]) + assert_array_equal(g1 < g2, [x < g2 for x in g1]) + assert_array_equal(g1 > g2, [x > g2 for x in g1]) + assert_array_equal(g1 <= g2, [x <= g2 for x in g1]) + assert_array_equal(g1 >= g2, [x >= g2 for x in g1]) + def check_unicode(self): g1 = array([u"This",u"is",u"example"]) g2 = array([u"This",u"was",u"example"]) - assert_array_equal(g1 == g2, [True, False, True]) - assert_array_equal(g1 != g2, [False, True, False]) - assert_array_equal(g1 <= g2, [True, True, True]) - assert_array_equal(g1 >= g2, [True, False, True]) - assert_array_equal(g1 < g2, [False, True, False]) - assert_array_equal(g1 > g2, [False, False, False]) + assert_array_equal(g1 == g2, [g1[i] == g2[i] for i in [0,1,2]]) + assert_array_equal(g1 != g2, [g1[i] != g2[i] for i in [0,1,2]]) + assert_array_equal(g1 <= g2, [g1[i] <= g2[i] for i in [0,1,2]]) + assert_array_equal(g1 >= g2, [g1[i] >= g2[i] for i in [0,1,2]]) + assert_array_equal(g1 < g2, [g1[i] < g2[i] for i in [0,1,2]]) + assert_array_equal(g1 > g2, [g1[i] > g2[i] for i in [0,1,2]]) # Import tests from unicode diff --git a/numpy/core/tests/test_unicode.py b/numpy/core/tests/test_unicode.py index 8da329860..642019352 100644 --- a/numpy/core/tests/test_unicode.py +++ b/numpy/core/tests/test_unicode.py @@ -28,14 +28,14 @@ class create_zeros(ScipyTestCase): # Check the length of the data buffer self.assert_(len(ua.data) == nbytes) # Small check that data in array element is ok - self.assert_(ua_scalar == u'\x00'*self.ulen) + self.assert_(ua_scalar == u'') # Encode to ascii and double check - self.assert_(ua_scalar.encode('ascii') == '\x00'*self.ulen) + self.assert_(ua_scalar.encode('ascii') == '') # Check buffer lengths for scalars if ucs4: - self.assert_(len(buffer(ua_scalar)) == 4*self.ulen) + self.assert_(len(buffer(ua_scalar)) == 0) else: - self.assert_(len(buffer(ua_scalar)) == 2*self.ulen) + self.assert_(len(buffer(ua_scalar)) == 0) def check_zeros0D(self): """Check creation of 0-dimensional objects""" |