diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/function_base.py | 23 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 2 |
2 files changed, 9 insertions, 16 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 3a73409fc..573516f3e 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -1323,21 +1323,15 @@ def piecewise(x, condlist, funclist, *args, **kw): """ x = asanyarray(x) n2 = len(funclist) - if (isscalar(condlist) or not (isinstance(condlist[0], list) or - isinstance(condlist[0], ndarray))): - if not isscalar(condlist) and x.size == 1 and x.ndim == 0: - condlist = [[c] for c in condlist] - else: - condlist = [condlist] + + # undocumented: single condition is promoted to a list of one condition + if isscalar(condlist) or ( + not isinstance(condlist[0], (list, ndarray)) and x.ndim != 0): + condlist = [condlist] + condlist = array(condlist, dtype=bool) n = len(condlist) - # This is a hack to work around problems with NumPy's - # handling of 0-d arrays and boolean indexing with - # numpy.bool_ scalars - zerod = False - if x.ndim == 0: - x = x[None] - zerod = True + if n == n2 - 1: # compute the "otherwise" condition. condelse = ~np.any(condlist, axis=0, keepdims=True) condlist = np.concatenate([condlist, condelse], axis=0) @@ -1352,8 +1346,7 @@ def piecewise(x, condlist, funclist, *args, **kw): vals = x[condlist[k]] if vals.size > 0: y[condlist[k]] = item(vals, *args, **kw) - if zerod: - y = y.squeeze() + return y diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 39edc18b4..2a42c44e6 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -2538,7 +2538,7 @@ class TestPiecewise(object): assert_(y == 0) x = 5 - y = piecewise(x, [[True], [False]], [1, 0]) + y = piecewise(x, [True, False], [1, 0]) assert_(y.ndim == 0) assert_(y == 1) |