summaryrefslogtreecommitdiff
path: root/numpy/linalg/tests
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2021-07-14 20:23:48 +0300
committerGitHub <noreply@github.com>2021-07-14 20:23:48 +0300
commita4e931f15e87cada0b9bb9ad80378115e72ba7ba (patch)
tree5b885d9f22a2f1eecda7067fa28e5f0bd7fedb67 /numpy/linalg/tests
parentf3533711c854d05e2d767e3e8373c882d4d9f3ae (diff)
parent6e405d53a504d6f97c8b1227d7f4d3c3c1aa2834 (diff)
downloadnumpy-a4e931f15e87cada0b9bb9ad80378115e72ba7ba.tar.gz
Merge pull request #19151 from czgdp1807/stack_mat
ENH: Vectorising np.linalg.qr
Diffstat (limited to 'numpy/linalg/tests')
-rw-r--r--numpy/linalg/tests/test_linalg.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index c6e8cdd03..4c54c0b53 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -1,6 +1,7 @@
""" Test functions for linalg module
"""
+from numpy.core.fromnumeric import shape
import os
import sys
import itertools
@@ -11,6 +12,7 @@ import pytest
import numpy as np
from numpy import array, single, double, csingle, cdouble, dot, identity, matmul
+from numpy.core import swapaxes
from numpy import multiply, atleast_2d, inf, asarray
from numpy import linalg
from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot, LinAlgError
@@ -1710,6 +1712,66 @@ class TestQR:
self.check_qr(m2)
self.check_qr(m2.T)
+ def check_qr_stacked(self, a):
+ # This test expects the argument `a` to be an ndarray or
+ # a subclass of an ndarray of inexact type.
+ a_type = type(a)
+ a_dtype = a.dtype
+ m, n = a.shape[-2:]
+ k = min(m, n)
+
+ # mode == 'complete'
+ q, r = linalg.qr(a, mode='complete')
+ assert_(q.dtype == a_dtype)
+ assert_(r.dtype == a_dtype)
+ assert_(isinstance(q, a_type))
+ assert_(isinstance(r, a_type))
+ assert_(q.shape[-2:] == (m, m))
+ assert_(r.shape[-2:] == (m, n))
+ assert_almost_equal(matmul(q, r), a)
+ I_mat = np.identity(q.shape[-1])
+ stack_I_mat = np.broadcast_to(I_mat,
+ q.shape[:-2] + (q.shape[-1],)*2)
+ assert_almost_equal(matmul(swapaxes(q, -1, -2).conj(), q), stack_I_mat)
+ assert_almost_equal(np.triu(r[..., :, :]), r)
+
+ # mode == 'reduced'
+ q1, r1 = linalg.qr(a, mode='reduced')
+ assert_(q1.dtype == a_dtype)
+ assert_(r1.dtype == a_dtype)
+ assert_(isinstance(q1, a_type))
+ assert_(isinstance(r1, a_type))
+ assert_(q1.shape[-2:] == (m, k))
+ assert_(r1.shape[-2:] == (k, n))
+ assert_almost_equal(matmul(q1, r1), a)
+ I_mat = np.identity(q1.shape[-1])
+ stack_I_mat = np.broadcast_to(I_mat,
+ q1.shape[:-2] + (q1.shape[-1],)*2)
+ assert_almost_equal(matmul(swapaxes(q1, -1, -2).conj(), q1),
+ stack_I_mat)
+ assert_almost_equal(np.triu(r1[..., :, :]), r1)
+
+ # mode == 'r'
+ r2 = linalg.qr(a, mode='r')
+ assert_(r2.dtype == a_dtype)
+ assert_(isinstance(r2, a_type))
+ assert_almost_equal(r2, r1)
+
+ @pytest.mark.parametrize("size", [
+ (3, 4), (4, 3), (4, 4),
+ (3, 0), (0, 3)])
+ @pytest.mark.parametrize("outer_size", [
+ (2, 2), (2,), (2, 3, 4)])
+ @pytest.mark.parametrize("dt", [
+ np.single, np.double,
+ np.csingle, np.cdouble])
+ def test_stacked_inputs(self, outer_size, size, dt):
+
+ A = np.random.normal(size=outer_size + size).astype(dt)
+ B = np.random.normal(size=outer_size + size).astype(dt)
+ self.check_qr_stacked(A)
+ self.check_qr_stacked(A + 1.j*B)
+
class TestCholesky:
# TODO: are there no other tests for cholesky?