diff options
author | Lars Buitinck <larsmans@gmail.com> | 2014-11-30 22:20:24 +0100 |
---|---|---|
committer | Lars Buitinck <larsmans@gmail.com> | 2014-11-30 22:20:24 +0100 |
commit | 24effb6b7a075e23d85ea0b60ed8a607fe218c14 (patch) | |
tree | 975ff344e21ca66b1834adcdf096cc56dfab7536 /numpy/lib | |
parent | 6ce98831797729d7fb8aa525ddda017aceffa5e3 (diff) | |
download | numpy-24effb6b7a075e23d85ea0b60ed8a607fe218c14.tar.gz |
ENH ensure np.packbits works on np.bool dtype
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/src/_compiled_base.c | 4 | ||||
-rw-r--r-- | numpy/lib/tests/test_packbits.py | 16 |
2 files changed, 12 insertions, 8 deletions
diff --git a/numpy/lib/src/_compiled_base.c b/numpy/lib/src/_compiled_base.c index 0cf034fca..99aec38f0 100644 --- a/numpy/lib/src/_compiled_base.c +++ b/numpy/lib/src/_compiled_base.c @@ -1376,9 +1376,9 @@ pack_bits(PyObject *input, int axis) if (inp == NULL) { return NULL; } - if (!PyArray_ISINTEGER(inp)) { + if (!PyArray_ISBOOL(inp) && !PyArray_ISINTEGER(inp)) { PyErr_SetString(PyExc_TypeError, - "Expected an input array of integer data type"); + "Expected an input array of integer or boolean data type"); goto fail; } diff --git a/numpy/lib/tests/test_packbits.py b/numpy/lib/tests/test_packbits.py index 5551de794..186e8960d 100644 --- a/numpy/lib/tests/test_packbits.py +++ b/numpy/lib/tests/test_packbits.py @@ -1,15 +1,19 @@ import numpy as np -from numpy.testing import assert_array_equal, assert_equal +from numpy.testing import assert_array_equal, assert_equal, assert_raises def test_packbits(): # Copied from the docstring. - a = np.array([[[1, 0, 1], [0, 1, 0]], - [[1, 1, 0], [0, 0, 1]]]) - b = np.packbits(a, axis=-1) - assert_equal(b.dtype, np.uint8) - assert_array_equal(b, np.array([[[160], [64]], [[192], [32]]])) + a = [[[1, 0, 1], [0, 1, 0]], + [[1, 1, 0], [0, 0, 1]]] + for dtype in [np.bool, np.uint8, np.int]: + arr = np.array(a, dtype=dtype) + b = np.packbits(arr, axis=-1) + assert_equal(b.dtype, np.uint8) + assert_array_equal(b, np.array([[[160], [64]], [[192], [32]]])) + + assert_raises(TypeError, np.packbits, np.array(a, dtype=float)) def test_unpackbits(): |