diff options
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 147 |
1 files changed, 92 insertions, 55 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index b90b0a9c9..2c0205175 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -17,7 +17,7 @@ from .multiarray import ( inner, int_asbuffer, lexsort, matmul, may_share_memory, min_scalar_type, ndarray, nditer, nested_iters, promote_types, putmask, result_type, set_numeric_ops, shares_memory, vdot, where, - zeros) + zeros, normalize_axis_index) if sys.version_info[0] < 3: from .multiarray import newbuffer, getbuffer @@ -27,7 +27,7 @@ from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE, ERR_DEFAULT, PINF, NAN) from . import numerictypes from .numerictypes import longlong, intc, int_, float_, complex_, bool_ -from ._internal import TooHardError +from ._internal import TooHardError, AxisError bitwise_not = invert ufunc = type(sin) @@ -65,7 +65,7 @@ __all__ = [ 'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', 'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul', 'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT', - 'TooHardError', + 'TooHardError', 'AxisError' ] @@ -441,7 +441,7 @@ def count_nonzero(a, axis=None): nullstr = a.dtype.type('') return (a != nullstr).sum(axis=axis, dtype=np.intp) - axis = asarray(_validate_axis(axis, a.ndim, 'axis')) + axis = asarray(normalize_axis_tuple(axis, a.ndim)) counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a) if axis.size == 1: @@ -1354,9 +1354,9 @@ def tensordot(a, b, axes=2): a, b = asarray(a), asarray(b) as_ = a.shape - nda = len(a.shape) + nda = a.ndim bs = b.shape - ndb = len(b.shape) + ndb = b.ndim equal = True if na != nb: equal = False @@ -1460,16 +1460,14 @@ def roll(a, shift, axis=None): return roll(a.ravel(), shift, 0).reshape(a.shape) else: + axis = normalize_axis_tuple(axis, a.ndim, allow_duplicate=True) broadcasted = broadcast(shift, axis) - if len(broadcasted.shape) > 1: + if broadcasted.ndim > 1: raise ValueError( "'shift' and 'axis' should be scalars or 1D sequences") shifts = {ax: 0 for ax in range(a.ndim)} for sh, ax in broadcasted: - if -a.ndim <= ax < a.ndim: - shifts[ax % a.ndim] += sh - else: - raise ValueError("'axis' entry is out of bounds") + shifts[ax] += sh rolls = [((slice(None), slice(None)),)] * a.ndim for ax, offset in shifts.items(): @@ -1527,15 +1525,12 @@ def rollaxis(a, axis, start=0): """ n = a.ndim - if axis < 0: - axis += n + axis = normalize_axis_index(axis, n) if start < 0: start += n msg = "'%s' arg requires %d <= %s < %d, but %d was passed in" - if not (0 <= axis < n): - raise ValueError(msg % ('axis', -n, 'axis', n, axis)) if not (0 <= start < n + 1): - raise ValueError(msg % ('start', -n, 'start', n + 1, start)) + raise AxisError(msg % ('start', -n, 'start', n + 1, start)) if axis < start: # it's been removed start -= 1 @@ -1547,17 +1542,59 @@ def rollaxis(a, axis, start=0): return a.transpose(axes) -def _validate_axis(axis, ndim, argname): +def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): + """ + Normalizes an axis argument into a tuple of non-negative integer axes. + + This handles shorthands such as ``1`` and converts them to ``(1,)``, + as well as performing the handling of negative indices covered by + `normalize_axis_index`. + + By default, this forbids axes from being specified multiple times. + + Used internally by multi-axis-checking logic. + + .. versionadded:: 1.13.0 + + Parameters + ---------- + axis : int, iterable of int + The un-normalized index or indices of the axis. + ndim : int + The number of dimensions of the array that `axis` should be normalized + against. + argname : str, optional + A prefix to put before the error message, typically the name of the + argument. + allow_duplicate : bool, optional + If False, the default, disallow an axis from being specified twice. + + Returns + ------- + normalized_axes : tuple of int + The normalized axis index, such that `0 <= normalized_axis < ndim` + + Raises + ------ + AxisError + If any axis provided is out of range + ValueError + If an axis is repeated + + See also + -------- + normalize_axis_index : normalizing a single scalar axis + """ try: axis = [operator.index(axis)] except TypeError: - axis = list(axis) - axis = [a + ndim if a < 0 else a for a in axis] - if not builtins.all(0 <= a < ndim for a in axis): - raise ValueError('invalid axis for this array in `%s` argument' % - argname) - if len(set(axis)) != len(axis): - raise ValueError('repeated axis in `%s` argument' % argname) + axis = tuple(axis) + axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis) + if not allow_duplicate and len(set(axis)) != len(axis): + if argname: + raise ValueError('repeated axis in `{}` argument'.format(argname)) + else: + raise ValueError('repeated axis') return axis @@ -1602,7 +1639,7 @@ def moveaxis(a, source, destination): >>> np.transpose(x).shape (5, 4, 3) - >>> np.swapaxis(x, 0, -1).shape + >>> np.swapaxes(x, 0, -1).shape (5, 4, 3) >>> np.moveaxis(x, [0, 1], [-1, -2]).shape (5, 4, 3) @@ -1617,8 +1654,8 @@ def moveaxis(a, source, destination): a = asarray(a) transpose = a.transpose - source = _validate_axis(source, a.ndim, 'source') - destination = _validate_axis(destination, a.ndim, 'destination') + source = normalize_axis_tuple(source, a.ndim, 'source') + destination = normalize_axis_tuple(destination, a.ndim, 'destination') if len(source) != len(destination): raise ValueError('`source` and `destination` arguments must have ' 'the same number of elements') @@ -1755,11 +1792,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): a = asarray(a) b = asarray(b) # Check axisa and axisb are within bounds - axis_msg = "'axis{0}' out of bounds" - if axisa < -a.ndim or axisa >= a.ndim: - raise ValueError(axis_msg.format('a')) - if axisb < -b.ndim or axisb >= b.ndim: - raise ValueError(axis_msg.format('b')) + axisa = normalize_axis_index(axisa, a.ndim, msg_prefix='axisa') + axisb = normalize_axis_index(axisb, b.ndim, msg_prefix='axisb') + # Move working axis to the end of the shape a = rollaxis(a, axisa, a.ndim) b = rollaxis(b, axisb, b.ndim) @@ -1773,8 +1808,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): if a.shape[-1] == 3 or b.shape[-1] == 3: shape += (3,) # Check axisc is within bounds - if axisc < -len(shape) or axisc >= len(shape): - raise ValueError(axis_msg.format('c')) + axisc = normalize_axis_index(axisc, len(shape), msg_prefix='axisc') dtype = promote_types(a.dtype, b.dtype) cp = empty(shape, dtype) @@ -1893,35 +1927,35 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None): 'array([ 0.000001, 0. , 2. , 3. ])' """ + if type(arr) is not ndarray: + class_name = type(arr).__name__ + else: + class_name = "array" + if arr.size > 0 or arr.shape == (0,): lst = array2string(arr, max_line_width, precision, suppress_small, - ', ', "array(") + ', ', class_name + "(") else: # show zero-length shape unless it is (0,) lst = "[], shape=%s" % (repr(arr.shape),) - if arr.__class__ is not ndarray: - cName = arr.__class__.__name__ - else: - cName = "array" - skipdtype = (arr.dtype.type in _typelessdata) and arr.size > 0 if skipdtype: - return "%s(%s)" % (cName, lst) + return "%s(%s)" % (class_name, lst) else: typename = arr.dtype.name # Quote typename in the output if it is "complex". if typename and not (typename[0].isalpha() and typename.isalnum()): typename = "'%s'" % typename - lf = '' + lf = ' ' if issubclass(arr.dtype.type, flexible): if arr.dtype.names: typename = "%s" % str(arr.dtype) else: typename = "'%s'" % str(arr.dtype) - lf = '\n'+' '*len("array(") - return cName + "(%s, %sdtype=%s)" % (lst, lf, typename) + lf = '\n'+' '*len(class_name + "(") + return "%s(%s,%sdtype=%s)" % (class_name, lst, lf, typename) def array_str(a, max_line_width=None, precision=None, suppress_small=None): @@ -2089,15 +2123,12 @@ def indices(dimensions, dtype=int): """ dimensions = tuple(dimensions) N = len(dimensions) - if N == 0: - return array([], dtype=dtype) + shape = (1,)*N res = empty((N,)+dimensions, dtype=dtype) for i, dim in enumerate(dimensions): - tmp = arange(dim, dtype=dtype) - tmp.shape = (1,)*i + (dim,)+(1,)*(N-i-1) - newdim = dimensions[:i] + (1,) + dimensions[i+1:] - val = zeros(newdim, dtype) - add(tmp, val, res[i]) + res[i] = arange(dim, dtype=dtype).reshape( + shape[:i] + (dim,) + shape[i+1:] + ) return res @@ -2114,8 +2145,8 @@ def fromfunction(function, shape, **kwargs): The function is called with N parameters, where N is the rank of `shape`. Each parameter represents the coordinates of the array varying along a specific axis. For example, if `shape` - were ``(2, 2)``, then the parameters in turn be (0, 0), (0, 1), - (1, 0), (1, 1). + were ``(2, 2)``, then the parameters would be + ``array([[0, 0], [1, 1]])`` and ``array([[0, 1], [0, 1]])`` shape : (N,) tuple of ints Shape of the output array, which also determines the shape of the coordinate arrays passed to `function`. @@ -2212,7 +2243,7 @@ def binary_repr(num, width=None): designated form. If the `width` value is insufficient, it will be ignored, and `num` will - be returned in binary(`num` > 0) or two's complement (`num` < 0) form + be returned in binary (`num` > 0) or two's complement (`num` < 0) form with its width equal to the minimum number of bits needed to represent the number in the designated form. This behavior is deprecated and will later raise an error. @@ -2282,10 +2313,16 @@ def binary_repr(num, width=None): else: poswidth = len(bin(-num)[2:]) - twocomp = 2**(poswidth + 1) + num + # See gh-8679: remove extra digit + # for numbers at boundaries. + if 2**(poswidth - 1) == -num: + poswidth -= 1 + + twocomp = 2**(poswidth + 1) + num binary = bin(twocomp)[2:] binwidth = len(binary) + outwidth = max(binwidth, width) warn_if_insufficient(width, binwidth) return '1' * (outwidth - binwidth) + binary |