summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorczgdp1807 <gdp.1807@gmail.com>2021-06-05 11:46:03 +0530
committerczgdp1807 <gdp.1807@gmail.com>2021-06-05 11:46:03 +0530
commitb6cd5b2efb6d099d1d9be01697143c59bb1491eb (patch)
tree6e648daf554c3b46c9b1e9bf7668a04c812eaba3 /numpy
parentc33cf14e144f8f3f8c226e81bbf693cdba7b3187 (diff)
downloadnumpy-b6cd5b2efb6d099d1d9be01697143c59bb1491eb.tar.gz
tests for stacked inputs added
Diffstat (limited to 'numpy')
-rw-r--r--numpy/linalg/tests/test_linalg.py61
1 files changed, 61 insertions, 0 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index c6e8cdd03..6e03f8b9b 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -1,6 +1,7 @@
""" Test functions for linalg module
"""
+from numpy.core.fromnumeric import shape
import os
import sys
import itertools
@@ -11,6 +12,7 @@ import pytest
import numpy as np
from numpy import array, single, double, csingle, cdouble, dot, identity, matmul
+from numpy.core import swapaxes
from numpy import multiply, atleast_2d, inf, asarray
from numpy import linalg
from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot, LinAlgError
@@ -1709,7 +1711,66 @@ class TestQR:
self.check_qr(m1)
self.check_qr(m2)
self.check_qr(m2.T)
+
+ def check_qr_stacked(self, a):
+ # This test expects the argument `a` to be an ndarray or
+ # a subclass of an ndarray of inexact type.
+ a_type = type(a)
+ a_dtype = a.dtype
+ m, n = a.shape[-2:]
+ k = min(m, n)
+
+ # mode == 'complete'
+ q, r = linalg.qr(a, mode='complete')
+ assert_(q.dtype == a_dtype)
+ assert_(r.dtype == a_dtype)
+ assert_(isinstance(q, a_type))
+ assert_(isinstance(r, a_type))
+ assert_(q.shape[-2:] == (m, m))
+ assert_(r.shape[-2:] == (m, n))
+ assert_almost_equal(matmul(q, r), a)
+ assert_almost_equal(swapaxes(q, -1, -2).conj(), np.linalg.inv(q))
+ assert_almost_equal(np.triu(r[..., :, :]), r)
+
+ # mode == 'reduced'
+ q1, r1 = linalg.qr(a, mode='reduced')
+ assert_(q1.dtype == a_dtype)
+ assert_(r1.dtype == a_dtype)
+ assert_(isinstance(q1, a_type))
+ assert_(isinstance(r1, a_type))
+ assert_(q1.shape[-2:] == (m, k))
+ assert_(r1.shape[-2:] == (k, n))
+ assert_almost_equal(matmul(q1, r1), a)
+ assert_almost_equal(swapaxes(q, -1, -2).conj(), np.linalg.inv(q))
+ assert_almost_equal(np.triu(r1[..., :, :]), r1)
+
+ # mode == 'r'
+ r2 = linalg.qr(a, mode='r')
+ assert_(r2.dtype == a_dtype)
+ assert_(isinstance(r2, a_type))
+ assert_almost_equal(r2, r1)
+ def test_stacked_inputs(self):
+
+ curr_state = np.random.get_state()
+ np.random.seed(0)
+
+ normal = np.random.normal
+ sizes = [(3, 4), (4, 3), (4, 4), (3, 0), (0, 3)]
+ dts = [np.float32, np.float64, np.complex64]
+ for size in sizes:
+ for dt in dts:
+ a1, a2, a3, a4 = [normal(size=size), normal(size=size),
+ normal(size=size), normal(size=size)]
+ b1, b2, b3, b4 = [normal(size=size), normal(size=size),
+ normal(size=size), normal(size=size)]
+ A = np.asarray([[a1, a2], [a3, a4]], dtype=dt)
+ B = np.asarray([[b1, b2], [b3, b4]], dtype=dt)
+ self.check_qr_stacked(A)
+ self.check_qr_stacked(B)
+ self.check_qr_stacked(A + 1.j*B)
+
+ np.random.set_state(curr_state)
class TestCholesky:
# TODO: are there no other tests for cholesky?