summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/release/upcoming_changes/19356.change.rst7
-rw-r--r--numpy/lib/function_base.py45
-rw-r--r--numpy/lib/tests/test_function_base.py22
3 files changed, 53 insertions, 21 deletions
diff --git a/doc/release/upcoming_changes/19356.change.rst b/doc/release/upcoming_changes/19356.change.rst
new file mode 100644
index 000000000..3c5ef4a91
--- /dev/null
+++ b/doc/release/upcoming_changes/19356.change.rst
@@ -0,0 +1,7 @@
+`numpy.vectorize` functions now produce the same output class as the base function
+----------------------------------------------------------------------------------
+When a function that respects `numpy.ndarray` subclasses is vectorized using
+`numpy.vectorize`, the vectorized function will now be subclass-safe
+also for cases that a signature is given (i.e., when creating a ``gufunc``):
+the output class will be the same as that returned by the first call to
+the underlying function.
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 2e9ae6644..fdbe698d6 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -1522,7 +1522,7 @@ def unwrap(p, discont=None, axis=-1, *, period=2*pi):
p : array_like
Input array.
discont : float, optional
- Maximum discontinuity between values, default is ``period/2``.
+ Maximum discontinuity between values, default is ``period/2``.
Values below ``period/2`` are treated as if they were ``period/2``.
To have an effect different from the default, `discont` should be
larger than ``period/2``.
@@ -1531,7 +1531,7 @@ def unwrap(p, discont=None, axis=-1, *, period=2*pi):
period: float, optional
Size of the range over which the input wraps. By default, it is
``2 pi``.
-
+
.. versionadded:: 1.21.0
Returns
@@ -1545,8 +1545,8 @@ def unwrap(p, discont=None, axis=-1, *, period=2*pi):
Notes
-----
- If the discontinuity in `p` is smaller than ``period/2``,
- but larger than `discont`, no unwrapping is done because taking
+ If the discontinuity in `p` is smaller than ``period/2``,
+ but larger than `discont`, no unwrapping is done because taking
the complement would only make the discontinuity larger.
Examples
@@ -1579,7 +1579,7 @@ def unwrap(p, discont=None, axis=-1, *, period=2*pi):
slice1 = tuple(slice1)
dtype = np.result_type(dd, period)
if _nx.issubdtype(dtype, _nx.integer):
- interval_high, rem = divmod(period, 2)
+ interval_high, rem = divmod(period, 2)
boundary_ambiguous = rem == 0
else:
interval_high = period / 2
@@ -1943,11 +1943,19 @@ def _calculate_shapes(broadcast_shape, dim_sizes, list_of_core_dims):
for core_dims in list_of_core_dims]
-def _create_arrays(broadcast_shape, dim_sizes, list_of_core_dims, dtypes):
+def _create_arrays(broadcast_shape, dim_sizes, list_of_core_dims, dtypes,
+ results=None):
"""Helper for creating output arrays in vectorize."""
shapes = _calculate_shapes(broadcast_shape, dim_sizes, list_of_core_dims)
- arrays = tuple(np.empty(shape, dtype=dtype)
- for shape, dtype in zip(shapes, dtypes))
+ if dtypes is None:
+ dtypes = [None] * len(shapes)
+ if results is None:
+ arrays = tuple(np.empty(shape=shape, dtype=dtype)
+ for shape, dtype in zip(shapes, dtypes))
+ else:
+ arrays = tuple(np.empty_like(result, shape=shape, dtype=dtype)
+ for result, shape, dtype
+ in zip(results, shapes, dtypes))
return arrays
@@ -2293,11 +2301,8 @@ class vectorize:
for result, core_dims in zip(results, output_core_dims):
_update_dim_sizes(dim_sizes, result, core_dims)
- if otypes is None:
- otypes = [asarray(result).dtype for result in results]
-
outputs = _create_arrays(broadcast_shape, dim_sizes,
- output_core_dims, otypes)
+ output_core_dims, otypes, results)
for output, result in zip(outputs, results):
output[index] = result
@@ -4136,13 +4141,13 @@ def trapz(y, x=None, dx=1.0, axis=-1):
If `x` is provided, the integration happens in sequence along its
elements - they are not sorted.
-
+
Integrate `y` (`x`) along each 1d slice on the given axis, compute
:math:`\int y(x) dx`.
When `x` is specified, this integrates along the parametric curve,
computing :math:`\int_t y(t) dt =
\int_t y(t) \left.\frac{dx}{dt}\right|_{x=x(t)} dt`.
-
+
Parameters
----------
y : array_like
@@ -4163,7 +4168,7 @@ def trapz(y, x=None, dx=1.0, axis=-1):
a single axis by the trapezoidal rule. If 'y' is a 1-dimensional array,
then the result is a float. If 'n' is greater than 1, then the result
is an 'n-1' dimensional array.
-
+
See Also
--------
sum, cumsum
@@ -4192,16 +4197,16 @@ def trapz(y, x=None, dx=1.0, axis=-1):
8.0
>>> np.trapz([1,2,3], dx=2)
8.0
-
+
Using a decreasing `x` corresponds to integrating in reverse:
-
- >>> np.trapz([1,2,3], x=[8,6,4])
+
+ >>> np.trapz([1,2,3], x=[8,6,4])
-8.0
-
+
More generally `x` is used to integrate along a parametric curve.
This finds the area of a circle, noting we repeat the sample which closes
the curve:
-
+
>>> theta = np.linspace(0, 2 * np.pi, num=1000, endpoint=True)
>>> np.trapz(np.cos(theta), x=np.sin(theta))
3.141571941375841
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index a4f49a78b..e1b615223 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1665,6 +1665,26 @@ class TestVectorize:
with assert_raises_regex(ValueError, 'new output dimensions'):
f(x)
+ def test_subclasses(self):
+ class subclass(np.ndarray):
+ pass
+
+ m = np.array([[1., 0., 0.],
+ [0., 0., 1.],
+ [0., 1., 0.]]).view(subclass)
+ v = np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]).view(subclass)
+ # generalized (gufunc)
+ matvec = np.vectorize(np.matmul, signature='(m,m),(m)->(m)')
+ r = matvec(m, v)
+ assert_equal(type(r), subclass)
+ assert_equal(r, [[1., 3., 2.], [4., 6., 5.], [7., 9., 8.]])
+
+ # element-wise (ufunc)
+ mult = np.vectorize(lambda x, y: x*y)
+ r = mult(m, v)
+ assert_equal(type(r), subclass)
+ assert_equal(r, m * v)
+
class TestLeaks:
class A:
@@ -1798,7 +1818,7 @@ class TestUnwrap:
assert_array_equal(unwrap([1, 1 + 2 * np.pi]), [1, 1])
# check that unwrap maintains continuity
assert_(np.all(diff(unwrap(rand(10) * 100)) < np.pi))
-
+
def test_period(self):
# check that unwrap removes jumps greater that 255
assert_array_equal(unwrap([1, 1 + 256], period=255), [1, 2])