diff options
author | Stefan Otte <stefan.otte@gmail.com> | 2014-08-20 11:40:02 +0200 |
---|---|---|
committer | Stefan Otte <stefan.otte@gmail.com> | 2014-11-10 11:56:43 +0100 |
commit | 1b12c394548e2d23bba83c0eccda958a28998293 (patch) | |
tree | 6cb0d45d3af6437b66a9b630da217f2286a98558 /numpy | |
parent | 96714918d64ebf64e0e133a385da061408e4a03b (diff) | |
download | numpy-1b12c394548e2d23bba83c0eccda958a28998293.tar.gz |
ENH: add `multi_dot`: dot with multiple arguments.
`np.linalg.multi_dot` computes the dot product of two or more arrays in
a single function call, while automatically selecting the fastest evaluation
order.
The algorithm for selecting the fastest evaluation order uses dynamic
programming and closely follows:
Cormen, "Introduction to Algorithms", Chapter 15.2, p. 370-378
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 187 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 90 |
2 files changed, 273 insertions, 4 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index e35c9ac97..43b2dc4dc 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -14,7 +14,7 @@ from __future__ import division, absolute_import, print_function __all__ = ['matrix_power', 'solve', 'tensorsolve', 'tensorinv', 'inv', 'cholesky', 'eigvals', 'eigvalsh', 'pinv', 'slogdet', 'det', 'svd', 'eig', 'eigh', 'lstsq', 'norm', 'qr', 'cond', 'matrix_rank', - 'LinAlgError'] + 'LinAlgError', 'multi_dot'] import warnings @@ -23,7 +23,7 @@ from numpy.core import ( csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot, add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size, finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product, abs, - broadcast + broadcast, atleast_2d, intp, asanyarray ) from numpy.lib import triu, asfarray from numpy.linalg import lapack_lite, _umath_linalg @@ -2153,3 +2153,186 @@ def norm(x, ord=None, axis=None, keepdims=False): return ret else: raise ValueError("Improper number of dimensions to norm.") + + +# multi_dot + +def multi_dot(arrays): + """ + Compute the dot product of two or more arrays in a single function call, + while automatically selecting the fastest evaluation order. + + `multi_dot` chains `numpy.dot` and uses an optimal parenthesizations + of the matrices [1]_ [2]_. Depending on the shape of the matrices + this can speed up the multiplication a lot. + + The first and last argument can be 1-D and are treated respectively as + row and column vector. The other arguments must be 2-D. + + Think of `multi_dot` as:: + + def multi_dot(arrays): return functools.reduce(np.dot, arrays) + + + Parameters + ---------- + arrays : sequence of array_like + First and last argument can be 1-D and are treated respectively as + row and column vector, the other arguments must be 2-D. + + Returns + ------- + output : ndarray + Returns the dot product of the supplied arrays. + + See Also + -------- + dot : dot multiplication with two arguments. + + References + ---------- + + .. [1] Cormen, "Introduction to Algorithms", Chapter 15.2, p. 370-378 + .. [2] http://en.wikipedia.org/wiki/Matrix_chain_multiplication + + Examples + -------- + `multi_dot` allows you to write:: + + >>> import numpy as np + >>> # Prepare some data + >>> A = np.random.random(10000, 100) + >>> B = np.random.random(100, 1000) + >>> C = np.random.random(1000, 5) + >>> D = np.random.random(5, 333) + >>> # the actual dot multiplication + >>> multi_dot([A, B, C, D]) + + instead of:: + + >>> np.dot(np.dot(np.dot(A, B), C), D) + >>> # or + >>> A.dot(B).dot(C).dot(D) + + + Example: multiplication costs of different parenthesizations + ------------------------------------------------------------ + + The cost for a matrix multiplication can be calculated with the + following function:: + + def cost(A, B): return A.shape[0] * A.shape[1] * B.shape[1] + + Let's assume we have three matrices + :math:`A_{10x100}, B_{100x5}, C_{5x50}$`. + + The costs for the two different parenthesizations are as follows:: + + cost((AB)C) = 10*100*5 + 10*5*50 = 5000 + 2500 = 7500 + cost(A(BC)) = 10*100*50 + 100*5*50 = 50000 + 25000 = 75000 + + """ + n = len(arrays) + # optimization only makes sense for len(arrays) > 2 + if n < 2: + raise ValueError("Expecting at least two arrays.") + elif n == 2: + return dot(arrays[0], arrays[1]) + + arrays = [asanyarray(a) for a in arrays] + + # save original ndim to reshape the result array into the proper form later + ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim + # Explicitly convert vectors to 2D arrays to keep the logic of the internal + # _multi_dot_* functions as simple as possible. + if arrays[0].ndim == 1: + arrays[0] = atleast_2d(arrays[0]) + if arrays[-1].ndim == 1: + arrays[-1] = atleast_2d(arrays[-1]).T + _assertRank2(*arrays) + + # _multi_dot_three is much faster than _multi_dot_matrix_chain_order + if n == 3: + result = _multi_dot_three(arrays[0], arrays[1], arrays[2]) + else: + order = _multi_dot_matrix_chain_order(arrays) + result = _multi_dot(arrays, order, 0, n - 1) + + # return proper shape + if ndim_first == 1 and ndim_last == 1: + return result[0, 0] # scalar + elif ndim_first == 1 or ndim_last == 1: + return result.ravel() # 1-D + else: + return result + + +def _multi_dot_three(A, B, C): + """ + Find best ordering for three arrays and do the multiplication. + + Doing in manually instead of using dynamic programing is + approximately 15 times faster. + + """ + # cost1 = cost((AB)C) + cost1 = (A.shape[0] * A.shape[1] * B.shape[1] + # (AB) + A.shape[0] * B.shape[1] * C.shape[1]) # (--)C + # cost2 = cost((AB)C) + cost2 = (B.shape[0] * B.shape[1] * C.shape[1] + # (BC) + A.shape[0] * A.shape[1] * C.shape[1]) # A(--) + + if cost1 < cost2: + return dot(dot(A, B), C) + else: + return dot(A, dot(B, C)) + + +def _multi_dot_matrix_chain_order(arrays, return_costs=False): + """ + Return a np.array which encodes the opimal order of mutiplications. + + The optimal order array is then used by `_multi_dot()` to do the + multiplication. + + Also return the cost matrix if `return_costs` is `True` + + The implementation CLOSELY follows Cormen, "Introduction to Algorithms", + Chapter 15.2, p. 370-378. Note that Cormen uses 1-based indices. + + cost[i, j] = min([ + cost[prefix] + cost[suffix] + cost_mult(prefix, suffix) + for k in range(i, j)]) + + """ + n = len(arrays) + # p stores the dimensions of the matrices + # Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50] + p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]] + # m is a matrix of costs of the subproblems + # m[i,j]: min number of scalar multiplications needed to compute A_{i..j} + m = zeros((n, n), dtype=double) + # s is the actual ordering + # s[i, j] is the value of k at which we split the product A_i..A_j + s = empty((n, n), dtype=intp) + + for l in range(1, n): + for i in range(n - l): + j = i + l + m[i, j] = Inf + for k in range(i, j): + q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1] + if q < m[i, j]: + m[i, j] = q + s[i, j] = k # Note that Cormen uses 1-based index + + return (s, m) if return_costs else s + + +def _multi_dot(arrays, order, i, j): + """Actually do the multiplication with the given order.""" + if i == j: + return arrays[i] + else: + return dot(_multi_dot(arrays, order, i, order[i, j]), + _multi_dot(arrays, order, order[i, j] + 1, j)) diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 34079bb87..63baa590d 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -12,7 +12,8 @@ import numpy as np from numpy import array, single, double, csingle, cdouble, dot, identity from numpy import multiply, atleast_2d, inf, asarray, matrix from numpy import linalg -from numpy.linalg import matrix_power, norm, matrix_rank +from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot +from numpy.linalg.linalg import _multi_dot_matrix_chain_order from numpy.testing import ( assert_, assert_equal, assert_raises, assert_array_equal, assert_almost_equal, assert_allclose, run_module_suite, @@ -207,7 +208,7 @@ for tgt, src in ((GENERALIZED_SQUARE_CASES, SQUARE_CASES), for case in src: if not isinstance(case.a, np.ndarray): continue - + a = np.array([case.a, 2*case.a, 3*case.a]) if case.b is None: b = None @@ -1192,5 +1193,90 @@ def test_xerbla_override(): raise SkipTest('Numpy xerbla not linked in.') +class TestMultiDot(object): + def test_basic_function_with_three_arguments(self): + # multi_dot with three arguments uses a fast hand coded algorithm to + # determine the optimal order. Therefore test it separately. + A = np.random.random((6, 2)) + B = np.random.random((2, 6)) + C = np.random.random((6, 2)) + + assert_almost_equal(multi_dot([A, B, C]), A.dot(B).dot(C)) + assert_almost_equal(multi_dot([A, B, C]), np.dot(A, np.dot(B, C))) + + def test_basic_function_with_dynamic_programing_optimization(self): + # multi_dot with four or more arguments uses the dynamic programing + # optimization and therefore deserve a separate + A = np.random.random((6, 2)) + B = np.random.random((2, 6)) + C = np.random.random((6, 2)) + D = np.random.random((2, 1)) + assert_almost_equal(multi_dot([A, B, C, D]), A.dot(B).dot(C).dot(D)) + + def test_vector_as_first_argument(self): + # The first argument can be 1-D + A1d = np.random.random(2) # 1-D + B = np.random.random((2, 6)) + C = np.random.random((6, 2)) + D = np.random.random((2, 2)) + + # the result should be 1-D + assert_equal(multi_dot([A1d, B, C, D]).shape, (2,)) + + def test_vector_as_last_argument(self): + # The last argument can be 1-D + A = np.random.random((6, 2)) + B = np.random.random((2, 6)) + C = np.random.random((6, 2)) + D1d = np.random.random(2) # 1-D + + # the result should be 1-D + assert_equal(multi_dot([A, B, C, D1d]).shape, (6,)) + + def test_vector_as_first_and_last_argument(self): + # The first and last arguments can be 1-D + A1d = np.random.random(2) # 1-D + B = np.random.random((2, 6)) + C = np.random.random((6, 2)) + D1d = np.random.random(2) # 1-D + + # the result should be a scalar + assert_equal(multi_dot([A1d, B, C, D1d]).shape, ()) + + def test_dynamic_programming_logic(self): + # Test for the dynamic programming part + # This test is directly taken from Cormen page 376. + arrays = [np.random.random((30, 35)), + np.random.random((35, 15)), + np.random.random((15, 5)), + np.random.random((5, 10)), + np.random.random((10, 20)), + np.random.random((20, 25))] + m_expected = np.array([[0., 15750., 7875., 9375., 11875., 15125.], + [0., 0., 2625., 4375., 7125., 10500.], + [0., 0., 0., 750., 2500., 5375.], + [0., 0., 0., 0., 1000., 3500.], + [0., 0., 0., 0., 0., 5000.], + [0., 0., 0., 0., 0., 0.]]) + s_expected = np.array([[0, 1, 1, 3, 3, 3], + [0, 0, 2, 3, 3, 3], + [0, 0, 0, 3, 3, 3], + [0, 0, 0, 0, 4, 5], + [0, 0, 0, 0, 0, 5], + [0, 0, 0, 0, 0, 0]], dtype=np.int) + s_expected -= 1 # Cormen uses 1-based index, python does not. + + s, m = _multi_dot_matrix_chain_order(arrays, return_costs=True) + + # Only the upper triangular part (without the diagonal) is interesting. + assert_almost_equal(np.triu(s[:-1, 1:]), + np.triu(s_expected[:-1, 1:])) + assert_almost_equal(np.triu(m), np.triu(m_expected)) + + def test_too_few_input_arrays(self): + assert_raises(ValueError, multi_dot, []) + assert_raises(ValueError, multi_dot, [np.random.random((3, 3))]) + + if __name__ == "__main__": run_module_suite() |