summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorLars Buitinck <larsmans@gmail.com>2014-11-30 22:20:24 +0100
committerLars Buitinck <larsmans@gmail.com>2014-11-30 22:20:24 +0100
commit24effb6b7a075e23d85ea0b60ed8a607fe218c14 (patch)
tree975ff344e21ca66b1834adcdf096cc56dfab7536 /numpy/lib
parent6ce98831797729d7fb8aa525ddda017aceffa5e3 (diff)
downloadnumpy-24effb6b7a075e23d85ea0b60ed8a607fe218c14.tar.gz
ENH ensure np.packbits works on np.bool dtype
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/src/_compiled_base.c4
-rw-r--r--numpy/lib/tests/test_packbits.py16
2 files changed, 12 insertions, 8 deletions
diff --git a/numpy/lib/src/_compiled_base.c b/numpy/lib/src/_compiled_base.c
index 0cf034fca..99aec38f0 100644
--- a/numpy/lib/src/_compiled_base.c
+++ b/numpy/lib/src/_compiled_base.c
@@ -1376,9 +1376,9 @@ pack_bits(PyObject *input, int axis)
if (inp == NULL) {
return NULL;
}
- if (!PyArray_ISINTEGER(inp)) {
+ if (!PyArray_ISBOOL(inp) && !PyArray_ISINTEGER(inp)) {
PyErr_SetString(PyExc_TypeError,
- "Expected an input array of integer data type");
+ "Expected an input array of integer or boolean data type");
goto fail;
}
diff --git a/numpy/lib/tests/test_packbits.py b/numpy/lib/tests/test_packbits.py
index 5551de794..186e8960d 100644
--- a/numpy/lib/tests/test_packbits.py
+++ b/numpy/lib/tests/test_packbits.py
@@ -1,15 +1,19 @@
import numpy as np
-from numpy.testing import assert_array_equal, assert_equal
+from numpy.testing import assert_array_equal, assert_equal, assert_raises
def test_packbits():
# Copied from the docstring.
- a = np.array([[[1, 0, 1], [0, 1, 0]],
- [[1, 1, 0], [0, 0, 1]]])
- b = np.packbits(a, axis=-1)
- assert_equal(b.dtype, np.uint8)
- assert_array_equal(b, np.array([[[160], [64]], [[192], [32]]]))
+ a = [[[1, 0, 1], [0, 1, 0]],
+ [[1, 1, 0], [0, 0, 1]]]
+ for dtype in [np.bool, np.uint8, np.int]:
+ arr = np.array(a, dtype=dtype)
+ b = np.packbits(arr, axis=-1)
+ assert_equal(b.dtype, np.uint8)
+ assert_array_equal(b, np.array([[[160], [64]], [[192], [32]]]))
+
+ assert_raises(TypeError, np.packbits, np.array(a, dtype=float))
def test_unpackbits():