diff options
author | Matti Picus <matti.picus@gmail.com> | 2019-03-10 12:04:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-03-10 12:04:10 +0200 |
commit | a090e46b9e3b5046f08e46b08df103d00985e47c (patch) | |
tree | 94417188cc82d9ad83ebc6dd2f7c39cce346f27f | |
parent | cbf3a081271a43e980e3c2f76625deb43fd53922 (diff) | |
parent | b3fdef08b5c586fff50fdabb0967e53ccaaab3c4 (diff) | |
download | numpy-a090e46b9e3b5046f08e46b08df103d00985e47c.tar.gz |
Merge pull request #13092 from mhvk/linspace-object-dtype-bug
BUG: ensure linspace works on object input.
-rw-r--r-- | numpy/core/function_base.py | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_function_base.py | 6 |
2 files changed, 8 insertions, 2 deletions
diff --git a/numpy/core/function_base.py b/numpy/core/function_base.py index f8800b83e..296213823 100644 --- a/numpy/core/function_base.py +++ b/numpy/core/function_base.py @@ -6,7 +6,7 @@ import operator from . import numeric as _nx from .numeric import (result_type, NaN, shares_memory, MAY_SHARE_BOUNDS, - TooHardError, asanyarray) + TooHardError, asanyarray, ndim) from numpy.core.multiarray import add_docstring from numpy.core import overrides @@ -140,7 +140,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, dtype = dt delta = stop - start - y = _nx.arange(0, num, dtype=dt).reshape((-1,) + (1,) * delta.ndim) + y = _nx.arange(0, num, dtype=dt).reshape((-1,) + (1,) * ndim(delta)) # In-place multiplication y *= delta/div is faster, but prevents the multiplicant # from overriding what class is produced, and thus prevents, e.g. use of Quantities, # see gh-7142. Hence, we multiply in place only for standard scalar types. diff --git a/numpy/core/tests/test_function_base.py b/numpy/core/tests/test_function_base.py index 459bacab0..8b820bd75 100644 --- a/numpy/core/tests/test_function_base.py +++ b/numpy/core/tests/test_function_base.py @@ -362,3 +362,9 @@ class TestLinspace(object): assert_(isinstance(y, tuple) and len(y) == 2 and len(y[0]) == num and isnan(y[1]), 'num={0}, endpoint={1}'.format(num, ept)) + + def test_object(self): + start = array(1, dtype='O') + stop = array(2, dtype='O') + y = linspace(start, stop, 3) + assert_array_equal(y, array([1., 1.5, 2.])) |