summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_arraypad.py
diff options
context:
space:
mode:
authorLars GrĂ¼ter <lagru@users.noreply.github.com>2018-11-15 23:23:26 +0100
committerMatti Picus <matti.picus@gmail.com>2018-11-15 14:23:26 -0800
commita4b96ad7649281de2c3a41292fcbab4c77c0743d (patch)
treedf39b13e69f51d79acd5a7e2647153a4c468c5ff /numpy/lib/tests/test_arraypad.py
parent7ada0c13a3e0d003670f421e8533cbb5388f705c (diff)
downloadnumpy-a4b96ad7649281de2c3a41292fcbab4c77c0743d.tar.gz
MAINT: Rewrite shape normalization in pad function (#11966)
Diffstat (limited to 'numpy/lib/tests/test_arraypad.py')
-rw-r--r--numpy/lib/tests/test_arraypad.py86
1 files changed, 86 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_arraypad.py b/numpy/lib/tests/test_arraypad.py
index e62fccaa0..20f6e4a1b 100644
--- a/numpy/lib/tests/test_arraypad.py
+++ b/numpy/lib/tests/test_arraypad.py
@@ -9,6 +9,91 @@ import numpy as np
from numpy.testing import (assert_array_equal, assert_raises, assert_allclose,
assert_equal)
from numpy.lib import pad
+from numpy.lib.arraypad import _as_pairs
+
+
+class TestAsPairs(object):
+
+ def test_single_value(self):
+ """Test casting for a single value."""
+ expected = np.array([[3, 3]] * 10)
+ for x in (3, [3], [[3]]):
+ result = _as_pairs(x, 10)
+ assert_equal(result, expected)
+ # Test with dtype=object
+ obj = object()
+ assert_equal(
+ _as_pairs(obj, 10),
+ np.array([[obj, obj]] * 10)
+ )
+
+ def test_two_values(self):
+ """Test proper casting for two different values."""
+ # Broadcasting in the first dimension with numbers
+ expected = np.array([[3, 4]] * 10)
+ for x in ([3, 4], [[3, 4]]):
+ result = _as_pairs(x, 10)
+ assert_equal(result, expected)
+ # and with dtype=object
+ obj = object()
+ assert_equal(
+ _as_pairs(["a", obj], 10),
+ np.array([["a", obj]] * 10)
+ )
+
+ # Broadcasting in the second / last dimension with numbers
+ assert_equal(
+ _as_pairs([[3], [4]], 2),
+ np.array([[3, 3], [4, 4]])
+ )
+ # and with dtype=object
+ assert_equal(
+ _as_pairs([["a"], [obj]], 2),
+ np.array([["a", "a"], [obj, obj]])
+ )
+
+ def test_with_none(self):
+ expected = ((None, None), (None, None), (None, None))
+ assert_equal(
+ _as_pairs(None, 3, as_index=False),
+ expected
+ )
+ assert_equal(
+ _as_pairs(None, 3, as_index=True),
+ expected
+ )
+
+ def test_pass_through(self):
+ """Test if `x` already matching desired output are passed through."""
+ expected = np.arange(12).reshape((6, 2))
+ assert_equal(
+ _as_pairs(expected, 6),
+ expected
+ )
+
+ def test_as_index(self):
+ """Test results if `as_index=True`."""
+ assert_equal(
+ _as_pairs([2.6, 3.3], 10, as_index=True),
+ np.array([[3, 3]] * 10, dtype=np.intp)
+ )
+ assert_equal(
+ _as_pairs([2.6, 4.49], 10, as_index=True),
+ np.array([[3, 4]] * 10, dtype=np.intp)
+ )
+ for x in (-3, [-3], [[-3]], [-3, 4], [3, -4], [[-3, 4]], [[4, -3]],
+ [[1, 2]] * 9 + [[1, -2]]):
+ with pytest.raises(ValueError, match="negative values"):
+ _as_pairs(x, 10, as_index=True)
+
+ def test_exceptions(self):
+ """Ensure faulty usage is discovered."""
+ with pytest.raises(ValueError, match="more dimensions than allowed"):
+ _as_pairs([[[3]]], 10)
+ with pytest.raises(ValueError, match="could not be broadcast"):
+ _as_pairs([[1, 2], [3, 4]], 3)
+ with pytest.raises(ValueError, match="could not be broadcast"):
+ _as_pairs(np.ones((2, 3)), 3)
class TestConditionalShortcuts(object):
@@ -535,6 +620,7 @@ class TestConstant(object):
assert_array_equal(arr, expected)
+
class TestLinearRamp(object):
def test_check_simple(self):
a = np.arange(100).astype('f')