summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/lib/tests/test_ufunclike.py18
-rw-r--r--numpy/lib/ufunclike.py4
2 files changed, 20 insertions, 2 deletions
diff --git a/numpy/lib/tests/test_ufunclike.py b/numpy/lib/tests/test_ufunclike.py
index 2762de5c6..5830e7175 100644
--- a/numpy/lib/tests/test_ufunclike.py
+++ b/numpy/lib/tests/test_ufunclike.py
@@ -58,10 +58,28 @@ array([ 2.169925 , 1.20163386, 2.70043972])
"""
from numpy.testing import *
+import numpy.core as nx
+import numpy.lib.ufunclike as ufl
def test():
return rundocs()
+def test_fix_with_subclass():
+ class MyArray(nx.ndarray):
+ def __new__(cls, data, metadata=None):
+ res = nx.array(data, copy=True).view(cls)
+ res.metadata = metadata
+ return res
+ def __array_wrap__(self, obj, context=None):
+ obj.metadata = self.metadata
+ return obj
+
+ a = nx.array([1.1, -1.1])
+ m = MyArray(a, metadata='foo')
+ f = ufl.fix(m)
+ assert_array_equal(f, nx.array([1,-1]))
+ assert isinstance(f, MyArray)
+ assert_equal(f.metadata, 'foo')
if __name__ == "__main__":
run_module_suite()
diff --git a/numpy/lib/ufunclike.py b/numpy/lib/ufunclike.py
index bb8ee7808..7432e91ae 100644
--- a/numpy/lib/ufunclike.py
+++ b/numpy/lib/ufunclike.py
@@ -41,10 +41,10 @@ def fix(x, y=None):
"""
x = nx.asanyarray(x)
- if y is None:
- y = nx.zeros_like(x)
y1 = nx.floor(x)
y2 = nx.ceil(x)
+ if y is None:
+ y = y1
y[...] = nx.where(x >= 0, y1, y2)
return y