summaryrefslogtreecommitdiff
path: root/numpy/linalg/tests/test_gufuncs_linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/tests/test_gufuncs_linalg.py')
-rw-r--r--numpy/linalg/tests/test_gufuncs_linalg.py500
1 files changed, 0 insertions, 500 deletions
diff --git a/numpy/linalg/tests/test_gufuncs_linalg.py b/numpy/linalg/tests/test_gufuncs_linalg.py
deleted file mode 100644
index 40f8c4058..000000000
--- a/numpy/linalg/tests/test_gufuncs_linalg.py
+++ /dev/null
@@ -1,500 +0,0 @@
-"""
-Test functions for gufuncs_linalg module
-Heavily inspired (ripped in part) test_linalg
-"""
-from __future__ import division, print_function
-
-################################################################################
-# The following functions are implemented in the module "gufuncs_linalg"
-#
-# category "linalg"
-# - inv (TestInv)
-# - poinv (TestPoinv)
-# - det (TestDet)
-# - slogdet (TestDet)
-# - eig (TestEig)
-# - eigh (TestEigh)
-# - eigvals (TestEigvals)
-# - eigvalsh (TestEigvalsh)
-# - cholesky
-# - solve (TestSolve)
-# - chosolve (TestChosolve)
-# - svd (TestSVD)
-
-# ** unimplemented **
-# - qr
-# - matrix_power
-# - matrix_rank
-# - pinv
-# - lstsq
-# - tensorinv
-# - tensorsolve
-# - norm
-# - cond
-#
-# category "inspired by pdl"
-# - quadratic_form
-# - matrix_multiply3
-# - add3 (TestAdd3)
-# - multiply3 (TestMultiply3)
-# - multiply3_add (TestMultiply3Add)
-# - multiply_add (TestMultiplyAdd)
-# - multiply_add2 (TestMultiplyAdd2)
-# - multiply4 (TestMultiply4)
-# - multiply4_add (TestMultiply4Add)
-#
-# category "others"
-# - convolve
-# - inner1d
-# - innerwt
-# - matrix_multiply
-
-from nose.plugins.skip import Skip, SkipTest
-import numpy as np
-
-from numpy.testing import (TestCase, assert_, assert_equal, assert_raises,
- assert_array_equal, assert_almost_equal,
- run_module_suite)
-
-from numpy import array, single, double, csingle, cdouble, dot, identity
-from numpy import multiply, inf
-import numpy.linalg._gufuncs_linalg as gula
-
-old_assert_almost_equal = assert_almost_equal
-
-def assert_almost_equal(a, b, **kw):
- if a.dtype.type in (single, csingle):
- decimal = 5
- else:
- decimal = 10
- old_assert_almost_equal(a, b, decimal = decimal, **kw)
-
-
-def assert_valid_eigen_no_broadcast(M, w, v, **kw):
- lhs = gula.matrix_multiply(M, v)
- rhs = w*v
- assert_almost_equal(lhs, rhs, **kw)
-
-
-def assert_valid_eigen_recurse(M, w, v, **kw):
- """check that w and v are valid eigenvalues/eigenvectors for matrix M
- broadcast"""
- if len(M.shape) > 2:
- for i in range(M.shape[0]):
- assert_valid_eigen_recurse(M[i], w[i], v[i], **kw)
- else:
- if len(M.shape) == 2:
- assert_valid_eigen_no_broadcast(M, w, v, **kw)
- else:
- raise AssertionError('Not enough dimensions')
-
-
-def assert_valid_eigen(M, w, v, **kw):
- if np.any(np.isnan(M)):
- raise AssertionError('nan found in matrix')
- if np.any(np.isnan(w)):
- raise AssertionError('nan found in eigenvalues')
- if np.any(np.isnan(v)):
- raise AssertionError('nan found in eigenvectors')
-
- assert_valid_eigen_recurse(M, w, v, **kw)
-
-
-def assert_valid_eigenvals_no_broadcast(M, w, **kw):
- ident = np.eye(M.shape[0], dtype=M.dtype)
- for i in range(w.shape[0]):
- assert_almost_equal(gula.det(M - w[i]*ident), 0.0, **kw)
-
-
-def assert_valid_eigenvals_recurse(M, w, **kw):
- if len(M.shape) > 2:
- for i in range(M.shape[0]):
- assert_valid_eigenvals_recurse(M[i], w[i], **kw)
- else:
- if len(M.shape) == 2:
- assert_valid_eigenvals_no_broadcast(M, w, **kw)
- else:
- raise AssertionError('Not enough dimensions')
-
-
-def assert_valid_eigenvals(M, w, **kw):
- if np.any(np.isnan(M)):
- raise AssertionError('nan found in matrix')
- if np.any(np.isnan(w)):
- raise AssertionError('nan found in eigenvalues')
- assert_valid_eigenvals_recurse(M, w, **kw)
-
-
-class MatrixGenerator(object):
- def real_matrices(self):
- a = [[1, 2],
- [3, 4]]
-
- b = [[4, 3],
- [2, 1]]
-
- return a, b
-
- def real_symmetric_matrices(self):
- a = [[ 2, -1],
- [-1, 2]]
-
- b = [[4, 3],
- [2, 1]]
-
- return a, b
-
- def complex_matrices(self):
- a = [[1+2j, 2+3j],
- [3+4j, 4+5j]]
-
- b = [[4+3j, 3+2j],
- [2+1j, 1+0j]]
-
- return a, b
-
- def complex_hermitian_matrices(self):
- a = [[2, -1],
- [-1, 2]]
-
- b = [[4+3j, 3+2j],
- [2-1j, 1+0j]]
-
- return a, b
-
- def real_matrices_vector(self):
- a, b = self.real_matrices()
- return [a], [b]
-
- def real_symmetric_matrices_vector(self):
- a, b = self.real_symmetric_matrices()
- return [a], [b]
-
- def complex_matrices_vector(self):
- a, b = self.complex_matrices()
- return [a], [b]
-
- def complex_hermitian_matrices_vector(self):
- a, b = self.complex_hermitian_matrices()
- return [a], [b]
-
-
-class GeneralTestCase(MatrixGenerator):
- def test_single(self):
- a, b = self.real_matrices()
- self.do(array(a, dtype=single),
- array(b, dtype=single))
-
- def test_double(self):
- a, b = self.real_matrices()
- self.do(array(a, dtype=double),
- array(b, dtype=double))
-
- def test_csingle(self):
- a, b = self.complex_matrices()
- self.do(array(a, dtype=csingle),
- array(b, dtype=csingle))
-
- def test_cdouble(self):
- a, b = self.complex_matrices()
- self.do(array(a, dtype=cdouble),
- array(b, dtype=cdouble))
-
- def test_vector_single(self):
- a, b = self.real_matrices_vector()
- self.do(array(a, dtype=single),
- array(b, dtype=single))
-
- def test_vector_double(self):
- a, b = self.real_matrices_vector()
- self.do(array(a, dtype=double),
- array(b, dtype=double))
-
- def test_vector_csingle(self):
- a, b = self.complex_matrices_vector()
- self.do(array(a, dtype=csingle),
- array(b, dtype=csingle))
-
- def test_vector_cdouble(self):
- a, b = self.complex_matrices_vector()
- self.do(array(a, dtype=cdouble),
- array(b, dtype=cdouble))
-
-
-class HermitianTestCase(MatrixGenerator):
- def test_single(self):
- a, b = self.real_symmetric_matrices()
- self.do(array(a, dtype=single),
- array(b, dtype=single))
-
- def test_double(self):
- a, b = self.real_symmetric_matrices()
- self.do(array(a, dtype=double),
- array(b, dtype=double))
-
- def test_csingle(self):
- a, b = self.complex_hermitian_matrices()
- self.do(array(a, dtype=csingle),
- array(b, dtype=csingle))
-
- def test_cdouble(self):
- a, b = self.complex_hermitian_matrices()
- self.do(array(a, dtype=cdouble),
- array(b, dtype=cdouble))
-
- def test_vector_single(self):
- a, b = self.real_symmetric_matrices_vector()
- self.do(array(a, dtype=single),
- array(b, dtype=single))
-
- def test_vector_double(self):
- a, b = self.real_symmetric_matrices_vector()
- self.do(array(a, dtype=double),
- array(b, dtype=double))
-
- def test_vector_csingle(self):
- a, b = self.complex_hermitian_matrices_vector()
- self.do(array(a, dtype=csingle),
- array(b, dtype=csingle))
-
- def test_vector_cdouble(self):
- a, b = self.complex_hermitian_matrices_vector()
- self.do(array(a, dtype=cdouble),
- array(b, dtype=cdouble))
-
-
-class TestMatrixMultiply(GeneralTestCase):
- def do(self, a, b):
- res = gula.matrix_multiply(a, b)
- if a.ndim == 2:
- assert_almost_equal(res, np.dot(a, b))
- else:
- assert_almost_equal(res[0], np.dot(a[0], b[0]))
-
- def test_column_matrix(self):
- A = np.arange(2*2).reshape((2, 2))
- B = np.arange(2*1).reshape((2, 1))
- res = gula.matrix_multiply(A, B)
- assert_almost_equal(res, np.dot(A, B))
-
-class TestInv(GeneralTestCase, TestCase):
- def do(self, a, b):
- a_inv = gula.inv(a)
- ident = identity(a.shape[-1])
- if 3 == len(a.shape):
- ident = ident.reshape((1, ident.shape[0], ident.shape[1]))
- assert_almost_equal(gula.matrix_multiply(a, a_inv), ident)
-
-
-class TestPoinv(HermitianTestCase, TestCase):
- def do(self, a, b):
- a_inv = gula.poinv(a)
- ident = identity(a.shape[-1])
- if 3 == len(a.shape):
- ident = ident.reshape((1, ident.shape[0], ident.shape[1]))
-
- assert_almost_equal(a_inv, gula.inv(a))
- assert_almost_equal(gula.matrix_multiply(a, a_inv), ident)
-
-
-class TestDet(GeneralTestCase, TestCase):
- def do(self, a, b):
- d = gula.det(a)
- s, ld = gula.slogdet(a)
- assert_almost_equal(s * np.exp(ld), d)
-
- if np.csingle == a.dtype.type or np.single == a.dtype.type:
- cmp_type=np.csingle
- else:
- cmp_type=np.cdouble
-
- ev = gula.eigvals(a.astype(cmp_type))
- assert_almost_equal(d.astype(cmp_type),
- multiply.reduce(ev.astype(cmp_type),
- axis=(ev.ndim-1)))
- if s != 0:
- assert_almost_equal(np.abs(s), 1)
- else:
- assert_equal(ld, -inf)
-
- def test_zero(self):
- assert_equal(gula.det(array([[0.0]], dtype=single)), 0.0)
- assert_equal(gula.det(array([[0.0]], dtype=double)), 0.0)
- assert_equal(gula.det(array([[0.0]], dtype=csingle)), 0.0)
- assert_equal(gula.det(array([[0.0]], dtype=cdouble)), 0.0)
-
- assert_equal(gula.slogdet(array([[0.0]], dtype=single)), (0.0, -inf))
- assert_equal(gula.slogdet(array([[0.0]], dtype=double)), (0.0, -inf))
- assert_equal(gula.slogdet(array([[0.0]], dtype=csingle)), (0.0, -inf))
- assert_equal(gula.slogdet(array([[0.0]], dtype=cdouble)), (0.0, -inf))
-
- def test_types(self):
- for typ in [(single, single),
- (double, double),
- (csingle, single),
- (cdouble, double)]:
- for x in [ [0], [[0]], [[[0]]] ]:
- assert_equal(gula.det(array(x, dtype=typ[0])).dtype, typ[0])
- assert_equal(gula.slogdet(array(x, dtype=typ[0]))[0].dtype, typ[0])
- assert_equal(gula.slogdet(array(x, dtype=typ[0]))[1].dtype, typ[1])
-
-
-class TestEig(GeneralTestCase, TestCase):
- def do(self, a, b):
- evalues, evectors = gula.eig(a)
- assert_valid_eigenvals(a, evalues)
- assert_valid_eigen(a, evalues, evectors)
- ev = gula.eigvals(a)
- assert_valid_eigenvals(a, evalues)
- assert_almost_equal(ev, evalues)
-
-
-class TestEigh(HermitianTestCase, TestCase):
- def do(self, a, b):
- evalues_lo, evectors_lo = gula.eigh(a, UPLO='L')
- evalues_up, evectors_up = gula.eigh(a, UPLO='U')
-
- assert_valid_eigenvals(a, evalues_lo)
- assert_valid_eigenvals(a, evalues_up)
- assert_valid_eigen(a, evalues_lo, evectors_lo)
- assert_valid_eigen(a, evalues_up, evectors_up)
- assert_almost_equal(evalues_lo, evalues_up)
- assert_almost_equal(evectors_lo, evectors_up)
-
- ev_lo = gula.eigvalsh(a, UPLO='L')
- ev_up = gula.eigvalsh(a, UPLO='U')
- assert_valid_eigenvals(a, ev_lo)
- assert_valid_eigenvals(a, ev_up)
- assert_almost_equal(ev_lo, evalues_lo)
- assert_almost_equal(ev_up, evalues_up)
-
-
-class TestSolve(GeneralTestCase, TestCase):
- def do(self, a, b):
- x = gula.solve(a, b)
- assert_almost_equal(b, gula.matrix_multiply(a, x))
-
-
-class TestChosolve(HermitianTestCase, TestCase):
- def do(self, a, b):
- """
- inner1d not defined for complex types.
- todo: implement alternative test
- """
- if csingle == a.dtype or cdouble == a.dtype:
- raise SkipTest
-
- x_lo = gula.chosolve(a, b, UPLO='L')
- x_up = gula.chosolve(a, b, UPLO='U')
- assert_almost_equal(x_lo, x_up)
- # inner1d not defined for complex types
- # todo: implement alternative test
- assert_almost_equal(b, gula.matrix_multiply(a, x_lo))
- assert_almost_equal(b, gula.matrix_multiply(a, x_up))
-
-
-class TestSVD(GeneralTestCase, TestCase):
- def do(self, a, b):
- """ still work in progress """
- raise SkipTest
- u, s, vt = gula.svd(a, 0)
- assert_almost_equal(a, dot(multiply(u, s), vt))
-
-"""
-class TestCholesky(HermitianTestCase, TestCase):
- def do(self, a, b):
- pass
-"""
-
-################################################################################
-# ufuncs inspired by pdl
-# - add3
-# - multiply3
-# - multiply3_add
-# - multiply_add
-# - multiply_add2
-# - multiply4
-# - multiply4_add
-
-class UfuncTestCase(object):
- parameter = range(0, 10)
-
- def _check_for_type(self, typ):
- a = np.array(self.__class__.parameter, dtype=typ)
- self.do(a)
-
- def _check_for_type_vector(self, typ):
- parameter = self.__class__.parameter
- a = np.array([parameter, parameter], dtype=typ)
- self.do(a)
-
- def test_single(self):
- self._check_for_type(single)
-
- def test_double(self):
- self._check_for_type(double)
-
- def test_csingle(self):
- self._check_for_type(csingle)
-
- def test_cdouble(self):
- self._check_for_type(cdouble)
-
- def test_single_vector(self):
- self._check_for_type_vector(single)
-
- def test_double_vector(self):
- self._check_for_type_vector(double)
-
- def test_csingle_vector(self):
- self._check_for_type_vector(csingle)
-
- def test_cdouble_vector(self):
- self._check_for_type_vector(cdouble)
-
-
-class TestAdd3(UfuncTestCase, TestCase):
- def do(self, a):
- r = gula.add3(a, a, a)
- assert_almost_equal(r, a+a+a)
-
-
-class TestMultiply3(UfuncTestCase, TestCase):
- def do(self, a):
- r = gula.multiply3(a, a, a)
- assert_almost_equal(r, a*a*a)
-
-
-class TestMultiply3Add(UfuncTestCase, TestCase):
- def do(self, a):
- r = gula.multiply3_add(a, a, a, a)
- assert_almost_equal(r, a*a*a+a)
-
-
-class TestMultiplyAdd(UfuncTestCase, TestCase):
- def do(self, a):
- r = gula.multiply_add(a, a, a)
- assert_almost_equal(r, a*a+a)
-
-
-class TestMultiplyAdd2(UfuncTestCase, TestCase):
- def do(self, a):
- r = gula.multiply_add2(a, a, a, a)
- assert_almost_equal(r, a*a+a+a)
-
-
-class TestMultiply4(UfuncTestCase, TestCase):
- def do(self, a):
- r = gula.multiply4(a, a, a, a)
- assert_almost_equal(r, a*a*a*a)
-
-
-class TestMultiply4_add(UfuncTestCase, TestCase):
- def do(self, a):
- r = gula.multiply4_add(a, a, a, a, a)
- assert_almost_equal(r, a*a*a*a+a)
-
-
-if __name__ == "__main__":
- print('testing gufuncs_linalg; gufuncs version: %s' % gula._impl.__version__)
- run_module_suite()