summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py147
1 files changed, 92 insertions, 55 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index b90b0a9c9..2c0205175 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -17,7 +17,7 @@ from .multiarray import (
inner, int_asbuffer, lexsort, matmul, may_share_memory,
min_scalar_type, ndarray, nditer, nested_iters, promote_types,
putmask, result_type, set_numeric_ops, shares_memory, vdot, where,
- zeros)
+ zeros, normalize_axis_index)
if sys.version_info[0] < 3:
from .multiarray import newbuffer, getbuffer
@@ -27,7 +27,7 @@ from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE,
ERR_DEFAULT, PINF, NAN)
from . import numerictypes
from .numerictypes import longlong, intc, int_, float_, complex_, bool_
-from ._internal import TooHardError
+from ._internal import TooHardError, AxisError
bitwise_not = invert
ufunc = type(sin)
@@ -65,7 +65,7 @@ __all__ = [
'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE',
'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul',
'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT',
- 'TooHardError',
+ 'TooHardError', 'AxisError'
]
@@ -441,7 +441,7 @@ def count_nonzero(a, axis=None):
nullstr = a.dtype.type('')
return (a != nullstr).sum(axis=axis, dtype=np.intp)
- axis = asarray(_validate_axis(axis, a.ndim, 'axis'))
+ axis = asarray(normalize_axis_tuple(axis, a.ndim))
counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a)
if axis.size == 1:
@@ -1354,9 +1354,9 @@ def tensordot(a, b, axes=2):
a, b = asarray(a), asarray(b)
as_ = a.shape
- nda = len(a.shape)
+ nda = a.ndim
bs = b.shape
- ndb = len(b.shape)
+ ndb = b.ndim
equal = True
if na != nb:
equal = False
@@ -1460,16 +1460,14 @@ def roll(a, shift, axis=None):
return roll(a.ravel(), shift, 0).reshape(a.shape)
else:
+ axis = normalize_axis_tuple(axis, a.ndim, allow_duplicate=True)
broadcasted = broadcast(shift, axis)
- if len(broadcasted.shape) > 1:
+ if broadcasted.ndim > 1:
raise ValueError(
"'shift' and 'axis' should be scalars or 1D sequences")
shifts = {ax: 0 for ax in range(a.ndim)}
for sh, ax in broadcasted:
- if -a.ndim <= ax < a.ndim:
- shifts[ax % a.ndim] += sh
- else:
- raise ValueError("'axis' entry is out of bounds")
+ shifts[ax] += sh
rolls = [((slice(None), slice(None)),)] * a.ndim
for ax, offset in shifts.items():
@@ -1527,15 +1525,12 @@ def rollaxis(a, axis, start=0):
"""
n = a.ndim
- if axis < 0:
- axis += n
+ axis = normalize_axis_index(axis, n)
if start < 0:
start += n
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
- if not (0 <= axis < n):
- raise ValueError(msg % ('axis', -n, 'axis', n, axis))
if not (0 <= start < n + 1):
- raise ValueError(msg % ('start', -n, 'start', n + 1, start))
+ raise AxisError(msg % ('start', -n, 'start', n + 1, start))
if axis < start:
# it's been removed
start -= 1
@@ -1547,17 +1542,59 @@ def rollaxis(a, axis, start=0):
return a.transpose(axes)
-def _validate_axis(axis, ndim, argname):
+def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
+ """
+ Normalizes an axis argument into a tuple of non-negative integer axes.
+
+ This handles shorthands such as ``1`` and converts them to ``(1,)``,
+ as well as performing the handling of negative indices covered by
+ `normalize_axis_index`.
+
+ By default, this forbids axes from being specified multiple times.
+
+ Used internally by multi-axis-checking logic.
+
+ .. versionadded:: 1.13.0
+
+ Parameters
+ ----------
+ axis : int, iterable of int
+ The un-normalized index or indices of the axis.
+ ndim : int
+ The number of dimensions of the array that `axis` should be normalized
+ against.
+ argname : str, optional
+ A prefix to put before the error message, typically the name of the
+ argument.
+ allow_duplicate : bool, optional
+ If False, the default, disallow an axis from being specified twice.
+
+ Returns
+ -------
+ normalized_axes : tuple of int
+ The normalized axis index, such that `0 <= normalized_axis < ndim`
+
+ Raises
+ ------
+ AxisError
+ If any axis provided is out of range
+ ValueError
+ If an axis is repeated
+
+ See also
+ --------
+ normalize_axis_index : normalizing a single scalar axis
+ """
try:
axis = [operator.index(axis)]
except TypeError:
- axis = list(axis)
- axis = [a + ndim if a < 0 else a for a in axis]
- if not builtins.all(0 <= a < ndim for a in axis):
- raise ValueError('invalid axis for this array in `%s` argument' %
- argname)
- if len(set(axis)) != len(axis):
- raise ValueError('repeated axis in `%s` argument' % argname)
+ axis = tuple(axis)
+ axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis)
+ if not allow_duplicate and len(set(axis)) != len(axis):
+ if argname:
+ raise ValueError('repeated axis in `{}` argument'.format(argname))
+ else:
+ raise ValueError('repeated axis')
return axis
@@ -1602,7 +1639,7 @@ def moveaxis(a, source, destination):
>>> np.transpose(x).shape
(5, 4, 3)
- >>> np.swapaxis(x, 0, -1).shape
+ >>> np.swapaxes(x, 0, -1).shape
(5, 4, 3)
>>> np.moveaxis(x, [0, 1], [-1, -2]).shape
(5, 4, 3)
@@ -1617,8 +1654,8 @@ def moveaxis(a, source, destination):
a = asarray(a)
transpose = a.transpose
- source = _validate_axis(source, a.ndim, 'source')
- destination = _validate_axis(destination, a.ndim, 'destination')
+ source = normalize_axis_tuple(source, a.ndim, 'source')
+ destination = normalize_axis_tuple(destination, a.ndim, 'destination')
if len(source) != len(destination):
raise ValueError('`source` and `destination` arguments must have '
'the same number of elements')
@@ -1755,11 +1792,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
a = asarray(a)
b = asarray(b)
# Check axisa and axisb are within bounds
- axis_msg = "'axis{0}' out of bounds"
- if axisa < -a.ndim or axisa >= a.ndim:
- raise ValueError(axis_msg.format('a'))
- if axisb < -b.ndim or axisb >= b.ndim:
- raise ValueError(axis_msg.format('b'))
+ axisa = normalize_axis_index(axisa, a.ndim, msg_prefix='axisa')
+ axisb = normalize_axis_index(axisb, b.ndim, msg_prefix='axisb')
+
# Move working axis to the end of the shape
a = rollaxis(a, axisa, a.ndim)
b = rollaxis(b, axisb, b.ndim)
@@ -1773,8 +1808,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
if a.shape[-1] == 3 or b.shape[-1] == 3:
shape += (3,)
# Check axisc is within bounds
- if axisc < -len(shape) or axisc >= len(shape):
- raise ValueError(axis_msg.format('c'))
+ axisc = normalize_axis_index(axisc, len(shape), msg_prefix='axisc')
dtype = promote_types(a.dtype, b.dtype)
cp = empty(shape, dtype)
@@ -1893,35 +1927,35 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None):
'array([ 0.000001, 0. , 2. , 3. ])'
"""
+ if type(arr) is not ndarray:
+ class_name = type(arr).__name__
+ else:
+ class_name = "array"
+
if arr.size > 0 or arr.shape == (0,):
lst = array2string(arr, max_line_width, precision, suppress_small,
- ', ', "array(")
+ ', ', class_name + "(")
else: # show zero-length shape unless it is (0,)
lst = "[], shape=%s" % (repr(arr.shape),)
- if arr.__class__ is not ndarray:
- cName = arr.__class__.__name__
- else:
- cName = "array"
-
skipdtype = (arr.dtype.type in _typelessdata) and arr.size > 0
if skipdtype:
- return "%s(%s)" % (cName, lst)
+ return "%s(%s)" % (class_name, lst)
else:
typename = arr.dtype.name
# Quote typename in the output if it is "complex".
if typename and not (typename[0].isalpha() and typename.isalnum()):
typename = "'%s'" % typename
- lf = ''
+ lf = ' '
if issubclass(arr.dtype.type, flexible):
if arr.dtype.names:
typename = "%s" % str(arr.dtype)
else:
typename = "'%s'" % str(arr.dtype)
- lf = '\n'+' '*len("array(")
- return cName + "(%s, %sdtype=%s)" % (lst, lf, typename)
+ lf = '\n'+' '*len(class_name + "(")
+ return "%s(%s,%sdtype=%s)" % (class_name, lst, lf, typename)
def array_str(a, max_line_width=None, precision=None, suppress_small=None):
@@ -2089,15 +2123,12 @@ def indices(dimensions, dtype=int):
"""
dimensions = tuple(dimensions)
N = len(dimensions)
- if N == 0:
- return array([], dtype=dtype)
+ shape = (1,)*N
res = empty((N,)+dimensions, dtype=dtype)
for i, dim in enumerate(dimensions):
- tmp = arange(dim, dtype=dtype)
- tmp.shape = (1,)*i + (dim,)+(1,)*(N-i-1)
- newdim = dimensions[:i] + (1,) + dimensions[i+1:]
- val = zeros(newdim, dtype)
- add(tmp, val, res[i])
+ res[i] = arange(dim, dtype=dtype).reshape(
+ shape[:i] + (dim,) + shape[i+1:]
+ )
return res
@@ -2114,8 +2145,8 @@ def fromfunction(function, shape, **kwargs):
The function is called with N parameters, where N is the rank of
`shape`. Each parameter represents the coordinates of the array
varying along a specific axis. For example, if `shape`
- were ``(2, 2)``, then the parameters in turn be (0, 0), (0, 1),
- (1, 0), (1, 1).
+ were ``(2, 2)``, then the parameters would be
+ ``array([[0, 0], [1, 1]])`` and ``array([[0, 1], [0, 1]])``
shape : (N,) tuple of ints
Shape of the output array, which also determines the shape of
the coordinate arrays passed to `function`.
@@ -2212,7 +2243,7 @@ def binary_repr(num, width=None):
designated form.
If the `width` value is insufficient, it will be ignored, and `num` will
- be returned in binary(`num` > 0) or two's complement (`num` < 0) form
+ be returned in binary (`num` > 0) or two's complement (`num` < 0) form
with its width equal to the minimum number of bits needed to represent
the number in the designated form. This behavior is deprecated and will
later raise an error.
@@ -2282,10 +2313,16 @@ def binary_repr(num, width=None):
else:
poswidth = len(bin(-num)[2:])
- twocomp = 2**(poswidth + 1) + num
+ # See gh-8679: remove extra digit
+ # for numbers at boundaries.
+ if 2**(poswidth - 1) == -num:
+ poswidth -= 1
+
+ twocomp = 2**(poswidth + 1) + num
binary = bin(twocomp)[2:]
binwidth = len(binary)
+
outwidth = max(binwidth, width)
warn_if_insufficient(width, binwidth)
return '1' * (outwidth - binwidth) + binary