summaryrefslogtreecommitdiff
path: root/numpy/lib/tests
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/tests')
-rw-r--r--numpy/lib/tests/test_function_base.py12
-rw-r--r--numpy/lib/tests/test_stride_tricks.py65
-rw-r--r--numpy/lib/tests/test_utils.py19
3 files changed, 94 insertions, 2 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 7bddb941c..4c7c0480c 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -2023,6 +2023,12 @@ class TestCorrCoef:
assert_array_almost_equal(c, np.array([[1., -1.], [-1., 1.]]))
assert_(np.all(np.abs(c) <= 1.0))
+ @pytest.mark.parametrize("test_type", [np.half, np.single, np.double, np.longdouble])
+ def test_corrcoef_dtype(self, test_type):
+ cast_A = self.A.astype(test_type)
+ res = corrcoef(cast_A, dtype=test_type)
+ assert test_type == res.dtype
+
class TestCov:
x1 = np.array([[0, 2], [1, 1], [2, 0]]).T
@@ -2123,6 +2129,12 @@ class TestCov:
aweights=self.unit_weights),
self.res1)
+ @pytest.mark.parametrize("test_type", [np.half, np.single, np.double, np.longdouble])
+ def test_cov_dtype(self, test_type):
+ cast_x1 = self.x1.astype(test_type)
+ res = cov(cast_x1, dtype=test_type)
+ assert test_type == res.dtype
+
class Test_I0:
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py
index 9d95eb9d0..10d7a19ab 100644
--- a/numpy/lib/tests/test_stride_tricks.py
+++ b/numpy/lib/tests/test_stride_tricks.py
@@ -5,7 +5,8 @@ from numpy.testing import (
assert_raises_regex, assert_warns,
)
from numpy.lib.stride_tricks import (
- as_strided, broadcast_arrays, _broadcast_shape, broadcast_to
+ as_strided, broadcast_arrays, _broadcast_shape, broadcast_to,
+ broadcast_shapes,
)
def assert_shapes_correct(input_shapes, expected_shape):
@@ -274,7 +275,9 @@ def test_broadcast_to_raises():
def test_broadcast_shape():
- # broadcast_shape is already exercized indirectly by broadcast_arrays
+ # tests internal _broadcast_shape
+ # _broadcast_shape is already exercised indirectly by broadcast_arrays
+ # _broadcast_shape is also exercised by the public broadcast_shapes function
assert_equal(_broadcast_shape(), ())
assert_equal(_broadcast_shape([1, 2]), (2,))
assert_equal(_broadcast_shape(np.ones((1, 1))), (1, 1))
@@ -288,6 +291,64 @@ def test_broadcast_shape():
assert_raises(ValueError, lambda: _broadcast_shape(*bad_args))
+def test_broadcast_shapes_succeeds():
+ # tests public broadcast_shapes
+ data = [
+ [[], ()],
+ [[()], ()],
+ [[(7,)], (7,)],
+ [[(1, 2), (2,)], (1, 2)],
+ [[(1, 1)], (1, 1)],
+ [[(1, 1), (3, 4)], (3, 4)],
+ [[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)],
+ [[(5, 6, 1)], (5, 6, 1)],
+ [[(1, 3), (3, 1)], (3, 3)],
+ [[(1, 0), (0, 0)], (0, 0)],
+ [[(0, 1), (0, 0)], (0, 0)],
+ [[(1, 0), (0, 1)], (0, 0)],
+ [[(1, 1), (0, 0)], (0, 0)],
+ [[(1, 1), (1, 0)], (1, 0)],
+ [[(1, 1), (0, 1)], (0, 1)],
+ [[(), (0,)], (0,)],
+ [[(0,), (0, 0)], (0, 0)],
+ [[(0,), (0, 1)], (0, 0)],
+ [[(1,), (0, 0)], (0, 0)],
+ [[(), (0, 0)], (0, 0)],
+ [[(1, 1), (0,)], (1, 0)],
+ [[(1,), (0, 1)], (0, 1)],
+ [[(1,), (1, 0)], (1, 0)],
+ [[(), (1, 0)], (1, 0)],
+ [[(), (0, 1)], (0, 1)],
+ [[(1,), (3,)], (3,)],
+ [[2, (3, 2)], (3, 2)],
+ ]
+ for input_shapes, target_shape in data:
+ assert_equal(broadcast_shapes(*input_shapes), target_shape)
+
+ assert_equal(broadcast_shapes(*([(1, 2)] * 32)), (1, 2))
+ assert_equal(broadcast_shapes(*([(1, 2)] * 100)), (1, 2))
+
+ # regression tests for gh-5862
+ assert_equal(broadcast_shapes(*([(2,)] * 32)), (2,))
+
+
+def test_broadcast_shapes_raises():
+ # tests public broadcast_shapes
+ data = [
+ [(3,), (4,)],
+ [(2, 3), (2,)],
+ [(3,), (3,), (4,)],
+ [(1, 3, 4), (2, 3, 3)],
+ [(1, 2), (3,1), (3,2), (10, 5)],
+ [2, (2, 3)],
+ ]
+ for input_shapes in data:
+ assert_raises(ValueError, lambda: broadcast_shapes(*input_shapes))
+
+ bad_args = [(2,)] * 32 + [(3,)] * 32
+ assert_raises(ValueError, lambda: broadcast_shapes(*bad_args))
+
+
def test_as_strided():
a = np.array([None])
a_view = as_strided(a)
diff --git a/numpy/lib/tests/test_utils.py b/numpy/lib/tests/test_utils.py
index 261cfef5d..33951b92a 100644
--- a/numpy/lib/tests/test_utils.py
+++ b/numpy/lib/tests/test_utils.py
@@ -140,3 +140,22 @@ class TestByteBounds:
def test_assert_raises_regex_context_manager():
with assert_raises_regex(ValueError, 'no deprecation warning'):
raise ValueError('no deprecation warning')
+
+
+def test_info_method_heading():
+ # info(class) should only print "Methods:" heading if methods exist
+
+ class NoPublicMethods:
+ pass
+
+ class WithPublicMethods:
+ def first_method():
+ pass
+
+ def _has_method_heading(cls):
+ out = StringIO()
+ utils.info(cls, output=out)
+ return 'Methods:' in out.getvalue()
+
+ assert _has_method_heading(WithPublicMethods)
+ assert not _has_method_heading(NoPublicMethods)