diff options
author | Matti Picus <matti.picus@gmail.com> | 2021-07-14 20:23:48 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-14 20:23:48 +0300 |
commit | a4e931f15e87cada0b9bb9ad80378115e72ba7ba (patch) | |
tree | 5b885d9f22a2f1eecda7067fa28e5f0bd7fedb67 /numpy/linalg/tests | |
parent | f3533711c854d05e2d767e3e8373c882d4d9f3ae (diff) | |
parent | 6e405d53a504d6f97c8b1227d7f4d3c3c1aa2834 (diff) | |
download | numpy-a4e931f15e87cada0b9bb9ad80378115e72ba7ba.tar.gz |
Merge pull request #19151 from czgdp1807/stack_mat
ENH: Vectorising np.linalg.qr
Diffstat (limited to 'numpy/linalg/tests')
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index c6e8cdd03..4c54c0b53 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -1,6 +1,7 @@ """ Test functions for linalg module """ +from numpy.core.fromnumeric import shape import os import sys import itertools @@ -11,6 +12,7 @@ import pytest import numpy as np from numpy import array, single, double, csingle, cdouble, dot, identity, matmul +from numpy.core import swapaxes from numpy import multiply, atleast_2d, inf, asarray from numpy import linalg from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot, LinAlgError @@ -1710,6 +1712,66 @@ class TestQR: self.check_qr(m2) self.check_qr(m2.T) + def check_qr_stacked(self, a): + # This test expects the argument `a` to be an ndarray or + # a subclass of an ndarray of inexact type. + a_type = type(a) + a_dtype = a.dtype + m, n = a.shape[-2:] + k = min(m, n) + + # mode == 'complete' + q, r = linalg.qr(a, mode='complete') + assert_(q.dtype == a_dtype) + assert_(r.dtype == a_dtype) + assert_(isinstance(q, a_type)) + assert_(isinstance(r, a_type)) + assert_(q.shape[-2:] == (m, m)) + assert_(r.shape[-2:] == (m, n)) + assert_almost_equal(matmul(q, r), a) + I_mat = np.identity(q.shape[-1]) + stack_I_mat = np.broadcast_to(I_mat, + q.shape[:-2] + (q.shape[-1],)*2) + assert_almost_equal(matmul(swapaxes(q, -1, -2).conj(), q), stack_I_mat) + assert_almost_equal(np.triu(r[..., :, :]), r) + + # mode == 'reduced' + q1, r1 = linalg.qr(a, mode='reduced') + assert_(q1.dtype == a_dtype) + assert_(r1.dtype == a_dtype) + assert_(isinstance(q1, a_type)) + assert_(isinstance(r1, a_type)) + assert_(q1.shape[-2:] == (m, k)) + assert_(r1.shape[-2:] == (k, n)) + assert_almost_equal(matmul(q1, r1), a) + I_mat = np.identity(q1.shape[-1]) + stack_I_mat = np.broadcast_to(I_mat, + q1.shape[:-2] + (q1.shape[-1],)*2) + assert_almost_equal(matmul(swapaxes(q1, -1, -2).conj(), q1), + stack_I_mat) + assert_almost_equal(np.triu(r1[..., :, :]), r1) + + # mode == 'r' + r2 = linalg.qr(a, mode='r') + assert_(r2.dtype == a_dtype) + assert_(isinstance(r2, a_type)) + assert_almost_equal(r2, r1) + + @pytest.mark.parametrize("size", [ + (3, 4), (4, 3), (4, 4), + (3, 0), (0, 3)]) + @pytest.mark.parametrize("outer_size", [ + (2, 2), (2,), (2, 3, 4)]) + @pytest.mark.parametrize("dt", [ + np.single, np.double, + np.csingle, np.cdouble]) + def test_stacked_inputs(self, outer_size, size, dt): + + A = np.random.normal(size=outer_size + size).astype(dt) + B = np.random.normal(size=outer_size + size).astype(dt) + self.check_qr_stacked(A) + self.check_qr_stacked(A + 1.j*B) + class TestCholesky: # TODO: are there no other tests for cholesky? |