diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2021-01-03 11:11:15 -0500 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2021-02-22 10:37:32 -0500 |
commit | aeae93b6c0042f6ed8f45205545985cc194f84f3 (patch) | |
tree | 404f3e80245dabae30b884cbf64b4821b8bc0451 | |
parent | 68e4d56eb9b539cccd582de7e7fb09c373d37609 (diff) | |
download | numpy-aeae93b6c0042f6ed8f45205545985cc194f84f3.tar.gz |
API: make piecewise subclass safe using use zeros_like.
Subclass input of piecewise was already respected, so it seems more
logical to ensure the output is the same subclass (possibly just an
oversight that it was not done before).
-rw-r--r-- | doc/release/upcoming_changes/18110.change.rst | 5 | ||||
-rw-r--r-- | numpy/lib/function_base.py | 4 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 8 |
3 files changed, 15 insertions, 2 deletions
diff --git a/doc/release/upcoming_changes/18110.change.rst b/doc/release/upcoming_changes/18110.change.rst new file mode 100644 index 000000000..7dbf8e5b7 --- /dev/null +++ b/doc/release/upcoming_changes/18110.change.rst @@ -0,0 +1,5 @@ +`numpy.piecewise` output class now matches the input class +---------------------------------------------------------- +When `numpy.ndarray` subclasses are used on input to `numpy.piecewise`, +they are passed on to the functions. The output will now be of the +same subclass as well. diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index d33a0fa7d..c6db42ce4 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -8,7 +8,7 @@ import numpy as np import numpy.core.numeric as _nx from numpy.core import transpose from numpy.core.numeric import ( - ones, zeros, arange, concatenate, array, asarray, asanyarray, empty, + ones, zeros_like, arange, concatenate, array, asarray, asanyarray, empty, ndarray, around, floor, ceil, take, dot, where, intp, integer, isscalar, absolute ) @@ -606,7 +606,7 @@ def piecewise(x, condlist, funclist, *args, **kw): .format(n, n, n+1) ) - y = zeros(x.shape, x.dtype) + y = zeros_like(x) for cond, func in zip(condlist, funclist): if not isinstance(func, collections.abc.Callable): y[cond] = func diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 4c7c0480c..afcb81eff 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -2399,6 +2399,14 @@ class TestPiecewise: assert_array_equal(y, np.array([[-1., -1., -1.], [3., 3., 1.]])) + def test_subclasses(self): + class subclass(np.ndarray): + pass + x = np.arange(5.).view(subclass) + r = piecewise(x, [x<2., x>=4], [-1., 1., 0.]) + assert_equal(type(r), subclass) + assert_equal(r, [-1., -1., 0., 0., 1.]) + class TestBincount: |