summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-04-19 08:18:54 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-04-19 08:18:54 +0000
commit11273d2fd8193e4daffaee16cabdb073a9298048 (patch)
treecf2dd543ffd4ea8ad80bbb38afa743b1e18ad815 /numpy
parente7da22ef0036484ef52f1ef362d6b0554be25347 (diff)
downloadnumpy-11273d2fd8193e4daffaee16cabdb073a9298048.tar.gz
When selecting from string arrays, remove NULL bytes from string scalar. Fix string comparisons so NULL bytes are not compared inappropriately.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/arrayobject.c79
-rw-r--r--numpy/core/tests/test_multiarray.py35
-rw-r--r--numpy/core/tests/test_unicode.py8
3 files changed, 91 insertions, 31 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 3a1e0c836..e118c6a22 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -932,6 +932,16 @@ PyArray_Scalar(void *data, PyArray_Descr *descr, PyObject *base)
type = descr->typeobj;
copyswap = descr->f->copyswap;
swap = !PyArray_ISNBO(descr->byteorder);
+ if PyTypeNum_ISSTRING(type_num) { /* Eliminate NULL bytes */
+ char *dptr = data;
+ dptr += itemsize-1;
+ while(*dptr-- == 0 && itemsize) itemsize--;
+ if (type_num == PyArray_UNICODE && itemsize) {
+ /* make sure itemsize is a multiple of 4 */
+ /* so round up to nearest multiple */
+ itemsize = (((itemsize-1) >> 2) + 1) << 2;
+ }
+ }
if (type->tp_itemsize != 0) /* String type */
obj = type->tp_alloc(type, itemsize);
else
@@ -3491,8 +3501,10 @@ array_str(PyArrayObject *self)
return s;
}
+
+
/*OBJECT_API
-*/
+ */
static int
PyArray_CompareUCS4(PyArray_UCS4 *s1, PyArray_UCS4 *s2, register size_t len)
{
@@ -3507,25 +3519,70 @@ PyArray_CompareUCS4(PyArray_UCS4 *s1, PyArray_UCS4 *s2, register size_t len)
return 0;
}
+static int
+_myunincmp(PyArray_UCS4 *s1, PyArray_UCS4 *s2, int len1, int len2)
+{
+ PyArray_UCS4 *sptr;
+ int val;
+
+ val = PyArray_CompareUCS4(s1, s2, MIN(len1, len2));
+ if ((val != 0) || (len1 == len2)) return val;
+ if (len2 > len1) {sptr = s2+len1; val = -1;}
+ else {sptr = s1+len2; val = 1;}
+ if (*sptr != 0) return val;
+
+ if (len2 > len1) {sptr = s2; val = -1;}
+ else {sptr = s1; val = 1;}
+ if (*sptr != 0) return val;
+ return 0;
+}
+
+
+
+/* Compare s1 and s2 which are not necessarily NULL-terminated.
+ s1 is of length len1
+ s2 is of length len2
+ If they are NULL terminated, then stop comparison.
+*/
+static int
+_mystrncmp(char *s1, char *s2, int len1, int len2)
+{
+ char *sptr;
+ int val;
+
+ val = strncmp(s1, s2, MIN(len1, len2));
+ if ((val != 0) || (len1 == len2)) return val;
+ if (len2 > len1) {sptr = s2+len1; val = -1;}
+ else {sptr = s1+len2; val = 1;}
+ if (*sptr != 0) return val;
+ return 0;
+}
static int
_compare_strings(PyObject *result, PyArrayMultiIterObject *multi,
- int cmp_op, int len, void *func)
+ int cmp_op, void *func)
{
PyArrayIterObject *iself, *iother;
Bool *dptr;
intp size;
int val;
- int (*cmpfunc)(void *, void *, size_t);
+ int N1, N2;
+ int (*cmpfunc)(void *, void *, int, int);
cmpfunc = func;
dptr = (Bool *)PyArray_DATA(result);
iself = multi->iters[0];
iother = multi->iters[1];
size = multi->size;
+ N1 = iself->ao->descr->elsize;
+ N2 = iother->ao->descr->elsize;
+ if ((void *)cmpfunc == (void *)_myunincmp) {
+ N1 >>= 2;
+ N2 >>= 2;
+ }
while(size--) {
val = cmpfunc((void *)iself->dataptr, (void *)iother->dataptr,
- len);
+ N1, N2);
switch (cmp_op) {
case Py_EQ:
*dptr = (val == 0);
@@ -3562,7 +3619,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
{
PyObject *result;
PyArrayMultiIterObject *mit;
- int N, val;
+ int val;
/* Cast arrays to a common type */
if (self->descr->type != other->descr->type) {
@@ -3603,12 +3660,6 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
Py_DECREF(other);
if (mit == NULL) return NULL;
- /* Choose largest size for result array */
- if (self->descr->elsize > other->descr->elsize)
- N = self->descr->elsize;
- else
- N = other->descr->elsize;
-
result = PyArray_NewFromDescr(&PyArray_Type,
PyArray_DescrFromType(PyArray_BOOL),
mit->nd,
@@ -3618,12 +3669,10 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
if (result == NULL) goto finish;
if (self->descr->type_num == PyArray_STRING) {
- val = _compare_strings(result, mit, cmp_op, N, strncmp);
+ val = _compare_strings(result, mit, cmp_op, _mystrncmp);
}
else {
- val = _compare_strings(result, mit, cmp_op,
- N/sizeof(PyArray_UCS4),
- PyArray_CompareUCS4);
+ val = _compare_strings(result, mit, cmp_op, _myunincmp);
}
if (val < 0) {Py_DECREF(result); result = NULL;}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 877865546..49c34e8c6 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -261,22 +261,33 @@ class test_string_compare(ScipyTestCase):
def check_string(self):
g1 = array(["This","is","example"])
g2 = array(["This","was","example"])
- assert_array_equal(g1 == g2, [True, False, True])
- assert_array_equal(g1 != g2, [False, True, False])
- assert_array_equal(g1 <= g2, [True, True, True])
- assert_array_equal(g1 >= g2, [True, False, True])
- assert_array_equal(g1 < g2, [False, True, False])
- assert_array_equal(g1 > g2, [False, False, False])
+ assert_array_equal(g1 == g2, [g1[i] == g2[i] for i in [0,1,2]])
+ assert_array_equal(g1 != g2, [g1[i] != g2[i] for i in [0,1,2]])
+ assert_array_equal(g1 <= g2, [g1[i] <= g2[i] for i in [0,1,2]])
+ assert_array_equal(g1 >= g2, [g1[i] >= g2[i] for i in [0,1,2]])
+ assert_array_equal(g1 < g2, [g1[i] < g2[i] for i in [0,1,2]])
+ assert_array_equal(g1 > g2, [g1[i] > g2[i] for i in [0,1,2]])
+
+ def check_mixed(self):
+ g1 = array(["spam","spa","spammer","and eggs"])
+ g2 = "spam"
+ assert_array_equal(g1 == g2, [x == g2 for x in g1])
+ assert_array_equal(g1 != g2, [x != g2 for x in g1])
+ assert_array_equal(g1 < g2, [x < g2 for x in g1])
+ assert_array_equal(g1 > g2, [x > g2 for x in g1])
+ assert_array_equal(g1 <= g2, [x <= g2 for x in g1])
+ assert_array_equal(g1 >= g2, [x >= g2 for x in g1])
+
def check_unicode(self):
g1 = array([u"This",u"is",u"example"])
g2 = array([u"This",u"was",u"example"])
- assert_array_equal(g1 == g2, [True, False, True])
- assert_array_equal(g1 != g2, [False, True, False])
- assert_array_equal(g1 <= g2, [True, True, True])
- assert_array_equal(g1 >= g2, [True, False, True])
- assert_array_equal(g1 < g2, [False, True, False])
- assert_array_equal(g1 > g2, [False, False, False])
+ assert_array_equal(g1 == g2, [g1[i] == g2[i] for i in [0,1,2]])
+ assert_array_equal(g1 != g2, [g1[i] != g2[i] for i in [0,1,2]])
+ assert_array_equal(g1 <= g2, [g1[i] <= g2[i] for i in [0,1,2]])
+ assert_array_equal(g1 >= g2, [g1[i] >= g2[i] for i in [0,1,2]])
+ assert_array_equal(g1 < g2, [g1[i] < g2[i] for i in [0,1,2]])
+ assert_array_equal(g1 > g2, [g1[i] > g2[i] for i in [0,1,2]])
# Import tests from unicode
diff --git a/numpy/core/tests/test_unicode.py b/numpy/core/tests/test_unicode.py
index 8da329860..642019352 100644
--- a/numpy/core/tests/test_unicode.py
+++ b/numpy/core/tests/test_unicode.py
@@ -28,14 +28,14 @@ class create_zeros(ScipyTestCase):
# Check the length of the data buffer
self.assert_(len(ua.data) == nbytes)
# Small check that data in array element is ok
- self.assert_(ua_scalar == u'\x00'*self.ulen)
+ self.assert_(ua_scalar == u'')
# Encode to ascii and double check
- self.assert_(ua_scalar.encode('ascii') == '\x00'*self.ulen)
+ self.assert_(ua_scalar.encode('ascii') == '')
# Check buffer lengths for scalars
if ucs4:
- self.assert_(len(buffer(ua_scalar)) == 4*self.ulen)
+ self.assert_(len(buffer(ua_scalar)) == 0)
else:
- self.assert_(len(buffer(ua_scalar)) == 2*self.ulen)
+ self.assert_(len(buffer(ua_scalar)) == 0)
def check_zeros0D(self):
"""Check creation of 0-dimensional objects"""