summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/conversion_utils.c489
-rw-r--r--numpy/core/tests/test_conversion_utils.py44
-rw-r--r--numpy/core/tests/test_einsum.py2
-rw-r--r--numpy/core/tests/test_multiarray.py5
4 files changed, 269 insertions, 271 deletions
diff --git a/numpy/core/src/multiarray/conversion_utils.c b/numpy/core/src/multiarray/conversion_utils.c
index 260ae7080..14d546867 100644
--- a/numpy/core/src/multiarray/conversion_utils.c
+++ b/numpy/core/src/multiarray/conversion_utils.c
@@ -322,100 +322,112 @@ PyArray_BoolConverter(PyObject *object, npy_bool *val)
return NPY_SUCCEED;
}
-/*NUMPY_API
- * Convert object to endian
- */
-NPY_NO_EXPORT int
-PyArray_ByteorderConverter(PyObject *obj, char *endian)
+static int
+string_converter_helper(
+ PyObject *object,
+ void *out,
+ int (*str_func)(char const*, Py_ssize_t, void*),
+ char const *name,
+ char const *message)
{
- char *str;
- PyObject *tmp = NULL;
-
- if (PyUnicode_Check(obj)) {
- obj = tmp = PyUnicode_AsASCIIString(obj);
+ /* allow bytes for compatibility */
+ PyObject *str_object = NULL;
+ if (PyBytes_Check(object)) {
+ str_object = PyUnicode_FromEncodedObject(object, NULL, NULL);
+ if (str_object == NULL) {
+ PyErr_Format(PyExc_ValueError,
+ "%s %s (got %R)", name, message, object);
+ return NPY_FAIL;
+ }
}
-
- *endian = NPY_SWAP;
- str = PyBytes_AsString(obj);
- if (!str) {
- Py_XDECREF(tmp);
+ else if (PyUnicode_Check(object)) {
+ str_object = object;
+ Py_INCREF(str_object);
+ }
+ else {
+ PyErr_Format(PyExc_TypeError,
+ "%s must be str, not %s", name, Py_TYPE(object)->tp_name);
return NPY_FAIL;
}
- if (strlen(str) < 1) {
- PyErr_SetString(PyExc_ValueError,
- "Byteorder string must be at least length 1");
- Py_XDECREF(tmp);
+
+ Py_ssize_t length;
+ char const *str = PyUnicode_AsUTF8AndSize(str_object, &length);
+ if (str == NULL) {
+ Py_DECREF(str_object);
return NPY_FAIL;
}
- *endian = str[0];
- if (str[0] != NPY_BIG && str[0] != NPY_LITTLE
- && str[0] != NPY_NATIVE && str[0] != NPY_IGNORE) {
- if (str[0] == 'b' || str[0] == 'B') {
- *endian = NPY_BIG;
- }
- else if (str[0] == 'l' || str[0] == 'L') {
- *endian = NPY_LITTLE;
- }
- else if (str[0] == 'n' || str[0] == 'N') {
- *endian = NPY_NATIVE;
- }
- else if (str[0] == 'i' || str[0] == 'I') {
- *endian = NPY_IGNORE;
- }
- else if (str[0] == 's' || str[0] == 'S') {
- *endian = NPY_SWAP;
- }
- else {
+
+ int ret = str_func(str, length, out);
+ Py_DECREF(str_object);
+ if (ret < 0) {
PyErr_Format(PyExc_ValueError,
- "%s is an unrecognized byteorder",
- str);
- Py_XDECREF(tmp);
- return NPY_FAIL;
- }
+ "%s %s (got %R)", name, message, object);
+ return NPY_FAIL;
}
- Py_XDECREF(tmp);
return NPY_SUCCEED;
}
-/*NUMPY_API
- * Convert object to sort kind
- */
-NPY_NO_EXPORT int
-PyArray_SortkindConverter(PyObject *obj, NPY_SORTKIND *sortkind)
+static int byteorder_parser(char const *str, Py_ssize_t length, void *data)
{
- char *str;
- PyObject *tmp = NULL;
+ char *endian = (char *)data;
- if (obj == Py_None) {
- *sortkind = NPY_QUICKSORT;
- return NPY_SUCCEED;
+ if (length < 1) {
+ return -1;
}
-
- if (PyUnicode_Check(obj)) {
- obj = tmp = PyUnicode_AsASCIIString(obj);
- if (obj == NULL) {
- return NPY_FAIL;
- }
+ else if (str[0] == NPY_BIG || str[0] == NPY_LITTLE
+ || str[0] == NPY_NATIVE || str[0] == NPY_IGNORE) {
+ *endian = str[0];
+ return 0;
+ }
+ else if (str[0] == 'b' || str[0] == 'B') {
+ *endian = NPY_BIG;
+ return 0;
+ }
+ else if (str[0] == 'l' || str[0] == 'L') {
+ *endian = NPY_LITTLE;
+ return 0;
+ }
+ else if (str[0] == 'n' || str[0] == 'N') {
+ *endian = NPY_NATIVE;
+ return 0;
+ }
+ else if (str[0] == 'i' || str[0] == 'I') {
+ *endian = NPY_IGNORE;
+ return 0;
+ }
+ else if (str[0] == 's' || str[0] == 'S') {
+ *endian = NPY_SWAP;
+ return 0;
+ }
+ else {
+ return -1;
}
+}
- *sortkind = NPY_QUICKSORT;
+/*NUMPY_API
+ * Convert object to endian
+ */
+NPY_NO_EXPORT int
+PyArray_ByteorderConverter(PyObject *obj, char *endian)
+{
+ return string_converter_helper(
+ obj, (void *)endian, byteorder_parser, "byteorder", "not recognized");
+}
- str = PyBytes_AsString(obj);
- if (!str) {
- Py_XDECREF(tmp);
- return NPY_FAIL;
- }
- if (strlen(str) < 1) {
- PyErr_SetString(PyExc_ValueError,
- "Sort kind string must be at least length 1");
- Py_XDECREF(tmp);
- return NPY_FAIL;
+static int sortkind_parser(char const *str, Py_ssize_t length, void *data)
+{
+ NPY_SORTKIND *sortkind = (NPY_SORTKIND *)data;
+
+ if (length < 1) {
+ return -1;
}
if (str[0] == 'q' || str[0] == 'Q') {
*sortkind = NPY_QUICKSORT;
+ return 0;
}
else if (str[0] == 'h' || str[0] == 'H') {
*sortkind = NPY_HEAPSORT;
+ return 0;
}
else if (str[0] == 'm' || str[0] == 'M') {
/*
@@ -424,6 +436,7 @@ PyArray_SortkindConverter(PyObject *obj, NPY_SORTKIND *sortkind)
* allowing other types of stable sorts to be used.
*/
*sortkind = NPY_MERGESORT;
+ return 0;
}
else if (str[0] == 's' || str[0] == 'S') {
/*
@@ -435,16 +448,39 @@ PyArray_SortkindConverter(PyObject *obj, NPY_SORTKIND *sortkind)
* Which one is used depends on the data type.
*/
*sortkind = NPY_STABLESORT;
+ return 0;
}
else {
- PyErr_Format(PyExc_ValueError,
- "%s is an unrecognized kind of sort",
- str);
- Py_XDECREF(tmp);
- return NPY_FAIL;
+ return -1;
+ }
+}
+
+/*NUMPY_API
+ * Convert object to sort kind
+ */
+NPY_NO_EXPORT int
+PyArray_SortkindConverter(PyObject *obj, NPY_SORTKIND *sortkind)
+{
+ /* Leave the desired default from the caller for Py_None */
+ if (obj == Py_None) {
+ return NPY_SUCCEED;
+ }
+ return string_converter_helper(
+ obj, (void *)sortkind, sortkind_parser, "sort kind",
+ "must be one of 'quick', 'heap', or 'stable'");
+}
+
+static int selectkind_parser(char const *str, Py_ssize_t length, void *data)
+{
+ NPY_SELECTKIND *selectkind = (NPY_SELECTKIND *)data;
+
+ if (length == 11 && strcmp(str, "introselect") == 0) {
+ *selectkind = NPY_INTROSELECT;
+ return 0;
+ }
+ else {
+ return -1;
}
- Py_XDECREF(tmp);
- return NPY_SUCCEED;
}
/*NUMPY_API
@@ -453,40 +489,29 @@ PyArray_SortkindConverter(PyObject *obj, NPY_SORTKIND *sortkind)
NPY_NO_EXPORT int
PyArray_SelectkindConverter(PyObject *obj, NPY_SELECTKIND *selectkind)
{
- char *str;
- PyObject *tmp = NULL;
+ return string_converter_helper(
+ obj, (void *)selectkind, selectkind_parser, "select kind",
+ "must be 'introselect'");
+}
- if (PyUnicode_Check(obj)) {
- obj = tmp = PyUnicode_AsASCIIString(obj);
- if (obj == NULL) {
- return NPY_FAIL;
- }
- }
+static int searchside_parser(char const *str, Py_ssize_t length, void *data)
+{
+ NPY_SEARCHSIDE *side = (NPY_SEARCHSIDE *)data;
- *selectkind = NPY_INTROSELECT;
- str = PyBytes_AsString(obj);
- if (!str) {
- Py_XDECREF(tmp);
- return NPY_FAIL;
+ if (length < 1) {
+ return -1;
}
- if (strlen(str) < 1) {
- PyErr_SetString(PyExc_ValueError,
- "Select kind string must be at least length 1");
- Py_XDECREF(tmp);
- return NPY_FAIL;
+ else if (str[0] == 'l' || str[0] == 'L') {
+ *side = NPY_SEARCHLEFT;
+ return 0;
}
- if (strcmp(str, "introselect") == 0) {
- *selectkind = NPY_INTROSELECT;
+ else if (str[0] == 'r' || str[0] == 'R') {
+ *side = NPY_SEARCHRIGHT;
+ return 0;
}
else {
- PyErr_Format(PyExc_ValueError,
- "%s is an unrecognized kind of select",
- str);
- Py_XDECREF(tmp);
- return NPY_FAIL;
+ return -1;
}
- Py_XDECREF(tmp);
- return NPY_SUCCEED;
}
/*NUMPY_API
@@ -495,36 +520,36 @@ PyArray_SelectkindConverter(PyObject *obj, NPY_SELECTKIND *selectkind)
NPY_NO_EXPORT int
PyArray_SearchsideConverter(PyObject *obj, void *addr)
{
- NPY_SEARCHSIDE *side = (NPY_SEARCHSIDE *)addr;
- char *str;
- PyObject *tmp = NULL;
+ return string_converter_helper(
+ obj, addr, searchside_parser, "search side",
+ "must be 'left' or 'right'");
+}
- if (PyUnicode_Check(obj)) {
- obj = tmp = PyUnicode_AsASCIIString(obj);
+static int order_parser(char const *str, Py_ssize_t length, void *data)
+{
+ NPY_ORDER *val = (NPY_ORDER *)data;
+ if (length != 1) {
+ return -1;
}
-
- str = PyBytes_AsString(obj);
- if (!str || strlen(str) < 1) {
- PyErr_SetString(PyExc_ValueError,
- "expected nonempty string for keyword 'side'");
- Py_XDECREF(tmp);
- return NPY_FAIL;
+ if (str[0] == 'C' || str[0] == 'c') {
+ *val = NPY_CORDER;
+ return 0;
}
-
- if (str[0] == 'l' || str[0] == 'L') {
- *side = NPY_SEARCHLEFT;
+ else if (str[0] == 'F' || str[0] == 'f') {
+ *val = NPY_FORTRANORDER;
+ return 0;
}
- else if (str[0] == 'r' || str[0] == 'R') {
- *side = NPY_SEARCHRIGHT;
+ else if (str[0] == 'A' || str[0] == 'a') {
+ *val = NPY_ANYORDER;
+ return 0;
+ }
+ else if (str[0] == 'K' || str[0] == 'k') {
+ *val = NPY_KEEPORDER;
+ return 0;
}
else {
- PyErr_Format(PyExc_ValueError,
- "'%s' is an invalid value for keyword 'side'", str);
- Py_XDECREF(tmp);
- return NPY_FAIL;
+ return -1;
}
- Py_XDECREF(tmp);
- return NPY_SUCCEED;
}
/*NUMPY_API
@@ -533,59 +558,36 @@ PyArray_SearchsideConverter(PyObject *obj, void *addr)
NPY_NO_EXPORT int
PyArray_OrderConverter(PyObject *object, NPY_ORDER *val)
{
- char *str;
- /* Leave the desired default from the caller for NULL/Py_None */
- if (object == NULL || object == Py_None) {
+ /* Leave the desired default from the caller for Py_None */
+ if (object == Py_None) {
return NPY_SUCCEED;
}
- else if (PyUnicode_Check(object)) {
- PyObject *tmp;
- int ret;
- tmp = PyUnicode_AsASCIIString(object);
- if (tmp == NULL) {
- PyErr_SetString(PyExc_ValueError,
- "Invalid unicode string passed in for the array ordering. "
- "Please pass in 'C', 'F', 'A' or 'K' instead");
- return NPY_FAIL;
- }
- ret = PyArray_OrderConverter(tmp, val);
- Py_DECREF(tmp);
- return ret;
- }
- else if (!PyBytes_Check(object) || PyBytes_GET_SIZE(object) < 1) {
- PyErr_SetString(PyExc_ValueError,
- "Non-string object detected for the array ordering. "
- "Please pass in 'C', 'F', 'A', or 'K' instead");
- return NPY_FAIL;
+ return string_converter_helper(
+ object, (void *)val, order_parser, "order",
+ "must be one of 'C', 'F', 'A', or 'K'");
+}
+
+static int clipmode_parser(char const *str, Py_ssize_t length, void *data)
+{
+ NPY_CLIPMODE *val = (NPY_CLIPMODE *)data;
+ if (length < 1) {
+ return -1;
+ }
+ if (str[0] == 'C' || str[0] == 'c') {
+ *val = NPY_CLIP;
+ return 0;
+ }
+ else if (str[0] == 'W' || str[0] == 'w') {
+ *val = NPY_WRAP;
+ return 0;
+ }
+ else if (str[0] == 'R' || str[0] == 'r') {
+ *val = NPY_RAISE;
+ return 0;
}
else {
- str = PyBytes_AS_STRING(object);
- if (strlen(str) != 1) {
- PyErr_SetString(PyExc_ValueError,
- "Non-string object detected for the array ordering. "
- "Please pass in 'C', 'F', 'A', or 'K' instead");
- return NPY_FAIL;
- }
-
- if (str[0] == 'C' || str[0] == 'c') {
- *val = NPY_CORDER;
- }
- else if (str[0] == 'F' || str[0] == 'f') {
- *val = NPY_FORTRANORDER;
- }
- else if (str[0] == 'A' || str[0] == 'a') {
- *val = NPY_ANYORDER;
- }
- else if (str[0] == 'K' || str[0] == 'k') {
- *val = NPY_KEEPORDER;
- }
- else {
- PyErr_SetString(PyExc_TypeError,
- "order not understood");
- return NPY_FAIL;
- }
+ return -1;
}
- return NPY_SUCCEED;
}
/*NUMPY_API
@@ -597,36 +599,14 @@ PyArray_ClipmodeConverter(PyObject *object, NPY_CLIPMODE *val)
if (object == NULL || object == Py_None) {
*val = NPY_RAISE;
}
- else if (PyBytes_Check(object)) {
- char *str;
- str = PyBytes_AS_STRING(object);
- if (str[0] == 'C' || str[0] == 'c') {
- *val = NPY_CLIP;
- }
- else if (str[0] == 'W' || str[0] == 'w') {
- *val = NPY_WRAP;
- }
- else if (str[0] == 'R' || str[0] == 'r') {
- *val = NPY_RAISE;
- }
- else {
- PyErr_SetString(PyExc_TypeError,
- "clipmode not understood");
- return NPY_FAIL;
- }
- }
- else if (PyUnicode_Check(object)) {
- PyObject *tmp;
- int ret;
- tmp = PyUnicode_AsASCIIString(object);
- if (tmp == NULL) {
- return NPY_FAIL;
- }
- ret = PyArray_ClipmodeConverter(tmp, val);
- Py_DECREF(tmp);
- return ret;
+
+ else if (PyBytes_Check(object) || PyUnicode_Check(object)) {
+ return string_converter_helper(
+ object, (void *)val, clipmode_parser, "clipmode",
+ "must be one of 'clip', 'raise', or 'wrap'");
}
else {
+ /* For users passing `np.RAISE`, `np.WRAP`, `np.CLIP` */
int number = PyArray_PyIntAsInt(object);
if (error_converting(number)) {
goto fail;
@@ -636,7 +616,8 @@ PyArray_ClipmodeConverter(PyObject *object, NPY_CLIPMODE *val)
*val = (NPY_CLIPMODE) number;
}
else {
- goto fail;
+ PyErr_Format(PyExc_ValueError,
+ "integer clipmode must be np.RAISE, np.WRAP, or np.CLIP");
}
}
return NPY_SUCCEED;
@@ -690,66 +671,56 @@ PyArray_ConvertClipmodeSequence(PyObject *object, NPY_CLIPMODE *modes, int n)
return NPY_SUCCEED;
}
+static int casting_parser(char const *str, Py_ssize_t length, void *data)
+{
+ NPY_CASTING *casting = (NPY_CASTING *)data;
+ if (length < 2) {
+ return -1;
+ }
+ switch (str[2]) {
+ case 0:
+ if (length == 2 && strcmp(str, "no") == 0) {
+ *casting = NPY_NO_CASTING;
+ return 0;
+ }
+ break;
+ case 'u':
+ if (length == 5 && strcmp(str, "equiv") == 0) {
+ *casting = NPY_EQUIV_CASTING;
+ return 0;
+ }
+ break;
+ case 'f':
+ if (length == 4 && strcmp(str, "safe") == 0) {
+ *casting = NPY_SAFE_CASTING;
+ return 0;
+ }
+ break;
+ case 'm':
+ if (length == 9 && strcmp(str, "same_kind") == 0) {
+ *casting = NPY_SAME_KIND_CASTING;
+ return 0;
+ }
+ break;
+ case 's':
+ if (length == 6 && strcmp(str, "unsafe") == 0) {
+ *casting = NPY_UNSAFE_CASTING;
+ return 0;
+ }
+ break;
+ }
+ return -1;
+}
+
/*NUMPY_API
* Convert any Python object, *obj*, to an NPY_CASTING enum.
*/
NPY_NO_EXPORT int
PyArray_CastingConverter(PyObject *obj, NPY_CASTING *casting)
{
- char *str = NULL;
- Py_ssize_t length = 0;
-
- if (PyUnicode_Check(obj)) {
- PyObject *str_obj;
- int ret;
- str_obj = PyUnicode_AsASCIIString(obj);
- if (str_obj == NULL) {
- return 0;
- }
- ret = PyArray_CastingConverter(str_obj, casting);
- Py_DECREF(str_obj);
- return ret;
- }
-
- if (PyBytes_AsStringAndSize(obj, &str, &length) < 0) {
- return 0;
- }
-
- if (length >= 2) switch (str[2]) {
- case 0:
- if (strcmp(str, "no") == 0) {
- *casting = NPY_NO_CASTING;
- return 1;
- }
- break;
- case 'u':
- if (strcmp(str, "equiv") == 0) {
- *casting = NPY_EQUIV_CASTING;
- return 1;
- }
- break;
- case 'f':
- if (strcmp(str, "safe") == 0) {
- *casting = NPY_SAFE_CASTING;
- return 1;
- }
- break;
- case 'm':
- if (strcmp(str, "same_kind") == 0) {
- *casting = NPY_SAME_KIND_CASTING;
- return 1;
- }
- break;
- case 's':
- if (strcmp(str, "unsafe") == 0) {
- *casting = NPY_UNSAFE_CASTING;
- return 1;
- }
- break;
- }
-
- PyErr_SetString(PyExc_ValueError,
- "casting must be one of 'no', 'equiv', 'safe', "
+ return string_converter_helper(
+ obj, (void *)casting, casting_parser, "casting",
+ "must be one of 'no', 'equiv', 'safe', "
"'same_kind', or 'unsafe'");
return 0;
}
diff --git a/numpy/core/tests/test_conversion_utils.py b/numpy/core/tests/test_conversion_utils.py
index 9e80fdcbf..3c3f9e6e1 100644
--- a/numpy/core/tests/test_conversion_utils.py
+++ b/numpy/core/tests/test_conversion_utils.py
@@ -1,6 +1,8 @@
"""
Tests for numpy/core/src/multiarray/conversion_utils.c
"""
+import re
+
import pytest
import numpy as np
@@ -12,6 +14,11 @@ class StringConverterTestCase:
case_insensitive = True
exact_match = False
+ def _check_value_error(self, val):
+ pattern = r'\(got {}\)'.format(re.escape(repr(val)))
+ with pytest.raises(ValueError, match=pattern) as exc:
+ self.conv(val)
+
def _check(self, val, expected):
assert self.conv(val) == expected
@@ -23,8 +30,8 @@ class StringConverterTestCase:
if len(val) != 1:
if self.exact_match:
- with pytest.raises(ValueError):
- self.conv(val[:1])
+ self._check_value_error(val[:1])
+ self._check_value_error(val + '\0')
else:
assert self.conv(val[:1]) == expected
@@ -35,11 +42,28 @@ class StringConverterTestCase:
assert self.conv(val.upper()) == expected
else:
if val != val.lower():
- with pytest.raises(ValueError):
- self.conv(val.lower())
+ self._check_value_error(val.lower())
if val != val.upper():
- with pytest.raises(ValueError):
- self.conv(val.upper())
+ self._check_value_error(val.upper())
+
+ def test_wrong_type(self):
+ # common cases which apply to all the below
+ with pytest.raises(TypeError):
+ self.conv({})
+ with pytest.raises(TypeError):
+ self.conv([])
+
+ def test_wrong_value(self):
+ # nonsense strings
+ self._check_value_error('')
+ self._check_value_error('\N{greek small letter pi}')
+
+ if self.allow_bytes:
+ self._check_value_error(b'')
+ # bytes which can't be converted to strings via utf8
+ self._check_value_error(b"\xFF")
+ if self.exact_match:
+ self._check_value_error("there's no way this is supported")
class TestByteorderConverter(StringConverterTestCase):
@@ -95,6 +119,14 @@ class TestOrderConverter(StringConverterTestCase):
self._check('a', 'NPY_ANYORDER')
self._check('k', 'NPY_KEEPORDER')
+ def test_flatten_invalid_order(self):
+ # invalid after gh-14596
+ with pytest.raises(ValueError):
+ self.conv('Z')
+ for order in [False, True, 0, 8]:
+ with pytest.raises(TypeError):
+ self.conv(order)
+
class TestClipmodeConverter(StringConverterTestCase):
""" Tests of PyArray_ClipmodeConverter """
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index fd0de8732..68491681a 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -27,7 +27,7 @@ class TestEinsum:
optimize=do_opt)
# order parameter must be a valid order
- assert_raises(TypeError, np.einsum, "", 0, order='W',
+ assert_raises(ValueError, np.einsum, "", 0, order='W',
optimize=do_opt)
# casting parameter must be a valid casting
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 4e2d2ad41..f36c27c6c 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2758,11 +2758,6 @@ class TestMethods:
assert_equal(x1.flatten('F'), y1f)
assert_equal(x1.flatten('F'), x1.T.flatten())
- def test_flatten_invalid_order(self):
- # invalid after gh-14596
- for order in ['Z', 'c', False, True, 0, 8]:
- x = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
- assert_raises(ValueError, x.flatten, {"order": order})
@pytest.mark.parametrize('func', (np.dot, np.matmul))
def test_arr_mult(self, func):