diff options
| author | czgdp1807 <gdp.1807@gmail.com> | 2021-06-05 11:46:03 +0530 |
|---|---|---|
| committer | czgdp1807 <gdp.1807@gmail.com> | 2021-06-05 11:46:03 +0530 |
| commit | b6cd5b2efb6d099d1d9be01697143c59bb1491eb (patch) | |
| tree | 6e648daf554c3b46c9b1e9bf7668a04c812eaba3 /numpy | |
| parent | c33cf14e144f8f3f8c226e81bbf693cdba7b3187 (diff) | |
| download | numpy-b6cd5b2efb6d099d1d9be01697143c59bb1491eb.tar.gz | |
tests for stacked inputs added
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/linalg/tests/test_linalg.py | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index c6e8cdd03..6e03f8b9b 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -1,6 +1,7 @@ """ Test functions for linalg module """ +from numpy.core.fromnumeric import shape import os import sys import itertools @@ -11,6 +12,7 @@ import pytest import numpy as np from numpy import array, single, double, csingle, cdouble, dot, identity, matmul +from numpy.core import swapaxes from numpy import multiply, atleast_2d, inf, asarray from numpy import linalg from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot, LinAlgError @@ -1709,7 +1711,66 @@ class TestQR: self.check_qr(m1) self.check_qr(m2) self.check_qr(m2.T) + + def check_qr_stacked(self, a): + # This test expects the argument `a` to be an ndarray or + # a subclass of an ndarray of inexact type. + a_type = type(a) + a_dtype = a.dtype + m, n = a.shape[-2:] + k = min(m, n) + + # mode == 'complete' + q, r = linalg.qr(a, mode='complete') + assert_(q.dtype == a_dtype) + assert_(r.dtype == a_dtype) + assert_(isinstance(q, a_type)) + assert_(isinstance(r, a_type)) + assert_(q.shape[-2:] == (m, m)) + assert_(r.shape[-2:] == (m, n)) + assert_almost_equal(matmul(q, r), a) + assert_almost_equal(swapaxes(q, -1, -2).conj(), np.linalg.inv(q)) + assert_almost_equal(np.triu(r[..., :, :]), r) + + # mode == 'reduced' + q1, r1 = linalg.qr(a, mode='reduced') + assert_(q1.dtype == a_dtype) + assert_(r1.dtype == a_dtype) + assert_(isinstance(q1, a_type)) + assert_(isinstance(r1, a_type)) + assert_(q1.shape[-2:] == (m, k)) + assert_(r1.shape[-2:] == (k, n)) + assert_almost_equal(matmul(q1, r1), a) + assert_almost_equal(swapaxes(q, -1, -2).conj(), np.linalg.inv(q)) + assert_almost_equal(np.triu(r1[..., :, :]), r1) + + # mode == 'r' + r2 = linalg.qr(a, mode='r') + assert_(r2.dtype == a_dtype) + assert_(isinstance(r2, a_type)) + assert_almost_equal(r2, r1) + def test_stacked_inputs(self): + + curr_state = np.random.get_state() + np.random.seed(0) + + normal = np.random.normal + sizes = [(3, 4), (4, 3), (4, 4), (3, 0), (0, 3)] + dts = [np.float32, np.float64, np.complex64] + for size in sizes: + for dt in dts: + a1, a2, a3, a4 = [normal(size=size), normal(size=size), + normal(size=size), normal(size=size)] + b1, b2, b3, b4 = [normal(size=size), normal(size=size), + normal(size=size), normal(size=size)] + A = np.asarray([[a1, a2], [a3, a4]], dtype=dt) + B = np.asarray([[b1, b2], [b3, b4]], dtype=dt) + self.check_qr_stacked(A) + self.check_qr_stacked(B) + self.check_qr_stacked(A + 1.j*B) + + np.random.set_state(curr_state) class TestCholesky: # TODO: are there no other tests for cholesky? |
