diff options
author | Stefan van der Walt <stefan@sun.ac.za> | 2008-07-08 08:24:37 +0000 |
---|---|---|
committer | Stefan van der Walt <stefan@sun.ac.za> | 2008-07-08 08:24:37 +0000 |
commit | 048bc867ad2ba31dbd784eb6432a492a65686510 (patch) | |
tree | b16176d8225de0f4c1321d36ab519e167d5e090b /numpy/lib | |
parent | 757b1fbfd996c969eb4e76d6949a6ae242ddb3ae (diff) | |
download | numpy-048bc867ad2ba31dbd784eb6432a492a65686510.tar.gz |
Piecewise should not expose raw memory. Closes #798.
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/function_base.py | 19 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 46 |
2 files changed, 54 insertions, 11 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index e8df0b439..2204f1863 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -563,8 +563,11 @@ def piecewise(x, condlist, funclist, *args, **kw): """ x = asanyarray(x) n2 = len(funclist) - if not isinstance(condlist, type([])): + if isscalar(condlist) or \ + not (isinstance(condlist[0], list) or + isinstance(condlist[0], ndarray)): condlist = [condlist] + condlist = [asarray(c, dtype=bool) for c in condlist] n = len(condlist) if n == n2-1: # compute the "otherwise" condition. totlist = condlist[0] @@ -573,10 +576,11 @@ def piecewise(x, condlist, funclist, *args, **kw): condlist.append(~totlist) n += 1 if (n != n2): - raise ValueError, "function list and condition list must be the same" + raise ValueError, "function list and condition list " \ + "must be the same" zerod = False - # This is a hack to work around problems with NumPy's - # handling of 0-d arrays and boolean indexing with + # This is a hack to work around problems with NumPy's + # handling of 0-d arrays and boolean indexing with # numpy.bool_ scalars if x.ndim == 0: x = x[None] @@ -589,7 +593,8 @@ def piecewise(x, condlist, funclist, *args, **kw): condition = condlist[k] newcondlist.append(condition) condlist = newcondlist - y = empty(x.shape, x.dtype) + + y = zeros(x.shape, x.dtype) for k in range(n): item = funclist[k] if not callable(item): @@ -1090,7 +1095,7 @@ class vectorize(object): self.__doc__ = pyfunc.__doc__ else: self.__doc__ = doc - if isinstance(otypes, types.StringType): + if isinstance(otypes, str): self.otypes = otypes for char in self.otypes: if char not in typecodes['All']: @@ -1121,7 +1126,7 @@ class vectorize(object): for arg in args: newargs.append(asarray(arg).flat[0]) theout = self.thefunc(*newargs) - if isinstance(theout, types.TupleType): + if isinstance(theout, tuple): self.nout = len(theout) else: self.nout = 1 diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 0be1a22b7..f4d6c009d 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -615,16 +615,54 @@ class TestUnique(TestCase): x = array([5+6j, 1+1j, 1+10j, 10, 5+6j]) assert(all(unique(x) == [1+1j, 1+10j, 5+6j, 10])) -def compare_results(res,desired): - for i in range(len(desired)): - assert_array_equal(res[i],desired[i]) -class TestPiecewise(TestCase): +class TestPiecewise(NumpyTestCase): + def check_simple(self): + # Condition is single bool list + x = piecewise([0, 0], [True, False], [1]) + assert_array_equal(x, [1, 0]) + + # List of conditions: single bool list + x = piecewise([0, 0], [[True, False]], [1]) + assert_array_equal(x, [1, 0]) + + # Conditions is single bool array + x = piecewise([0, 0], array([True, False]), [1]) + assert_array_equal(x, [1, 0]) + + # Condition is single int array + x = piecewise([0, 0], array([1, 0]), [1]) + assert_array_equal(x, [1, 0]) + + # List of conditions: int array + x = piecewise([0, 0], [array([1, 0])], [1]) + assert_array_equal(x, [1, 0]) + + + x = piecewise([0, 0], [[False, True]], [lambda x: -1]) + assert_array_equal(x, [0, -1]) + + x = piecewise([1, 2], [[True, False], [False, True]], [3, 4]) + assert_array_equal(x, [3, 4]) + + def check_default(self): + # No value specified for x[1], should be 0 + x = piecewise([1, 2], [True, False], [2]) + assert_array_equal(x, [2, 0]) + + # Should set x[1] to 3 + x = piecewise([1, 2], [True, False], [2, 3]) + assert_array_equal(x, [2, 3]) + def test_0d(self): x = array(3) y = piecewise(x, x>3, [4, 0]) assert y.ndim == 0 assert y == 0 +def compare_results(res,desired): + for i in range(len(desired)): + assert_array_equal(res[i],desired[i]) + if __name__ == "__main__": run_module_suite() |