summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py23
1 files changed, 8 insertions, 15 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 3a73409fc..573516f3e 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -1323,21 +1323,15 @@ def piecewise(x, condlist, funclist, *args, **kw):
"""
x = asanyarray(x)
n2 = len(funclist)
- if (isscalar(condlist) or not (isinstance(condlist[0], list) or
- isinstance(condlist[0], ndarray))):
- if not isscalar(condlist) and x.size == 1 and x.ndim == 0:
- condlist = [[c] for c in condlist]
- else:
- condlist = [condlist]
+
+ # undocumented: single condition is promoted to a list of one condition
+ if isscalar(condlist) or (
+ not isinstance(condlist[0], (list, ndarray)) and x.ndim != 0):
+ condlist = [condlist]
+
condlist = array(condlist, dtype=bool)
n = len(condlist)
- # This is a hack to work around problems with NumPy's
- # handling of 0-d arrays and boolean indexing with
- # numpy.bool_ scalars
- zerod = False
- if x.ndim == 0:
- x = x[None]
- zerod = True
+
if n == n2 - 1: # compute the "otherwise" condition.
condelse = ~np.any(condlist, axis=0, keepdims=True)
condlist = np.concatenate([condlist, condelse], axis=0)
@@ -1352,8 +1346,7 @@ def piecewise(x, condlist, funclist, *args, **kw):
vals = x[condlist[k]]
if vals.size > 0:
y[condlist[k]] = item(vals, *args, **kw)
- if zerod:
- y = y.squeeze()
+
return y