summaryrefslogtreecommitdiff
path: root/numpy/linalg/tests
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/tests')
-rw-r--r--numpy/linalg/tests/test_linalg.py90
1 files changed, 88 insertions, 2 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 34079bb87..63baa590d 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -12,7 +12,8 @@ import numpy as np
from numpy import array, single, double, csingle, cdouble, dot, identity
from numpy import multiply, atleast_2d, inf, asarray, matrix
from numpy import linalg
-from numpy.linalg import matrix_power, norm, matrix_rank
+from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot
+from numpy.linalg.linalg import _multi_dot_matrix_chain_order
from numpy.testing import (
assert_, assert_equal, assert_raises, assert_array_equal,
assert_almost_equal, assert_allclose, run_module_suite,
@@ -207,7 +208,7 @@ for tgt, src in ((GENERALIZED_SQUARE_CASES, SQUARE_CASES),
for case in src:
if not isinstance(case.a, np.ndarray):
continue
-
+
a = np.array([case.a, 2*case.a, 3*case.a])
if case.b is None:
b = None
@@ -1192,5 +1193,90 @@ def test_xerbla_override():
raise SkipTest('Numpy xerbla not linked in.')
+class TestMultiDot(object):
+ def test_basic_function_with_three_arguments(self):
+ # multi_dot with three arguments uses a fast hand coded algorithm to
+ # determine the optimal order. Therefore test it separately.
+ A = np.random.random((6, 2))
+ B = np.random.random((2, 6))
+ C = np.random.random((6, 2))
+
+ assert_almost_equal(multi_dot([A, B, C]), A.dot(B).dot(C))
+ assert_almost_equal(multi_dot([A, B, C]), np.dot(A, np.dot(B, C)))
+
+ def test_basic_function_with_dynamic_programing_optimization(self):
+ # multi_dot with four or more arguments uses the dynamic programing
+ # optimization and therefore deserve a separate
+ A = np.random.random((6, 2))
+ B = np.random.random((2, 6))
+ C = np.random.random((6, 2))
+ D = np.random.random((2, 1))
+ assert_almost_equal(multi_dot([A, B, C, D]), A.dot(B).dot(C).dot(D))
+
+ def test_vector_as_first_argument(self):
+ # The first argument can be 1-D
+ A1d = np.random.random(2) # 1-D
+ B = np.random.random((2, 6))
+ C = np.random.random((6, 2))
+ D = np.random.random((2, 2))
+
+ # the result should be 1-D
+ assert_equal(multi_dot([A1d, B, C, D]).shape, (2,))
+
+ def test_vector_as_last_argument(self):
+ # The last argument can be 1-D
+ A = np.random.random((6, 2))
+ B = np.random.random((2, 6))
+ C = np.random.random((6, 2))
+ D1d = np.random.random(2) # 1-D
+
+ # the result should be 1-D
+ assert_equal(multi_dot([A, B, C, D1d]).shape, (6,))
+
+ def test_vector_as_first_and_last_argument(self):
+ # The first and last arguments can be 1-D
+ A1d = np.random.random(2) # 1-D
+ B = np.random.random((2, 6))
+ C = np.random.random((6, 2))
+ D1d = np.random.random(2) # 1-D
+
+ # the result should be a scalar
+ assert_equal(multi_dot([A1d, B, C, D1d]).shape, ())
+
+ def test_dynamic_programming_logic(self):
+ # Test for the dynamic programming part
+ # This test is directly taken from Cormen page 376.
+ arrays = [np.random.random((30, 35)),
+ np.random.random((35, 15)),
+ np.random.random((15, 5)),
+ np.random.random((5, 10)),
+ np.random.random((10, 20)),
+ np.random.random((20, 25))]
+ m_expected = np.array([[0., 15750., 7875., 9375., 11875., 15125.],
+ [0., 0., 2625., 4375., 7125., 10500.],
+ [0., 0., 0., 750., 2500., 5375.],
+ [0., 0., 0., 0., 1000., 3500.],
+ [0., 0., 0., 0., 0., 5000.],
+ [0., 0., 0., 0., 0., 0.]])
+ s_expected = np.array([[0, 1, 1, 3, 3, 3],
+ [0, 0, 2, 3, 3, 3],
+ [0, 0, 0, 3, 3, 3],
+ [0, 0, 0, 0, 4, 5],
+ [0, 0, 0, 0, 0, 5],
+ [0, 0, 0, 0, 0, 0]], dtype=np.int)
+ s_expected -= 1 # Cormen uses 1-based index, python does not.
+
+ s, m = _multi_dot_matrix_chain_order(arrays, return_costs=True)
+
+ # Only the upper triangular part (without the diagonal) is interesting.
+ assert_almost_equal(np.triu(s[:-1, 1:]),
+ np.triu(s_expected[:-1, 1:]))
+ assert_almost_equal(np.triu(m), np.triu(m_expected))
+
+ def test_too_few_input_arrays(self):
+ assert_raises(ValueError, multi_dot, [])
+ assert_raises(ValueError, multi_dot, [np.random.random((3, 3))])
+
+
if __name__ == "__main__":
run_module_suite()