summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--benchmarks/benchmarks/bench_ma.py146
-rw-r--r--doc/release/upcoming_changes/22707.compatibility.rst4
-rw-r--r--doc/release/upcoming_changes/22707.expired.rst13
-rw-r--r--doc/release/upcoming_changes/22707.improvement.rst8
-rw-r--r--numpy/__init__.py4
-rw-r--r--numpy/core/_exceptions.py27
-rw-r--r--numpy/core/_internal.py7
-rw-r--r--numpy/core/code_generators/generate_umath.py2
-rw-r--r--numpy/core/meson.build1
-rw-r--r--numpy/core/setup.py1
-rw-r--r--numpy/core/src/multiarray/arrayobject.c218
-rw-r--r--numpy/core/src/multiarray/common_dtype.c5
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c3
-rw-r--r--numpy/core/src/multiarray/convert_datatype.h2
-rw-r--r--numpy/core/src/multiarray/datetime.c6
-rw-r--r--numpy/core/src/multiarray/dtypemeta.c12
-rw-r--r--numpy/core/src/multiarray/experimental_public_dtype_api.c8
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c44
-rw-r--r--numpy/core/src/umath/dispatching.c31
-rw-r--r--numpy/core/src/umath/loops.c.src19
-rw-r--r--numpy/core/src/umath/loops.h.src37
-rw-r--r--numpy/core/src/umath/loops_unary.dispatch.c.src364
-rw-r--r--numpy/core/src/umath/simd.inc.src67
-rw-r--r--numpy/core/src/umath/ufunc_type_resolution.c19
-rw-r--r--numpy/core/tests/test_deprecations.py74
-rw-r--r--numpy/core/tests/test_multiarray.py117
-rw-r--r--numpy/core/tests/test_regression.py5
-rw-r--r--numpy/core/tests/test_umath.py32
-rw-r--r--numpy/core/tests/test_unicode.py8
-rw-r--r--numpy/exceptions.py52
-rw-r--r--numpy/lib/function_base.py51
-rw-r--r--numpy/lib/nanfunctions.py14
-rw-r--r--numpy/lib/tests/test_function_base.py50
-rw-r--r--numpy/lib/tests/test_nanfunctions.py58
-rw-r--r--numpy/linalg/tests/test_regression.py5
-rw-r--r--numpy/ma/bench.py130
-rw-r--r--numpy/ma/extras.py6
-rw-r--r--numpy/ma/tests/test_extras.py29
-rw-r--r--numpy/tests/test_public_api.py2
40 files changed, 1169 insertions, 513 deletions
diff --git a/.gitignore b/.gitignore
index 6f63498e0..9851fcc77 100644
--- a/.gitignore
+++ b/.gitignore
@@ -216,6 +216,7 @@ numpy/core/src/_simd/_simd.dispatch.c
numpy/core/src/_simd/_simd_data.inc
numpy/core/src/_simd/_simd_inc.h
# umath module
+numpy/core/src/umath/loops_unary.dispatch.c
numpy/core/src/umath/loops_unary_fp.dispatch.c
numpy/core/src/umath/loops_arithm_fp.dispatch.c
numpy/core/src/umath/loops_arithmetic.dispatch.c
diff --git a/benchmarks/benchmarks/bench_ma.py b/benchmarks/benchmarks/bench_ma.py
index 0247065a7..49ccf92fe 100644
--- a/benchmarks/benchmarks/bench_ma.py
+++ b/benchmarks/benchmarks/bench_ma.py
@@ -119,3 +119,149 @@ class Concatenate(Benchmark):
def time_it(self, mode, n):
np.ma.concatenate(self.args)
+
+
+class MAFunctions1v(Benchmark):
+ param_names = ['mtype', 'func', 'msize']
+ params = [['np', 'np.ma'],
+ ['sin', 'log', 'sqrt'],
+ ['small', 'big']]
+
+ def setup(self, mtype, func, msize):
+ xs = np.random.uniform(-1, 1, 6).reshape(2, 3)
+ m1 = [[True, False, False], [False, False, True]]
+ xl = np.random.uniform(-1, 1, 100*100).reshape(100, 100)
+ maskx = xl > 0.8
+ self.nmxs = np.ma.array(xs, mask=m1)
+ self.nmxl = np.ma.array(xl, mask=maskx)
+
+ def time_functions_1v(self, mtype, func, msize):
+ # fun = {'np.ma.sin': np.ma.sin, 'np.sin': np.sin}[func]
+ fun = eval(f"{mtype}.{func}")
+ if msize == 'small':
+ fun(self.nmxs)
+ elif msize == 'big':
+ fun(self.nmxl)
+
+
+class MAMethod0v(Benchmark):
+ param_names = ['method', 'msize']
+ params = [['ravel', 'transpose', 'compressed', 'conjugate'],
+ ['small', 'big']]
+
+ def setup(self, method, msize):
+ xs = np.random.uniform(-1, 1, 6).reshape(2, 3)
+ m1 = [[True, False, False], [False, False, True]]
+ xl = np.random.uniform(-1, 1, 100*100).reshape(100, 100)
+ maskx = xl > 0.8
+ self.nmxs = np.ma.array(xs, mask=m1)
+ self.nmxl = np.ma.array(xl, mask=maskx)
+
+ def time_methods_0v(self, method, msize):
+ if msize == 'small':
+ mdat = self.nmxs
+ elif msize == 'big':
+ mdat = self.nmxl
+ getattr(mdat, method)()
+
+
+class MAFunctions2v(Benchmark):
+ param_names = ['mtype', 'func', 'msize']
+ params = [['np', 'np.ma'],
+ ['multiply', 'divide', 'power'],
+ ['small', 'big']]
+
+ def setup(self, mtype, func, msize):
+ # Small arrays
+ xs = np.random.uniform(-1, 1, 6).reshape(2, 3)
+ ys = np.random.uniform(-1, 1, 6).reshape(2, 3)
+ m1 = [[True, False, False], [False, False, True]]
+ m2 = [[True, False, True], [False, False, True]]
+ self.nmxs = np.ma.array(xs, mask=m1)
+ self.nmys = np.ma.array(ys, mask=m2)
+ # Big arrays
+ xl = np.random.uniform(-1, 1, 100*100).reshape(100, 100)
+ yl = np.random.uniform(-1, 1, 100*100).reshape(100, 100)
+ maskx = xl > 0.8
+ masky = yl < -0.8
+ self.nmxl = np.ma.array(xl, mask=maskx)
+ self.nmyl = np.ma.array(yl, mask=masky)
+
+ def time_functions_2v(self, mtype, func, msize):
+ fun = eval(f"{mtype}.{func}")
+ if msize == 'small':
+ fun(self.nmxs, self.nmys)
+ elif msize == 'big':
+ fun(self.nmxl, self.nmyl)
+
+
+class MAMethodGetItem(Benchmark):
+ param_names = ['margs', 'msize']
+ params = [[0, (0, 0), [0, -1]],
+ ['small', 'big']]
+
+ def setup(self, margs, msize):
+ xs = np.random.uniform(-1, 1, 6).reshape(2, 3)
+ m1 = [[True, False, False], [False, False, True]]
+ xl = np.random.uniform(-1, 1, 100*100).reshape(100, 100)
+ maskx = xl > 0.8
+ self.nmxs = np.ma.array(xs, mask=m1)
+ self.nmxl = np.ma.array(xl, mask=maskx)
+
+ def time_methods_getitem(self, margs, msize):
+ if msize == 'small':
+ mdat = self.nmxs
+ elif msize == 'big':
+ mdat = self.nmxl
+ getattr(mdat, '__getitem__')(margs)
+
+
+class MAMethodSetItem(Benchmark):
+ param_names = ['margs', 'mset', 'msize']
+ params = [[0, (0, 0), (-1, 0)],
+ [17, np.ma.masked],
+ ['small', 'big']]
+
+ def setup(self, margs, mset, msize):
+ xs = np.random.uniform(-1, 1, 6).reshape(2, 3)
+ m1 = [[True, False, False], [False, False, True]]
+ xl = np.random.uniform(-1, 1, 100*100).reshape(100, 100)
+ maskx = xl > 0.8
+ self.nmxs = np.ma.array(xs, mask=m1)
+ self.nmxl = np.ma.array(xl, mask=maskx)
+
+ def time_methods_setitem(self, margs, mset, msize):
+ if msize == 'small':
+ mdat = self.nmxs
+ elif msize == 'big':
+ mdat = self.nmxl
+ getattr(mdat, '__setitem__')(margs, mset)
+
+
+class Where(Benchmark):
+ param_names = ['mtype', 'msize']
+ params = [['np', 'np.ma'],
+ ['small', 'big']]
+
+ def setup(self, mtype, msize):
+ # Small arrays
+ xs = np.random.uniform(-1, 1, 6).reshape(2, 3)
+ ys = np.random.uniform(-1, 1, 6).reshape(2, 3)
+ m1 = [[True, False, False], [False, False, True]]
+ m2 = [[True, False, True], [False, False, True]]
+ self.nmxs = np.ma.array(xs, mask=m1)
+ self.nmys = np.ma.array(ys, mask=m2)
+ # Big arrays
+ xl = np.random.uniform(-1, 1, 100*100).reshape(100, 100)
+ yl = np.random.uniform(-1, 1, 100*100).reshape(100, 100)
+ maskx = xl > 0.8
+ masky = yl < -0.8
+ self.nmxl = np.ma.array(xl, mask=maskx)
+ self.nmyl = np.ma.array(yl, mask=masky)
+
+ def time_where(self, mtype, msize):
+ fun = eval(f"{mtype}.where")
+ if msize == 'small':
+ fun(self.nmxs > 2, self.nmxs, self.nmys)
+ elif msize == 'big':
+ fun(self.nmxl > 2, self.nmxl, self.nmyl)
diff --git a/doc/release/upcoming_changes/22707.compatibility.rst b/doc/release/upcoming_changes/22707.compatibility.rst
new file mode 100644
index 000000000..8c9805f37
--- /dev/null
+++ b/doc/release/upcoming_changes/22707.compatibility.rst
@@ -0,0 +1,4 @@
+* When comparing datetimes and timedelta using ``np.equal`` or ``np.not_equal``
+ numpy previously allowed the comparison with ``casting="unsafe"``.
+ This operation now fails. Forcing the output dtype using the ``dtype``
+ kwarg can make the operation succeed, but we do not recommend it.
diff --git a/doc/release/upcoming_changes/22707.expired.rst b/doc/release/upcoming_changes/22707.expired.rst
new file mode 100644
index 000000000..496752e8d
--- /dev/null
+++ b/doc/release/upcoming_changes/22707.expired.rst
@@ -0,0 +1,13 @@
+``==`` and ``!=`` warnings finalized
+------------------------------------
+The ``==`` and ``!=`` operators on arrays now always:
+
+* raise errors that occur during comparisons such as when the arrays
+ have incompatible shapes (``np.array([1, 2]) == np.array([1, 2, 3])``).
+* return an array of all ``True`` or all ``False`` when values are
+ fundamentally not comparable (e.g. have different dtypes). An example
+ is ``np.array(["a"]) == np.array([1])``.
+
+This mimics the Python behavior of returning ``False`` and ``True``
+when comparing incompatible types like ``"a" == 1`` and ``"a" != 1``.
+For a long time these gave ``DeprecationWarning`` or ``FutureWarning``.
diff --git a/doc/release/upcoming_changes/22707.improvement.rst b/doc/release/upcoming_changes/22707.improvement.rst
new file mode 100644
index 000000000..1b8d4f844
--- /dev/null
+++ b/doc/release/upcoming_changes/22707.improvement.rst
@@ -0,0 +1,8 @@
+New ``DTypePromotionError``
+---------------------------
+NumPy now has a new ``DTypePromotionError`` which is used when two
+dtypes cannot be promoted to a common one, for example::
+
+ np.result_type("M8[s]", np.complex128)
+
+raises this new exception.
diff --git a/numpy/__init__.py b/numpy/__init__.py
index 1e41dc8bf..9f8e60a07 100644
--- a/numpy/__init__.py
+++ b/numpy/__init__.py
@@ -107,6 +107,8 @@ import sys
import warnings
from ._globals import _NoValue, _CopyMode
+from . import exceptions
+# Note that the following names are imported explicitly for backcompat:
from .exceptions import (
ComplexWarning, ModuleDeprecationWarning, VisibleDeprecationWarning,
TooHardError, AxisError)
@@ -130,7 +132,7 @@ else:
raise ImportError(msg) from e
__all__ = [
- 'ModuleDeprecationWarning', 'VisibleDeprecationWarning',
+ 'exceptions', 'ModuleDeprecationWarning', 'VisibleDeprecationWarning',
'ComplexWarning', 'TooHardError', 'AxisError']
# mapping of {name: (value, deprecation_msg)}
diff --git a/numpy/core/_exceptions.py b/numpy/core/_exceptions.py
index 62579ed0d..87d4213a6 100644
--- a/numpy/core/_exceptions.py
+++ b/numpy/core/_exceptions.py
@@ -36,36 +36,35 @@ class UFuncTypeError(TypeError):
@_display_as_base
-class _UFuncBinaryResolutionError(UFuncTypeError):
- """ Thrown when a binary resolution fails """
+class _UFuncNoLoopError(UFuncTypeError):
+ """ Thrown when a ufunc loop cannot be found """
def __init__(self, ufunc, dtypes):
super().__init__(ufunc)
self.dtypes = tuple(dtypes)
- assert len(self.dtypes) == 2
def __str__(self):
return (
- "ufunc {!r} cannot use operands with types {!r} and {!r}"
+ "ufunc {!r} did not contain a loop with signature matching types "
+ "{!r} -> {!r}"
).format(
- self.ufunc.__name__, *self.dtypes
+ self.ufunc.__name__,
+ _unpack_tuple(self.dtypes[:self.ufunc.nin]),
+ _unpack_tuple(self.dtypes[self.ufunc.nin:])
)
@_display_as_base
-class _UFuncNoLoopError(UFuncTypeError):
- """ Thrown when a ufunc loop cannot be found """
+class _UFuncBinaryResolutionError(_UFuncNoLoopError):
+ """ Thrown when a binary resolution fails """
def __init__(self, ufunc, dtypes):
- super().__init__(ufunc)
- self.dtypes = tuple(dtypes)
+ super().__init__(ufunc, dtypes)
+ assert len(self.dtypes) == 2
def __str__(self):
return (
- "ufunc {!r} did not contain a loop with signature matching types "
- "{!r} -> {!r}"
+ "ufunc {!r} cannot use operands with types {!r} and {!r}"
).format(
- self.ufunc.__name__,
- _unpack_tuple(self.dtypes[:self.ufunc.nin]),
- _unpack_tuple(self.dtypes[self.ufunc.nin:])
+ self.ufunc.__name__, *self.dtypes
)
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index 85076f3e1..c78385880 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -9,6 +9,7 @@ import re
import sys
import warnings
+from ..exceptions import DTypePromotionError
from .multiarray import dtype, array, ndarray, promote_types
try:
import ctypes
@@ -454,7 +455,8 @@ def _promote_fields(dt1, dt2):
"""
# Both must be structured and have the same names in the same order
if (dt1.names is None or dt2.names is None) or dt1.names != dt2.names:
- raise TypeError("invalid type promotion")
+ raise DTypePromotionError(
+ f"field names `{dt1.names}` and `{dt2.names}` mismatch.")
# if both are identical, we can (maybe!) just return the same dtype.
identical = dt1 is dt2
@@ -467,7 +469,8 @@ def _promote_fields(dt1, dt2):
# Check that the titles match (if given):
if field1[2:] != field2[2:]:
- raise TypeError("invalid type promotion")
+ raise DTypePromotionError(
+ f"field titles of field '{name}' mismatch")
if len(field1) == 2:
new_fields.append((name, new_descr))
else:
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py
index 40382b8ae..768c8deee 100644
--- a/numpy/core/code_generators/generate_umath.py
+++ b/numpy/core/code_generators/generate_umath.py
@@ -426,7 +426,7 @@ defdict = {
Ufunc(1, 1, None,
docstrings.get('numpy.core.umath.negative'),
'PyUFunc_NegativeTypeResolver',
- TD(ints+flts+timedeltaonly, simd=[('avx2', ints)]),
+ TD(ints+flts+timedeltaonly, dispatch=[('loops_unary', ints+'fdg')]),
TD(cmplx, f='neg'),
TD(O, f='PyNumber_Negative'),
),
diff --git a/numpy/core/meson.build b/numpy/core/meson.build
index 2b23d3f14..50cd8ccc5 100644
--- a/numpy/core/meson.build
+++ b/numpy/core/meson.build
@@ -747,6 +747,7 @@ src_umath = [
src_file.process('src/umath/loops_modulo.dispatch.c.src'),
src_file.process('src/umath/loops_trigonometric.dispatch.c.src'),
src_file.process('src/umath/loops_umath_fp.dispatch.c.src'),
+ src_file.process('src/umath/loops_unary.dispatch.c.src'),
src_file.process('src/umath/loops_unary_fp.dispatch.c.src'),
src_file.process('src/umath/matmul.c.src'),
src_file.process('src/umath/matmul.h.src'),
diff --git a/numpy/core/setup.py b/numpy/core/setup.py
index 4b61bd855..da5bc64c0 100644
--- a/numpy/core/setup.py
+++ b/numpy/core/setup.py
@@ -1005,6 +1005,7 @@ def configuration(parent_package='',top_path=None):
join('src', 'umath', 'loops.h.src'),
join('src', 'umath', 'loops_utils.h.src'),
join('src', 'umath', 'loops.c.src'),
+ join('src', 'umath', 'loops_unary.dispatch.c.src'),
join('src', 'umath', 'loops_unary_fp.dispatch.c.src'),
join('src', 'umath', 'loops_arithm_fp.dispatch.c.src'),
join('src', 'umath', 'loops_arithmetic.dispatch.c.src'),
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index ceafffd51..08e2cc683 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -875,108 +875,6 @@ DEPRECATE_silence_error(const char *msg) {
return 0;
}
-/*
- * Comparisons can fail, but we do not always want to pass on the exception
- * (see comment in array_richcompare below), but rather return NotImplemented.
- * Here, an exception should be set on entrance.
- * Returns either NotImplemented with the exception cleared, or NULL
- * with the exception set.
- * Raises deprecation warnings for cases where behaviour is meant to change
- * (2015-05-14, 1.10)
- */
-
-NPY_NO_EXPORT PyObject *
-_failed_comparison_workaround(PyArrayObject *self, PyObject *other, int cmp_op)
-{
- PyObject *exc, *val, *tb;
- PyArrayObject *array_other;
- int other_is_flexible, ndim_other;
- int self_is_flexible = PyTypeNum_ISFLEXIBLE(PyArray_DESCR(self)->type_num);
-
- PyErr_Fetch(&exc, &val, &tb);
- /*
- * Determine whether other has a flexible dtype; here, inconvertible
- * is counted as inflexible. (This repeats work done in the ufunc,
- * but OK to waste some time in an unlikely path.)
- */
- array_other = (PyArrayObject *)PyArray_FROM_O(other);
- if (array_other) {
- other_is_flexible = PyTypeNum_ISFLEXIBLE(
- PyArray_DESCR(array_other)->type_num);
- ndim_other = PyArray_NDIM(array_other);
- Py_DECREF(array_other);
- }
- else {
- PyErr_Clear(); /* we restore the original error if needed */
- other_is_flexible = 0;
- ndim_other = 0;
- }
- if (cmp_op == Py_EQ || cmp_op == Py_NE) {
- /*
- * note: for == and !=, a structured dtype self cannot get here,
- * but a string can. Other can be string or structured.
- */
- if (other_is_flexible || self_is_flexible) {
- /*
- * For scalars, returning NotImplemented is correct.
- * For arrays, we emit a future deprecation warning.
- * When this warning is removed, a correctly shaped
- * array of bool should be returned.
- */
- if (ndim_other != 0 || PyArray_NDIM(self) != 0) {
- /* 2015-05-14, 1.10 */
- if (DEPRECATE_FUTUREWARNING(
- "elementwise comparison failed; returning scalar "
- "instead, but in the future will perform "
- "elementwise comparison") < 0) {
- goto fail;
- }
- }
- }
- else {
- /*
- * If neither self nor other had a flexible dtype, the error cannot
- * have been caused by a lack of implementation in the ufunc.
- *
- * 2015-05-14, 1.10
- */
- if (DEPRECATE(
- "elementwise comparison failed; "
- "this will raise an error in the future.") < 0) {
- goto fail;
- }
- }
- Py_XDECREF(exc);
- Py_XDECREF(val);
- Py_XDECREF(tb);
- Py_INCREF(Py_NotImplemented);
- return Py_NotImplemented;
- }
- else if (other_is_flexible || self_is_flexible) {
- /*
- * For LE, LT, GT, GE and a flexible self or other, we return
- * NotImplemented, which is the correct answer since the ufuncs do
- * not in fact implement loops for those. This will get us the
- * desired TypeError.
- */
- Py_XDECREF(exc);
- Py_XDECREF(val);
- Py_XDECREF(tb);
- Py_INCREF(Py_NotImplemented);
- return Py_NotImplemented;
- }
- else {
- /* LE, LT, GT, or GE with non-flexible other; just pass on error */
- goto fail;
- }
-
-fail:
- /*
- * Reraise the original exception, possibly chaining with a new one.
- */
- npy_PyErr_ChainExceptionsCause(exc, val, tb);
- return NULL;
-}
NPY_NO_EXPORT PyObject *
array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
@@ -1074,33 +972,99 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
- if (result == NULL) {
+
+ /*
+ * At this point `self` can take control of the operation by converting
+ * `other` to an array (it would have a chance to take control).
+ * If we are not in `==` and `!=`, this is an error and we hope that
+ * the existing error makes sense and derives from `TypeError` (which
+ * python would raise for `NotImplemented`) when it should.
+ *
+ * However, if the issue is no matching loop for the given dtypes and
+ * we are inside == and !=, then returning an array of True or False
+ * makes sense (following Python behavior for `==` and `!=`).
+ * Effectively: Both *dtypes* told us that they cannot be compared.
+ *
+ * In theory, the error could be raised from within an object loop, the
+ * solution to that could be pushing this into the ufunc (where we can
+ * distinguish the two easily). In practice, it seems like it should not
+ * but a huge problem: The ufunc loop will itself call `==` which should
+ * probably never raise a UFuncNoLoopError.
+ *
+ * TODO: If/once we correctly push structured comparisons into the ufunc
+ * we could consider pushing this path into the ufunc itself as a
+ * fallback loop (which ignores the input arrays).
+ * This would have the advantage that subclasses implemementing
+ * `__array_ufunc__` do not explicitly need `__eq__` and `__ne__`.
+ */
+ if (result == NULL
+ && (cmp_op == Py_EQ || cmp_op == Py_NE)
+ && PyErr_ExceptionMatches(npy_UFuncNoLoopError)) {
+ PyErr_Clear();
+
+ PyArrayObject *array_other = (PyArrayObject *)PyArray_FROM_O(other);
+ if (PyArray_TYPE(array_other) == NPY_VOID) {
+ /*
+ * Void arrays are currently not handled by ufuncs, so if the other
+ * is a void array, we defer to it (will raise a TypeError).
+ */
+ Py_DECREF(array_other);
+ Py_RETURN_NOTIMPLEMENTED;
+ }
+
+ if (PyArray_NDIM(self) == 0 && PyArray_NDIM(array_other) == 0) {
+ /*
+ * (seberg) not sure that this is best, but we preserve Python
+ * bool result for "scalar" inputs for now by returning
+ * `NotImplemented`.
+ */
+ Py_DECREF(array_other);
+ Py_RETURN_NOTIMPLEMENTED;
+ }
+
+ /* Hack warning: using NpyIter to allocate broadcasted result. */
+ PyArrayObject *ops[3] = {self, array_other, NULL};
+ npy_uint32 flags = NPY_ITER_ZEROSIZE_OK | NPY_ITER_REFS_OK;
+ npy_uint32 op_flags[3] = {
+ NPY_ITER_READONLY, NPY_ITER_READONLY,
+ NPY_ITER_ALLOCATE | NPY_ITER_WRITEONLY};
+
+ PyArray_Descr *bool_descr = PyArray_DescrFromType(NPY_BOOL);
+ PyArray_Descr *op_descrs[3] = {
+ PyArray_DESCR(self), PyArray_DESCR(array_other), bool_descr};
+
+ NpyIter *iter = NpyIter_MultiNew(
+ 3, ops, flags, NPY_KEEPORDER, NPY_NO_CASTING,
+ op_flags, op_descrs);
+
+ Py_CLEAR(bool_descr);
+ Py_CLEAR(array_other);
+ if (iter == NULL) {
+ return NULL;
+ }
+ PyArrayObject *res = NpyIter_GetOperandArray(iter)[2];
+ Py_INCREF(res);
+ if (NpyIter_Deallocate(iter) != NPY_SUCCEED) {
+ Py_DECREF(res);
+ return NULL;
+ }
+
/*
- * 2015-05-14, 1.10; updated 2018-06-18, 1.16.
- *
- * Comparisons can raise errors when element-wise comparison is not
- * possible. Some of these, though, should not be passed on.
- * In particular, the ufuncs do not have loops for flexible dtype,
- * so those should be treated separately. Furthermore, for EQ and NE,
- * we should never fail.
- *
- * Our ideal behaviour would be:
- *
- * 1. For EQ and NE:
- * - If self and other are scalars, return NotImplemented,
- * so that python can assign True of False as appropriate.
- * - If either is an array, return an array of False or True.
- *
- * 2. For LT, LE, GE, GT:
- * - If self or other was flexible, return NotImplemented
- * (as is in fact the case), so python can raise a TypeError.
- * - If other is not convertible to an array, pass on the error
- * (MHvK, 2018-06-18: not sure about this, but it's what we have).
- *
- * However, for backwards compatibility, we cannot yet return arrays,
- * so we raise warnings instead.
+ * The array is guaranteed to be newly allocated and thus contiguous,
+ * so simply fill it with 0 or 1.
*/
- result = _failed_comparison_workaround(self, other, cmp_op);
+ memset(PyArray_BYTES(res), cmp_op == Py_EQ ? 0 : 1, PyArray_NBYTES(res));
+
+ /* Ensure basic subclass support by wrapping: */
+ if (!PyArray_CheckExact(self)) {
+ /*
+ * If other is also a subclass (with higher priority) we would
+ * already have deferred. So use `self` for wrapping. If users
+ * need more, they need to override `==` and `!=`.
+ */
+ Py_SETREF(res, PyArray_SubclassWrap(self, res));
+ }
+ return (PyObject *)res;
}
return result;
}
diff --git a/numpy/core/src/multiarray/common_dtype.c b/numpy/core/src/multiarray/common_dtype.c
index 3561a905a..38a130221 100644
--- a/numpy/core/src/multiarray/common_dtype.c
+++ b/numpy/core/src/multiarray/common_dtype.c
@@ -8,6 +8,7 @@
#include "numpy/arrayobject.h"
#include "common_dtype.h"
+#include "convert_datatype.h"
#include "dtypemeta.h"
#include "abstractdtypes.h"
@@ -61,7 +62,7 @@ PyArray_CommonDType(PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtype2)
}
if (common_dtype == (PyArray_DTypeMeta *)Py_NotImplemented) {
Py_DECREF(Py_NotImplemented);
- PyErr_Format(PyExc_TypeError,
+ PyErr_Format(npy_DTypePromotionError,
"The DTypes %S and %S do not have a common DType. "
"For example they cannot be stored in a single array unless "
"the dtype is `object`.", dtype1, dtype2);
@@ -288,7 +289,7 @@ PyArray_PromoteDTypeSequence(
Py_INCREF(dtypes_in[l]);
PyTuple_SET_ITEM(dtypes_in_tuple, l, (PyObject *)dtypes_in[l]);
}
- PyErr_Format(PyExc_TypeError,
+ PyErr_Format(npy_DTypePromotionError,
"The DType %S could not be promoted by %S. This means that "
"no common DType exists for the given inputs. "
"For example they cannot be stored in a single array unless "
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index eeb42df66..3973fc795 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -50,6 +50,9 @@ NPY_NO_EXPORT npy_intp REQUIRED_STR_LEN[] = {0, 3, 5, 10, 10, 20, 20, 20, 20};
*/
NPY_NO_EXPORT int npy_promotion_state = NPY_USE_LEGACY_PROMOTION;
NPY_NO_EXPORT PyObject *NO_NEP50_WARNING_CTX = NULL;
+NPY_NO_EXPORT PyObject *npy_DTypePromotionError = NULL;
+NPY_NO_EXPORT PyObject *npy_UFuncNoLoopError = NULL;
+
static PyObject *
PyArray_GetGenericToVoidCastingImpl(void);
diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h
index b6bc7d8a7..1a23965f8 100644
--- a/numpy/core/src/multiarray/convert_datatype.h
+++ b/numpy/core/src/multiarray/convert_datatype.h
@@ -14,6 +14,8 @@ extern NPY_NO_EXPORT npy_intp REQUIRED_STR_LEN[];
#define NPY_USE_WEAK_PROMOTION_AND_WARN 2
extern NPY_NO_EXPORT int npy_promotion_state;
extern NPY_NO_EXPORT PyObject *NO_NEP50_WARNING_CTX;
+extern NPY_NO_EXPORT PyObject *npy_DTypePromotionError;
+extern NPY_NO_EXPORT PyObject *npy_UFuncNoLoopError;
NPY_NO_EXPORT int
npy_give_promotion_warnings(void);
diff --git a/numpy/core/src/multiarray/datetime.c b/numpy/core/src/multiarray/datetime.c
index 2abd68ca2..695b696c2 100644
--- a/numpy/core/src/multiarray/datetime.c
+++ b/numpy/core/src/multiarray/datetime.c
@@ -1622,6 +1622,12 @@ compute_datetime_metadata_greatest_common_divisor(
return 0;
+ /*
+ * We do not use `DTypePromotionError` below. The reason this is that a
+ * `DTypePromotionError` indicates that `arr_dt1 != arr_dt2` for
+ * all values, but this is wrong for "0". This could be changed but
+ * for now we consider them errors that occur _while_ promoting.
+ */
incompatible_units: {
PyObject *umeta1 = metastr_to_unicode(meta1, 0);
if (umeta1 == NULL) {
diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c
index 6c33da729..edc07bc92 100644
--- a/numpy/core/src/multiarray/dtypemeta.c
+++ b/numpy/core/src/multiarray/dtypemeta.c
@@ -404,7 +404,7 @@ void_common_instance(PyArray_Descr *descr1, PyArray_Descr *descr2)
if (descr1->subarray == NULL && descr1->names == NULL &&
descr2->subarray == NULL && descr2->names == NULL) {
if (descr1->elsize != descr2->elsize) {
- PyErr_SetString(PyExc_TypeError,
+ PyErr_SetString(npy_DTypePromotionError,
"Invalid type promotion with void datatypes of different "
"lengths. Use the `np.bytes_` datatype instead to pad the "
"shorter value with trailing zero bytes.");
@@ -443,7 +443,7 @@ void_common_instance(PyArray_Descr *descr1, PyArray_Descr *descr2)
return NULL;
}
if (!cmp) {
- PyErr_SetString(PyExc_TypeError,
+ PyErr_SetString(npy_DTypePromotionError,
"invalid type promotion with subarray datatypes "
"(shape mismatch).");
return NULL;
@@ -473,7 +473,7 @@ void_common_instance(PyArray_Descr *descr1, PyArray_Descr *descr2)
return new_descr;
}
- PyErr_SetString(PyExc_TypeError,
+ PyErr_SetString(npy_DTypePromotionError,
"invalid type promotion with structured datatype(s).");
return NULL;
}
@@ -617,6 +617,12 @@ string_unicode_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
static PyArray_DTypeMeta *
datetime_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
{
+ /*
+ * Timedelta/datetime shouldn't actuall promote at all. That they
+ * currently do means that we need additional hacks in the comparison
+ * type resolver. For comparisons we have to make sure we reject it
+ * nicely in order to return an array of True/False values.
+ */
if (cls->type_num == NPY_DATETIME && other->type_num == NPY_TIMEDELTA) {
/*
* TODO: We actually currently do allow promotion here. This is
diff --git a/numpy/core/src/multiarray/experimental_public_dtype_api.c b/numpy/core/src/multiarray/experimental_public_dtype_api.c
index 79261a9a7..84507b481 100644
--- a/numpy/core/src/multiarray/experimental_public_dtype_api.c
+++ b/numpy/core/src/multiarray/experimental_public_dtype_api.c
@@ -258,6 +258,14 @@ PyArrayInitDTypeMeta_FromSpec(
/*
* And now, register all the casts that are currently defined!
*/
+ if (spec->casts == NULL) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ "DType must at least provide a function to cast (or just copy) "
+ "between its own instances!");
+ return -1;
+ }
+
PyArrayMethod_Spec **next_meth_spec = spec->casts;
while (1) {
PyArrayMethod_Spec *meth_spec = *next_meth_spec;
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index ca4fdfeca..5da3d66df 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -4811,6 +4811,38 @@ intern_strings(void)
return 0;
}
+
+/*
+ * Initializes global constants. At some points these need to be cleaned
+ * up, and sometimes we also import them where they are needed. But for
+ * some things, adding an `npy_cache_import` everywhere seems inconvenient.
+ *
+ * These globals should not need the C-layer at all and will be imported
+ * before anything on the C-side is initialized.
+ */
+static int
+initialize_static_globals(void)
+{
+ assert(npy_DTypePromotionError == NULL);
+ npy_cache_import(
+ "numpy.exceptions", "DTypePromotionError",
+ &npy_DTypePromotionError);
+ if (npy_DTypePromotionError == NULL) {
+ return -1;
+ }
+
+ assert(npy_UFuncNoLoopError == NULL);
+ npy_cache_import(
+ "numpy.core._exceptions", "_UFuncNoLoopError",
+ &npy_UFuncNoLoopError);
+ if (npy_UFuncNoLoopError == NULL) {
+ return -1;
+ }
+
+ return 0;
+}
+
+
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"_multiarray_umath",
@@ -4861,6 +4893,14 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
goto err;
}
+ if (intern_strings() < 0) {
+ goto err;
+ }
+
+ if (initialize_static_globals() < 0) {
+ goto err;
+ }
+
if (PyType_Ready(&PyUFunc_Type) < 0) {
goto err;
}
@@ -5033,10 +5073,6 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
goto err;
}
- if (intern_strings() < 0) {
- goto err;
- }
-
if (set_typeinfo(d) != 0) {
goto err;
}
diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c
index 2de5a5670..6d6c481fb 100644
--- a/numpy/core/src/umath/dispatching.c
+++ b/numpy/core/src/umath/dispatching.c
@@ -43,6 +43,7 @@
#include <convert_datatype.h>
#include "numpy/ndarraytypes.h"
+#include "numpy/npy_3kcompat.h"
#include "common.h"
#include "dispatching.h"
@@ -947,7 +948,7 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
int cacheable = 1; /* unused, as we modify the original `op_dtypes` */
if (legacy_promote_using_legacy_type_resolver(ufunc,
ops, signature, op_dtypes, &cacheable, NPY_FALSE) < 0) {
- return NULL;
+ goto handle_error;
}
}
@@ -959,10 +960,7 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
npy_promotion_state = old_promotion_state;
if (info == NULL) {
- if (!PyErr_Occurred()) {
- raise_no_loop_found_error(ufunc, (PyObject **)op_dtypes);
- }
- return NULL;
+ goto handle_error;
}
PyArrayMethodObject *method = (PyArrayMethodObject *)PyTuple_GET_ITEM(info, 1);
@@ -984,7 +982,7 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
/* Reset the promotion state: */
npy_promotion_state = NPY_USE_WEAK_PROMOTION_AND_WARN;
if (res < 0) {
- return NULL;
+ goto handle_error;
}
}
@@ -1018,12 +1016,29 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
* If signature is forced the cache may contain an incompatible
* loop found via promotion (signature not enforced). Reject it.
*/
- raise_no_loop_found_error(ufunc, (PyObject **)op_dtypes);
- return NULL;
+ goto handle_error;
}
}
return method;
+
+ handle_error:
+ /* We only set the "no loop found error here" */
+ if (!PyErr_Occurred()) {
+ raise_no_loop_found_error(ufunc, (PyObject **)op_dtypes);
+ }
+ /*
+ * Otherwise an error occurred, but if the error was DTypePromotionError
+ * then we chain it, because DTypePromotionError effectively means that there
+ * is no loop available. (We failed finding a loop by using promotion.)
+ */
+ else if (PyErr_ExceptionMatches(npy_DTypePromotionError)) {
+ PyObject *err_type = NULL, *err_value = NULL, *err_traceback = NULL;
+ PyErr_Fetch(&err_type, &err_value, &err_traceback);
+ raise_no_loop_found_error(ufunc, (PyObject **)op_dtypes);
+ npy_PyErr_ChainExceptionsCause(err_type, err_value, err_traceback);
+ }
+ return NULL;
}
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index fe5aa9374..0b4856847 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -601,14 +601,6 @@ NPY_NO_EXPORT NPY_GCC_OPT_3 @ATTR@ void
#if @CHK@
NPY_NO_EXPORT NPY_GCC_OPT_3 @ATTR@ void
-@TYPE@_negative@isa@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
-{
- UNARY_LOOP_FAST(@type@, @type@, *out = -in);
-}
-#endif
-
-#if @CHK@
-NPY_NO_EXPORT NPY_GCC_OPT_3 @ATTR@ void
@TYPE@_logical_not@isa@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
{
UNARY_LOOP_FAST(@type@, npy_bool, *out = !in);
@@ -1546,17 +1538,6 @@ NPY_NO_EXPORT void
}
NPY_NO_EXPORT void
-@TYPE@_negative(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
-{
- if (!run_unary_simd_negative_@TYPE@(args, dimensions, steps)) {
- UNARY_LOOP {
- const @type@ in1 = *(@type@ *)ip1;
- *((@type@ *)op1) = -in1;
- }
- }
-}
-
-NPY_NO_EXPORT void
@TYPE@_positive(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
{
UNARY_LOOP {
diff --git a/numpy/core/src/umath/loops.h.src b/numpy/core/src/umath/loops.h.src
index 424e204c1..e3a410968 100644
--- a/numpy/core/src/umath/loops.h.src
+++ b/numpy/core/src/umath/loops.h.src
@@ -140,9 +140,6 @@ NPY_NO_EXPORT void
@S@@TYPE@_conjugate@isa@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
NPY_NO_EXPORT void
-@S@@TYPE@_negative@isa@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
-
-NPY_NO_EXPORT void
@S@@TYPE@_logical_not@isa@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
NPY_NO_EXPORT void
@@ -206,6 +203,23 @@ NPY_NO_EXPORT void
/**end repeat**/
+
+#ifndef NPY_DISABLE_OPTIMIZATION
+ #include "loops_unary.dispatch.h"
+#endif
+/**begin repeat
+ * #TYPE = UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * BYTE, SHORT, INT, LONG, LONGLONG#
+ */
+/**begin repeat1
+ * #kind = negative#
+ */
+NPY_CPU_DISPATCH_DECLARE(NPY_NO_EXPORT void @TYPE@_@kind@,
+ (char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(data)))
+/**end repeat1**/
+/**end repeat**/
+
+
/*
*****************************************************************************
** FLOAT LOOPS **
@@ -226,6 +240,20 @@ NPY_CPU_DISPATCH_DECLARE(NPY_NO_EXPORT void @TYPE@_@kind@,
/**end repeat**/
#ifndef NPY_DISABLE_OPTIMIZATION
+ #include "loops_unary.dispatch.h"
+#endif
+/**begin repeat
+ * #TYPE = FLOAT, DOUBLE, LONGDOUBLE#
+ */
+/**begin repeat1
+ * #kind = negative#
+ */
+NPY_CPU_DISPATCH_DECLARE(NPY_NO_EXPORT void @TYPE@_@kind@,
+ (char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(data)))
+/**end repeat1**/
+/**end repeat**/
+
+#ifndef NPY_DISABLE_OPTIMIZATION
#include "loops_arithm_fp.dispatch.h"
#endif
/**begin repeat
@@ -362,6 +390,7 @@ NPY_CPU_DISPATCH_DECLARE(NPY_NO_EXPORT void @TYPE@_@kind@, (
* #TYPE = HALF, FLOAT, DOUBLE, LONGDOUBLE#
* #c = f, f, , l#
* #C = F, F, , L#
+ * #half = 1, 0, 0, 0#
*/
/**begin repeat1
@@ -440,8 +469,10 @@ NPY_NO_EXPORT void
NPY_NO_EXPORT void
@TYPE@_absolute(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
+#if @half@
NPY_NO_EXPORT void
@TYPE@_negative(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
+#endif
NPY_NO_EXPORT void
@TYPE@_positive(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
diff --git a/numpy/core/src/umath/loops_unary.dispatch.c.src b/numpy/core/src/umath/loops_unary.dispatch.c.src
new file mode 100644
index 000000000..1e2a81d20
--- /dev/null
+++ b/numpy/core/src/umath/loops_unary.dispatch.c.src
@@ -0,0 +1,364 @@
+/*@targets
+ ** $maxopt baseline
+ ** neon asimd
+ ** sse2 avx2 avx512_skx
+ ** vsx2
+ ** vx vxe
+ **/
+
+#define _UMATHMODULE
+#define _MULTIARRAYMODULE
+#define NPY_NO_DEPRECATED_API NPY_API_VERSION
+
+#include "numpy/npy_math.h"
+#include "simd/simd.h"
+#include "loops_utils.h"
+#include "loops.h"
+#include "lowlevel_strided_loops.h"
+// Provides the various *_LOOP macros
+#include "fast_loop_macros.h"
+
+/*******************************************************************************
+ ** Scalar ops
+ ******************************************************************************/
+#define scalar_negative(X) (-X)
+
+/*******************************************************************************
+ ** extra SIMD intrinsics
+ ******************************************************************************/
+
+#if NPY_SIMD
+
+/**begin repeat
+ * #sfx = s8, u8, s16, u16, s32, u32, s64, u64#
+ * #ssfx = 8, 8, 16, 16, 32, 32, 64, 64#
+ */
+static NPY_INLINE npyv_@sfx@
+npyv_negative_@sfx@(npyv_@sfx@ v)
+{
+#if defined(NPY_HAVE_NEON) && (defined(__aarch64__) || @ssfx@ < 64)
+ return npyv_reinterpret_@sfx@_s@ssfx@(vnegq_s@ssfx@(npyv_reinterpret_s@ssfx@_@sfx@(v)));
+#else
+ // (x ^ -1) + 1
+ const npyv_@sfx@ m1 = npyv_setall_@sfx@((npyv_lanetype_@sfx@)-1);
+ return npyv_sub_@sfx@(npyv_xor_@sfx@(v, m1), m1);
+#endif
+}
+/**end repeat**/
+
+/**begin repeat
+ * #sfx = f32, f64#
+ * #VCHK = NPY_SIMD_F32, NPY_SIMD_F64#
+ * #fd = f, #
+ */
+#if @VCHK@
+static NPY_INLINE npyv_@sfx@
+npyv_negative_@sfx@(npyv_@sfx@ v)
+{
+#if defined(NPY_HAVE_NEON)
+ return vnegq_@sfx@(v);
+#else
+ // (v ^ signmask)
+ const npyv_@sfx@ signmask = npyv_setall_@sfx@(-0.@fd@);
+ return npyv_xor_@sfx@(v, signmask);
+#endif
+}
+#endif // @VCHK@
+/**end repeat**/
+
+#endif // NPY_SIMD
+
+/********************************************************************************
+ ** Defining the SIMD kernels
+ ********************************************************************************/
+/**begin repeat
+ * #sfx = s8, u8, s16, u16, s32, u32, s64, u64, f32, f64#
+ * #simd_chk = NPY_SIMD*8, NPY_SIMD_F32, NPY_SIMD_F64#
+ * #is_fp = 0*8, 1*2#
+ * #supports_ncontig = 0*4,1*6#
+ */
+/**begin repeat1
+ * #kind = negative#
+ * #intrin = negative#
+ * #unroll = 4#
+ */
+#if @simd_chk@
+#if @unroll@ < 1
+#error "Unroll must be at least 1"
+#elif NPY_SIMD != 128 && @unroll@ > 2
+// Avoid memory bandwidth bottleneck for larger SIMD
+#define UNROLL 2
+#else
+#define UNROLL @unroll@
+#endif
+// contiguous inputs and output.
+static NPY_INLINE void
+simd_unary_cc_@intrin@_@sfx@(const npyv_lanetype_@sfx@ *ip,
+ npyv_lanetype_@sfx@ *op,
+ npy_intp len)
+{
+ const int vstep = npyv_nlanes_@sfx@;
+ const int wstep = vstep * UNROLL;
+
+ // unrolled vector loop
+ for (; len >= wstep; len -= wstep, ip += wstep, op += wstep) {
+ /**begin repeat2
+ * #U = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ #if UNROLL > @U@
+ npyv_@sfx@ v_@U@ = npyv_load_@sfx@(ip + @U@ * vstep);
+ npyv_@sfx@ r_@U@ = npyv_@intrin@_@sfx@(v_@U@);
+ npyv_store_@sfx@(op + @U@ * vstep, r_@U@);
+ #endif
+ /**end repeat2**/
+ }
+ // single vector loop
+ for (; len >= vstep; len -= vstep, ip += vstep, op +=vstep) {
+ npyv_@sfx@ v = npyv_load_@sfx@(ip);
+ npyv_@sfx@ r = npyv_@intrin@_@sfx@(v);
+ npyv_store_@sfx@(op, r);
+ }
+ // scalar finish up any remaining iterations
+ for (; len > 0; --len, ++ip, ++op) {
+ *op = scalar_@intrin@(*ip);
+ }
+}
+
+#if @supports_ncontig@
+// contiguous input, non-contiguous output
+static NPY_INLINE void
+simd_unary_cn_@intrin@_@sfx@(const npyv_lanetype_@sfx@ *ip,
+ npyv_lanetype_@sfx@ *op, npy_intp ostride,
+ npy_intp len)
+{
+ const int vstep = npyv_nlanes_@sfx@;
+ const int wstep = vstep * UNROLL;
+
+ // unrolled vector loop
+ for (; len >= wstep; len -= wstep, ip += wstep, op += ostride*wstep) {
+ /**begin repeat2
+ * #U = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ #if UNROLL > @U@
+ npyv_@sfx@ v_@U@ = npyv_load_@sfx@(ip + @U@ * vstep);
+ npyv_@sfx@ r_@U@ = npyv_@intrin@_@sfx@(v_@U@);
+ npyv_storen_@sfx@(op + @U@ * vstep * ostride, ostride, r_@U@);
+ #endif
+ /**end repeat2**/
+ }
+ // single vector loop
+ for (; len >= vstep; len -= vstep, ip += vstep, op += ostride*vstep) {
+ npyv_@sfx@ v = npyv_load_@sfx@(ip);
+ npyv_@sfx@ r = npyv_@intrin@_@sfx@(v);
+ npyv_storen_@sfx@(op, ostride, r);
+ }
+ // scalar finish up any remaining iterations
+ for (; len > 0; --len, ++ip, op += ostride) {
+ *op = scalar_@intrin@(*ip);
+ }
+}
+// non-contiguous input, contiguous output
+static NPY_INLINE void
+simd_unary_nc_@intrin@_@sfx@(const npyv_lanetype_@sfx@ *ip, npy_intp istride,
+ npyv_lanetype_@sfx@ *op,
+ npy_intp len)
+{
+ const int vstep = npyv_nlanes_@sfx@;
+ const int wstep = vstep * UNROLL;
+
+ // unrolled vector loop
+ for (; len >= wstep; len -= wstep, ip += istride*wstep, op += wstep) {
+ /**begin repeat2
+ * #U = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ #if UNROLL > @U@
+ npyv_@sfx@ v_@U@ = npyv_loadn_@sfx@(ip + @U@ * vstep * istride, istride);
+ npyv_@sfx@ r_@U@ = npyv_@intrin@_@sfx@(v_@U@);
+ npyv_store_@sfx@(op + @U@ * vstep, r_@U@);
+ #endif
+ /**end repeat2**/
+ }
+ // single vector loop
+ for (; len >= vstep; len -= vstep, ip += istride*vstep, op += vstep) {
+ npyv_@sfx@ v = npyv_loadn_@sfx@(ip, istride);
+ npyv_@sfx@ r = npyv_@intrin@_@sfx@(v);
+ npyv_store_@sfx@(op, r);
+ }
+ // scalar finish up any remaining iterations
+ for (; len > 0; --len, ip += istride, ++op) {
+ *op = scalar_@intrin@(*ip);
+ }
+}
+// non-contiguous input and output
+// limit unroll to 2x
+#if UNROLL > 2
+#undef UNROLL
+#define UNROLL 2
+#endif
+static NPY_INLINE void
+simd_unary_nn_@intrin@_@sfx@(const npyv_lanetype_@sfx@ *ip, npy_intp istride,
+ npyv_lanetype_@sfx@ *op, npy_intp ostride,
+ npy_intp len)
+{
+ const int vstep = npyv_nlanes_@sfx@;
+ const int wstep = vstep * UNROLL;
+
+ // unrolled vector loop
+ for (; len >= wstep; len -= wstep, ip += istride*wstep, op += ostride*wstep) {
+ /**begin repeat2
+ * #U = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ #if UNROLL > @U@
+ npyv_@sfx@ v_@U@ = npyv_loadn_@sfx@(ip + @U@ * vstep * istride, istride);
+ npyv_@sfx@ r_@U@ = npyv_@intrin@_@sfx@(v_@U@);
+ npyv_storen_@sfx@(op + @U@ * vstep * ostride, ostride, r_@U@);
+ #endif
+ /**end repeat2**/
+ }
+ // single vector loop
+ for (; len >= vstep; len -= vstep, ip += istride*vstep, op += ostride*vstep) {
+ npyv_@sfx@ v = npyv_loadn_@sfx@(ip, istride);
+ npyv_@sfx@ r = npyv_@intrin@_@sfx@(v);
+ npyv_storen_@sfx@(op, ostride, r);
+ }
+ // scalar finish up any remaining iterations
+ for (; len > 0; --len, ip += istride, op += ostride) {
+ *op = scalar_@intrin@(*ip);
+ }
+}
+#endif // @supports_ncontig@
+#undef UNROLL
+#endif // @simd_chk@
+/*end repeat1**/
+/**end repeat**/
+
+/********************************************************************************
+ ** Defining ufunc inner functions
+ ********************************************************************************/
+/**begin repeat
+ * #TYPE = UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * BYTE, SHORT, INT, LONG, LONGLONG,
+ * FLOAT, DOUBLE, LONGDOUBLE#
+ *
+ * #BTYPE = BYTE, SHORT, INT, LONG, LONGLONG,
+ * BYTE, SHORT, INT, LONG, LONGLONG,
+ * FLOAT, DOUBLE, LONGDOUBLE#
+ * #type = npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
+ * npy_byte, npy_short, npy_int, npy_long, npy_longlong,
+ * npy_float, npy_double, npy_longdouble#
+ *
+ * #is_fp = 0*10, 1*3#
+ * #is_unsigned = 1*5, 0*5, 0*3#
+ * #supports_ncontig = 0*2, 1*3, 0*2, 1*3, 1*3#
+ */
+#undef TO_SIMD_SFX
+#if 0
+/**begin repeat1
+ * #len = 8, 16, 32, 64#
+ */
+#elif NPY_SIMD && NPY_BITSOF_@BTYPE@ == @len@
+ #if @is_fp@
+ #define TO_SIMD_SFX(X) X##_f@len@
+ #if NPY_BITSOF_@BTYPE@ == 32 && !NPY_SIMD_F32
+ #undef TO_SIMD_SFX
+ #endif
+ #if NPY_BITSOF_@BTYPE@ == 64 && !NPY_SIMD_F64
+ #undef TO_SIMD_SFX
+ #endif
+ #elif @is_unsigned@
+ #define TO_SIMD_SFX(X) X##_u@len@
+ #else
+ #define TO_SIMD_SFX(X) X##_s@len@
+ #endif
+/**end repeat1**/
+#endif
+
+/**begin repeat1
+ * #kind = negative#
+ * #intrin = negative#
+ */
+NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(@TYPE@_@kind@)
+(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
+{
+ char *ip = args[0], *op = args[1];
+ npy_intp istep = steps[0], ostep = steps[1],
+ len = dimensions[0];
+#ifdef TO_SIMD_SFX
+ #undef STYPE
+ #define STYPE TO_SIMD_SFX(npyv_lanetype)
+ if (!is_mem_overlap(ip, istep, op, ostep, len)) {
+ if (IS_UNARY_CONT(@type@, @type@)) {
+ // no overlap and operands are contiguous
+ TO_SIMD_SFX(simd_unary_cc_@intrin@)(
+ (STYPE*)ip, (STYPE*)op, len
+ );
+ goto clear;
+ }
+ #if @supports_ncontig@
+ const npy_intp istride = istep / sizeof(STYPE);
+ const npy_intp ostride = ostep / sizeof(STYPE);
+ if (TO_SIMD_SFX(npyv_loadable_stride)(istride) &&
+ TO_SIMD_SFX(npyv_storable_stride)(ostride))
+ {
+ if (istride == 1 && ostride != 1) {
+ // contiguous input, non-contiguous output
+ TO_SIMD_SFX(simd_unary_cn_@intrin@)(
+ (STYPE*)ip, (STYPE*)op, ostride, len
+ );
+ goto clear;
+ }
+ else if (istride != 1 && ostride == 1) {
+ // non-contiguous input, contiguous output
+ TO_SIMD_SFX(simd_unary_nc_@intrin@)(
+ (STYPE*)ip, istride, (STYPE*)op, len
+ );
+ goto clear;
+ }
+ // SSE2 does better with unrolled scalar for heavy non-contiguous
+ #if !defined(NPY_HAVE_SSE2)
+ else if (istride != 1 && ostride != 1) {
+ // non-contiguous input and output
+ TO_SIMD_SFX(simd_unary_nn_@intrin@)(
+ (STYPE*)ip, istride, (STYPE*)op, ostride, len
+ );
+ goto clear;
+ }
+ #endif
+ }
+ #endif // @supports_ncontig@
+ }
+#endif // TO_SIMD_SFX
+#ifndef NPY_DISABLE_OPTIMIZATION
+ /*
+ * scalar unrolls
+ * 8x unroll performed best on
+ * - Apple M1 Native / arm64
+ * - Apple M1 Rosetta / SSE42
+ * - iMacPro / AVX512
+ */
+ #define UNROLL 8
+ for (; len >= UNROLL; len -= UNROLL, ip += istep*UNROLL, op += ostep*UNROLL) {
+ /**begin repeat2
+ * #U = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ #if UNROLL > @U@
+ const @type@ in_@U@ = *((const @type@ *)(ip + @U@ * istep));
+ *((@type@ *)(op + @U@ * ostep)) = scalar_@intrin@(in_@U@);
+ #endif
+ /**end repeat2**/
+ }
+#endif // NPY_DISABLE_OPTIMIZATION
+ for (; len > 0; --len, ip += istep, op += ostep) {
+ *((@type@ *)op) = scalar_@intrin@(*(const @type@ *)ip);
+ }
+#ifdef TO_SIMD_SFX
+clear:
+ npyv_cleanup();
+#endif
+#if @is_fp@
+ npy_clear_floatstatus_barrier((char*)dimensions);
+#endif
+}
+/**end repeat**/
+
+#undef NEGATIVE_CONTIG_ONLY
diff --git a/numpy/core/src/umath/simd.inc.src b/numpy/core/src/umath/simd.inc.src
index 5351ec1fa..6fc1501c9 100644
--- a/numpy/core/src/umath/simd.inc.src
+++ b/numpy/core/src/umath/simd.inc.src
@@ -129,39 +129,9 @@ run_@func@_avx512_skx_@TYPE@(char **args, npy_intp const *dimensions, npy_intp c
* #vector = 1, 1, 0#
* #VECTOR = NPY_SIMD, NPY_SIMD_F64, 0 #
*/
-
-/**begin repeat1
- * #func = negative#
- * #check = IS_BLOCKABLE_UNARY#
- * #name = unary#
- */
-
-#if @vector@ && defined NPY_HAVE_SSE2_INTRINSICS
-
-/* prototypes */
-static void
-sse2_@func@_@TYPE@(@type@ *, @type@ *, const npy_intp n);
-
-#endif
-
-static inline int
-run_@name@_simd_@func@_@TYPE@(char **args, npy_intp const *dimensions, npy_intp const *steps)
-{
-#if @vector@ && defined NPY_HAVE_SSE2_INTRINSICS
- if (@check@(sizeof(@type@), VECTOR_SIZE_BYTES)) {
- sse2_@func@_@TYPE@((@type@*)args[1], (@type@*)args[0], dimensions[0]);
- return 1;
- }
-#endif
- return 0;
-}
-
-/**end repeat1**/
-
/**begin repeat1
* #kind = isnan, isfinite, isinf, signbit#
*/
-
#if @vector@ && defined NPY_HAVE_SSE2_INTRINSICS
static void
@@ -181,9 +151,7 @@ run_@kind@_simd_@TYPE@(char **args, npy_intp const *dimensions, npy_intp const *
#endif
return 0;
}
-
/**end repeat1**/
-
/**end repeat**/
/*
@@ -426,41 +394,6 @@ sse2_@kind@_@TYPE@(npy_bool * op, @type@ * ip1, npy_intp n)
}
/**end repeat1**/
-
-static void
-sse2_negative_@TYPE@(@type@ * op, @type@ * ip, const npy_intp n)
-{
- /*
- * get 0x7FFFFFFF mask (everything but signbit set)
- * float & ~mask will remove the sign, float ^ mask flips the sign
- * this is equivalent to how the compiler implements fabs on amd64
- */
- const @vtype@ mask = @vpre@_set1_@vsuf@(-0.@c@);
-
- /* align output to VECTOR_SIZE_BYTES bytes */
- LOOP_BLOCK_ALIGN_VAR(op, @type@, VECTOR_SIZE_BYTES) {
- op[i] = -ip[i];
- }
- assert((npy_uintp)n < (VECTOR_SIZE_BYTES / sizeof(@type@)) ||
- npy_is_aligned(&op[i], VECTOR_SIZE_BYTES));
- if (npy_is_aligned(&ip[i], VECTOR_SIZE_BYTES)) {
- LOOP_BLOCKED(@type@, VECTOR_SIZE_BYTES) {
- @vtype@ a = @vpre@_load_@vsuf@(&ip[i]);
- @vpre@_store_@vsuf@(&op[i], @vpre@_xor_@vsuf@(mask, a));
- }
- }
- else {
- LOOP_BLOCKED(@type@, VECTOR_SIZE_BYTES) {
- @vtype@ a = @vpre@_loadu_@vsuf@(&ip[i]);
- @vpre@_store_@vsuf@(&op[i], @vpre@_xor_@vsuf@(mask, a));
- }
- }
- LOOP_BLOCKED_END {
- op[i] = -ip[i];
- }
-}
-/**end repeat1**/
-
/**end repeat**/
/* bunch of helper functions used in ISA_exp/log_FLOAT*/
diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c
index 707f39e94..a0a16a0f9 100644
--- a/numpy/core/src/umath/ufunc_type_resolution.c
+++ b/numpy/core/src/umath/ufunc_type_resolution.c
@@ -358,13 +358,24 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc,
}
if (type_tup == NULL) {
+ if (PyArray_ISDATETIME(operands[0])
+ && PyArray_ISDATETIME(operands[1])
+ && type_num1 != type_num2) {
+ /*
+ * Reject mixed datetime and timedelta explictly, this was always
+ * implicitly rejected because casting fails (except with
+ * `casting="unsafe"` admittedly).
+ * This is required to ensure that `==` and `!=` can correctly
+ * detect that they should return a result array of False/True.
+ */
+ return raise_binary_type_reso_error(ufunc, operands);
+ }
/*
- * DEPRECATED NumPy 1.20, 2020-12.
- * This check is required to avoid the FutureWarning that
- * ResultType will give for number->string promotions.
+ * This check is required to avoid a potential FutureWarning that
+ * ResultType would give for number->string promotions.
* (We never supported flexible dtypes here.)
*/
- if (!PyArray_ISFLEXIBLE(operands[0]) &&
+ else if (!PyArray_ISFLEXIBLE(operands[0]) &&
!PyArray_ISFLEXIBLE(operands[1])) {
out_dtypes[0] = PyArray_ResultType(2, operands, 0, NULL);
if (out_dtypes[0] == NULL) {
diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py
index 3a8db40df..4ec1f83d4 100644
--- a/numpy/core/tests/test_deprecations.py
+++ b/numpy/core/tests/test_deprecations.py
@@ -138,80 +138,6 @@ class _VisibleDeprecationTestCase(_DeprecationTestCase):
warning_cls = np.VisibleDeprecationWarning
-class TestComparisonDeprecations(_DeprecationTestCase):
- """This tests the deprecation, for non-element-wise comparison logic.
- This used to mean that when an error occurred during element-wise comparison
- (i.e. broadcasting) NotImplemented was returned, but also in the comparison
- itself, False was given instead of the error.
-
- Also test FutureWarning for the None comparison.
- """
-
- message = "elementwise.* comparison failed; .*"
-
- def test_normal_types(self):
- for op in (operator.eq, operator.ne):
- # Broadcasting errors:
- self.assert_deprecated(op, args=(np.zeros(3), []))
- a = np.zeros(3, dtype='i,i')
- # (warning is issued a couple of times here)
- self.assert_deprecated(op, args=(a, a[:-1]), num=None)
-
- # ragged array comparison returns True/False
- a = np.array([1, np.array([1,2,3])], dtype=object)
- b = np.array([1, np.array([1,2,3])], dtype=object)
- self.assert_deprecated(op, args=(a, b), num=None)
-
- def test_string(self):
- # For two string arrays, strings always raised the broadcasting error:
- a = np.array(['a', 'b'])
- b = np.array(['a', 'b', 'c'])
- assert_warns(FutureWarning, lambda x, y: x == y, a, b)
-
- # The empty list is not cast to string, and this used to pass due
- # to dtype mismatch; now (2018-06-21) it correctly leads to a
- # FutureWarning.
- assert_warns(FutureWarning, lambda: a == [])
-
- def test_void_dtype_equality_failures(self):
- class NotArray:
- def __array__(self):
- raise TypeError
-
- # Needed so Python 3 does not raise DeprecationWarning twice.
- def __ne__(self, other):
- return NotImplemented
-
- self.assert_deprecated(lambda: np.arange(2) == NotArray())
- self.assert_deprecated(lambda: np.arange(2) != NotArray())
-
- def test_array_richcompare_legacy_weirdness(self):
- # It doesn't really work to use assert_deprecated here, b/c part of
- # the point of assert_deprecated is to check that when warnings are
- # set to "error" mode then the error is propagated -- which is good!
- # But here we are testing a bunch of code that is deprecated *because*
- # it has the habit of swallowing up errors and converting them into
- # different warnings. So assert_warns will have to be sufficient.
- assert_warns(FutureWarning, lambda: np.arange(2) == "a")
- assert_warns(FutureWarning, lambda: np.arange(2) != "a")
- # No warning for scalar comparisons
- with warnings.catch_warnings():
- warnings.filterwarnings("error")
- assert_(not (np.array(0) == "a"))
- assert_(np.array(0) != "a")
- assert_(not (np.int16(0) == "a"))
- assert_(np.int16(0) != "a")
-
- for arg1 in [np.asarray(0), np.int16(0)]:
- struct = np.zeros(2, dtype="i4,i4")
- for arg2 in [struct, "a"]:
- for f in [operator.lt, operator.le, operator.gt, operator.ge]:
- with warnings.catch_warnings() as l:
- warnings.filterwarnings("always")
- assert_raises(TypeError, f, arg1, arg2)
- assert_(not l)
-
-
class TestDatetime64Timezone(_DeprecationTestCase):
"""Parsing of datetime64 with timezones deprecated in 1.11.0, because
datetime64 is now timezone naive rather than UTC only.
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 4d3996d86..63ac32f20 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1271,6 +1271,13 @@ class TestStructured:
a = np.zeros((1, 0, 1), [('a', '<f8', (1, 1))])
assert_equal(a, a)
+ @pytest.mark.parametrize("op", [operator.eq, operator.ne])
+ def test_structured_array_comparison_bad_broadcasts(self, op):
+ a = np.zeros(3, dtype='i,i')
+ b = np.array([], dtype="i,i")
+ with pytest.raises(ValueError):
+ op(a, b)
+
def test_structured_comparisons_with_promotion(self):
# Check that structured arrays can be compared so long as their
# dtypes promote fine:
@@ -1291,7 +1298,10 @@ class TestStructured:
assert_equal(a == b, [False, True])
assert_equal(a != b, [True, False])
- def test_void_comparison_failures(self):
+ @pytest.mark.parametrize("op", [
+ operator.eq, lambda x, y: operator.eq(y, x),
+ operator.ne, lambda x, y: operator.ne(y, x)])
+ def test_void_comparison_failures(self, op):
# In principle, one could decide to return an array of False for some
# if comparisons are impossible. But right now we return TypeError
# when "void" dtype are involved.
@@ -1299,18 +1309,18 @@ class TestStructured:
y = np.zeros(3)
# Cannot compare non-structured to structured:
with pytest.raises(TypeError):
- x == y
+ op(x, y)
# Added title prevents promotion, but casts are OK:
y = np.zeros(3, dtype=[(('title', 'a'), 'i1')])
assert np.can_cast(y.dtype, x.dtype)
with pytest.raises(TypeError):
- x == y
+ op(x, y)
x = np.zeros(3, dtype="V7")
y = np.zeros(3, dtype="V8")
with pytest.raises(TypeError):
- x == y
+ op(x, y)
def test_casting(self):
# Check that casting a structured array to change its byte order
@@ -9493,6 +9503,105 @@ def test_equal_override():
assert_equal(array != my_always_equal, 'ne')
+@pytest.mark.parametrize("op", [operator.eq, operator.ne])
+@pytest.mark.parametrize(["dt1", "dt2"], [
+ ([("f", "i")], [("f", "i")]), # structured comparison (successfull)
+ ("M8", "d"), # impossible comparison: result is all True or False
+ ("d", "d"), # valid comparison
+ ])
+def test_equal_subclass_no_override(op, dt1, dt2):
+ # Test how the three different possible code-paths deal with subclasses
+
+ class MyArr(np.ndarray):
+ called_wrap = 0
+
+ def __array_wrap__(self, new):
+ type(self).called_wrap += 1
+ return super().__array_wrap__(new)
+
+ numpy_arr = np.zeros(5, dtype=dt1)
+ my_arr = np.zeros(5, dtype=dt2).view(MyArr)
+
+ assert type(op(numpy_arr, my_arr)) is MyArr
+ assert type(op(my_arr, numpy_arr)) is MyArr
+ # We expect 2 calls (more if there were more fields):
+ assert MyArr.called_wrap == 2
+
+
+@pytest.mark.parametrize(["dt1", "dt2"], [
+ ("M8[ns]", "d"),
+ ("M8[s]", "l"),
+ ("m8[ns]", "d"),
+ # Missing: ("m8[ns]", "l") as timedelta currently promotes ints
+ ("M8[s]", "m8[s]"),
+ ("S5", "U5"),
+ # Structured/void dtypes have explicit paths not tested here.
+])
+def test_no_loop_gives_all_true_or_false(dt1, dt2):
+ # Make sure they broadcast to test result shape, use random values, since
+ # the actual value should be ignored
+ arr1 = np.random.randint(5, size=100).astype(dt1)
+ arr2 = np.random.randint(5, size=99)[:, np.newaxis].astype(dt2)
+
+ res = arr1 == arr2
+ assert res.shape == (99, 100)
+ assert res.dtype == bool
+ assert not res.any()
+
+ res = arr1 != arr2
+ assert res.shape == (99, 100)
+ assert res.dtype == bool
+ assert res.all()
+
+ # incompatible shapes raise though
+ arr2 = np.random.randint(5, size=99).astype(dt2)
+ with pytest.raises(ValueError):
+ arr1 == arr2
+
+ with pytest.raises(ValueError):
+ arr1 != arr2
+
+ # Basic test with another operation:
+ with pytest.raises(np.core._exceptions._UFuncNoLoopError):
+ arr1 > arr2
+
+
+@pytest.mark.parametrize("op", [
+ operator.eq, operator.ne, operator.le, operator.lt, operator.ge,
+ operator.gt])
+def test_comparisons_forwards_error(op):
+ class NotArray:
+ def __array__(self):
+ raise TypeError("run you fools")
+
+ with pytest.raises(TypeError, match="run you fools"):
+ op(np.arange(2), NotArray())
+
+ with pytest.raises(TypeError, match="run you fools"):
+ op(NotArray(), np.arange(2))
+
+
+def test_richcompare_scalar_boolean_singleton_return():
+ # These are currently guaranteed to be the boolean singletons, but maybe
+ # returning NumPy booleans would also be OK:
+ assert (np.array(0) == "a") is False
+ assert (np.array(0) != "a") is True
+ assert (np.int16(0) == "a") is False
+ assert (np.int16(0) != "a") is True
+
+
+@pytest.mark.parametrize("op", [
+ operator.eq, operator.ne, operator.le, operator.lt, operator.ge,
+ operator.gt])
+def test_ragged_comparison_fails(op):
+ # This needs to convert the internal array to True/False, which fails:
+ a = np.array([1, np.array([1, 2, 3])], dtype=object)
+ b = np.array([1, np.array([1, 2, 3])], dtype=object)
+
+ with pytest.raises(ValueError, match="The truth value.*ambiguous"):
+ op(a, b)
+
+
@pytest.mark.parametrize(
["fun", "npfun"],
[
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index 160e4a3a4..f638284de 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -128,10 +128,7 @@ class TestRegression:
assert_(a[1] == 'auto')
assert_(a[0] != 'auto')
b = np.linspace(0, 10, 11)
- # This should return true for now, but will eventually raise an error:
- with suppress_warnings() as sup:
- sup.filter(FutureWarning)
- assert_(b != 'auto')
+ assert_array_equal(b != 'auto', np.ones(11, dtype=bool))
assert_(b[0] != 'auto')
def test_unicode_swapping(self):
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 1160eca54..88ab7e014 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -1069,8 +1069,9 @@ class TestPower:
assert_complex_equal(np.power(zero, -p), cnan)
assert_complex_equal(np.power(zero, -1+0.2j), cnan)
- # Testing 0^{Non-zero} issue 18378
+ @pytest.mark.skipif(IS_WASM, reason="fp errors don't work in wasm")
def test_zero_power_nonzero(self):
+ # Testing 0^{Non-zero} issue 18378
zero = np.array([0.0+0.0j])
cnan = np.array([complex(np.nan, np.nan)])
@@ -2634,6 +2635,35 @@ class TestAbsoluteNegative:
np.abs(d, out=d)
np.abs(np.ones_like(d), out=d)
+ @pytest.mark.parametrize("dtype", ['d', 'f', 'int32', 'int64'])
+ @pytest.mark.parametrize("big", [True, False])
+ def test_noncontiguous(self, dtype, big):
+ data = np.array([-1.0, 1.0, -0.0, 0.0, 2.2251e-308, -2.5, 2.5, -6,
+ 6, -2.2251e-308, -8, 10], dtype=dtype)
+ expect = np.array([1.0, -1.0, 0.0, -0.0, -2.2251e-308, 2.5, -2.5, 6,
+ -6, 2.2251e-308, 8, -10], dtype=dtype)
+ if big:
+ data = np.repeat(data, 10)
+ expect = np.repeat(expect, 10)
+ out = np.ndarray(data.shape, dtype=dtype)
+ ncontig_in = data[1::2]
+ ncontig_out = out[1::2]
+ contig_in = np.array(ncontig_in)
+ # contig in, contig out
+ assert_array_equal(np.negative(contig_in), expect[1::2])
+ # contig in, ncontig out
+ assert_array_equal(np.negative(contig_in, out=ncontig_out),
+ expect[1::2])
+ # ncontig in, contig out
+ assert_array_equal(np.negative(ncontig_in), expect[1::2])
+ # ncontig in, ncontig out
+ assert_array_equal(np.negative(ncontig_in, out=ncontig_out),
+ expect[1::2])
+ # contig in, contig out, nd stride
+ data_split = np.array(np.array_split(data, 2))
+ expect_split = np.array(np.array_split(expect, 2))
+ assert_equal(np.negative(data_split), expect_split)
+
class TestPositive:
def test_valid(self):
diff --git a/numpy/core/tests/test_unicode.py b/numpy/core/tests/test_unicode.py
index 2d7c2818e..e5454bd48 100644
--- a/numpy/core/tests/test_unicode.py
+++ b/numpy/core/tests/test_unicode.py
@@ -36,10 +36,10 @@ def test_string_cast():
uni_arr1 = str_arr.astype('>U')
uni_arr2 = str_arr.astype('<U')
- with pytest.warns(FutureWarning):
- assert str_arr != uni_arr1
- with pytest.warns(FutureWarning):
- assert str_arr != uni_arr2
+ assert_array_equal(str_arr != uni_arr1, np.ones(2, dtype=bool))
+ assert_array_equal(uni_arr1 != str_arr, np.ones(2, dtype=bool))
+ assert_array_equal(str_arr == uni_arr1, np.zeros(2, dtype=bool))
+ assert_array_equal(uni_arr1 == str_arr, np.zeros(2, dtype=bool))
assert_array_equal(uni_arr1, uni_arr2)
diff --git a/numpy/exceptions.py b/numpy/exceptions.py
index 81a2f3c65..721b8102e 100644
--- a/numpy/exceptions.py
+++ b/numpy/exceptions.py
@@ -25,8 +25,9 @@ Exceptions
.. autosummary::
:toctree: generated/
- AxisError Given when an axis was invalid.
- TooHardError Error specific to `numpy.shares_memory`.
+ AxisError Given when an axis was invalid.
+ DTypePromotionError Given when no common dtype could be found.
+ TooHardError Error specific to `numpy.shares_memory`.
"""
@@ -35,7 +36,7 @@ from ._utils import set_module as _set_module
__all__ = [
"ComplexWarning", "VisibleDeprecationWarning",
- "TooHardError", "AxisError"]
+ "TooHardError", "AxisError", "DTypePromotionError"]
# Disallow reloading this module so as to preserve the identities of the
@@ -195,3 +196,48 @@ class AxisError(ValueError, IndexError):
if self._msg is not None:
msg = f"{self._msg}: {msg}"
return msg
+
+
+class DTypePromotionError(TypeError):
+ """Multiple DTypes could not be converted to a common one.
+
+ This exception derives from ``TypeError`` and is raised whenever dtypes
+ cannot be converted to a single common one. This can be because they
+ are of a different category/class or incompatible instances of the same
+ one (see Examples).
+
+ Notes
+ -----
+ Many functions will use promotion to find the correct result and
+ implementation. For these functions the error will typically be chained
+ with a more specific error indicating that no implementation was found
+ for the input dtypes.
+
+ Typically promotion should be considered "invalid" between the dtypes of
+ two arrays when `arr1 == arr2` can safely return all ``False`` because the
+ dtypes are fundamentally different.
+
+ Examples
+ --------
+ Datetimes and complex numbers are incompatible classes and cannot be
+ promoted:
+
+ >>> np.result_type(np.dtype("M8[s]"), np.complex128)
+ DTypePromotionError: The DType <class 'numpy.dtype[datetime64]'> could not
+ be promoted by <class 'numpy.dtype[complex128]'>. This means that no common
+ DType exists for the given inputs. For example they cannot be stored in a
+ single array unless the dtype is `object`. The full list of DTypes is:
+ (<class 'numpy.dtype[datetime64]'>, <class 'numpy.dtype[complex128]'>)
+
+ For example for structured dtypes, the structure can mismatch and the
+ same ``DTypePromotionError`` is given when two structured dtypes with
+ a mismatch in their number of fields is given:
+
+ >>> dtype1 = np.dtype([("field1", np.float64), ("field2", np.int64)])
+ >>> dtype2 = np.dtype([("field1", np.float64)])
+ >>> np.promote_types(dtype1, dtype2)
+ DTypePromotionError: field names `('field1', 'field2')` and `('field1',)`
+ mismatch.
+
+ """
+ pass
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 0ab49fa11..35a3b3543 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3689,7 +3689,7 @@ def msort(a):
return b
-def _ureduce(a, func, **kwargs):
+def _ureduce(a, func, keepdims=False, **kwargs):
"""
Internal Function.
Call `func` with `a` as first argument swapping the axes to use extended
@@ -3717,13 +3717,20 @@ def _ureduce(a, func, **kwargs):
"""
a = np.asanyarray(a)
axis = kwargs.get('axis', None)
+ out = kwargs.get('out', None)
+
+ if keepdims is np._NoValue:
+ keepdims = False
+
+ nd = a.ndim
if axis is not None:
- keepdim = list(a.shape)
- nd = a.ndim
axis = _nx.normalize_axis_tuple(axis, nd)
- for ax in axis:
- keepdim[ax] = 1
+ if keepdims:
+ if out is not None:
+ index_out = tuple(
+ 0 if i in axis else slice(None) for i in range(nd))
+ kwargs['out'] = out[(Ellipsis, ) + index_out]
if len(axis) == 1:
kwargs['axis'] = axis[0]
@@ -3736,12 +3743,27 @@ def _ureduce(a, func, **kwargs):
# merge reduced axis
a = a.reshape(a.shape[:nkeep] + (-1,))
kwargs['axis'] = -1
- keepdim = tuple(keepdim)
else:
- keepdim = (1,) * a.ndim
+ if keepdims:
+ if out is not None:
+ index_out = (0, ) * nd
+ kwargs['out'] = out[(Ellipsis, ) + index_out]
r = func(a, **kwargs)
- return r, keepdim
+
+ if out is not None:
+ return out
+
+ if keepdims:
+ if axis is None:
+ index_r = (np.newaxis, ) * nd
+ else:
+ index_r = tuple(
+ np.newaxis if i in axis else slice(None)
+ for i in range(nd))
+ r = r[(Ellipsis, ) + index_r]
+
+ return r
def _median_dispatcher(
@@ -3831,12 +3853,8 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
>>> assert not np.all(a==b)
"""
- r, k = _ureduce(a, func=_median, axis=axis, out=out,
+ return _ureduce(a, func=_median, keepdims=keepdims, axis=axis, out=out,
overwrite_input=overwrite_input)
- if keepdims:
- return r.reshape(k)
- else:
- return r
def _median(a, axis=None, out=None, overwrite_input=False):
@@ -4452,17 +4470,14 @@ def _quantile_unchecked(a,
method="linear",
keepdims=False):
"""Assumes that q is in [0, 1], and is an ndarray"""
- r, k = _ureduce(a,
+ return _ureduce(a,
func=_quantile_ureduce_func,
q=q,
+ keepdims=keepdims,
axis=axis,
out=out,
overwrite_input=overwrite_input,
method=method)
- if keepdims:
- return r.reshape(q.shape + k)
- else:
- return r
def _quantile_is_valid(q):
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py
index 3814c0727..ae2dfa165 100644
--- a/numpy/lib/nanfunctions.py
+++ b/numpy/lib/nanfunctions.py
@@ -1214,12 +1214,9 @@ def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=np._NoValu
if a.size == 0:
return np.nanmean(a, axis, out=out, keepdims=keepdims)
- r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
+ return function_base._ureduce(a, func=_nanmedian, keepdims=keepdims,
+ axis=axis, out=out,
overwrite_input=overwrite_input)
- if keepdims and keepdims is not np._NoValue:
- return r.reshape(k)
- else:
- return r
def _nanpercentile_dispatcher(
@@ -1556,17 +1553,14 @@ def _nanquantile_unchecked(
# so deal them upfront
if a.size == 0:
return np.nanmean(a, axis, out=out, keepdims=keepdims)
- r, k = function_base._ureduce(a,
+ return function_base._ureduce(a,
func=_nanquantile_ureduce_func,
q=q,
+ keepdims=keepdims,
axis=axis,
out=out,
overwrite_input=overwrite_input,
method=method)
- if keepdims and keepdims is not np._NoValue:
- return r.reshape(q.shape + k)
- else:
- return r
def _nanquantile_ureduce_func(a, q, axis=None, out=None, overwrite_input=False,
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index c5b31ebf4..1bb4c4efa 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -25,6 +25,7 @@ from numpy.lib import (
i0, insert, interp, kaiser, meshgrid, msort, piecewise, place, rot90,
select, setxor1d, sinc, trapz, trim_zeros, unwrap, unique, vectorize
)
+from numpy.core.numeric import normalize_axis_tuple
def get_mat(n):
@@ -3331,6 +3332,32 @@ class TestPercentile:
assert_equal(np.percentile(d, [1, 7], axis=(0, 3),
keepdims=True).shape, (2, 1, 5, 7, 1))
+ @pytest.mark.parametrize('q', [7, [1, 7]])
+ @pytest.mark.parametrize(
+ argnames='axis',
+ argvalues=[
+ None,
+ 1,
+ (1,),
+ (0, 1),
+ (-3, -1),
+ ]
+ )
+ def test_keepdims_out(self, q, axis):
+ d = np.ones((3, 5, 7, 11))
+ if axis is None:
+ shape_out = (1,) * d.ndim
+ else:
+ axis_norm = normalize_axis_tuple(axis, d.ndim)
+ shape_out = tuple(
+ 1 if i in axis_norm else d.shape[i] for i in range(d.ndim))
+ shape_out = np.shape(q) + shape_out
+
+ out = np.empty(shape_out)
+ result = np.percentile(d, q, axis=axis, keepdims=True, out=out)
+ assert result is out
+ assert_equal(result.shape, shape_out)
+
def test_out(self):
o = np.zeros((4,))
d = np.ones((3, 4))
@@ -3843,6 +3870,29 @@ class TestMedian:
assert_equal(np.median(d, axis=(0, 1, 3), keepdims=True).shape,
(1, 1, 7, 1))
+ @pytest.mark.parametrize(
+ argnames='axis',
+ argvalues=[
+ None,
+ 1,
+ (1, ),
+ (0, 1),
+ (-3, -1),
+ ]
+ )
+ def test_keepdims_out(self, axis):
+ d = np.ones((3, 5, 7, 11))
+ if axis is None:
+ shape_out = (1,) * d.ndim
+ else:
+ axis_norm = normalize_axis_tuple(axis, d.ndim)
+ shape_out = tuple(
+ 1 if i in axis_norm else d.shape[i] for i in range(d.ndim))
+ out = np.empty(shape_out)
+ result = np.median(d, axis=axis, keepdims=True, out=out)
+ assert result is out
+ assert_equal(result.shape, shape_out)
+
class TestAdd_newdoc_ufunc:
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py
index 733a077ea..64464edcc 100644
--- a/numpy/lib/tests/test_nanfunctions.py
+++ b/numpy/lib/tests/test_nanfunctions.py
@@ -3,6 +3,7 @@ import pytest
import inspect
import numpy as np
+from numpy.core.numeric import normalize_axis_tuple
from numpy.lib.nanfunctions import _nan_mask, _replace_nan
from numpy.testing import (
assert_, assert_equal, assert_almost_equal, assert_raises,
@@ -807,6 +808,33 @@ class TestNanFunctions_Median:
res = np.nanmedian(d, axis=(0, 1, 3), keepdims=True)
assert_equal(res.shape, (1, 1, 7, 1))
+ @pytest.mark.parametrize(
+ argnames='axis',
+ argvalues=[
+ None,
+ 1,
+ (1, ),
+ (0, 1),
+ (-3, -1),
+ ]
+ )
+ def test_keepdims_out(self, axis):
+ d = np.ones((3, 5, 7, 11))
+ # Randomly set some elements to NaN:
+ w = np.random.random((4, 200)) * np.array(d.shape)[:, None]
+ w = w.astype(np.intp)
+ d[tuple(w)] = np.nan
+ if axis is None:
+ shape_out = (1,) * d.ndim
+ else:
+ axis_norm = normalize_axis_tuple(axis, d.ndim)
+ shape_out = tuple(
+ 1 if i in axis_norm else d.shape[i] for i in range(d.ndim))
+ out = np.empty(shape_out)
+ result = np.nanmedian(d, axis=axis, keepdims=True, out=out)
+ assert result is out
+ assert_equal(result.shape, shape_out)
+
def test_out(self):
mat = np.random.rand(3, 3)
nan_mat = np.insert(mat, [0, 2], np.nan, axis=1)
@@ -982,6 +1010,36 @@ class TestNanFunctions_Percentile:
res = np.nanpercentile(d, 90, axis=(0, 1, 3), keepdims=True)
assert_equal(res.shape, (1, 1, 7, 1))
+ @pytest.mark.parametrize('q', [7, [1, 7]])
+ @pytest.mark.parametrize(
+ argnames='axis',
+ argvalues=[
+ None,
+ 1,
+ (1,),
+ (0, 1),
+ (-3, -1),
+ ]
+ )
+ def test_keepdims_out(self, q, axis):
+ d = np.ones((3, 5, 7, 11))
+ # Randomly set some elements to NaN:
+ w = np.random.random((4, 200)) * np.array(d.shape)[:, None]
+ w = w.astype(np.intp)
+ d[tuple(w)] = np.nan
+ if axis is None:
+ shape_out = (1,) * d.ndim
+ else:
+ axis_norm = normalize_axis_tuple(axis, d.ndim)
+ shape_out = tuple(
+ 1 if i in axis_norm else d.shape[i] for i in range(d.ndim))
+ shape_out = np.shape(q) + shape_out
+
+ out = np.empty(shape_out)
+ result = np.nanpercentile(d, q, axis=axis, keepdims=True, out=out)
+ assert result is out
+ assert_equal(result.shape, shape_out)
+
def test_out(self):
mat = np.random.rand(3, 3)
nan_mat = np.insert(mat, [0, 2], np.nan, axis=1)
diff --git a/numpy/linalg/tests/test_regression.py b/numpy/linalg/tests/test_regression.py
index 7ed932bc9..af38443a9 100644
--- a/numpy/linalg/tests/test_regression.py
+++ b/numpy/linalg/tests/test_regression.py
@@ -107,10 +107,7 @@ class TestRegression:
assert_raises(ValueError, linalg.norm, testvector, ord='nuc')
assert_raises(ValueError, linalg.norm, testvector, ord=np.inf)
assert_raises(ValueError, linalg.norm, testvector, ord=-np.inf)
- with warnings.catch_warnings():
- warnings.simplefilter("error", DeprecationWarning)
- assert_raises((AttributeError, DeprecationWarning),
- linalg.norm, testvector, ord=0)
+ assert_raises(ValueError, linalg.norm, testvector, ord=0)
assert_raises(ValueError, linalg.norm, testvector, ord=-1)
assert_raises(ValueError, linalg.norm, testvector, ord=-2)
diff --git a/numpy/ma/bench.py b/numpy/ma/bench.py
deleted file mode 100644
index 56865683d..000000000
--- a/numpy/ma/bench.py
+++ /dev/null
@@ -1,130 +0,0 @@
-#!/usr/bin/env python3
-
-import timeit
-import numpy
-
-
-###############################################################################
-# Global variables #
-###############################################################################
-
-
-# Small arrays
-xs = numpy.random.uniform(-1, 1, 6).reshape(2, 3)
-ys = numpy.random.uniform(-1, 1, 6).reshape(2, 3)
-zs = xs + 1j * ys
-m1 = [[True, False, False], [False, False, True]]
-m2 = [[True, False, True], [False, False, True]]
-nmxs = numpy.ma.array(xs, mask=m1)
-nmys = numpy.ma.array(ys, mask=m2)
-nmzs = numpy.ma.array(zs, mask=m1)
-
-# Big arrays
-xl = numpy.random.uniform(-1, 1, 100*100).reshape(100, 100)
-yl = numpy.random.uniform(-1, 1, 100*100).reshape(100, 100)
-zl = xl + 1j * yl
-maskx = xl > 0.8
-masky = yl < -0.8
-nmxl = numpy.ma.array(xl, mask=maskx)
-nmyl = numpy.ma.array(yl, mask=masky)
-nmzl = numpy.ma.array(zl, mask=maskx)
-
-
-###############################################################################
-# Functions #
-###############################################################################
-
-
-def timer(s, v='', nloop=500, nrep=3):
- units = ["s", "ms", "µs", "ns"]
- scaling = [1, 1e3, 1e6, 1e9]
- print("%s : %-50s : " % (v, s), end=' ')
- varnames = ["%ss,nm%ss,%sl,nm%sl" % tuple(x*4) for x in 'xyz']
- setup = 'from __main__ import numpy, ma, %s' % ','.join(varnames)
- Timer = timeit.Timer(stmt=s, setup=setup)
- best = min(Timer.repeat(nrep, nloop)) / nloop
- if best > 0.0:
- order = min(-int(numpy.floor(numpy.log10(best)) // 3), 3)
- else:
- order = 3
- print("%d loops, best of %d: %.*g %s per loop" % (nloop, nrep,
- 3,
- best * scaling[order],
- units[order]))
-
-
-def compare_functions_1v(func, nloop=500,
- xs=xs, nmxs=nmxs, xl=xl, nmxl=nmxl):
- funcname = func.__name__
- print("-"*50)
- print(f'{funcname} on small arrays')
- module, data = "numpy.ma", "nmxs"
- timer("%(module)s.%(funcname)s(%(data)s)" % locals(), v="%11s" % module, nloop=nloop)
-
- print("%s on large arrays" % funcname)
- module, data = "numpy.ma", "nmxl"
- timer("%(module)s.%(funcname)s(%(data)s)" % locals(), v="%11s" % module, nloop=nloop)
- return
-
-def compare_methods(methodname, args, vars='x', nloop=500, test=True,
- xs=xs, nmxs=nmxs, xl=xl, nmxl=nmxl):
- print("-"*50)
- print(f'{methodname} on small arrays')
- data, ver = f'nm{vars}l', 'numpy.ma'
- timer("%(data)s.%(methodname)s(%(args)s)" % locals(), v=ver, nloop=nloop)
-
- print("%s on large arrays" % methodname)
- data, ver = "nm%sl" % vars, 'numpy.ma'
- timer("%(data)s.%(methodname)s(%(args)s)" % locals(), v=ver, nloop=nloop)
- return
-
-def compare_functions_2v(func, nloop=500, test=True,
- xs=xs, nmxs=nmxs,
- ys=ys, nmys=nmys,
- xl=xl, nmxl=nmxl,
- yl=yl, nmyl=nmyl):
- funcname = func.__name__
- print("-"*50)
- print(f'{funcname} on small arrays')
- module, data = "numpy.ma", "nmxs,nmys"
- timer("%(module)s.%(funcname)s(%(data)s)" % locals(), v="%11s" % module, nloop=nloop)
-
- print(f'{funcname} on large arrays')
- module, data = "numpy.ma", "nmxl,nmyl"
- timer("%(module)s.%(funcname)s(%(data)s)" % locals(), v="%11s" % module, nloop=nloop)
- return
-
-
-if __name__ == '__main__':
- compare_functions_1v(numpy.sin)
- compare_functions_1v(numpy.log)
- compare_functions_1v(numpy.sqrt)
-
- compare_functions_2v(numpy.multiply)
- compare_functions_2v(numpy.divide)
- compare_functions_2v(numpy.power)
-
- compare_methods('ravel', '', nloop=1000)
- compare_methods('conjugate', '', 'z', nloop=1000)
- compare_methods('transpose', '', nloop=1000)
- compare_methods('compressed', '', nloop=1000)
- compare_methods('__getitem__', '0', nloop=1000)
- compare_methods('__getitem__', '(0,0)', nloop=1000)
- compare_methods('__getitem__', '[0,-1]', nloop=1000)
- compare_methods('__setitem__', '0, 17', nloop=1000, test=False)
- compare_methods('__setitem__', '(0,0), 17', nloop=1000, test=False)
-
- print("-"*50)
- print("__setitem__ on small arrays")
- timer('nmxs.__setitem__((-1,0),numpy.ma.masked)', 'numpy.ma ', nloop=10000)
-
- print("-"*50)
- print("__setitem__ on large arrays")
- timer('nmxl.__setitem__((-1,0),numpy.ma.masked)', 'numpy.ma ', nloop=10000)
-
- print("-"*50)
- print("where on small arrays")
- timer('numpy.ma.where(nmxs>2,nmxs,nmys)', 'numpy.ma ', nloop=1000)
- print("-"*50)
- print("where on large arrays")
- timer('numpy.ma.where(nmxl>2,nmxl,nmyl)', 'numpy.ma ', nloop=100)
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 4e7f8e85e..41bce0f22 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -732,12 +732,8 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
else:
return m
- r, k = _ureduce(a, func=_median, axis=axis, out=out,
+ return _ureduce(a, func=_median, keepdims=keepdims, axis=axis, out=out,
overwrite_input=overwrite_input)
- if keepdims:
- return r.reshape(k)
- else:
- return r
def _median(a, axis=None, out=None, overwrite_input=False):
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index 3c95e25ea..38603fb84 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -12,6 +12,7 @@ import itertools
import pytest
import numpy as np
+from numpy.core.numeric import normalize_axis_tuple
from numpy.testing import (
assert_warns, suppress_warnings
)
@@ -989,6 +990,34 @@ class TestMedian:
assert_(r is out)
assert_(type(r) is MaskedArray)
+ @pytest.mark.parametrize(
+ argnames='axis',
+ argvalues=[
+ None,
+ 1,
+ (1, ),
+ (0, 1),
+ (-3, -1),
+ ]
+ )
+ def test_keepdims_out(self, axis):
+ mask = np.zeros((3, 5, 7, 11), dtype=bool)
+ # Randomly set some elements to True:
+ w = np.random.random((4, 200)) * np.array(mask.shape)[:, None]
+ w = w.astype(np.intp)
+ mask[tuple(w)] = np.nan
+ d = masked_array(np.ones(mask.shape), mask=mask)
+ if axis is None:
+ shape_out = (1,) * d.ndim
+ else:
+ axis_norm = normalize_axis_tuple(axis, d.ndim)
+ shape_out = tuple(
+ 1 if i in axis_norm else d.shape[i] for i in range(d.ndim))
+ out = masked_array(np.empty(shape_out))
+ result = median(d, axis=axis, keepdims=True, out=out)
+ assert result is out
+ assert_equal(result.shape, shape_out)
+
def test_single_non_masked_value_on_axis(self):
data = [[1., 0.],
[0., 3.],
diff --git a/numpy/tests/test_public_api.py b/numpy/tests/test_public_api.py
index 4cd602510..28ed2f1df 100644
--- a/numpy/tests/test_public_api.py
+++ b/numpy/tests/test_public_api.py
@@ -280,7 +280,6 @@ PRIVATE_BUT_PRESENT_MODULES = ['numpy.' + s for s in [
"lib.utils",
"linalg.lapack_lite",
"linalg.linalg",
- "ma.bench",
"ma.core",
"ma.testutils",
"ma.timer_comparison",
@@ -361,6 +360,7 @@ SKIP_LIST_2 = [
'numpy.matlib.char',
'numpy.matlib.rec',
'numpy.matlib.emath',
+ 'numpy.matlib.exceptions',
'numpy.matlib.math',
'numpy.matlib.linalg',
'numpy.matlib.fft',