diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/conversion_utils.c | 489 | ||||
-rw-r--r-- | numpy/core/tests/test_conversion_utils.py | 44 | ||||
-rw-r--r-- | numpy/core/tests/test_einsum.py | 2 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 5 |
4 files changed, 269 insertions, 271 deletions
diff --git a/numpy/core/src/multiarray/conversion_utils.c b/numpy/core/src/multiarray/conversion_utils.c index 260ae7080..14d546867 100644 --- a/numpy/core/src/multiarray/conversion_utils.c +++ b/numpy/core/src/multiarray/conversion_utils.c @@ -322,100 +322,112 @@ PyArray_BoolConverter(PyObject *object, npy_bool *val) return NPY_SUCCEED; } -/*NUMPY_API - * Convert object to endian - */ -NPY_NO_EXPORT int -PyArray_ByteorderConverter(PyObject *obj, char *endian) +static int +string_converter_helper( + PyObject *object, + void *out, + int (*str_func)(char const*, Py_ssize_t, void*), + char const *name, + char const *message) { - char *str; - PyObject *tmp = NULL; - - if (PyUnicode_Check(obj)) { - obj = tmp = PyUnicode_AsASCIIString(obj); + /* allow bytes for compatibility */ + PyObject *str_object = NULL; + if (PyBytes_Check(object)) { + str_object = PyUnicode_FromEncodedObject(object, NULL, NULL); + if (str_object == NULL) { + PyErr_Format(PyExc_ValueError, + "%s %s (got %R)", name, message, object); + return NPY_FAIL; + } } - - *endian = NPY_SWAP; - str = PyBytes_AsString(obj); - if (!str) { - Py_XDECREF(tmp); + else if (PyUnicode_Check(object)) { + str_object = object; + Py_INCREF(str_object); + } + else { + PyErr_Format(PyExc_TypeError, + "%s must be str, not %s", name, Py_TYPE(object)->tp_name); return NPY_FAIL; } - if (strlen(str) < 1) { - PyErr_SetString(PyExc_ValueError, - "Byteorder string must be at least length 1"); - Py_XDECREF(tmp); + + Py_ssize_t length; + char const *str = PyUnicode_AsUTF8AndSize(str_object, &length); + if (str == NULL) { + Py_DECREF(str_object); return NPY_FAIL; } - *endian = str[0]; - if (str[0] != NPY_BIG && str[0] != NPY_LITTLE - && str[0] != NPY_NATIVE && str[0] != NPY_IGNORE) { - if (str[0] == 'b' || str[0] == 'B') { - *endian = NPY_BIG; - } - else if (str[0] == 'l' || str[0] == 'L') { - *endian = NPY_LITTLE; - } - else if (str[0] == 'n' || str[0] == 'N') { - *endian = NPY_NATIVE; - } - else if (str[0] == 'i' || str[0] == 'I') { - *endian = NPY_IGNORE; - } - else if (str[0] == 's' || str[0] == 'S') { - *endian = NPY_SWAP; - } - else { + + int ret = str_func(str, length, out); + Py_DECREF(str_object); + if (ret < 0) { PyErr_Format(PyExc_ValueError, - "%s is an unrecognized byteorder", - str); - Py_XDECREF(tmp); - return NPY_FAIL; - } + "%s %s (got %R)", name, message, object); + return NPY_FAIL; } - Py_XDECREF(tmp); return NPY_SUCCEED; } -/*NUMPY_API - * Convert object to sort kind - */ -NPY_NO_EXPORT int -PyArray_SortkindConverter(PyObject *obj, NPY_SORTKIND *sortkind) +static int byteorder_parser(char const *str, Py_ssize_t length, void *data) { - char *str; - PyObject *tmp = NULL; + char *endian = (char *)data; - if (obj == Py_None) { - *sortkind = NPY_QUICKSORT; - return NPY_SUCCEED; + if (length < 1) { + return -1; } - - if (PyUnicode_Check(obj)) { - obj = tmp = PyUnicode_AsASCIIString(obj); - if (obj == NULL) { - return NPY_FAIL; - } + else if (str[0] == NPY_BIG || str[0] == NPY_LITTLE + || str[0] == NPY_NATIVE || str[0] == NPY_IGNORE) { + *endian = str[0]; + return 0; + } + else if (str[0] == 'b' || str[0] == 'B') { + *endian = NPY_BIG; + return 0; + } + else if (str[0] == 'l' || str[0] == 'L') { + *endian = NPY_LITTLE; + return 0; + } + else if (str[0] == 'n' || str[0] == 'N') { + *endian = NPY_NATIVE; + return 0; + } + else if (str[0] == 'i' || str[0] == 'I') { + *endian = NPY_IGNORE; + return 0; + } + else if (str[0] == 's' || str[0] == 'S') { + *endian = NPY_SWAP; + return 0; + } + else { + return -1; } +} - *sortkind = NPY_QUICKSORT; +/*NUMPY_API + * Convert object to endian + */ +NPY_NO_EXPORT int +PyArray_ByteorderConverter(PyObject *obj, char *endian) +{ + return string_converter_helper( + obj, (void *)endian, byteorder_parser, "byteorder", "not recognized"); +} - str = PyBytes_AsString(obj); - if (!str) { - Py_XDECREF(tmp); - return NPY_FAIL; - } - if (strlen(str) < 1) { - PyErr_SetString(PyExc_ValueError, - "Sort kind string must be at least length 1"); - Py_XDECREF(tmp); - return NPY_FAIL; +static int sortkind_parser(char const *str, Py_ssize_t length, void *data) +{ + NPY_SORTKIND *sortkind = (NPY_SORTKIND *)data; + + if (length < 1) { + return -1; } if (str[0] == 'q' || str[0] == 'Q') { *sortkind = NPY_QUICKSORT; + return 0; } else if (str[0] == 'h' || str[0] == 'H') { *sortkind = NPY_HEAPSORT; + return 0; } else if (str[0] == 'm' || str[0] == 'M') { /* @@ -424,6 +436,7 @@ PyArray_SortkindConverter(PyObject *obj, NPY_SORTKIND *sortkind) * allowing other types of stable sorts to be used. */ *sortkind = NPY_MERGESORT; + return 0; } else if (str[0] == 's' || str[0] == 'S') { /* @@ -435,16 +448,39 @@ PyArray_SortkindConverter(PyObject *obj, NPY_SORTKIND *sortkind) * Which one is used depends on the data type. */ *sortkind = NPY_STABLESORT; + return 0; } else { - PyErr_Format(PyExc_ValueError, - "%s is an unrecognized kind of sort", - str); - Py_XDECREF(tmp); - return NPY_FAIL; + return -1; + } +} + +/*NUMPY_API + * Convert object to sort kind + */ +NPY_NO_EXPORT int +PyArray_SortkindConverter(PyObject *obj, NPY_SORTKIND *sortkind) +{ + /* Leave the desired default from the caller for Py_None */ + if (obj == Py_None) { + return NPY_SUCCEED; + } + return string_converter_helper( + obj, (void *)sortkind, sortkind_parser, "sort kind", + "must be one of 'quick', 'heap', or 'stable'"); +} + +static int selectkind_parser(char const *str, Py_ssize_t length, void *data) +{ + NPY_SELECTKIND *selectkind = (NPY_SELECTKIND *)data; + + if (length == 11 && strcmp(str, "introselect") == 0) { + *selectkind = NPY_INTROSELECT; + return 0; + } + else { + return -1; } - Py_XDECREF(tmp); - return NPY_SUCCEED; } /*NUMPY_API @@ -453,40 +489,29 @@ PyArray_SortkindConverter(PyObject *obj, NPY_SORTKIND *sortkind) NPY_NO_EXPORT int PyArray_SelectkindConverter(PyObject *obj, NPY_SELECTKIND *selectkind) { - char *str; - PyObject *tmp = NULL; + return string_converter_helper( + obj, (void *)selectkind, selectkind_parser, "select kind", + "must be 'introselect'"); +} - if (PyUnicode_Check(obj)) { - obj = tmp = PyUnicode_AsASCIIString(obj); - if (obj == NULL) { - return NPY_FAIL; - } - } +static int searchside_parser(char const *str, Py_ssize_t length, void *data) +{ + NPY_SEARCHSIDE *side = (NPY_SEARCHSIDE *)data; - *selectkind = NPY_INTROSELECT; - str = PyBytes_AsString(obj); - if (!str) { - Py_XDECREF(tmp); - return NPY_FAIL; + if (length < 1) { + return -1; } - if (strlen(str) < 1) { - PyErr_SetString(PyExc_ValueError, - "Select kind string must be at least length 1"); - Py_XDECREF(tmp); - return NPY_FAIL; + else if (str[0] == 'l' || str[0] == 'L') { + *side = NPY_SEARCHLEFT; + return 0; } - if (strcmp(str, "introselect") == 0) { - *selectkind = NPY_INTROSELECT; + else if (str[0] == 'r' || str[0] == 'R') { + *side = NPY_SEARCHRIGHT; + return 0; } else { - PyErr_Format(PyExc_ValueError, - "%s is an unrecognized kind of select", - str); - Py_XDECREF(tmp); - return NPY_FAIL; + return -1; } - Py_XDECREF(tmp); - return NPY_SUCCEED; } /*NUMPY_API @@ -495,36 +520,36 @@ PyArray_SelectkindConverter(PyObject *obj, NPY_SELECTKIND *selectkind) NPY_NO_EXPORT int PyArray_SearchsideConverter(PyObject *obj, void *addr) { - NPY_SEARCHSIDE *side = (NPY_SEARCHSIDE *)addr; - char *str; - PyObject *tmp = NULL; + return string_converter_helper( + obj, addr, searchside_parser, "search side", + "must be 'left' or 'right'"); +} - if (PyUnicode_Check(obj)) { - obj = tmp = PyUnicode_AsASCIIString(obj); +static int order_parser(char const *str, Py_ssize_t length, void *data) +{ + NPY_ORDER *val = (NPY_ORDER *)data; + if (length != 1) { + return -1; } - - str = PyBytes_AsString(obj); - if (!str || strlen(str) < 1) { - PyErr_SetString(PyExc_ValueError, - "expected nonempty string for keyword 'side'"); - Py_XDECREF(tmp); - return NPY_FAIL; + if (str[0] == 'C' || str[0] == 'c') { + *val = NPY_CORDER; + return 0; } - - if (str[0] == 'l' || str[0] == 'L') { - *side = NPY_SEARCHLEFT; + else if (str[0] == 'F' || str[0] == 'f') { + *val = NPY_FORTRANORDER; + return 0; } - else if (str[0] == 'r' || str[0] == 'R') { - *side = NPY_SEARCHRIGHT; + else if (str[0] == 'A' || str[0] == 'a') { + *val = NPY_ANYORDER; + return 0; + } + else if (str[0] == 'K' || str[0] == 'k') { + *val = NPY_KEEPORDER; + return 0; } else { - PyErr_Format(PyExc_ValueError, - "'%s' is an invalid value for keyword 'side'", str); - Py_XDECREF(tmp); - return NPY_FAIL; + return -1; } - Py_XDECREF(tmp); - return NPY_SUCCEED; } /*NUMPY_API @@ -533,59 +558,36 @@ PyArray_SearchsideConverter(PyObject *obj, void *addr) NPY_NO_EXPORT int PyArray_OrderConverter(PyObject *object, NPY_ORDER *val) { - char *str; - /* Leave the desired default from the caller for NULL/Py_None */ - if (object == NULL || object == Py_None) { + /* Leave the desired default from the caller for Py_None */ + if (object == Py_None) { return NPY_SUCCEED; } - else if (PyUnicode_Check(object)) { - PyObject *tmp; - int ret; - tmp = PyUnicode_AsASCIIString(object); - if (tmp == NULL) { - PyErr_SetString(PyExc_ValueError, - "Invalid unicode string passed in for the array ordering. " - "Please pass in 'C', 'F', 'A' or 'K' instead"); - return NPY_FAIL; - } - ret = PyArray_OrderConverter(tmp, val); - Py_DECREF(tmp); - return ret; - } - else if (!PyBytes_Check(object) || PyBytes_GET_SIZE(object) < 1) { - PyErr_SetString(PyExc_ValueError, - "Non-string object detected for the array ordering. " - "Please pass in 'C', 'F', 'A', or 'K' instead"); - return NPY_FAIL; + return string_converter_helper( + object, (void *)val, order_parser, "order", + "must be one of 'C', 'F', 'A', or 'K'"); +} + +static int clipmode_parser(char const *str, Py_ssize_t length, void *data) +{ + NPY_CLIPMODE *val = (NPY_CLIPMODE *)data; + if (length < 1) { + return -1; + } + if (str[0] == 'C' || str[0] == 'c') { + *val = NPY_CLIP; + return 0; + } + else if (str[0] == 'W' || str[0] == 'w') { + *val = NPY_WRAP; + return 0; + } + else if (str[0] == 'R' || str[0] == 'r') { + *val = NPY_RAISE; + return 0; } else { - str = PyBytes_AS_STRING(object); - if (strlen(str) != 1) { - PyErr_SetString(PyExc_ValueError, - "Non-string object detected for the array ordering. " - "Please pass in 'C', 'F', 'A', or 'K' instead"); - return NPY_FAIL; - } - - if (str[0] == 'C' || str[0] == 'c') { - *val = NPY_CORDER; - } - else if (str[0] == 'F' || str[0] == 'f') { - *val = NPY_FORTRANORDER; - } - else if (str[0] == 'A' || str[0] == 'a') { - *val = NPY_ANYORDER; - } - else if (str[0] == 'K' || str[0] == 'k') { - *val = NPY_KEEPORDER; - } - else { - PyErr_SetString(PyExc_TypeError, - "order not understood"); - return NPY_FAIL; - } + return -1; } - return NPY_SUCCEED; } /*NUMPY_API @@ -597,36 +599,14 @@ PyArray_ClipmodeConverter(PyObject *object, NPY_CLIPMODE *val) if (object == NULL || object == Py_None) { *val = NPY_RAISE; } - else if (PyBytes_Check(object)) { - char *str; - str = PyBytes_AS_STRING(object); - if (str[0] == 'C' || str[0] == 'c') { - *val = NPY_CLIP; - } - else if (str[0] == 'W' || str[0] == 'w') { - *val = NPY_WRAP; - } - else if (str[0] == 'R' || str[0] == 'r') { - *val = NPY_RAISE; - } - else { - PyErr_SetString(PyExc_TypeError, - "clipmode not understood"); - return NPY_FAIL; - } - } - else if (PyUnicode_Check(object)) { - PyObject *tmp; - int ret; - tmp = PyUnicode_AsASCIIString(object); - if (tmp == NULL) { - return NPY_FAIL; - } - ret = PyArray_ClipmodeConverter(tmp, val); - Py_DECREF(tmp); - return ret; + + else if (PyBytes_Check(object) || PyUnicode_Check(object)) { + return string_converter_helper( + object, (void *)val, clipmode_parser, "clipmode", + "must be one of 'clip', 'raise', or 'wrap'"); } else { + /* For users passing `np.RAISE`, `np.WRAP`, `np.CLIP` */ int number = PyArray_PyIntAsInt(object); if (error_converting(number)) { goto fail; @@ -636,7 +616,8 @@ PyArray_ClipmodeConverter(PyObject *object, NPY_CLIPMODE *val) *val = (NPY_CLIPMODE) number; } else { - goto fail; + PyErr_Format(PyExc_ValueError, + "integer clipmode must be np.RAISE, np.WRAP, or np.CLIP"); } } return NPY_SUCCEED; @@ -690,66 +671,56 @@ PyArray_ConvertClipmodeSequence(PyObject *object, NPY_CLIPMODE *modes, int n) return NPY_SUCCEED; } +static int casting_parser(char const *str, Py_ssize_t length, void *data) +{ + NPY_CASTING *casting = (NPY_CASTING *)data; + if (length < 2) { + return -1; + } + switch (str[2]) { + case 0: + if (length == 2 && strcmp(str, "no") == 0) { + *casting = NPY_NO_CASTING; + return 0; + } + break; + case 'u': + if (length == 5 && strcmp(str, "equiv") == 0) { + *casting = NPY_EQUIV_CASTING; + return 0; + } + break; + case 'f': + if (length == 4 && strcmp(str, "safe") == 0) { + *casting = NPY_SAFE_CASTING; + return 0; + } + break; + case 'm': + if (length == 9 && strcmp(str, "same_kind") == 0) { + *casting = NPY_SAME_KIND_CASTING; + return 0; + } + break; + case 's': + if (length == 6 && strcmp(str, "unsafe") == 0) { + *casting = NPY_UNSAFE_CASTING; + return 0; + } + break; + } + return -1; +} + /*NUMPY_API * Convert any Python object, *obj*, to an NPY_CASTING enum. */ NPY_NO_EXPORT int PyArray_CastingConverter(PyObject *obj, NPY_CASTING *casting) { - char *str = NULL; - Py_ssize_t length = 0; - - if (PyUnicode_Check(obj)) { - PyObject *str_obj; - int ret; - str_obj = PyUnicode_AsASCIIString(obj); - if (str_obj == NULL) { - return 0; - } - ret = PyArray_CastingConverter(str_obj, casting); - Py_DECREF(str_obj); - return ret; - } - - if (PyBytes_AsStringAndSize(obj, &str, &length) < 0) { - return 0; - } - - if (length >= 2) switch (str[2]) { - case 0: - if (strcmp(str, "no") == 0) { - *casting = NPY_NO_CASTING; - return 1; - } - break; - case 'u': - if (strcmp(str, "equiv") == 0) { - *casting = NPY_EQUIV_CASTING; - return 1; - } - break; - case 'f': - if (strcmp(str, "safe") == 0) { - *casting = NPY_SAFE_CASTING; - return 1; - } - break; - case 'm': - if (strcmp(str, "same_kind") == 0) { - *casting = NPY_SAME_KIND_CASTING; - return 1; - } - break; - case 's': - if (strcmp(str, "unsafe") == 0) { - *casting = NPY_UNSAFE_CASTING; - return 1; - } - break; - } - - PyErr_SetString(PyExc_ValueError, - "casting must be one of 'no', 'equiv', 'safe', " + return string_converter_helper( + obj, (void *)casting, casting_parser, "casting", + "must be one of 'no', 'equiv', 'safe', " "'same_kind', or 'unsafe'"); return 0; } diff --git a/numpy/core/tests/test_conversion_utils.py b/numpy/core/tests/test_conversion_utils.py index 9e80fdcbf..3c3f9e6e1 100644 --- a/numpy/core/tests/test_conversion_utils.py +++ b/numpy/core/tests/test_conversion_utils.py @@ -1,6 +1,8 @@ """ Tests for numpy/core/src/multiarray/conversion_utils.c """ +import re + import pytest import numpy as np @@ -12,6 +14,11 @@ class StringConverterTestCase: case_insensitive = True exact_match = False + def _check_value_error(self, val): + pattern = r'\(got {}\)'.format(re.escape(repr(val))) + with pytest.raises(ValueError, match=pattern) as exc: + self.conv(val) + def _check(self, val, expected): assert self.conv(val) == expected @@ -23,8 +30,8 @@ class StringConverterTestCase: if len(val) != 1: if self.exact_match: - with pytest.raises(ValueError): - self.conv(val[:1]) + self._check_value_error(val[:1]) + self._check_value_error(val + '\0') else: assert self.conv(val[:1]) == expected @@ -35,11 +42,28 @@ class StringConverterTestCase: assert self.conv(val.upper()) == expected else: if val != val.lower(): - with pytest.raises(ValueError): - self.conv(val.lower()) + self._check_value_error(val.lower()) if val != val.upper(): - with pytest.raises(ValueError): - self.conv(val.upper()) + self._check_value_error(val.upper()) + + def test_wrong_type(self): + # common cases which apply to all the below + with pytest.raises(TypeError): + self.conv({}) + with pytest.raises(TypeError): + self.conv([]) + + def test_wrong_value(self): + # nonsense strings + self._check_value_error('') + self._check_value_error('\N{greek small letter pi}') + + if self.allow_bytes: + self._check_value_error(b'') + # bytes which can't be converted to strings via utf8 + self._check_value_error(b"\xFF") + if self.exact_match: + self._check_value_error("there's no way this is supported") class TestByteorderConverter(StringConverterTestCase): @@ -95,6 +119,14 @@ class TestOrderConverter(StringConverterTestCase): self._check('a', 'NPY_ANYORDER') self._check('k', 'NPY_KEEPORDER') + def test_flatten_invalid_order(self): + # invalid after gh-14596 + with pytest.raises(ValueError): + self.conv('Z') + for order in [False, True, 0, 8]: + with pytest.raises(TypeError): + self.conv(order) + class TestClipmodeConverter(StringConverterTestCase): """ Tests of PyArray_ClipmodeConverter """ diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index fd0de8732..68491681a 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -27,7 +27,7 @@ class TestEinsum: optimize=do_opt) # order parameter must be a valid order - assert_raises(TypeError, np.einsum, "", 0, order='W', + assert_raises(ValueError, np.einsum, "", 0, order='W', optimize=do_opt) # casting parameter must be a valid casting diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 4e2d2ad41..f36c27c6c 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2758,11 +2758,6 @@ class TestMethods: assert_equal(x1.flatten('F'), y1f) assert_equal(x1.flatten('F'), x1.T.flatten()) - def test_flatten_invalid_order(self): - # invalid after gh-14596 - for order in ['Z', 'c', False, True, 0, 8]: - x = np.array([[1, 2, 3], [4, 5, 6]], np.int32) - assert_raises(ValueError, x.flatten, {"order": order}) @pytest.mark.parametrize('func', (np.dot, np.matmul)) def test_arr_mult(self, func): |