diff options
author | Jonathan Helmus <jjhelmus@gmail.com> | 2013-06-28 12:19:25 -0500 |
---|---|---|
committer | Jonathan Helmus <jjhelmus@gmail.com> | 2013-09-13 16:25:12 -0500 |
commit | f8cdbbaee1967011d98aa454b14232488df451d3 (patch) | |
tree | 5cabdc04b3df90d020a80cbcd546546e008f7f89 /numpy | |
parent | ad3f60156d244471ad86d57bb92b2b4e32c22324 (diff) | |
download | numpy-f8cdbbaee1967011d98aa454b14232488df451d3.tar.gz |
ENH: Add dtype parameter to linspace and logspace functions.
Many NumPy functions such as arange allow users to define the dtype of the
returned type with a dtype parameter. This adds this same functionality
to the logspace and linspace functions.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/function_base.py | 36 | ||||
-rw-r--r-- | numpy/core/tests/test_function_base.py | 30 |
2 files changed, 50 insertions, 16 deletions
diff --git a/numpy/core/function_base.py b/numpy/core/function_base.py index f2c895608..3396adf3d 100644 --- a/numpy/core/function_base.py +++ b/numpy/core/function_base.py @@ -3,9 +3,10 @@ from __future__ import division, absolute_import, print_function __all__ = ['logspace', 'linspace'] from . import numeric as _nx -from .numeric import array +from .numeric import array, result_type -def linspace(start, stop, num=50, endpoint=True, retstep=False): + +def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None): """ Return evenly spaced numbers over a specified interval. @@ -31,6 +32,9 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False): retstep : bool, optional If True, return (`samples`, `step`), where `step` is the spacing between samples. + dtype : dtype + The type of the output array. If `dtype` is not given, infer the data + type from the other input arguments. Returns ------- @@ -74,23 +78,28 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False): """ num = int(num) + + if dtype is None: + dtype = result_type(start, stop, float(num)) + if num <= 0: - return array([], float) + return array([], dtype) if endpoint: if num == 1: - return array([float(start)]) + return array([start], dtype=dtype) step = (stop-start)/float((num-1)) - y = _nx.arange(0, num) * step + start + y = _nx.arange(0, num, dtype=dtype) * step + start y[-1] = stop else: step = (stop-start)/float(num) - y = _nx.arange(0, num) * step + start + y = _nx.arange(0, num, dtype=dtype) * step + start if retstep: - return y, step + return y.astype(dtype), step else: - return y + return y.astype(dtype) + -def logspace(start,stop,num=50,endpoint=True,base=10.0): +def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None): """ Return numbers spaced evenly on a log scale. @@ -116,6 +125,9 @@ def logspace(start,stop,num=50,endpoint=True,base=10.0): The base of the log space. The step size between the elements in ``ln(samples) / ln(base)`` (or ``log_base(samples)``) is uniform. Default is 10.0. + dtype : dtype + The type of the output array. If `dtype` is not given, infer the data + type from the other input arguments. Returns ------- @@ -136,7 +148,7 @@ def logspace(start,stop,num=50,endpoint=True,base=10.0): >>> y = np.linspace(start, stop, num=num, endpoint=endpoint) ... # doctest: +SKIP - >>> power(base, y) + >>> power(base, y).astype(dtype) ... # doctest: +SKIP Examples @@ -165,4 +177,6 @@ def logspace(start,stop,num=50,endpoint=True,base=10.0): """ y = linspace(start, stop, num=num, endpoint=endpoint) - return _nx.power(base, y) + if dtype is None: + return _nx.power(base, y) + return _nx.power(base, y).astype(dtype) diff --git a/numpy/core/tests/test_function_base.py b/numpy/core/tests/test_function_base.py index efca9ef8a..b43452e1c 100644 --- a/numpy/core/tests/test_function_base.py +++ b/numpy/core/tests/test_function_base.py @@ -1,23 +1,35 @@ from __future__ import division, absolute_import, print_function from numpy.testing import * -from numpy import logspace, linspace +from numpy import logspace, linspace, dtype + class TestLogspace(TestCase): + def test_basic(self): y = logspace(0, 6) - assert_(len(y)==50) + assert_(len(y) == 50) y = logspace(0, 6, num=100) - assert_(y[-1] == 10**6) + assert_(y[-1] == 10 ** 6) y = logspace(0, 6, endpoint=0) - assert_(y[-1] < 10**6) + assert_(y[-1] < 10 ** 6) y = logspace(0, 6, num=7) assert_array_equal(y, [1, 10, 100, 1e3, 1e4, 1e5, 1e6]) + def test_dtype(self): + y = logspace(0, 6, dtype='float32') + assert_equal(y.dtype, dtype('float32')) + y = logspace(0, 6, dtype='float64') + assert_equal(y.dtype, dtype('float64')) + y = logspace(0, 6, dtype='int32') + assert_equal(y.dtype, dtype('int32')) + + class TestLinspace(TestCase): + def test_basic(self): y = linspace(0, 10) - assert_(len(y)==50) + assert_(len(y) == 50) y = linspace(2, 10, num=100) assert_(y[-1] == 10) y = linspace(2, 10, endpoint=0) @@ -35,3 +47,11 @@ class TestLinspace(TestCase): t3 = linspace(0, 1, 2).dtype assert_equal(t1, t2) assert_equal(t2, t3) + + def test_dtype(self): + y = linspace(0, 6, dtype='float32') + assert_equal(y.dtype, dtype('float32')) + y = linspace(0, 6, dtype='float64') + assert_equal(y.dtype, dtype('float64')) + y = linspace(0, 6, dtype='int32') + assert_equal(y.dtype, dtype('int32')) |