summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaime <jaime.frio@gmail.com>2016-05-15 09:42:17 +0200
committerJaime <jaime.frio@gmail.com>2016-05-15 09:42:17 +0200
commite96ebfc233412e8da98ae96f5c640d00f5411f6e (patch)
tree6f47621ee6d8f785fd2a2989c91c9efadbd8cab2
parent730af6f800d7e9b57ae7a2233aa7ca563603c960 (diff)
parent49f70f7c3a447c8e0bc9ba0d1be7e4df34f04f58 (diff)
downloadnumpy-e96ebfc233412e8da98ae96f5c640d00f5411f6e.tar.gz
Merge pull request #7631 from perimosocordiae/patch-2
MAINT: linalg: fix comment, simplify math
-rw-r--r--numpy/linalg/linalg.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index f4acd0ef3..1f3ba7c54 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -2346,12 +2346,12 @@ def _multi_dot_three(A, B, C):
than `_multi_dot_matrix_chain_order`
"""
- # cost1 = cost((AB)C)
- cost1 = (A.shape[0] * A.shape[1] * B.shape[1] + # (AB)
- A.shape[0] * B.shape[1] * C.shape[1]) # (--)C
- # cost2 = cost((AB)C)
- cost2 = (B.shape[0] * B.shape[1] * C.shape[1] + # (BC)
- A.shape[0] * A.shape[1] * C.shape[1]) # A(--)
+ a0, a1b0 = A.shape
+ b1c0, c1 = C.shape
+ # cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1
+ cost1 = a0 * b1c0 * (a1b0 + c1)
+ # cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1
+ cost2 = a1b0 * c1 * (a0 + b1c0)
if cost1 < cost2:
return dot(dot(A, B), C)