summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/release/upcoming_changes/19459.new_feature.rst4
-rw-r--r--doc/source/reference/routines.other.rst9
-rw-r--r--numpy/__init__.pyi9
-rw-r--r--numpy/core/_exceptions.py96
-rw-r--r--numpy/core/tests/test__exceptions.py30
-rw-r--r--numpy/typing/tests/data/fail/warnings_and_errors.py8
-rw-r--r--numpy/typing/tests/data/pass/warnings_and_errors.py3
-rw-r--r--numpy/typing/tests/data/reveal/warnings_and_errors.py3
8 files changed, 139 insertions, 23 deletions
diff --git a/doc/release/upcoming_changes/19459.new_feature.rst b/doc/release/upcoming_changes/19459.new_feature.rst
new file mode 100644
index 000000000..aecae670f
--- /dev/null
+++ b/doc/release/upcoming_changes/19459.new_feature.rst
@@ -0,0 +1,4 @@
+The ``ndim`` and ``axis`` attributes have been added to `numpy.AxisError`
+-------------------------------------------------------------------------
+The ``ndim`` and ``axis`` parameters are now also stored as attributes
+within each `numpy.AxisError` instance.
diff --git a/doc/source/reference/routines.other.rst b/doc/source/reference/routines.other.rst
index aefd680bb..339857409 100644
--- a/doc/source/reference/routines.other.rst
+++ b/doc/source/reference/routines.other.rst
@@ -55,4 +55,11 @@ Matlab-like Functions
:toctree: generated/
who
- disp \ No newline at end of file
+ disp
+
+Exceptions
+----------
+.. autosummary::
+ :toctree: generated/
+
+ AxisError
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index cafa296eb..413a32569 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -3681,9 +3681,12 @@ class RankWarning(UserWarning): ...
class TooHardError(RuntimeError): ...
class AxisError(ValueError, IndexError):
- def __init__(
- self, axis: int, ndim: Optional[int] = ..., msg_prefix: Optional[str] = ...
- ) -> None: ...
+ axis: None | int
+ ndim: None | int
+ @overload
+ def __init__(self, axis: str, ndim: None = ..., msg_prefix: None = ...) -> None: ...
+ @overload
+ def __init__(self, axis: int, ndim: int, msg_prefix: None | str = ...) -> None: ...
_CallType = TypeVar("_CallType", bound=Union[_ErrFunc, _SupportsWrite])
diff --git a/numpy/core/_exceptions.py b/numpy/core/_exceptions.py
index 77aa2f6e1..3cd8042ce 100644
--- a/numpy/core/_exceptions.py
+++ b/numpy/core/_exceptions.py
@@ -122,20 +122,94 @@ class TooHardError(RuntimeError):
@set_module('numpy')
class AxisError(ValueError, IndexError):
- """ Axis supplied was invalid. """
- def __init__(self, axis, ndim=None, msg_prefix=None):
- # single-argument form just delegates to base class
- if ndim is None and msg_prefix is None:
- msg = axis
+ """Axis supplied was invalid.
+
+ This is raised whenever an ``axis`` parameter is specified that is larger
+ than the number of array dimensions.
+ For compatibility with code written against older numpy versions, which
+ raised a mixture of `ValueError` and `IndexError` for this situation, this
+ exception subclasses both to ensure that ``except ValueError`` and
+ ``except IndexError`` statements continue to catch `AxisError`.
+
+ .. versionadded:: 1.13
+
+ Parameters
+ ----------
+ axis : int or str
+ The out of bounds axis or a custom exception message.
+ If an axis is provided, then `ndim` should be specified as well.
+ ndim : int, optional
+ The number of array dimensions.
+ msg_prefix : str, optional
+ A prefix for the exception message.
+
+ Attributes
+ ----------
+ axis : int, optional
+ The out of bounds axis or ``None`` if a custom exception
+ message was provided. This should be the axis as passed by
+ the user, before any normalization to resolve negative indices.
+
+ .. versionadded:: 1.22
+ ndim : int, optional
+ The number of array dimensions or ``None`` if a custom exception
+ message was provided.
+
+ .. versionadded:: 1.22
+
+
+ Examples
+ --------
+ >>> array_1d = np.arange(10)
+ >>> np.cumsum(array_1d, axis=1)
+ Traceback (most recent call last):
+ ...
+ numpy.AxisError: axis 1 is out of bounds for array of dimension 1
+
+ Negative axes are preserved:
+
+ >>> np.cumsum(array_1d, axis=-2)
+ Traceback (most recent call last):
+ ...
+ numpy.AxisError: axis -2 is out of bounds for array of dimension 1
- # do the string formatting here, to save work in the C code
+ The class constructor generally takes the axis and arrays'
+ dimensionality as arguments:
+
+ >>> print(np.AxisError(2, 1, msg_prefix='error'))
+ error: axis 2 is out of bounds for array of dimension 1
+
+ Alternatively, a custom exception message can be passed:
+
+ >>> print(np.AxisError('Custom error message'))
+ Custom error message
+
+ """
+
+ __slots__ = ("axis", "ndim", "_msg")
+
+ def __init__(self, axis, ndim=None, msg_prefix=None):
+ if ndim is msg_prefix is None:
+ # single-argument form: directly set the error message
+ self._msg = axis
+ self.axis = None
+ self.ndim = None
else:
- msg = ("axis {} is out of bounds for array of dimension {}"
- .format(axis, ndim))
- if msg_prefix is not None:
- msg = "{}: {}".format(msg_prefix, msg)
+ self._msg = msg_prefix
+ self.axis = axis
+ self.ndim = ndim
- super().__init__(msg)
+ def __str__(self):
+ axis = self.axis
+ ndim = self.ndim
+
+ if axis is ndim is None:
+ return self._msg
+ else:
+ msg = f"axis {axis} is out of bounds for array of dimension {ndim}"
+ if self._msg is not None:
+ msg = f"{self._msg}: {msg}"
+ return msg
@_display_as_base
diff --git a/numpy/core/tests/test__exceptions.py b/numpy/core/tests/test__exceptions.py
index 51c056936..c87412aa4 100644
--- a/numpy/core/tests/test__exceptions.py
+++ b/numpy/core/tests/test__exceptions.py
@@ -4,6 +4,7 @@ Tests of the ._exceptions module. Primarily for exercising the __str__ methods.
import pickle
+import pytest
import numpy as np
_ArrayMemoryError = np.core._exceptions._ArrayMemoryError
@@ -56,3 +57,32 @@ class TestUFuncNoLoopError:
def test_pickling(self):
""" Test that _UFuncNoLoopError can be pickled """
assert isinstance(pickle.dumps(_UFuncNoLoopError), bytes)
+
+
+@pytest.mark.parametrize("args", [
+ (2, 1, None),
+ (2, 1, "test_prefix"),
+ ("test message",),
+])
+class TestAxisError:
+ def test_attr(self, args):
+ """Validate attribute types."""
+ exc = np.AxisError(*args)
+ if len(args) == 1:
+ assert exc.axis is None
+ assert exc.ndim is None
+ else:
+ axis, ndim, *_ = args
+ assert exc.axis == axis
+ assert exc.ndim == ndim
+
+ def test_pickling(self, args):
+ """Test that `AxisError` can be pickled."""
+ exc = np.AxisError(*args)
+ exc2 = pickle.loads(pickle.dumps(exc))
+
+ assert type(exc) is type(exc2)
+ for name in ("axis", "ndim", "args"):
+ attr1 = getattr(exc, name)
+ attr2 = getattr(exc2, name)
+ assert attr1 == attr2, name
diff --git a/numpy/typing/tests/data/fail/warnings_and_errors.py b/numpy/typing/tests/data/fail/warnings_and_errors.py
index 7390cc45f..f4fa38293 100644
--- a/numpy/typing/tests/data/fail/warnings_and_errors.py
+++ b/numpy/typing/tests/data/fail/warnings_and_errors.py
@@ -1,7 +1,5 @@
import numpy as np
-np.AxisError(1.0) # E: Argument 1 to "AxisError" has incompatible type
-np.AxisError(1, ndim=2.0) # E: Argument "ndim" to "AxisError" has incompatible type
-np.AxisError(
- 2, msg_prefix=404 # E: Argument "msg_prefix" to "AxisError" has incompatible type
-)
+np.AxisError(1.0) # E: No overload variant
+np.AxisError(1, ndim=2.0) # E: No overload variant
+np.AxisError(2, msg_prefix=404) # E: No overload variant
diff --git a/numpy/typing/tests/data/pass/warnings_and_errors.py b/numpy/typing/tests/data/pass/warnings_and_errors.py
index 5b6ec2626..a556bf6bc 100644
--- a/numpy/typing/tests/data/pass/warnings_and_errors.py
+++ b/numpy/typing/tests/data/pass/warnings_and_errors.py
@@ -1,7 +1,6 @@
import numpy as np
-np.AxisError(1)
+np.AxisError("test")
np.AxisError(1, ndim=2)
-np.AxisError(1, ndim=None)
np.AxisError(1, ndim=2, msg_prefix="error")
np.AxisError(1, ndim=2, msg_prefix=None)
diff --git a/numpy/typing/tests/data/reveal/warnings_and_errors.py b/numpy/typing/tests/data/reveal/warnings_and_errors.py
index c428deb7a..3f20a0135 100644
--- a/numpy/typing/tests/data/reveal/warnings_and_errors.py
+++ b/numpy/typing/tests/data/reveal/warnings_and_errors.py
@@ -7,4 +7,5 @@ reveal_type(np.VisibleDeprecationWarning()) # E: numpy.VisibleDeprecationWarnin
reveal_type(np.ComplexWarning()) # E: numpy.ComplexWarning
reveal_type(np.RankWarning()) # E: numpy.RankWarning
reveal_type(np.TooHardError()) # E: numpy.TooHardError
-reveal_type(np.AxisError(1)) # E: numpy.AxisError
+reveal_type(np.AxisError("test")) # E: numpy.AxisError
+reveal_type(np.AxisError(5, 1)) # E: numpy.AxisError