summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/numeric.py34
-rw-r--r--numpy/core/tests/test_numeric.py4
-rw-r--r--numpy/lib/function_base.py2
-rw-r--r--numpy/linalg/linalg.py54
-rw-r--r--numpy/linalg/umath_linalg.c.src100
-rw-r--r--numpy/ma/extras.py30
6 files changed, 136 insertions, 88 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 7a0fa4b62..ea2d4d0a2 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1478,8 +1478,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
axisb : int, optional
Axis of `b` that defines the vector(s). By default, the last axis.
axisc : int, optional
- Axis of `c` containing the cross product vector(s). By default, the
- last axis.
+ Axis of `c` containing the cross product vector(s). Ignored if
+ both input vectors have dimension 2, as the return is scalar.
+ By default, the last axis.
axis : int, optional
If defined, the axis of `a`, `b` and `c` that defines the vector(s)
and cross product(s). Overrides `axisa`, `axisb` and `axisc`.
@@ -1570,6 +1571,12 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
axisa, axisb, axisc = (axis,) * 3
a = asarray(a)
b = asarray(b)
+ # Check axisa and axisb are within bounds
+ axis_msg = "'axis{0}' out of bounds"
+ if axisa < -a.ndim or axisa >= a.ndim:
+ raise ValueError(axis_msg.format('a'))
+ if axisb < -b.ndim or axisb >= b.ndim:
+ raise ValueError(axis_msg.format('b'))
# Move working axis to the end of the shape
a = rollaxis(a, axisa, a.ndim)
b = rollaxis(b, axisb, b.ndim)
@@ -1578,10 +1585,13 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3):
raise ValueError(msg)
- # Create the output array
+ # Create the output array
shape = broadcast(a[..., 0], b[..., 0]).shape
if a.shape[-1] == 3 or b.shape[-1] == 3:
shape += (3,)
+ # Check axisc is within bounds
+ if axisc < -len(shape) or axisc >= len(shape):
+ raise ValueError(axis_msg.format('c'))
dtype = promote_types(a.dtype, b.dtype)
cp = empty(shape, dtype)
@@ -1604,12 +1614,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
# a0 * b1 - a1 * b0
multiply(a0, b1, out=cp)
cp -= a1 * b0
- if cp.ndim == 0:
- return cp
- else:
- # This works because we are moving the last axis
- return rollaxis(cp, -1, axisc)
+ return cp
else:
+ assert b.shape[-1] == 3
# cp0 = a1 * b2 - 0 (a2 = 0)
# cp1 = 0 - a0 * b2 (a2 = 0)
# cp2 = a0 * b1 - a1 * b0
@@ -1618,7 +1625,8 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
negative(cp1, out=cp1)
multiply(a0, b1, out=cp2)
cp2 -= a1 * b0
- elif a.shape[-1] == 3:
+ else:
+ assert a.shape[-1] == 3
if b.shape[-1] == 3:
# cp0 = a1 * b2 - a2 * b1
# cp1 = a2 * b0 - a0 * b2
@@ -1633,6 +1641,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
multiply(a1, b0, out=tmp)
cp2 -= tmp
else:
+ assert b.shape[-1] == 2
# cp0 = 0 - a2 * b1 (b2 = 0)
# cp1 = a2 * b0 - 0 (b2 = 0)
# cp2 = a0 * b1 - a1 * b0
@@ -1642,11 +1651,8 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
multiply(a0, b1, out=cp2)
cp2 -= a1 * b0
- if cp.ndim == 1:
- return cp
- else:
- # This works because we are moving the last axis
- return rollaxis(cp, -1, axisc)
+ # This works because we are moving the last axis
+ return rollaxis(cp, -1, axisc)
#Use numarray's printing function
from .arrayprint import array2string, get_printoptions, set_printoptions
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 9f0fb47e5..ee304a7af 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2221,6 +2221,10 @@ class TestCross(TestCase):
assert_equal(np.cross(u, v, axisa=1, axisc=2).shape, (10, 5, 3, 7))
assert_raises(ValueError, np.cross, u, v, axisa=-5, axisb=2)
assert_raises(ValueError, np.cross, u, v, axisa=1, axisb=-4)
+ # gh-5885
+ u = np.ones((3, 4, 2))
+ for axisc in range(-2, 2):
+ assert_equal(np.cross(u, u, axisc=axisc).shape, (3, 4))
def test_outer_out_param():
arr1 = np.ones((5,))
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 04063755c..94d63c027 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -2320,7 +2320,7 @@ def hanning(M):
.. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right)
\\qquad 0 \\leq n \\leq M-1
- The Hanning was named for Julius van Hann, an Austrian meteorologist.
+ The Hanning was named for Julius von Hann, an Austrian meteorologist.
It is also known as the Cosine Bell. Some authors prefer that it be
called a Hann window, to help avoid confusion with the very similar
Hamming window.
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 30180f24a..d2e786970 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -23,7 +23,7 @@ from numpy.core import (
csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot,
add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product, abs,
- broadcast, atleast_2d, intp, asanyarray
+ broadcast, atleast_2d, intp, asanyarray, isscalar
)
from numpy.lib import triu, asfarray
from numpy.linalg import lapack_lite, _umath_linalg
@@ -382,7 +382,7 @@ def solve(a, b):
extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
r = gufunc(a, b, signature=signature, extobj=extobj)
- return wrap(r.astype(result_t))
+ return wrap(r.astype(result_t, copy=False))
def tensorinv(a, ind=2):
@@ -522,7 +522,7 @@ def inv(a):
signature = 'D->D' if isComplexType(t) else 'd->d'
extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj)
- return wrap(ainv.astype(result_t))
+ return wrap(ainv.astype(result_t, copy=False))
# Cholesky decomposition
@@ -606,7 +606,8 @@ def cholesky(a):
_assertNdSquareness(a)
t, result_t = _commonType(a)
signature = 'D->D' if isComplexType(t) else 'd->d'
- return wrap(gufunc(a, signature=signature, extobj=extobj).astype(result_t))
+ r = gufunc(a, signature=signature, extobj=extobj)
+ return wrap(r.astype(result_t, copy=False))
# QR decompostion
@@ -781,7 +782,7 @@ def qr(a, mode='reduced'):
if mode == 'economic':
if t != result_t :
- a = a.astype(result_t)
+ a = a.astype(result_t, copy=False)
return wrap(a.T)
# generate q from a
@@ -908,7 +909,7 @@ def eigvals(a):
else:
result_t = _complexType(result_t)
- return w.astype(result_t)
+ return w.astype(result_t, copy=False)
def eigvalsh(a, UPLO='L'):
"""
@@ -978,7 +979,7 @@ def eigvalsh(a, UPLO='L'):
t, result_t = _commonType(a)
signature = 'D->d' if isComplexType(t) else 'd->d'
w = gufunc(a, signature=signature, extobj=extobj)
- return w.astype(_realType(result_t))
+ return w.astype(_realType(result_t), copy=False)
def _convertarray(a):
t, result_t = _commonType(a)
@@ -1124,8 +1125,8 @@ def eig(a):
else:
result_t = _complexType(result_t)
- vt = vt.astype(result_t)
- return w.astype(result_t), wrap(vt)
+ vt = vt.astype(result_t, copy=False)
+ return w.astype(result_t, copy=False), wrap(vt)
def eigh(a, UPLO='L'):
@@ -1232,8 +1233,8 @@ def eigh(a, UPLO='L'):
signature = 'D->dD' if isComplexType(t) else 'd->dd'
w, vt = gufunc(a, signature=signature, extobj=extobj)
- w = w.astype(_realType(result_t))
- vt = vt.astype(result_t)
+ w = w.astype(_realType(result_t), copy=False)
+ vt = vt.astype(result_t, copy=False)
return w, wrap(vt)
@@ -1344,9 +1345,9 @@ def svd(a, full_matrices=1, compute_uv=1):
signature = 'D->DdD' if isComplexType(t) else 'd->ddd'
u, s, vt = gufunc(a, signature=signature, extobj=extobj)
- u = u.astype(result_t)
- s = s.astype(_realType(result_t))
- vt = vt.astype(result_t)
+ u = u.astype(result_t, copy=False)
+ s = s.astype(_realType(result_t), copy=False)
+ vt = vt.astype(result_t, copy=False)
return wrap(u), s, wrap(vt)
else:
if m < n:
@@ -1356,7 +1357,7 @@ def svd(a, full_matrices=1, compute_uv=1):
signature = 'D->d' if isComplexType(t) else 'd->d'
s = gufunc(a, signature=signature, extobj=extobj)
- s = s.astype(_realType(result_t))
+ s = s.astype(_realType(result_t), copy=False)
return s
def cond(x, p=None):
@@ -1695,7 +1696,15 @@ def slogdet(a):
real_t = _realType(result_t)
signature = 'D->Dd' if isComplexType(t) else 'd->dd'
sign, logdet = _umath_linalg.slogdet(a, signature=signature)
- return sign.astype(result_t), logdet.astype(real_t)
+ if isscalar(sign):
+ sign = sign.astype(result_t)
+ else:
+ sign = sign.astype(result_t, copy=False)
+ if isscalar(logdet):
+ logdet = logdet.astype(real_t)
+ else:
+ logdet = logdet.astype(real_t, copy=False)
+ return sign, logdet
def det(a):
"""
@@ -1749,7 +1758,12 @@ def det(a):
_assertNdSquareness(a)
t, result_t = _commonType(a)
signature = 'D->D' if isComplexType(t) else 'd->d'
- return _umath_linalg.det(a, signature=signature).astype(result_t)
+ r = _umath_linalg.det(a, signature=signature)
+ if isscalar(r):
+ r = r.astype(result_t)
+ else:
+ r = r.astype(result_t, copy=False)
+ return r
# Linear Least Squares
@@ -1905,12 +1919,12 @@ def lstsq(a, b, rcond=-1):
if results['rank'] == n and m > n:
if isComplexType(t):
resids = sum(abs(transpose(bstar)[n:,:])**2, axis=0).astype(
- result_real_t)
+ result_real_t, copy=False)
else:
resids = sum((transpose(bstar)[n:,:])**2, axis=0).astype(
- result_real_t)
+ result_real_t, copy=False)
- st = s[:min(n, m)].copy().astype(result_real_t)
+ st = s[:min(n, m)].astype(result_real_t, copy=True)
return wrap(x), wrap(resids), results['rank'], st
diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src
index a09d2c10a..60cada325 100644
--- a/numpy/linalg/umath_linalg.c.src
+++ b/numpy/linalg/umath_linalg.c.src
@@ -1128,6 +1128,7 @@ static void
npy_uint8 *tmp_buff = NULL;
size_t matrix_size;
size_t pivot_size;
+ size_t safe_m;
/* notes:
* matrix will need to be copied always, as factorization in lapack is
* made inplace
@@ -1138,8 +1139,9 @@ static void
*/
INIT_OUTER_LOOP_3
m = (fortran_int) dimensions[0];
- matrix_size = m*m*sizeof(@typ@);
- pivot_size = m*sizeof(fortran_int);
+ safe_m = m;
+ matrix_size = safe_m * safe_m * sizeof(@typ@);
+ pivot_size = safe_m * sizeof(fortran_int);
tmp_buff = (npy_uint8 *)malloc(matrix_size + pivot_size);
if (tmp_buff)
@@ -1172,6 +1174,7 @@ static void
npy_uint8 *tmp_buff;
size_t matrix_size;
size_t pivot_size;
+ size_t safe_m;
/* notes:
* matrix will need to be copied always, as factorization in lapack is
* made inplace
@@ -1182,8 +1185,9 @@ static void
*/
INIT_OUTER_LOOP_2
m = (fortran_int) dimensions[0];
- matrix_size = m*m*sizeof(@typ@);
- pivot_size = m*sizeof(fortran_int);
+ safe_m = m;
+ matrix_size = safe_m * safe_m * sizeof(@typ@);
+ pivot_size = safe_m * sizeof(fortran_int);
tmp_buff = (npy_uint8 *)malloc(matrix_size + pivot_size);
if (tmp_buff)
@@ -1252,14 +1256,15 @@ init_@lapack_func@(EIGH_PARAMS_t* params, char JOBZ, char UPLO,
fortran_int liwork = -1;
fortran_int info;
npy_uint8 *a, *w, *work, *iwork;
- size_t alloc_size = N*(N+1)*sizeof(@typ@);
+ size_t safe_N = N;
+ size_t alloc_size = safe_N * (safe_N + 1) * sizeof(@typ@);
mem_buff = malloc(alloc_size);
if (!mem_buff)
goto error;
a = mem_buff;
- w = mem_buff + N*N*sizeof(@typ@);
+ w = mem_buff + safe_N * safe_N * sizeof(@typ@);
LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N,
(@ftyp@*)a, &N, (@ftyp@*)w,
&query_work_size, &lwork,
@@ -1344,12 +1349,14 @@ init_@lapack_func@(EIGH_PARAMS_t *params,
fortran_int liwork = -1;
npy_uint8 *a, *w, *work, *rwork, *iwork;
fortran_int info;
+ size_t safe_N = N;
- mem_buff = malloc(N*N*sizeof(@typ@)+N*sizeof(@basetyp@));
+ mem_buff = malloc(safe_N * safe_N * sizeof(@typ@) +
+ safe_N * sizeof(@basetyp@));
if (!mem_buff)
goto error;
a = mem_buff;
- w = mem_buff+N*N*sizeof(@typ@);
+ w = mem_buff + safe_N * safe_N * sizeof(@typ@);
LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N,
(@ftyp@*)a, &N, (@fbasetyp@*)w,
@@ -1581,14 +1588,16 @@ init_@lapack_func@(GESV_PARAMS_t *params, fortran_int N, fortran_int NRHS)
{
npy_uint8 *mem_buff = NULL;
npy_uint8 *a, *b, *ipiv;
- mem_buff = malloc(N*N*sizeof(@ftyp@) +
- N*NRHS*sizeof(@ftyp@) +
- N*sizeof(fortran_int));
+ size_t safe_N = N;
+ size_t safe_NRHS = NRHS;
+ mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@) +
+ safe_N * safe_NRHS*sizeof(@ftyp@) +
+ safe_N * sizeof(fortran_int));
if (!mem_buff)
goto error;
a = mem_buff;
- b = a + N*N*sizeof(@ftyp@);
- ipiv = b + N*NRHS*sizeof(@ftyp@);
+ b = a + safe_N * safe_N * sizeof(@ftyp@);
+ ipiv = b + safe_N * safe_NRHS * sizeof(@ftyp@);
params->A = a;
params->B = b;
@@ -1759,8 +1768,9 @@ init_@lapack_func@(POTR_PARAMS_t *params, char UPLO, fortran_int N)
{
npy_uint8 *mem_buff = NULL;
npy_uint8 *a;
+ size_t safe_N = N;
- mem_buff = malloc(N*N*sizeof(@ftyp@));
+ mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@));
if (!mem_buff)
goto error;
@@ -1924,11 +1934,12 @@ init_@lapack_func@(GEEV_PARAMS_t *params, char jobvl, char jobvr, fortran_int n)
npy_uint8 *mem_buff=NULL;
npy_uint8 *mem_buff2=NULL;
npy_uint8 *a, *wr, *wi, *vlr, *vrr, *work, *w, *vl, *vr;
- size_t a_size = n*n*sizeof(@typ@);
- size_t wr_size = n*sizeof(@typ@);
- size_t wi_size = n*sizeof(@typ@);
- size_t vlr_size = jobvl=='V' ? n*n*sizeof(@typ@) : 0;
- size_t vrr_size = jobvr=='V' ? n*n*sizeof(@typ@) : 0;
+ size_t safe_n = n;
+ size_t a_size = safe_n * safe_n * sizeof(@typ@);
+ size_t wr_size = safe_n * sizeof(@typ@);
+ size_t wi_size = safe_n * sizeof(@typ@);
+ size_t vlr_size = jobvl=='V' ? safe_n * safe_n * sizeof(@typ@) : 0;
+ size_t vrr_size = jobvr=='V' ? safe_n * safe_n * sizeof(@typ@) : 0;
size_t w_size = wr_size*2;
size_t vl_size = vlr_size*2;
size_t vr_size = vrr_size*2;
@@ -2120,11 +2131,12 @@ init_@lapack_func@(GEEV_PARAMS_t* params,
npy_uint8 *mem_buff = NULL;
npy_uint8 *mem_buff2 = NULL;
npy_uint8 *a, *w, *vl, *vr, *work, *rwork;
- size_t a_size = n*n*sizeof(@ftyp@);
- size_t w_size = n*sizeof(@ftyp@);
- size_t vl_size = jobvl=='V'? n*n*sizeof(@ftyp@) : 0;
- size_t vr_size = jobvr=='V'? n*n*sizeof(@ftyp@) : 0;
- size_t rwork_size = 2*n*sizeof(@realtyp@);
+ size_t safe_n = n;
+ size_t a_size = safe_n * safe_n * sizeof(@ftyp@);
+ size_t w_size = safe_n * sizeof(@ftyp@);
+ size_t vl_size = jobvl=='V'? safe_n * safe_n * sizeof(@ftyp@) : 0;
+ size_t vr_size = jobvr=='V'? safe_n * safe_n * sizeof(@ftyp@) : 0;
+ size_t rwork_size = 2 * safe_n * sizeof(@realtyp@);
size_t work_count = 0;
@typ@ work_size_query;
fortran_int do_size_query = -1;
@@ -2446,20 +2458,27 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
npy_uint8 *mem_buff = NULL;
npy_uint8 *mem_buff2 = NULL;
npy_uint8 *a, *s, *u, *vt, *work, *iwork;
- size_t a_size = (size_t)m*(size_t)n*sizeof(@ftyp@);
+ size_t safe_m = m;
+ size_t safe_n = n;
+ size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
fortran_int min_m_n = m<n?m:n;
- size_t s_size = ((size_t)min_m_n)*sizeof(@ftyp@);
- fortran_int u_row_count, vt_column_count;
+ size_t safe_min_m_n = min_m_n;
+ size_t s_size = safe_min_m_n * sizeof(@ftyp@);
+ fortran_int u_row_count, vt_column_count;
+ size_t safe_u_row_count, safe_vt_column_count;
size_t u_size, vt_size;
fortran_int work_count;
size_t work_size;
- size_t iwork_size = 8*((size_t)min_m_n)*sizeof(fortran_int);
+ size_t iwork_size = 8 * safe_min_m_n * sizeof(fortran_int);
if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count))
goto error;
- u_size = ((size_t)u_row_count)*m*sizeof(@ftyp@);
- vt_size = n*((size_t)vt_column_count)*sizeof(@ftyp@);
+ safe_u_row_count = u_row_count;
+ safe_vt_column_count = vt_column_count;
+
+ u_size = safe_u_row_count * safe_m * sizeof(@ftyp@);
+ vt_size = safe_n * safe_vt_column_count * sizeof(@ftyp@);
mem_buff = malloc(a_size + s_size + u_size + vt_size + iwork_size);
@@ -2557,21 +2576,28 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
npy_uint8 *mem_buff = NULL, *mem_buff2 = NULL;
npy_uint8 *a,*s, *u, *vt, *work, *rwork, *iwork;
size_t a_size, s_size, u_size, vt_size, work_size, rwork_size, iwork_size;
+ size_t safe_u_row_count, safe_vt_column_count;
fortran_int u_row_count, vt_column_count, work_count;
+ size_t safe_m = m;
+ size_t safe_n = n;
fortran_int min_m_n = m<n?m:n;
+ size_t safe_min_m_n = min_m_n;
if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count))
goto error;
- a_size = ((size_t)m)*((size_t)n)*sizeof(@ftyp@);
- s_size = ((size_t)min_m_n)*sizeof(@frealtyp@);
- u_size = ((size_t)u_row_count)*m*sizeof(@ftyp@);
- vt_size = n*((size_t)vt_column_count)*sizeof(@ftyp@);
+ safe_u_row_count = u_row_count;
+ safe_vt_column_count = vt_column_count;
+
+ a_size = safe_m * safe_n * sizeof(@ftyp@);
+ s_size = safe_min_m_n * sizeof(@frealtyp@);
+ u_size = safe_u_row_count * safe_m * sizeof(@ftyp@);
+ vt_size = safe_n * safe_vt_column_count * sizeof(@ftyp@);
rwork_size = 'N'==jobz?
- 7*((size_t)min_m_n) :
- (5*(size_t)min_m_n*(size_t)min_m_n + 5*(size_t)min_m_n);
+ (7 * safe_min_m_n) :
+ (5*safe_min_m_n * safe_min_m_n + 5*safe_min_m_n);
rwork_size *= sizeof(@ftyp@);
- iwork_size = 8*((size_t)min_m_n)*sizeof(fortran_int);
+ iwork_size = 8 * safe_min_m_n* sizeof(fortran_int);
mem_buff = malloc(a_size +
s_size +
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 64a9844cf..3c10af24c 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -718,7 +718,7 @@ def median(a, axis=None, out=None, overwrite_input=False):
#..............................................................................
def compress_nd(x, axis=None):
"""Supress slices from multiple dimensions which contain masked values.
-
+
Parameters
----------
x : array_like, MaskedArray
@@ -731,29 +731,27 @@ def compress_nd(x, axis=None):
- If axis is a tuple of ints, those are the axes to supress slices from.
- If axis is an int, then that is the only axis to supress slices from.
- If axis is None, all axis are selected.
-
+
Returns
-------
compress_array : ndarray
- The compressed array.
- """
+ The compressed array.
+ """
x = asarray(x)
m = getmask(x)
# Set axis to tuple of ints
- if isinstance(axis, tuple):
- axis = tuple(ax % x.ndim for ax in axis)
- elif isinstance(axis, int):
- axis = (axis % x.ndim,)
+ if isinstance(axis, int):
+ axis = (axis,)
elif axis is None:
axis = tuple(range(x.ndim))
- else:
- raise ValueError('Invalid type for axis argument')
+ elif not isinstance(axis, tuple):
+ raise ValueError('Invalid type for axis argument')
# Check axis input
- for ax in axis:
- if not (0 <= ax < x.ndim):
- raise ValueError('axis %d is out of range' % ax)
- if not len(axis) == len(set(axis)):
- raise ValueError('axis cannot have dupliate entries')
+ axis = [ax + x.ndim if ax < 0 else ax for ax in axis]
+ if not all(0 <= ax < x.ndim for ax in axis):
+ raise ValueError("'axis' entry is out of bounds")
+ if len(axis) != len(set(axis)):
+ raise ValueError("duplicate value in 'axis'")
# Nothing is masked: return x
if m is nomask or not m.any():
return x._data
@@ -766,7 +764,7 @@ def compress_nd(x, axis=None):
axes = tuple(list(range(ax)) + list(range(ax + 1, x.ndim)))
data = data[(slice(None),)*ax + (~m.any(axis=axes),)]
return data
-
+
def compress_rowcols(x, axis=None):
"""
Suppress the rows and/or columns of a 2-D array that contain