summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/records.py6
-rw-r--r--numpy/core/src/multiarray/cblasfuncs.c97
-rw-r--r--numpy/core/src/multiarray/iterators.c12
-rw-r--r--numpy/core/tests/test_indexing.py5
-rw-r--r--numpy/core/tests/test_numeric.py11
-rw-r--r--numpy/core/tests/test_records.py9
-rw-r--r--numpy/lib/stride_tricks.py3
7 files changed, 124 insertions, 19 deletions
diff --git a/numpy/core/records.py b/numpy/core/records.py
index ca6070cf7..9f5dcc811 100644
--- a/numpy/core/records.py
+++ b/numpy/core/records.py
@@ -425,7 +425,7 @@ class recarray(ndarray):
def __array_finalize__(self, obj):
if self.dtype.type is not record:
- # if self.dtype is not np.record, invoke __setattr__ which will
+ # if self.dtype is not np.record, invoke __setattr__ which will
# convert it to a record if it is a void dtype.
self.dtype = self.dtype
@@ -496,13 +496,13 @@ class recarray(ndarray):
return self.setfield(val, *res)
def __getitem__(self, indx):
- obj = ndarray.__getitem__(self, indx)
+ obj = super(recarray, self).__getitem__(indx)
# copy behavior of getattr, except that here
# we might also be returning a single element
if isinstance(obj, ndarray):
if obj.dtype.fields:
- obj = obj.view(recarray)
+ obj = obj.view(type(self))
if issubclass(obj.dtype.type, nt.void):
return obj.view(dtype=(self.dtype.type, obj.dtype))
return obj
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c
index 67f325ba1..516c6e8ae 100644
--- a/numpy/core/src/multiarray/cblasfuncs.c
+++ b/numpy/core/src/multiarray/cblasfuncs.c
@@ -111,6 +111,74 @@ gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
}
+/*
+ * Helper: dispatch to appropriate cblas_?syrk for typenum.
+ */
+static void
+syrk(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
+ int n, int k,
+ PyArrayObject *A, int lda, PyArrayObject *R)
+{
+ const void *Adata = PyArray_DATA(A);
+ void *Rdata = PyArray_DATA(R);
+ int ldc = PyArray_DIM(R, 1) > 1 ? PyArray_DIM(R, 1) : 1;
+
+ npy_intp i;
+ npy_intp j;
+
+ switch (typenum) {
+ case NPY_DOUBLE:
+ cblas_dsyrk(order, CblasUpper, trans, n, k, 1.,
+ Adata, lda, 0., Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_double*)PyArray_GETPTR2(R, j, i)) = *((npy_double*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ case NPY_FLOAT:
+ cblas_ssyrk(order, CblasUpper, trans, n, k, 1.f,
+ Adata, lda, 0.f, Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_float*)PyArray_GETPTR2(R, j, i)) = *((npy_float*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ case NPY_CDOUBLE:
+ cblas_zsyrk(order, CblasUpper, trans, n, k, oneD,
+ Adata, lda, zeroD, Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_cdouble*)PyArray_GETPTR2(R, j, i)) = *((npy_cdouble*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ case NPY_CFLOAT:
+ cblas_csyrk(order, CblasUpper, trans, n, k, oneF,
+ Adata, lda, zeroF, Rdata, ldc);
+
+ for (i = 0; i < n; i++)
+ {
+ for (j = i + 1; j < n; j++)
+ {
+ *((npy_cfloat*)PyArray_GETPTR2(R, j, i)) = *((npy_cfloat*)PyArray_GETPTR2(R, i, j));
+ }
+ }
+ break;
+ }
+}
+
+
typedef enum {_scalar, _column, _row, _matrix} MatrixShape;
@@ -647,7 +715,34 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
Trans2 = CblasTrans;
ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1);
}
- gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret);
+
+ /*
+ * Use syrk if we have a case of a matrix times its transpose.
+ * Otherwise, use gemm for all other cases.
+ */
+ if (
+ (PyArray_BYTES(ap1) == PyArray_BYTES(ap2)) &&
+ (PyArray_DIM(ap1, 0) == PyArray_DIM(ap2, 1)) &&
+ (PyArray_DIM(ap1, 1) == PyArray_DIM(ap2, 0)) &&
+ (PyArray_STRIDE(ap1, 0) == PyArray_STRIDE(ap2, 1)) &&
+ (PyArray_STRIDE(ap1, 1) == PyArray_STRIDE(ap2, 0)) &&
+ ((Trans1 == CblasTrans) ^ (Trans2 == CblasTrans)) &&
+ ((Trans1 == CblasNoTrans) ^ (Trans2 == CblasNoTrans))
+ )
+ {
+ if (Trans1 == CblasNoTrans)
+ {
+ syrk(typenum, Order, Trans1, N, M, ap1, lda, ret);
+ }
+ else
+ {
+ syrk(typenum, Order, Trans1, N, M, ap2, ldb, ret);
+ }
+ }
+ else
+ {
+ gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb, ret);
+ }
NPY_END_ALLOW_THREADS;
}
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c
index 702f9e21a..5099e3e19 100644
--- a/numpy/core/src/multiarray/iterators.c
+++ b/numpy/core/src/multiarray/iterators.c
@@ -1456,9 +1456,9 @@ PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...)
int i, ntot, err=0;
ntot = n + nadd;
- if (ntot < 2 || ntot > NPY_MAXARGS) {
+ if (ntot < 1 || ntot > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
- "Need at least 2 and at most %d "
+ "Need at least 1 and at most %d "
"array objects.", NPY_MAXARGS);
return NULL;
}
@@ -1522,9 +1522,9 @@ PyArray_MultiIterNew(int n, ...)
int i, err = 0;
- if (n < 2 || n > NPY_MAXARGS) {
+ if (n < 1 || n > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
- "Need at least 2 and at most %d "
+ "Need at least 1 and at most %d "
"array objects.", NPY_MAXARGS);
return NULL;
}
@@ -1603,12 +1603,12 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k
++n;
}
}
- if (n < 2 || n > NPY_MAXARGS) {
+ if (n < 1 || n > NPY_MAXARGS) {
if (PyErr_Occurred()) {
return NULL;
}
PyErr_Format(PyExc_ValueError,
- "Need at least 2 and at most %d "
+ "Need at least 1 and at most %d "
"array objects.", NPY_MAXARGS);
return NULL;
}
diff --git a/numpy/core/tests/test_indexing.py b/numpy/core/tests/test_indexing.py
index 38280d05e..deb2130b7 100644
--- a/numpy/core/tests/test_indexing.py
+++ b/numpy/core/tests/test_indexing.py
@@ -895,10 +895,7 @@ class TestMultiIndexingAutomated(TestCase):
+ arr.shape[ax + len(indx[1:]):]))
# Check if broadcasting works
- if len(indx[1:]) != 1:
- res = np.broadcast(*indx[1:]) # raises ValueError...
- else:
- res = indx[1]
+ res = np.broadcast(*indx[1:])
# unfortunately the indices might be out of bounds. So check
# that first, and use mode='wrap' then. However only if
# there are any indices...
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index b7e146b5a..d63118080 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2207,11 +2207,20 @@ class TestBroadcast(TestCase):
for a, ia in zip(arrs, mit.iters):
assert_(a is ia.base)
+ def test_broadcast_single_arg(self):
+ # gh-6899
+ arrs = [np.empty((5, 6, 7))]
+ mit = np.broadcast(*arrs)
+ assert_equal(mit.shape, (5, 6, 7))
+ assert_equal(mit.nd, 3)
+ assert_equal(mit.numiter, 1)
+ assert_(arrs[0] is mit.iters[0].base)
+
def test_number_of_arguments(self):
arr = np.empty((5,))
for j in range(35):
arrs = [arr] * j
- if j < 2 or j > 32:
+ if j < 1 or j > 32:
assert_raises(ValueError, np.broadcast, *arrs)
else:
mit = np.broadcast(*arrs)
diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py
index e0f0a3a8f..9fbdf51d6 100644
--- a/numpy/core/tests/test_records.py
+++ b/numpy/core/tests/test_records.py
@@ -122,13 +122,20 @@ class TestFromrecords(TestCase):
assert_equal(rv.dtype.type, np.record)
#check that getitem also preserves np.recarray and np.record
- r = np.rec.array(np.ones(4, dtype=[('a', 'i4'), ('b', 'i4'),
+ r = np.rec.array(np.ones(4, dtype=[('a', 'i4'), ('b', 'i4'),
('c', 'i4,i4')]))
assert_equal(r['c'].dtype.type, np.record)
assert_equal(type(r['c']), np.recarray)
assert_equal(r[['a', 'b']].dtype.type, np.record)
assert_equal(type(r[['a', 'b']]), np.recarray)
+ #and that it preserves subclasses (gh-6949)
+ class C(np.recarray):
+ pass
+
+ c = r.view(C)
+ assert_equal(type(c['c']), C)
+
# check that accessing nested structures keep record type, but
# not for subarrays, non-void structures, non-structured voids
test_dtype = [('a', 'f4,f4'), ('b', 'V8'), ('c', ('f4',2)),
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py
index f4b43a5a9..4c23ab355 100644
--- a/numpy/lib/stride_tricks.py
+++ b/numpy/lib/stride_tricks.py
@@ -121,9 +121,6 @@ def _broadcast_shape(*args):
"""
if not args:
raise ValueError('must provide at least one argument')
- if len(args) == 1:
- # a single argument does not work with np.broadcast
- return np.asarray(args[0]).shape
# use the old-iterator because np.nditer does not handle size 0 arrays
# consistently
b = np.broadcast(*args[:32])