summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsslivkoff <stormslivkoff@gmail.com>2020-04-29 14:08:27 -0700
committerGitHub <noreply@github.com>2020-04-30 00:08:27 +0300
commit4753652c605833e2398d7e46b754a71729a54713 (patch)
tree2d3e24c3d06a31053893cbae0b64b9401596617f
parent44e26f24621ce86ed6b24bfba4a04e82319d7bf0 (diff)
downloadnumpy-4753652c605833e2398d7e46b754a71729a54713.tar.gz
ENH: update numpy.linalg.multi_dot to accept an `out` argument (#15715)
* ENH: update numpy.linalg.multi_dot to accept an `out` argument * TST ensure value returned by numpy.linalg.multi_dot matches out * DOC add note about initial call to numpy.linalg.multi_dot * DOC add release note for #15715 Co-authored-by: Matti Picus <matti.picus@gmail.com> Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net> Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
-rw-r--r--doc/release/upcoming_changes/15715.new_feature.rst5
-rw-r--r--numpy/linalg/linalg.py36
-rw-r--r--numpy/linalg/tests/test_linalg.py35
3 files changed, 65 insertions, 11 deletions
diff --git a/doc/release/upcoming_changes/15715.new_feature.rst b/doc/release/upcoming_changes/15715.new_feature.rst
new file mode 100644
index 000000000..a568aee64
--- /dev/null
+++ b/doc/release/upcoming_changes/15715.new_feature.rst
@@ -0,0 +1,5 @@
+`numpy.linalg.multi_dot` now accepts an ``out`` argument
+--------------------------------------------------------
+
+``out`` can be used to avoid creating unnecessary copies of the final product
+computed by `numpy.linalg.multidot`.
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 6d3afdd49..5ee326f3c 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -2613,12 +2613,13 @@ def norm(x, ord=None, axis=None, keepdims=False):
# multi_dot
-def _multidot_dispatcher(arrays):
- return arrays
+def _multidot_dispatcher(arrays, *, out=None):
+ yield from arrays
+ yield out
@array_function_dispatch(_multidot_dispatcher)
-def multi_dot(arrays):
+def multi_dot(arrays, *, out=None):
"""
Compute the dot product of two or more arrays in a single function call,
while automatically selecting the fastest evaluation order.
@@ -2642,6 +2643,15 @@ def multi_dot(arrays):
If the first argument is 1-D it is treated as row vector.
If the last argument is 1-D it is treated as column vector.
The other arguments must be 2-D.
+ out : ndarray, optional
+ Output argument. This must have the exact kind that would be returned
+ if it was not used. In particular, it must have the right type, must be
+ C-contiguous, and its dtype must be the dtype that would be returned
+ for `dot(a, b)`. This is a performance feature. Therefore, if these
+ conditions are not met, an exception is raised, instead of attempting
+ to be flexible.
+
+ .. versionadded:: 1.19.0
Returns
-------
@@ -2699,7 +2709,7 @@ def multi_dot(arrays):
if n < 2:
raise ValueError("Expecting at least two arrays.")
elif n == 2:
- return dot(arrays[0], arrays[1])
+ return dot(arrays[0], arrays[1], out=out)
arrays = [asanyarray(a) for a in arrays]
@@ -2715,10 +2725,10 @@ def multi_dot(arrays):
# _multi_dot_three is much faster than _multi_dot_matrix_chain_order
if n == 3:
- result = _multi_dot_three(arrays[0], arrays[1], arrays[2])
+ result = _multi_dot_three(arrays[0], arrays[1], arrays[2], out=out)
else:
order = _multi_dot_matrix_chain_order(arrays)
- result = _multi_dot(arrays, order, 0, n - 1)
+ result = _multi_dot(arrays, order, 0, n - 1, out=out)
# return proper shape
if ndim_first == 1 and ndim_last == 1:
@@ -2729,7 +2739,7 @@ def multi_dot(arrays):
return result
-def _multi_dot_three(A, B, C):
+def _multi_dot_three(A, B, C, out=None):
"""
Find the best order for three arrays and do the multiplication.
@@ -2745,9 +2755,9 @@ def _multi_dot_three(A, B, C):
cost2 = a1b0 * c1 * (a0 + b1c0)
if cost1 < cost2:
- return dot(dot(A, B), C)
+ return dot(dot(A, B), C, out=out)
else:
- return dot(A, dot(B, C))
+ return dot(A, dot(B, C), out=out)
def _multi_dot_matrix_chain_order(arrays, return_costs=False):
@@ -2791,10 +2801,14 @@ def _multi_dot_matrix_chain_order(arrays, return_costs=False):
return (s, m) if return_costs else s
-def _multi_dot(arrays, order, i, j):
+def _multi_dot(arrays, order, i, j, out=None):
"""Actually do the multiplication with the given order."""
if i == j:
+ # the initial call with non-None out should never get here
+ assert out is None
+
return arrays[i]
else:
return dot(_multi_dot(arrays, order, i, order[i, j]),
- _multi_dot(arrays, order, order[i, j] + 1, j))
+ _multi_dot(arrays, order, order[i, j] + 1, j),
+ out=out)
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index dae4ef61e..3f3bf9f70 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -1930,6 +1930,41 @@ class TestMultiDot:
# the result should be a scalar
assert_equal(multi_dot([A1d, B, C, D1d]).shape, ())
+ def test_three_arguments_and_out(self):
+ # multi_dot with three arguments uses a fast hand coded algorithm to
+ # determine the optimal order. Therefore test it separately.
+ A = np.random.random((6, 2))
+ B = np.random.random((2, 6))
+ C = np.random.random((6, 2))
+
+ out = np.zeros((6, 2))
+ ret = multi_dot([A, B, C], out=out)
+ assert out is ret
+ assert_almost_equal(out, A.dot(B).dot(C))
+ assert_almost_equal(out, np.dot(A, np.dot(B, C)))
+
+ def test_two_arguments_and_out(self):
+ # separate code path with two arguments
+ A = np.random.random((6, 2))
+ B = np.random.random((2, 6))
+ out = np.zeros((6, 6))
+ ret = multi_dot([A, B], out=out)
+ assert out is ret
+ assert_almost_equal(out, A.dot(B))
+ assert_almost_equal(out, np.dot(A, B))
+
+ def test_dynamic_programing_optimization_and_out(self):
+ # multi_dot with four or more arguments uses the dynamic programing
+ # optimization and therefore deserve a separate test
+ A = np.random.random((6, 2))
+ B = np.random.random((2, 6))
+ C = np.random.random((6, 2))
+ D = np.random.random((2, 1))
+ out = np.zeros((6, 1))
+ ret = multi_dot([A, B, C, D], out=out)
+ assert out is ret
+ assert_almost_equal(out, A.dot(B).dot(C).dot(D))
+
def test_dynamic_programming_logic(self):
# Test for the dynamic programming part
# This test is directly taken from Cormen page 376.