summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_packbits.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/tests/test_packbits.py')
-rw-r--r--numpy/lib/tests/test_packbits.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_packbits.py b/numpy/lib/tests/test_packbits.py
index fde5c37f2..00d5ca827 100644
--- a/numpy/lib/tests/test_packbits.py
+++ b/numpy/lib/tests/test_packbits.py
@@ -266,3 +266,66 @@ def test_unpackbits_large():
assert_array_equal(np.packbits(np.unpackbits(d, axis=1), axis=1), d)
d = d.T.copy()
assert_array_equal(np.packbits(np.unpackbits(d, axis=0), axis=0), d)
+
+
+def test_unpackbits_count():
+ # test complete invertibility of packbits and unpackbits with count
+ x = np.array([
+ [1, 0, 1, 0, 0, 1, 0],
+ [0, 1, 1, 1, 0, 0, 0],
+ [0, 0, 1, 0, 0, 1, 1],
+ [1, 1, 0, 0, 0, 1, 1],
+ [1, 0, 1, 0, 1, 0, 1],
+ [0, 0, 1, 1, 1, 0, 0],
+ [0, 1, 0, 1, 0, 1, 0],
+ ], dtype=np.uint8)
+
+ padded1 = np.zeros(57, dtype=np.uint8)
+ padded1[:49] = x.ravel()
+
+ packed = np.packbits(x)
+ for count in range(58):
+ unpacked = np.unpackbits(packed, count=count)
+ assert_equal(unpacked.dtype, np.uint8)
+ assert_array_equal(unpacked, padded1[:count])
+ for count in range(-1, -57, -1):
+ unpacked = np.unpackbits(packed, count=count)
+ assert_equal(unpacked.dtype, np.uint8)
+ # count -1 because padded1 has 57 instead of 56 elements
+ assert_array_equal(unpacked, padded1[:count-1])
+ for kwargs in [{}, {'count': None}]:
+ unpacked = np.unpackbits(packed, **kwargs)
+ assert_equal(unpacked.dtype, np.uint8)
+ assert_array_equal(unpacked, padded1[:-1])
+ assert_raises(ValueError, np.unpackbits, packed, count=-57)
+
+ padded2 = np.zeros((9, 9), dtype=np.uint8)
+ padded2[:7, :7] = x
+
+ packed0 = np.packbits(x, axis=0)
+ packed1 = np.packbits(x, axis=1)
+ for count in range(10):
+ unpacked0 = np.unpackbits(packed0, axis=0, count=count)
+ assert_equal(unpacked0.dtype, np.uint8)
+ assert_array_equal(unpacked0, padded2[:count, :x.shape[1]])
+ unpacked1 = np.unpackbits(packed1, axis=1, count=count)
+ assert_equal(unpacked1.dtype, np.uint8)
+ assert_array_equal(unpacked1, padded2[:x.shape[1], :count])
+ for count in range(-1, -9, -1):
+ unpacked0 = np.unpackbits(packed0, axis=0, count=count)
+ assert_equal(unpacked0.dtype, np.uint8)
+ # count -1 because one extra zero of padding
+ assert_array_equal(unpacked0, padded2[:count-1, :x.shape[1]])
+ unpacked1 = np.unpackbits(packed1, axis=1, count=count)
+ assert_equal(unpacked1.dtype, np.uint8)
+ assert_array_equal(unpacked1, padded2[:x.shape[0], :count-1])
+ for kwargs in [{}, {'count': None}]:
+ unpacked0 = np.unpackbits(packed0, axis=0, **kwargs)
+ assert_equal(unpacked0.dtype, np.uint8)
+ assert_array_equal(unpacked0, padded2[:-1, :x.shape[1]])
+ unpacked1 = np.unpackbits(packed1, axis=1, **kwargs)
+ assert_equal(unpacked1.dtype, np.uint8)
+ assert_array_equal(unpacked1, padded2[:x.shape[0], :-1])
+ assert_raises(ValueError, np.unpackbits, packed0, axis=0, count=-9)
+ assert_raises(ValueError, np.unpackbits, packed1, axis=1, count=-9)
+