summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorDaniel da Silva <ddasilva@users.noreply.github.com>2022-11-21 17:17:52 -0500
committerGitHub <noreply@github.com>2022-11-21 15:17:52 -0700
commit04d0e2155704ad939980a4eafefe3d817076fa39 (patch)
treedf9054cd895630afdaf4b179ea1375a578e622de /numpy
parent5d32b8d63ab9a376c095f142b86bb55cdfbb95ff (diff)
downloadnumpy-04d0e2155704ad939980a4eafefe3d817076fa39.tar.gz
ENH: raise TypeError when arange() is called with string dtype (#22087)
* ENH: raise TypeError when arange() is called with string dtype * Add release note for dtype=str change to arange() * DOC: Minor wording/formatting touchups to release note. * Update numpy/core/tests/test_multiarray.py Co-authored-by: Ross Barnowski <rossbar@berkeley.edu> * Move check to PyArray_ArangeObj * remove old code * BUG,MAINT: Clean out arange string error and other paths * BUGS: Fixup and cleanup arange code a bit * DOC: Update release note to new message * BUG: Fix refcounting and simplify arange * MAINT: Use SETREF to make arange dtype discovery more compact * MAINT: Update numpy/core/src/multiarray/ctors.c Co-authored-by: Ross Barnowski <rossbar@berkeley.edu> Co-authored-by: Sebastian Berg <sebastianb@nvidia.com> Co-authored-by: Charles Harris <charlesr.harris@gmail.com>
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/arraytypes.c.src13
-rw-r--r--numpy/core/src/multiarray/ctors.c132
-rw-r--r--numpy/core/tests/test_multiarray.py56
3 files changed, 140 insertions, 61 deletions
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src
index 34694aac6..c03d09784 100644
--- a/numpy/core/src/multiarray/arraytypes.c.src
+++ b/numpy/core/src/multiarray/arraytypes.c.src
@@ -3922,7 +3922,18 @@ OBJECT_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp
*/
-#define BOOL_fill NULL
+/* Boolean fill never works, but define it so that it works up to length 2 */
+static int
+BOOL_fill(PyObject **buffer, npy_intp length, void *NPY_UNUSED(ignored))
+{
+ NPY_ALLOW_C_API_DEF;
+ NPY_ALLOW_C_API;
+ PyErr_SetString(PyExc_TypeError,
+ "arange() is only supported for booleans when the result has at "
+ "most length 2.");
+ NPY_DISABLE_C_API;
+ return -1;
+}
/* this requires buffer to be filled with objects or NULL */
static int
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c
index 05575cad7..fc3942e91 100644
--- a/numpy/core/src/multiarray/ctors.c
+++ b/numpy/core/src/multiarray/ctors.c
@@ -3244,9 +3244,9 @@ PyArray_ArangeObj(PyObject *start, PyObject *stop, PyObject *step, PyArray_Descr
{
PyArrayObject *range;
PyArray_ArrFuncs *funcs;
- PyObject *next, *err;
- npy_intp length;
+ PyObject *next = NULL;
PyArray_Descr *native = NULL;
+ npy_intp length;
int swap;
NPY_BEGIN_THREADS_DEF;
@@ -3259,94 +3259,100 @@ PyArray_ArangeObj(PyObject *start, PyObject *stop, PyObject *step, PyArray_Descr
return (PyObject *)datetime_arange(start, stop, step, dtype);
}
- if (!dtype) {
- PyArray_Descr *deftype;
- PyArray_Descr *newtype;
+ /* We need to replace many of these, so hold on for easier cleanup */
+ Py_XINCREF(start);
+ Py_XINCREF(stop);
+ Py_XINCREF(step);
+ Py_XINCREF(dtype);
+ if (!dtype) {
/* intentionally made to be at least NPY_LONG */
- deftype = PyArray_DescrFromType(NPY_LONG);
- newtype = PyArray_DescrFromObject(start, deftype);
- Py_DECREF(deftype);
- if (newtype == NULL) {
- return NULL;
+ dtype = PyArray_DescrFromType(NPY_LONG);
+ Py_SETREF(dtype, PyArray_DescrFromObject(start, dtype));
+ if (dtype == NULL) {
+ goto fail;
}
- deftype = newtype;
if (stop && stop != Py_None) {
- newtype = PyArray_DescrFromObject(stop, deftype);
- Py_DECREF(deftype);
- if (newtype == NULL) {
- return NULL;
+ Py_SETREF(dtype, PyArray_DescrFromObject(stop, dtype));
+ if (dtype == NULL) {
+ goto fail;
}
- deftype = newtype;
}
if (step && step != Py_None) {
- newtype = PyArray_DescrFromObject(step, deftype);
- Py_DECREF(deftype);
- if (newtype == NULL) {
- return NULL;
+ Py_SETREF(dtype, PyArray_DescrFromObject(step, dtype));
+ if (dtype == NULL) {
+ goto fail;
}
- deftype = newtype;
}
- dtype = deftype;
+ }
+
+ /*
+ * If dtype is not in native byte-order then get native-byte
+ * order version. And then swap on the way out.
+ */
+ if (!PyArray_ISNBO(dtype->byteorder)) {
+ native = PyArray_DescrNewByteorder(dtype, NPY_NATBYTE);
+ if (native == NULL) {
+ goto fail;
+ }
+ swap = 1;
}
else {
Py_INCREF(dtype);
+ native = dtype;
+ swap = 0;
}
- if (!step || step == Py_None) {
- step = PyLong_FromLong(1);
+
+ funcs = native->f;
+ if (!funcs->fill) {
+ /* This effectively forbids subarray types as well... */
+ PyErr_Format(PyExc_TypeError,
+ "arange() not supported for inputs with DType %S.",
+ Py_TYPE(dtype));
+ goto fail;
}
- else {
- Py_XINCREF(step);
+
+ if (!step || step == Py_None) {
+ Py_XSETREF(step, PyLong_FromLong(1));
+ if (step == NULL) {
+ goto fail;
+ }
}
if (!stop || stop == Py_None) {
- stop = start;
+ Py_XSETREF(stop, start);
start = PyLong_FromLong(0);
+ if (start == NULL) {
+ goto fail;
+ }
}
- else {
- Py_INCREF(start);
- }
+
/* calculate the length and next = start + step*/
length = _calc_length(start, stop, step, &next,
PyTypeNum_ISCOMPLEX(dtype->type_num));
- err = PyErr_Occurred();
+ PyObject *err = PyErr_Occurred();
if (err) {
- Py_DECREF(dtype);
- if (err && PyErr_GivenExceptionMatches(err, PyExc_OverflowError)) {
+ if (PyErr_GivenExceptionMatches(err, PyExc_OverflowError)) {
PyErr_SetString(PyExc_ValueError, "Maximum allowed size exceeded");
}
goto fail;
}
if (length <= 0) {
length = 0;
- range = (PyArrayObject *)PyArray_SimpleNewFromDescr(1, &length, dtype);
- Py_DECREF(step);
- Py_DECREF(start);
- return (PyObject *)range;
- }
-
- /*
- * If dtype is not in native byte-order then get native-byte
- * order version. And then swap on the way out.
- */
- if (!PyArray_ISNBO(dtype->byteorder)) {
- native = PyArray_DescrNewByteorder(dtype, NPY_NATBYTE);
- swap = 1;
- }
- else {
- native = dtype;
- swap = 0;
}
+ Py_INCREF(native);
range = (PyArrayObject *)PyArray_SimpleNewFromDescr(1, &length, native);
if (range == NULL) {
goto fail;
}
+ if (length == 0) {
+ goto finish;
+ }
/*
* place start in the buffer and the next value in the second position
* if length > 2, then call the inner loop, otherwise stop
*/
- funcs = PyArray_DESCR(range)->f;
if (funcs->setitem(start, PyArray_DATA(range), range) < 0) {
goto fail;
}
@@ -3360,11 +3366,7 @@ PyArray_ArangeObj(PyObject *start, PyObject *stop, PyObject *step, PyArray_Descr
if (length == 2) {
goto finish;
}
- if (!funcs->fill) {
- PyErr_SetString(PyExc_ValueError, "no fill-function for data-type.");
- Py_DECREF(range);
- goto fail;
- }
+
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(range));
funcs->fill(PyArray_DATA(range), length, range);
NPY_END_THREADS;
@@ -3376,19 +3378,29 @@ PyArray_ArangeObj(PyObject *start, PyObject *stop, PyObject *step, PyArray_Descr
if (swap) {
PyObject *new;
new = PyArray_Byteswap(range, 1);
+ if (new == NULL) {
+ goto fail;
+ }
Py_DECREF(new);
+ /* Replace dtype after swapping in-place above: */
Py_DECREF(PyArray_DESCR(range));
- /* steals the reference */
+ Py_INCREF(dtype);
((PyArrayObject_fields *)range)->descr = dtype;
}
+ Py_DECREF(dtype);
+ Py_DECREF(native);
Py_DECREF(start);
+ Py_DECREF(stop);
Py_DECREF(step);
- Py_DECREF(next);
+ Py_XDECREF(next);
return (PyObject *)range;
fail:
- Py_DECREF(start);
- Py_DECREF(step);
+ Py_XDECREF(dtype);
+ Py_XDECREF(native);
+ Py_XDECREF(start);
+ Py_XDECREF(stop);
+ Py_XDECREF(step);
Py_XDECREF(next);
return NULL;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 15619bcb3..ad1d9bb04 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -9,6 +9,7 @@ import functools
import ctypes
import os
import gc
+import re
import weakref
import pytest
from contextlib import contextmanager
@@ -9324,6 +9325,61 @@ class TestArange:
assert len(keyword_start_stop) == 6
assert_array_equal(keyword_stop, keyword_zerotostop)
+ def test_arange_booleans(self):
+ # Arange makes some sense for booleans and works up to length 2.
+ # But it is weird since `arange(2, 4, dtype=bool)` works.
+ # Arguably, much or all of this could be deprecated/removed.
+ res = np.arange(False, dtype=bool)
+ assert_array_equal(res, np.array([], dtype="bool"))
+
+ res = np.arange(True, dtype="bool")
+ assert_array_equal(res, [False])
+
+ res = np.arange(2, dtype="bool")
+ assert_array_equal(res, [False, True])
+
+ # This case is especially weird, but drops out without special case:
+ res = np.arange(6, 8, dtype="bool")
+ assert_array_equal(res, [True, True])
+
+ with pytest.raises(TypeError):
+ np.arange(3, dtype="bool")
+
+ @pytest.mark.parametrize("dtype", ["S3", "U", "5i"])
+ def test_rejects_bad_dtypes(self, dtype):
+ dtype = np.dtype(dtype)
+ DType_name = re.escape(str(type(dtype)))
+ with pytest.raises(TypeError,
+ match=rf"arange\(\) not supported for inputs .* {DType_name}"):
+ np.arange(2, dtype=dtype)
+
+ def test_rejects_strings(self):
+ # Explicitly test error for strings which may call "b" - "a":
+ DType_name = re.escape(str(type(np.array("a").dtype)))
+ with pytest.raises(TypeError,
+ match=rf"arange\(\) not supported for inputs .* {DType_name}"):
+ np.arange("a", "b")
+
+ def test_byteswapped(self):
+ res_be = np.arange(1, 1000, dtype=">i4")
+ res_le = np.arange(1, 1000, dtype="<i4")
+ assert res_be.dtype == ">i4"
+ assert res_le.dtype == "<i4"
+ assert_array_equal(res_le, res_be)
+
+ @pytest.mark.parametrize("which", [0, 1, 2])
+ def test_error_paths_and_promotion(self, which):
+ args = [0, 1, 2] # start, stop, and step
+ args[which] = np.float64(2.) # should ensure float64 output
+
+ assert np.arange(*args).dtype == np.float64
+
+ # Cover stranger error path, test only to achieve code coverage!
+ args[which] = [None, []]
+ with pytest.raises(ValueError):
+ # Fails discovering start dtype
+ np.arange(*args)
+
class TestArrayFinalize:
""" Tests __array_finalize__ """