diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-04-19 08:18:54 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-04-19 08:18:54 +0000 |
commit | 11273d2fd8193e4daffaee16cabdb073a9298048 (patch) | |
tree | cf2dd543ffd4ea8ad80bbb38afa743b1e18ad815 /numpy/core/src | |
parent | e7da22ef0036484ef52f1ef362d6b0554be25347 (diff) | |
download | numpy-11273d2fd8193e4daffaee16cabdb073a9298048.tar.gz |
When selecting from string arrays, remove NULL bytes from string scalar. Fix string comparisons so NULL bytes are not compared inappropriately.
Diffstat (limited to 'numpy/core/src')
-rw-r--r-- | numpy/core/src/arrayobject.c | 79 |
1 files changed, 64 insertions, 15 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 3a1e0c836..e118c6a22 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -932,6 +932,16 @@ PyArray_Scalar(void *data, PyArray_Descr *descr, PyObject *base) type = descr->typeobj; copyswap = descr->f->copyswap; swap = !PyArray_ISNBO(descr->byteorder); + if PyTypeNum_ISSTRING(type_num) { /* Eliminate NULL bytes */ + char *dptr = data; + dptr += itemsize-1; + while(*dptr-- == 0 && itemsize) itemsize--; + if (type_num == PyArray_UNICODE && itemsize) { + /* make sure itemsize is a multiple of 4 */ + /* so round up to nearest multiple */ + itemsize = (((itemsize-1) >> 2) + 1) << 2; + } + } if (type->tp_itemsize != 0) /* String type */ obj = type->tp_alloc(type, itemsize); else @@ -3491,8 +3501,10 @@ array_str(PyArrayObject *self) return s; } + + /*OBJECT_API -*/ + */ static int PyArray_CompareUCS4(PyArray_UCS4 *s1, PyArray_UCS4 *s2, register size_t len) { @@ -3507,25 +3519,70 @@ PyArray_CompareUCS4(PyArray_UCS4 *s1, PyArray_UCS4 *s2, register size_t len) return 0; } +static int +_myunincmp(PyArray_UCS4 *s1, PyArray_UCS4 *s2, int len1, int len2) +{ + PyArray_UCS4 *sptr; + int val; + + val = PyArray_CompareUCS4(s1, s2, MIN(len1, len2)); + if ((val != 0) || (len1 == len2)) return val; + if (len2 > len1) {sptr = s2+len1; val = -1;} + else {sptr = s1+len2; val = 1;} + if (*sptr != 0) return val; + + if (len2 > len1) {sptr = s2; val = -1;} + else {sptr = s1; val = 1;} + if (*sptr != 0) return val; + return 0; +} + + + +/* Compare s1 and s2 which are not necessarily NULL-terminated. + s1 is of length len1 + s2 is of length len2 + If they are NULL terminated, then stop comparison. +*/ +static int +_mystrncmp(char *s1, char *s2, int len1, int len2) +{ + char *sptr; + int val; + + val = strncmp(s1, s2, MIN(len1, len2)); + if ((val != 0) || (len1 == len2)) return val; + if (len2 > len1) {sptr = s2+len1; val = -1;} + else {sptr = s1+len2; val = 1;} + if (*sptr != 0) return val; + return 0; +} static int _compare_strings(PyObject *result, PyArrayMultiIterObject *multi, - int cmp_op, int len, void *func) + int cmp_op, void *func) { PyArrayIterObject *iself, *iother; Bool *dptr; intp size; int val; - int (*cmpfunc)(void *, void *, size_t); + int N1, N2; + int (*cmpfunc)(void *, void *, int, int); cmpfunc = func; dptr = (Bool *)PyArray_DATA(result); iself = multi->iters[0]; iother = multi->iters[1]; size = multi->size; + N1 = iself->ao->descr->elsize; + N2 = iother->ao->descr->elsize; + if ((void *)cmpfunc == (void *)_myunincmp) { + N1 >>= 2; + N2 >>= 2; + } while(size--) { val = cmpfunc((void *)iself->dataptr, (void *)iother->dataptr, - len); + N1, N2); switch (cmp_op) { case Py_EQ: *dptr = (val == 0); @@ -3562,7 +3619,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op) { PyObject *result; PyArrayMultiIterObject *mit; - int N, val; + int val; /* Cast arrays to a common type */ if (self->descr->type != other->descr->type) { @@ -3603,12 +3660,6 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op) Py_DECREF(other); if (mit == NULL) return NULL; - /* Choose largest size for result array */ - if (self->descr->elsize > other->descr->elsize) - N = self->descr->elsize; - else - N = other->descr->elsize; - result = PyArray_NewFromDescr(&PyArray_Type, PyArray_DescrFromType(PyArray_BOOL), mit->nd, @@ -3618,12 +3669,10 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op) if (result == NULL) goto finish; if (self->descr->type_num == PyArray_STRING) { - val = _compare_strings(result, mit, cmp_op, N, strncmp); + val = _compare_strings(result, mit, cmp_op, _mystrncmp); } else { - val = _compare_strings(result, mit, cmp_op, - N/sizeof(PyArray_UCS4), - PyArray_CompareUCS4); + val = _compare_strings(result, mit, cmp_op, _myunincmp); } if (val < 0) {Py_DECREF(result); result = NULL;} |