summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py23
-rw-r--r--numpy/lib/tests/test_function_base.py2
2 files changed, 9 insertions, 16 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 3a73409fc..573516f3e 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -1323,21 +1323,15 @@ def piecewise(x, condlist, funclist, *args, **kw):
"""
x = asanyarray(x)
n2 = len(funclist)
- if (isscalar(condlist) or not (isinstance(condlist[0], list) or
- isinstance(condlist[0], ndarray))):
- if not isscalar(condlist) and x.size == 1 and x.ndim == 0:
- condlist = [[c] for c in condlist]
- else:
- condlist = [condlist]
+
+ # undocumented: single condition is promoted to a list of one condition
+ if isscalar(condlist) or (
+ not isinstance(condlist[0], (list, ndarray)) and x.ndim != 0):
+ condlist = [condlist]
+
condlist = array(condlist, dtype=bool)
n = len(condlist)
- # This is a hack to work around problems with NumPy's
- # handling of 0-d arrays and boolean indexing with
- # numpy.bool_ scalars
- zerod = False
- if x.ndim == 0:
- x = x[None]
- zerod = True
+
if n == n2 - 1: # compute the "otherwise" condition.
condelse = ~np.any(condlist, axis=0, keepdims=True)
condlist = np.concatenate([condlist, condelse], axis=0)
@@ -1352,8 +1346,7 @@ def piecewise(x, condlist, funclist, *args, **kw):
vals = x[condlist[k]]
if vals.size > 0:
y[condlist[k]] = item(vals, *args, **kw)
- if zerod:
- y = y.squeeze()
+
return y
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 39edc18b4..2a42c44e6 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -2538,7 +2538,7 @@ class TestPiecewise(object):
assert_(y == 0)
x = 5
- y = piecewise(x, [[True], [False]], [1, 0])
+ y = piecewise(x, [True, False], [1, 0])
assert_(y.ndim == 0)
assert_(y == 1)