summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorMatteo Raso <matteo_luigi_raso@protonmail.com>2023-02-08 20:39:55 -0500
committerMatteo Raso <matteo_luigi_raso@protonmail.com>2023-02-08 20:39:55 -0500
commit7a2ded1522305cfbab4e34a18198f3cbcae7755c (patch)
treea35eb8a086354d49e4249ad959622373bc940e87 /numpy/lib
parent84596aeec1682de93d82c60b53726838b29ad311 (diff)
downloadnumpy-7a2ded1522305cfbab4e34a18198f3cbcae7755c.tar.gz
Added a test for positional args (PR-23061)
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py19
-rw-r--r--numpy/lib/tests/test_function_base.py11
2 files changed, 30 insertions, 0 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index f7bc09166..30349f9e5 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -2286,6 +2286,25 @@ class vectorize:
"""
def __init__(self, pyfunc=np._NoValue, otypes=None, doc=None,
excluded=None, cache=False, signature=None):
+
+ if not callable(pyfunc):
+ p_temp = pyfunc
+ pyfunc = np._NoValue
+ if p_temp is not None and p_temp is not np._NoValue:
+ o_temp = otypes
+ otypes = p_temp
+ if o_temp is not None:
+ d_temp = doc
+ doc = o_temp
+ if d_temp is not None:
+ e_temp = excluded
+ excluded = d_temp
+ if e_temp is True or e_temp is False:
+ c_temp = cache
+ cache = e_temp
+ if c_temp is not None:
+ signature = c_temp
+
self.pyfunc = pyfunc
self.cache = cache
self.signature = signature
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 6499b78b7..b56d776cb 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1817,6 +1817,17 @@ class TestVectorize:
assert_array_equal(r, [1, 2, 3])
assert f.__name__ == 'f'
+ def test_decorator_positional_args(self):
+ A = np.vectorize(abs, ['float64'], "return float absolute value")
+ y1 = A([1, -1, 0, 0.3, 5])
+
+ @vectorize(['float64'], "return float absolute value")
+ def myabs(a):
+ return abs(a)
+
+ y2 = myabs([1, -1, 0, 0.3, 5])
+ assert_array_equal(y1, y2)
+
def test_positional_regression_9477(self):
# This supplies the first keyword argument as a positional,
# to ensure that they are still properly forwarded after the