summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2016-01-30 15:56:11 -0500
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2016-01-31 16:24:42 -0500
commit0aa03bef711e57220ad5286f68363e6aca7cdfad (patch)
tree93dc6f24f01d68f1503b554491a0b4687ccdfb7d
parent0bba66540744223b3a10b76ef212780448bef874 (diff)
downloadnumpy-0aa03bef711e57220ad5286f68363e6aca7cdfad.tar.gz
Reascertain that linspace respects ndarray subclasses in start, stop.
-rw-r--r--numpy/core/function_base.py11
-rw-r--r--numpy/core/tests/test_function_base.py15
2 files changed, 22 insertions, 4 deletions
diff --git a/numpy/core/function_base.py b/numpy/core/function_base.py
index c82c9bb6b..21ca1af01 100644
--- a/numpy/core/function_base.py
+++ b/numpy/core/function_base.py
@@ -96,18 +96,23 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None):
y = _nx.arange(0, num, dtype=dt)
+ delta = stop - start
if num > 1:
- delta = stop - start
step = delta / div
if step == 0:
# Special handling for denormal numbers, gh-5437
y /= div
- y *= delta
+ y = y * delta
else:
- y *= step
+ # One might be tempted to use faster, in-place multiplication here,
+ # but this prevents step from overriding what class is produced,
+ # and thus prevents, e.g., use of Quantities; see gh-7142.
+ y = y * step
else:
# 0 and 1 item long sequences have an undefined step
step = NaN
+ # Multiply with delta to allow possible override of output class.
+ y = y * delta
y += start
diff --git a/numpy/core/tests/test_function_base.py b/numpy/core/tests/test_function_base.py
index 2df7ba3ea..6b5430611 100644
--- a/numpy/core/tests/test_function_base.py
+++ b/numpy/core/tests/test_function_base.py
@@ -1,7 +1,7 @@
from __future__ import division, absolute_import, print_function
from numpy import (logspace, linspace, dtype, array, finfo, typecodes, arange,
- isnan)
+ isnan, ndarray)
from numpy.testing import (
TestCase, run_module_suite, assert_, assert_equal, assert_raises,
assert_array_equal
@@ -115,6 +115,19 @@ class TestLinspace(TestCase):
b = PhysicalQuantity(1.0)
assert_equal(linspace(a, b), linspace(0.0, 1.0))
+ def test_subclass(self):
+ class PhysicalQuantity2(ndarray):
+ __array_priority__ = 10
+
+ a = array(0).view(PhysicalQuantity2)
+ b = array(1).view(PhysicalQuantity2)
+ ls = linspace(a, b)
+ assert type(ls) is PhysicalQuantity2
+ assert_equal(ls, linspace(0.0, 1.0))
+ ls = linspace(a, b, 1)
+ assert type(ls) is PhysicalQuantity2
+ assert_equal(ls, linspace(0.0, 1.0, 1))
+
def test_denormal_numbers(self):
# Regression test for gh-5437. Will probably fail when compiled
# with ICC, which flushes denormals to zero