summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoy Smart <roytsmart@gmail.com>2023-02-24 17:16:32 -0700
committerRoy Smart <roytsmart@gmail.com>2023-03-06 02:43:57 -0700
commit3b89f5227e7d8f8eddd2959982efb6739efdb729 (patch)
treeed8fcef1a659acb52d791fcc82dc12b737c72def
parentb4f6b0b9d2af3e5a7ed0bfb51c8d692298bc2dde (diff)
downloadnumpy-3b89f5227e7d8f8eddd2959982efb6739efdb729.tar.gz
ENH: Modify `numpy.logspace` so that the `base` argument broadcasts correctly against `start` and `stop`.
-rw-r--r--doc/release/upcoming_changes/23275.improvement.rst4
-rw-r--r--numpy/core/function_base.py24
-rw-r--r--numpy/core/tests/test_function_base.py28
3 files changed, 51 insertions, 5 deletions
diff --git a/doc/release/upcoming_changes/23275.improvement.rst b/doc/release/upcoming_changes/23275.improvement.rst
new file mode 100644
index 000000000..14ed5d9ad
--- /dev/null
+++ b/doc/release/upcoming_changes/23275.improvement.rst
@@ -0,0 +1,4 @@
+``numpy.logspace`` now supports a non-scalar ``base`` argument
+--------------------------------------------------------------
+The ``base`` argument of ``numpy.logspace`` can now be array-like if it's
+broadcastable against the ``start`` and ``stop`` arguments. \ No newline at end of file
diff --git a/numpy/core/function_base.py b/numpy/core/function_base.py
index 3dc51a81b..c82b495b1 100644
--- a/numpy/core/function_base.py
+++ b/numpy/core/function_base.py
@@ -3,6 +3,7 @@ import warnings
import operator
import types
+import numpy as np
from . import numeric as _nx
from .numeric import result_type, NaN, asanyarray, ndim
from numpy.core.multiarray import add_docstring
@@ -183,7 +184,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
def _logspace_dispatcher(start, stop, num=None, endpoint=None, base=None,
dtype=None, axis=None):
- return (start, stop)
+ return (start, stop, base)
@array_function_dispatch(_logspace_dispatcher)
@@ -199,6 +200,9 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
.. versionchanged:: 1.16.0
Non-scalar `start` and `stop` are now supported.
+ .. versionchanged:: 1.25.0
+ Non-scalar 'base` is now supported
+
Parameters
----------
start : array_like
@@ -223,9 +227,10 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
an integer; `float` is chosen even if the arguments would produce an
array of integers.
axis : int, optional
- The axis in the result to store the samples. Relevant only if start
- or stop are array-like. By default (0), the samples will be along a
- new axis inserted at the beginning. Use -1 to get an axis at the end.
+ The axis in the result to store the samples. Relevant only if start,
+ stop, or base are array-like. By default (0), the samples will be
+ along a new axis inserted at the beginning. Use -1 to get an axis at
+ the end.
.. versionadded:: 1.16.0
@@ -247,7 +252,7 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
Notes
-----
- Logspace is equivalent to the code
+ If base is a scalar, logspace is equivalent to the code
>>> y = np.linspace(start, stop, num=num, endpoint=endpoint)
... # doctest: +SKIP
@@ -262,6 +267,9 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
array([100. , 177.827941 , 316.22776602, 562.34132519])
>>> np.logspace(2.0, 3.0, num=4, base=2.0)
array([4. , 5.0396842 , 6.34960421, 8. ])
+ >>> np.logspace(2.0, 3.0, num=4, base=[2.0, 3.0], axis=-1)
+ array([[ 4. , 5.0396842 , 6.34960421, 8. ],
+ [ 9. , 12.98024613, 18.72075441, 27. ]])
Graphical illustration:
@@ -279,7 +287,13 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
>>> plt.show()
"""
+ ndmax = np.broadcast(start, stop, base).ndim
+ start, stop, base = (
+ np.array(a, copy=False, subok=True, ndmin=ndmax)
+ for a in (start, stop, base)
+ )
y = linspace(start, stop, num=num, endpoint=endpoint, axis=axis)
+ base = np.expand_dims(base, axis=axis)
if dtype is None:
return _nx.power(base, y)
return _nx.power(base, y).astype(dtype, copy=False)
diff --git a/numpy/core/tests/test_function_base.py b/numpy/core/tests/test_function_base.py
index 21583dd44..79f1ecfc9 100644
--- a/numpy/core/tests/test_function_base.py
+++ b/numpy/core/tests/test_function_base.py
@@ -1,3 +1,4 @@
+import pytest
from numpy import (
logspace, linspace, geomspace, dtype, array, sctypes, arange, isnan,
ndarray, sqrt, nextafter, stack, errstate
@@ -65,6 +66,33 @@ class TestLogspace:
t5 = logspace(start, stop, 6, axis=-1)
assert_equal(t5, t2.T)
+ @pytest.mark.parametrize("axis", [0, 1, -1])
+ def test_base_array(self, axis: int):
+ start = 1
+ stop = 2
+ num = 6
+ base = array([1, 2])
+ t1 = logspace(start, stop, num=num, base=base, axis=axis)
+ t2 = stack(
+ [logspace(start, stop, num=num, base=_base) for _base in base],
+ axis=(axis + 1) % t1.ndim,
+ )
+ assert_equal(t1, t2)
+
+ @pytest.mark.parametrize("axis", [0, 1, -1])
+ def test_stop_base_array(self, axis: int):
+ start = 1
+ stop = array([2, 3])
+ num = 6
+ base = array([1, 2])
+ t1 = logspace(start, stop, num=num, base=base, axis=axis)
+ t2 = stack(
+ [logspace(start, _stop, num=num, base=_base)
+ for _stop, _base in zip(stop, base)],
+ axis=(axis + 1) % t1.ndim,
+ )
+ assert_equal(t1, t2)
+
def test_dtype(self):
y = logspace(0, 6, dtype='float32')
assert_equal(y.dtype, dtype('float32'))