diff options
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/lib/function_base.py | 58 | ||||
| -rw-r--r-- | numpy/lib/tests/test_function_base.py | 30 |
2 files changed, 76 insertions, 12 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 27ed79044..cba04765a 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -2119,8 +2119,8 @@ def _create_arrays(broadcast_shape, dim_sizes, list_of_core_dims, dtypes, @set_module('numpy') class vectorize: """ - vectorize(pyfunc, otypes=None, doc=None, excluded=None, cache=False, - signature=None) + vectorize(pyfunc=np._NoValue, otypes=None, doc=None, excluded=None, + cache=False, signature=None) Generalized function class. @@ -2136,8 +2136,9 @@ class vectorize: Parameters ---------- - pyfunc : callable + pyfunc : callable, optional A python function or method. + Can be omitted to produce a decorator with keyword arguments. otypes : str or list of dtypes, optional The output data type. It must be specified as either a string of typecode characters or a list of data type specifiers. There should @@ -2169,8 +2170,9 @@ class vectorize: Returns ------- - vectorized : callable - Vectorized function. + out : callable + A vectorized function if ``pyfunc`` was provided, + a decorator otherwise. See Also -------- @@ -2267,19 +2269,36 @@ class vectorize: [0., 0., 1., 2., 1., 0.], [0., 0., 0., 1., 2., 1.]]) + Decorator syntax is supported. The decorator can be called as + a function to provide keyword arguments. + >>>@np.vectorize + ...def identity(x): + ... return x + ... + >>>identity([0, 1, 2]) + array([0, 1, 2]) + >>>@np.vectorize(otypes=[float]) + ...def as_float(x): + ... return x + ... + >>>as_float([0, 1, 2]) + array([0., 1., 2.]) """ - def __init__(self, pyfunc, otypes=None, doc=None, excluded=None, - cache=False, signature=None): + def __init__(self, pyfunc=np._NoValue, otypes=None, doc=None, + excluded=None, cache=False, signature=None): self.pyfunc = pyfunc self.cache = cache self.signature = signature - self.__name__ = pyfunc.__name__ - self._ufunc = {} # Caching to improve default performance + if pyfunc != np._NoValue: + self.__name__ = pyfunc.__name__ + self._ufunc = {} # Caching to improve default performance + self._doc = None + self.__doc__ = doc if doc is None: self.__doc__ = pyfunc.__doc__ else: - self.__doc__ = doc + self._doc = doc if isinstance(otypes, str): for char in otypes: @@ -2301,7 +2320,15 @@ class vectorize: else: self._in_and_out_core_dims = None - def __call__(self, *args, **kwargs): + def _init_stage_2(self, pyfunc, *args, **kwargs): + self.__name__ = pyfunc.__name__ + self.pyfunc = pyfunc + if self._doc is None: + self.__doc__ = pyfunc.__doc__ + else: + self.__doc__ = self._doc + + def _call_as_normal(self, *args, **kwargs): """ Return arrays with the results of `pyfunc` broadcast (vectorized) over `args` and `kwargs` not in `excluded`. @@ -2331,6 +2358,13 @@ class vectorize: return self._vectorize_call(func=func, args=vargs) + def __call__(self, *args, **kwargs): + if self.pyfunc is np._NoValue: + self._init_stage_2(*args, **kwargs) + return self + + return self._call_as_normal(*args, **kwargs) + def _get_ufunc_and_otypes(self, func, args): """Return (ufunc, otypes).""" # frompyfunc will fail if args is empty @@ -2457,7 +2491,7 @@ class vectorize: if outputs is None: for result, core_dims in zip(results, output_core_dims): - _update_dim_sizes(dim_sizes, result, core_dims) + _update_dim_sizes(dim_sizes, result, core_dims) outputs = _create_arrays(broadcast_shape, dim_sizes, output_core_dims, otypes, results) diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 69e4c4848..169484a09 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -1788,6 +1788,36 @@ class TestVectorize: assert f2.__name__ == 'f2' + def test_decorator(self): + @vectorize + def addsubtract(a, b): + if a > b: + return a - b + else: + return a + b + + r = addsubtract([0, 3, 6, 9], [1, 3, 5, 7]) + assert_array_equal(r, [1, 6, 1, 2]) + + def test_signature_otypes_decorator(self): + @vectorize(signature='(n)->(n)', otypes=['float64']) + def f(x): + return x + + r = f([1, 2, 3]) + assert_equal(r.dtype, np.dtype('float64')) + assert_array_equal(r, [1, 2, 3]) + assert f.__name__ == 'f' + + def test_positional_regression_9477(self): + # This supplies the first keyword argument as a positional, + # to ensure that they are still properly forwarded after the + # enhancement for #9477 + f = vectorize((lambda x: x), ['float64']) + r = f([2]) + assert_equal(r.dtype, np.dtype('float64')) + + class TestLeaks: class A: iters = 20 |
