summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/function_base.py22
-rw-r--r--numpy/lib/tests/test_function_base.py77
2 files changed, 92 insertions, 7 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 7eeed7825..d2859d94d 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -2026,7 +2026,7 @@ class vectorize:
self.pyfunc = pyfunc
self.cache = cache
self.signature = signature
- self._ufunc = None # Caching to improve default performance
+ self._ufunc = {} # Caching to improve default performance
if doc is None:
self.__doc__ = pyfunc.__doc__
@@ -2091,14 +2091,22 @@ class vectorize:
if self.otypes is not None:
otypes = self.otypes
- nout = len(otypes)
- # Note logic here: We only *use* self._ufunc if func is self.pyfunc
- # even though we set self._ufunc regardless.
- if func is self.pyfunc and self._ufunc is not None:
- ufunc = self._ufunc
+ # self._ufunc is a dictionary whose keys are the number of
+ # arguments (i.e. len(args)) and whose values are ufuncs created
+ # by frompyfunc. len(args) can be different for different calls if
+ # self.pyfunc has parameters with default values. We only use the
+ # cache when func is self.pyfunc, which occurs when the call uses
+ # only positional arguments and no arguments are excluded.
+
+ nin = len(args)
+ nout = len(self.otypes)
+ if func is not self.pyfunc or nin not in self._ufunc:
+ ufunc = frompyfunc(func, nin, nout)
else:
- ufunc = self._ufunc = frompyfunc(func, len(args), nout)
+ ufunc = None # We'll get it from self._ufunc
+ if func is self.pyfunc:
+ ufunc = self._ufunc.setdefault(nin, ufunc)
else:
# Get number of outputs and output types by calling the function on
# the first entries of args. We also cache the result to prevent
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 23bf3296d..9470a355d 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1273,6 +1273,83 @@ class TestVectorize:
r2 = np.array([3, 4, 5])
assert_array_equal(r1, r2)
+ def test_keywords_with_otypes_order1(self):
+ # gh-1620: The second call of f would crash with
+ # `ValueError: invalid number of arguments`.
+
+ import math
+
+ def foo(x, y=1.0):
+ return y*math.floor(x)
+
+ f = vectorize(foo, otypes=[float])
+ r1 = f(np.arange(3.0), 1.0)
+ r2 = f(np.arange(3.0))
+ assert_array_equal(r1, r2)
+
+ def test_keywords_with_otypes_order2(self):
+ # gh-1620: The second call of f would crash with
+ # `ValueError: non-broadcastable output operand with shape ()
+ # doesn't match the broadcast shape (3,)`.
+
+ import math
+
+ def foo(x, y=1.0):
+ return y*math.floor(x)
+
+ f = vectorize(foo, otypes=[float])
+ r1 = f(np.arange(3.0))
+ r2 = f(np.arange(3.0), 1.0)
+ assert_array_equal(r1, r2)
+
+ def test_keywords_with_otypes_order3(self):
+ # gh-1620: The third call of f would crash with
+ # `ValueError: invalid number of arguments`.
+
+ import math
+
+ def foo(x, y=1.0):
+ return y*math.floor(x)
+
+ f = vectorize(foo, otypes=[float])
+ r1 = f(np.arange(3.0))
+ r2 = f(np.arange(3.0), y=1.0)
+ r3 = f(np.arange(3.0))
+ assert_array_equal(r1, r2)
+ assert_array_equal(r1, r3)
+
+ def test_keywords_with_otypes_several_kwd_args1(self):
+ # gh-1620 Make sure different uses of keyword arguments
+ # don't break the vectorized function.
+
+ import math
+
+ def foo(x, y=1.0, z=0.0):
+ return y*math.floor(x) + z
+
+ f = vectorize(foo, otypes=[float])
+ r1 = f(10.4, z=100)
+ r2 = f(10.4, y=-1)
+ r3 = f(10.4)
+ assert_equal(r1, foo(10.4, z=100))
+ assert_equal(r2, foo(10.4, y=-1))
+ assert_equal(r3, foo(10.4))
+
+ def test_keywords_with_otypes_several_kwd_args2(self):
+ # gh-1620 Make sure different uses of keyword arguments
+ # don't break the vectorized function.
+
+ import math
+
+ def foo(x, y=1.0, z=0.0):
+ return y*math.floor(x) + z
+
+ f = vectorize(foo, otypes=[float])
+ r1 = f(z=100, x=10.4, y=-1)
+ r2 = f(1, 2, 3)
+ assert_equal(r1, foo(z=100, x=10.4, y=-1))
+ assert_equal(r2, foo(1, 2, 3))
+
def test_keywords_no_func_code(self):
# This needs to test a function that has keywords but
# no func_code attribute, since otherwise vectorize will