diff options
author | sslivkoff <stormslivkoff@gmail.com> | 2020-04-29 14:08:27 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-04-30 00:08:27 +0300 |
commit | 4753652c605833e2398d7e46b754a71729a54713 (patch) | |
tree | 2d3e24c3d06a31053893cbae0b64b9401596617f /numpy | |
parent | 44e26f24621ce86ed6b24bfba4a04e82319d7bf0 (diff) | |
download | numpy-4753652c605833e2398d7e46b754a71729a54713.tar.gz |
ENH: update numpy.linalg.multi_dot to accept an `out` argument (#15715)
* ENH: update numpy.linalg.multi_dot to accept an `out` argument
* TST ensure value returned by numpy.linalg.multi_dot matches out
* DOC add note about initial call to numpy.linalg.multi_dot
* DOC add release note for #15715
Co-authored-by: Matti Picus <matti.picus@gmail.com>
Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 36 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 35 |
2 files changed, 60 insertions, 11 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 6d3afdd49..5ee326f3c 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -2613,12 +2613,13 @@ def norm(x, ord=None, axis=None, keepdims=False): # multi_dot -def _multidot_dispatcher(arrays): - return arrays +def _multidot_dispatcher(arrays, *, out=None): + yield from arrays + yield out @array_function_dispatch(_multidot_dispatcher) -def multi_dot(arrays): +def multi_dot(arrays, *, out=None): """ Compute the dot product of two or more arrays in a single function call, while automatically selecting the fastest evaluation order. @@ -2642,6 +2643,15 @@ def multi_dot(arrays): If the first argument is 1-D it is treated as row vector. If the last argument is 1-D it is treated as column vector. The other arguments must be 2-D. + out : ndarray, optional + Output argument. This must have the exact kind that would be returned + if it was not used. In particular, it must have the right type, must be + C-contiguous, and its dtype must be the dtype that would be returned + for `dot(a, b)`. This is a performance feature. Therefore, if these + conditions are not met, an exception is raised, instead of attempting + to be flexible. + + .. versionadded:: 1.19.0 Returns ------- @@ -2699,7 +2709,7 @@ def multi_dot(arrays): if n < 2: raise ValueError("Expecting at least two arrays.") elif n == 2: - return dot(arrays[0], arrays[1]) + return dot(arrays[0], arrays[1], out=out) arrays = [asanyarray(a) for a in arrays] @@ -2715,10 +2725,10 @@ def multi_dot(arrays): # _multi_dot_three is much faster than _multi_dot_matrix_chain_order if n == 3: - result = _multi_dot_three(arrays[0], arrays[1], arrays[2]) + result = _multi_dot_three(arrays[0], arrays[1], arrays[2], out=out) else: order = _multi_dot_matrix_chain_order(arrays) - result = _multi_dot(arrays, order, 0, n - 1) + result = _multi_dot(arrays, order, 0, n - 1, out=out) # return proper shape if ndim_first == 1 and ndim_last == 1: @@ -2729,7 +2739,7 @@ def multi_dot(arrays): return result -def _multi_dot_three(A, B, C): +def _multi_dot_three(A, B, C, out=None): """ Find the best order for three arrays and do the multiplication. @@ -2745,9 +2755,9 @@ def _multi_dot_three(A, B, C): cost2 = a1b0 * c1 * (a0 + b1c0) if cost1 < cost2: - return dot(dot(A, B), C) + return dot(dot(A, B), C, out=out) else: - return dot(A, dot(B, C)) + return dot(A, dot(B, C), out=out) def _multi_dot_matrix_chain_order(arrays, return_costs=False): @@ -2791,10 +2801,14 @@ def _multi_dot_matrix_chain_order(arrays, return_costs=False): return (s, m) if return_costs else s -def _multi_dot(arrays, order, i, j): +def _multi_dot(arrays, order, i, j, out=None): """Actually do the multiplication with the given order.""" if i == j: + # the initial call with non-None out should never get here + assert out is None + return arrays[i] else: return dot(_multi_dot(arrays, order, i, order[i, j]), - _multi_dot(arrays, order, order[i, j] + 1, j)) + _multi_dot(arrays, order, order[i, j] + 1, j), + out=out) diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index dae4ef61e..3f3bf9f70 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -1930,6 +1930,41 @@ class TestMultiDot: # the result should be a scalar assert_equal(multi_dot([A1d, B, C, D1d]).shape, ()) + def test_three_arguments_and_out(self): + # multi_dot with three arguments uses a fast hand coded algorithm to + # determine the optimal order. Therefore test it separately. + A = np.random.random((6, 2)) + B = np.random.random((2, 6)) + C = np.random.random((6, 2)) + + out = np.zeros((6, 2)) + ret = multi_dot([A, B, C], out=out) + assert out is ret + assert_almost_equal(out, A.dot(B).dot(C)) + assert_almost_equal(out, np.dot(A, np.dot(B, C))) + + def test_two_arguments_and_out(self): + # separate code path with two arguments + A = np.random.random((6, 2)) + B = np.random.random((2, 6)) + out = np.zeros((6, 6)) + ret = multi_dot([A, B], out=out) + assert out is ret + assert_almost_equal(out, A.dot(B)) + assert_almost_equal(out, np.dot(A, B)) + + def test_dynamic_programing_optimization_and_out(self): + # multi_dot with four or more arguments uses the dynamic programing + # optimization and therefore deserve a separate test + A = np.random.random((6, 2)) + B = np.random.random((2, 6)) + C = np.random.random((6, 2)) + D = np.random.random((2, 1)) + out = np.zeros((6, 1)) + ret = multi_dot([A, B, C, D], out=out) + assert out is ret + assert_almost_equal(out, A.dot(B).dot(C).dot(D)) + def test_dynamic_programming_logic(self): # Test for the dynamic programming part # This test is directly taken from Cormen page 376. |