summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/release/upcoming_changes/18110.change.rst5
-rw-r--r--numpy/lib/function_base.py4
-rw-r--r--numpy/lib/tests/test_function_base.py8
3 files changed, 15 insertions, 2 deletions
diff --git a/doc/release/upcoming_changes/18110.change.rst b/doc/release/upcoming_changes/18110.change.rst
new file mode 100644
index 000000000..7dbf8e5b7
--- /dev/null
+++ b/doc/release/upcoming_changes/18110.change.rst
@@ -0,0 +1,5 @@
+`numpy.piecewise` output class now matches the input class
+----------------------------------------------------------
+When `numpy.ndarray` subclasses are used on input to `numpy.piecewise`,
+they are passed on to the functions. The output will now be of the
+same subclass as well.
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index d33a0fa7d..c6db42ce4 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -8,7 +8,7 @@ import numpy as np
import numpy.core.numeric as _nx
from numpy.core import transpose
from numpy.core.numeric import (
- ones, zeros, arange, concatenate, array, asarray, asanyarray, empty,
+ ones, zeros_like, arange, concatenate, array, asarray, asanyarray, empty,
ndarray, around, floor, ceil, take, dot, where, intp,
integer, isscalar, absolute
)
@@ -606,7 +606,7 @@ def piecewise(x, condlist, funclist, *args, **kw):
.format(n, n, n+1)
)
- y = zeros(x.shape, x.dtype)
+ y = zeros_like(x)
for cond, func in zip(condlist, funclist):
if not isinstance(func, collections.abc.Callable):
y[cond] = func
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 4c7c0480c..afcb81eff 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -2399,6 +2399,14 @@ class TestPiecewise:
assert_array_equal(y, np.array([[-1., -1., -1.],
[3., 3., 1.]]))
+ def test_subclasses(self):
+ class subclass(np.ndarray):
+ pass
+ x = np.arange(5.).view(subclass)
+ r = piecewise(x, [x<2., x>=4], [-1., 1., 0.])
+ assert_equal(type(r), subclass)
+ assert_equal(r, [-1., -1., 0., 0., 1.])
+
class TestBincount: