summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2012-11-12 21:57:40 -0800
committerCharles Harris <charlesr.harris@gmail.com>2012-11-12 21:57:40 -0800
commit780c1a6bb6d18bc21fed7fa251faffb9ea51e861 (patch)
tree4cfef9eaead7922cf19b6b2e0b16c63910925623
parentcb41689b6150b0d325755b469fba975ce105361f (diff)
parente31eb4bb807dc58f112abdf79056806cafac7f9c (diff)
downloadnumpy-780c1a6bb6d18bc21fed7fa251faffb9ea51e861.tar.gz
Merge pull request #2730 from leschef/faster_dot
ENH: Remove the need for temporary copies in numpy.dot
-rw-r--r--numpy/core/blasdot/_dotblas.c17
-rw-r--r--numpy/core/tests/test_blasdot.py62
2 files changed, 75 insertions, 4 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c
index c73dd6a47..29d781a35 100644
--- a/numpy/core/blasdot/_dotblas.c
+++ b/numpy/core/blasdot/_dotblas.c
@@ -769,8 +769,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
* We may be able to handle single-segment arrays here
* using appropriate values of Order, Trans1, and Trans2.
*/
-
- if (!PyArray_ISCONTIGUOUS(ap2)) {
+ if (!PyArray_IS_C_CONTIGUOUS(ap2) && !PyArray_IS_F_CONTIGUOUS(ap2)) {
PyObject *new = PyArray_Copy(ap2);
Py_DECREF(ap2);
@@ -779,7 +778,7 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
goto fail;
}
}
- if (!PyArray_ISCONTIGUOUS(ap1)) {
+ if (!PyArray_IS_C_CONTIGUOUS(ap1) && !PyArray_IS_F_CONTIGUOUS(ap1)) {
PyObject *new = PyArray_Copy(ap1);
Py_DECREF(ap1);
@@ -800,6 +799,18 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1);
ldb = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1);
ldc = (PyArray_DIM(ret, 1) > 1 ? PyArray_DIM(ret, 1) : 1);
+
+ /*
+ * Avoid temporary copies for arrays in Fortran order
+ */
+ if (PyArray_IS_F_CONTIGUOUS(ap1)) {
+ Trans1 = CblasTrans;
+ lda = (PyArray_DIM(ap1, 0) > 1 ? PyArray_DIM(ap1, 0) : 1);
+ }
+ if (PyArray_IS_F_CONTIGUOUS(ap2)) {
+ Trans2 = CblasTrans;
+ ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1);
+ }
if (typenum == NPY_DOUBLE) {
cblas_dgemm(Order, Trans1, Trans2,
L, N, M,
diff --git a/numpy/core/tests/test_blasdot.py b/numpy/core/tests/test_blasdot.py
index 73c3c4a05..3795def66 100644
--- a/numpy/core/tests/test_blasdot.py
+++ b/numpy/core/tests/test_blasdot.py
@@ -88,4 +88,64 @@ def test_dot_3args_errors():
r = np.empty((1024, 32), dtype=int)
assert_raises(ValueError, np.dot, f, v, r)
-
+def test_dot_array_order():
+ """ Test numpy dot with different order C, F
+
+ Comparing results with multiarray dot.
+ Double and single precisions array are compared using relative
+ precision of 7 and 5 decimals respectively.
+ Use 30 decimal when comparing exact operations like:
+ (a.b)' = b'.a'
+ """
+ _dot = np.core.multiarray.dot
+ a_dim, b_dim, c_dim = 10, 4, 7
+ orders = ["C", "F"]
+ dtypes_prec = {np.float64: 7, np.float32: 5}
+ np.random.seed(7)
+
+ for arr_type, prec in dtypes_prec.iteritems():
+ for a_order in orders:
+ a = np.asarray(np.random.randn(a_dim, a_dim),
+ dtype=arr_type, order=a_order)
+ assert_array_equal(np.dot(a, a), a.dot(a))
+ # (a.a)' = a'.a', note that mse~=1e-31 needs almost_equal
+ assert_almost_equal(a.dot(a), a.T.dot(a.T).T, decimal=30)
+
+ #
+ # Check with making explicit copy
+ #
+ a_T = a.T.copy(order=a_order)
+ assert_array_equal(a_T.dot(a_T), a.T.dot(a.T))
+ assert_array_equal(a.dot(a_T), a.dot(a.T))
+ assert_array_equal(a_T.dot(a), a.T.dot(a))
+
+ #
+ # Compare with multiarray dot
+ #
+ assert_almost_equal(a.dot(a), _dot(a, a), decimal=prec)
+ assert_almost_equal(a.T.dot(a), _dot(a.T, a), decimal=prec)
+ assert_almost_equal(a.dot(a.T), _dot(a, a.T), decimal=prec)
+ assert_almost_equal(a.T.dot(a.T), _dot(a.T, a.T), decimal=prec)
+ for res in a.dot(a), a.T.dot(a), a.dot(a.T), a.T.dot(a.T):
+ assert res.flags.c_contiguous
+
+ for b_order in orders:
+ b = np.asarray(np.random.randn(a_dim, b_dim),
+ dtype=arr_type, order=b_order)
+ b_T = b.T.copy(order=b_order)
+ assert_array_equal(a_T.dot(b), a.T.dot(b))
+ assert_array_equal(b_T.dot(a), b.T.dot(a))
+ # (b'.a)' = a'.b
+ assert_almost_equal(b.T.dot(a), a.T.dot(b).T, decimal=30)
+ assert_almost_equal(a.dot(b), _dot(a, b), decimal=prec)
+ assert_almost_equal(b.T.dot(a), _dot(b.T, a), decimal=prec)
+
+
+ for c_order in orders:
+ c = np.asarray(np.random.randn(b_dim, c_dim),
+ dtype=arr_type, order=c_order)
+ c_T = c.T.copy(order=c_order)
+ assert_array_equal(c.T.dot(b.T), c_T.dot(b_T))
+ assert_almost_equal(c.T.dot(b.T).T, b.dot(c), decimal=30)
+ assert_almost_equal(b.dot(c), _dot(b, c), decimal=prec)
+ assert_almost_equal(c.T.dot(b.T), _dot(c.T, b.T), decimal=prec)