summaryrefslogtreecommitdiff
path: root/numpy/lib/tests
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/tests')
-rw-r--r--numpy/lib/tests/test__iotools.py18
-rw-r--r--numpy/lib/tests/test_arraysetops.py6
-rw-r--r--numpy/lib/tests/test_format.py32
-rw-r--r--numpy/lib/tests/test_function_base.py72
-rw-r--r--numpy/lib/tests/test_io.py52
-rw-r--r--numpy/lib/tests/test_nanfunctions.py16
-rw-r--r--numpy/lib/tests/test_packbits.py26
-rw-r--r--numpy/lib/tests/test_polynomial.py11
-rw-r--r--numpy/lib/tests/test_stride_tricks.py46
-rw-r--r--numpy/lib/tests/test_twodim_base.py34
10 files changed, 271 insertions, 42 deletions
diff --git a/numpy/lib/tests/test__iotools.py b/numpy/lib/tests/test__iotools.py
index 4db19382a..92ca1c973 100644
--- a/numpy/lib/tests/test__iotools.py
+++ b/numpy/lib/tests/test__iotools.py
@@ -7,7 +7,7 @@ from datetime import date
import numpy as np
from numpy.compat import asbytes, asbytes_nested
from numpy.testing import (
- run_module_suite, TestCase, assert_, assert_equal
+ run_module_suite, TestCase, assert_, assert_equal, assert_allclose
)
from numpy.lib._iotools import (
LineSplitter, NameValidator, StringConverter,
@@ -76,7 +76,7 @@ class TestLineSplitter(TestCase):
test = LineSplitter((6, 6, 9))(strg)
assert_equal(test, asbytes_nested(['1', '3 4', '5 6']))
-#-------------------------------------------------------------------------------
+# -----------------------------------------------------------------------------
class TestNameValidator(TestCase):
@@ -127,7 +127,7 @@ class TestNameValidator(TestCase):
assert_(validator(namelist) is None)
assert_equal(validator(namelist, nbfields=3), ['f0', 'f1', 'f2'])
-#-------------------------------------------------------------------------------
+# -----------------------------------------------------------------------------
def _bytes_to_date(s):
@@ -150,13 +150,17 @@ class TestStringConverter(TestCase):
"Tests the upgrade method."
converter = StringConverter()
assert_equal(converter._status, 0)
- converter.upgrade(asbytes('0'))
+ # test int
+ assert_equal(converter.upgrade(asbytes('0')), 0)
assert_equal(converter._status, 1)
- converter.upgrade(asbytes('0.'))
+ # test float
+ assert_allclose(converter.upgrade(asbytes('0.')), 0.0)
assert_equal(converter._status, 2)
- converter.upgrade(asbytes('0j'))
+ # test complex
+ assert_equal(converter.upgrade(asbytes('0j')), complex('0j'))
assert_equal(converter._status, 3)
- converter.upgrade(asbytes('a'))
+ # test str
+ assert_equal(converter.upgrade(asbytes('a')), asbytes('a'))
assert_equal(converter._status, len(converter._mapper) - 1)
def test_missing(self):
diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py
index e83f8552e..39196f4bc 100644
--- a/numpy/lib/tests/test_arraysetops.py
+++ b/numpy/lib/tests/test_arraysetops.py
@@ -109,6 +109,12 @@ class TestSetOps(TestCase):
assert_array_equal(a2, unq)
assert_array_equal(a2_inv, inv)
+ # test for chararrays with return_inverse (gh-5099)
+ a = np.chararray(5)
+ a[...] = ''
+ a2, a2_inv = np.unique(a, return_inverse=True)
+ assert_array_equal(a2_inv, np.zeros(5))
+
def test_intersect1d(self):
# unique inputs
a = np.array([5, 7, 1, 2])
diff --git a/numpy/lib/tests/test_format.py b/numpy/lib/tests/test_format.py
index b266f1c15..ee77386bc 100644
--- a/numpy/lib/tests/test_format.py
+++ b/numpy/lib/tests/test_format.py
@@ -688,28 +688,28 @@ def test_bad_header():
def test_large_file_support():
from nose import SkipTest
+ if (sys.platform == 'win32' or sys.platform == 'cygwin'):
+ raise SkipTest("Unknown if Windows has sparse filesystems")
# try creating a large sparse file
- with tempfile.NamedTemporaryFile() as tf:
- try:
- # seek past end would work too, but linux truncate somewhat
- # increases the chances that we have a sparse filesystem and can
- # avoid actually writing 5GB
- import subprocess as sp
- sp.check_call(["truncate", "-s", "5368709120", tf.name])
- except:
- raise SkipTest("Could not create 5GB large file")
- # write a small array to the end
- f = open(tf.name, "wb")
+ tf_name = os.path.join(tempdir, 'sparse_file')
+ try:
+ # seek past end would work too, but linux truncate somewhat
+ # increases the chances that we have a sparse filesystem and can
+ # avoid actually writing 5GB
+ import subprocess as sp
+ sp.check_call(["truncate", "-s", "5368709120", tf_name])
+ except:
+ raise SkipTest("Could not create 5GB large file")
+ # write a small array to the end
+ with open(tf_name, "wb") as f:
f.seek(5368709120)
d = np.arange(5)
np.save(f, d)
- f.close()
- # read it back
- f = open(tf.name, "rb")
+ # read it back
+ with open(tf_name, "rb") as f:
f.seek(5368709120)
r = np.load(f)
- f.close()
- assert_array_equal(r, d)
+ assert_array_equal(r, d)
if __name__ == "__main__":
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index a3f805691..03521ca4c 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -124,6 +124,11 @@ class TestAverage(TestCase):
assert_array_equal(average(y1, weights=w2, axis=1), desired)
assert_equal(average(y1, weights=w2), 5.)
+ y3 = rand(5).astype(np.float32)
+ w3 = rand(5).astype(np.float64)
+
+ assert_(np.average(y3, weights=w3).dtype == np.result_type(y3, w3))
+
def test_returned(self):
y = np.array([[1, 2, 3], [4, 5, 6]])
@@ -312,6 +317,16 @@ class TestInsert(TestCase):
np.insert([0, 1, 2], x, [3, 4, 5])
assert_equal(x, np.array([1, 1, 1]))
+ def test_structured_array(self):
+ a = np.array([(1, 'a'), (2, 'b'), (3, 'c')],
+ dtype=[('foo', 'i'), ('bar', 'a1')])
+ val = (4, 'd')
+ b = np.insert(a, 0, val)
+ assert_array_equal(b[0], np.array(val, dtype=b.dtype))
+ val = [(4, 'd')] * 2
+ b = np.insert(a, [0, 2], val)
+ assert_array_equal(b[[0, 3]], np.array(val, dtype=b.dtype))
+
class TestAmax(TestCase):
def test_basic(self):
@@ -516,8 +531,18 @@ class TestGradient(TestCase):
def test_masked(self):
# Make sure that gradient supports subclasses like masked arrays
- x = np.ma.array([[1, 1], [3, 4]])
- assert_equal(type(gradient(x)[0]), type(x))
+ x = np.ma.array([[1, 1], [3, 4]],
+ mask=[[False, False], [False, False]])
+ out = gradient(x)[0]
+ assert_equal(type(out), type(x))
+ # And make sure that the output and input don't have aliased mask
+ # arrays
+ assert_(x.mask is not out.mask)
+ # Also check that edge_order=2 doesn't alter the original mask
+ x2 = np.ma.arange(5)
+ x2[2] = np.ma.masked
+ np.gradient(x2, edge_order=2)
+ assert_array_equal(x2.mask, [False, False, True, False, False])
def test_datetime64(self):
# Make sure gradient() can handle special types like datetime64
@@ -526,7 +551,7 @@ class TestGradient(TestCase):
'1910-10-12', '1910-12-12', '1912-12-12'],
dtype='datetime64[D]')
dx = np.array(
- [-7, -3, 0, 31, 61, 396, 1066],
+ [-5, -3, 0, 31, 61, 396, 731],
dtype='timedelta64[D]')
assert_array_equal(gradient(x), dx)
assert_(dx.dtype == np.dtype('timedelta64[D]'))
@@ -537,7 +562,7 @@ class TestGradient(TestCase):
[-5, -3, 10, 12, 61, 321, 300],
dtype='timedelta64[D]')
dx = np.array(
- [-3, 7, 7, 25, 154, 119, -161],
+ [2, 7, 7, 25, 154, 119, -21],
dtype='timedelta64[D]')
assert_array_equal(gradient(x), dx)
assert_(dx.dtype == np.dtype('timedelta64[D]'))
@@ -551,7 +576,7 @@ class TestGradient(TestCase):
dx = x[1] - x[0]
y = 2 * x ** 3 + 4 * x ** 2 + 2 * x
analytical = 6 * x ** 2 + 8 * x + 2
- num_error = np.abs((np.gradient(y, dx) / analytical) - 1)
+ num_error = np.abs((np.gradient(y, dx, edge_order=2) / analytical) - 1)
assert_(np.all(num_error < 0.03) == True)
@@ -836,6 +861,13 @@ class TestDigitize(TestCase):
bins = [1, 1, 0, 1]
assert_raises(ValueError, digitize, x, bins)
+ def test_casting_error(self):
+ x = [1, 2, 3+1.j]
+ bins = [1, 2, 3]
+ assert_raises(TypeError, digitize, x, bins)
+ x, bins = bins, x
+ assert_raises(TypeError, digitize, x, bins)
+
class TestUnwrap(TestCase):
def test_simple(self):
@@ -1072,6 +1104,13 @@ class TestHistogram(TestCase):
h, b = histogram(a, weights=np.ones(10, float))
assert_(issubdtype(h.dtype, float))
+ def test_f32_rounding(self):
+ # gh-4799, check that the rounding of the edges works with float32
+ x = np.array([276.318359 , -69.593948 , 21.329449], dtype=np.float32)
+ y = np.array([5005.689453, 4481.327637, 6010.369629], dtype=np.float32)
+ counts_hist, xedges, yedges = np.histogram2d(x, y, bins=100)
+ assert_equal(counts_hist.sum(), 3.)
+
def test_weights(self):
v = rand(100)
w = np.ones(100) * 5
@@ -1460,7 +1499,7 @@ class TestMeshgrid(TestCase):
# Test that meshgrid complains about invalid arguments
# Regression test for issue #4755:
# https://github.com/numpy/numpy/issues/4755
- assert_raises(TypeError, meshgrid,
+ assert_raises(TypeError, meshgrid,
[1, 2, 3], [4, 5, 6, 7], indices='ij')
@@ -1587,6 +1626,9 @@ class TestInterp(TestCase):
def test_exceptions(self):
assert_raises(ValueError, interp, 0, [], [])
assert_raises(ValueError, interp, 0, [0], [1, 2])
+ assert_raises(ValueError, interp, 0, [0, 1], [1, 2], period=0)
+ assert_raises(ValueError, interp, 0, [], [], period=360)
+ assert_raises(ValueError, interp, 0, [0], [1, 2], period=360)
def test_basic(self):
x = np.linspace(0, 1, 5)
@@ -1627,6 +1669,16 @@ class TestInterp(TestCase):
fp = np.sin(xp)
assert_almost_equal(np.interp(np.pi, xp, fp), 0.0)
+ def test_period(self):
+ x = [-180, -170, -185, 185, -10, -5, 0, 365]
+ xp = [190, -190, 350, -350]
+ fp = [5, 10, 3, 4]
+ y = [7.5, 5., 8.75, 6.25, 3., 3.25, 3.5, 3.75]
+ assert_almost_equal(np.interp(x, xp, fp, period=360), y)
+ x = np.array(x, order='F').reshape(2, -1)
+ y = np.array(y, order='C').reshape(2, -1)
+ assert_almost_equal(np.interp(x, xp, fp, period=360), y)
+
def compare_results(res, desired):
for i in range(len(desired)):
@@ -1860,6 +1912,14 @@ class TestScoreatpercentile(TestCase):
np.percentile(a, [50])
assert_equal(a, np.array([2, 3, 4, 1]))
+ def test_no_p_overwrite(self):
+ p = np.linspace(0., 100., num=5)
+ np.percentile(np.arange(100.), p, interpolation="midpoint")
+ assert_array_equal(p, np.linspace(0., 100., num=5))
+ p = np.linspace(0., 100., num=5).tolist()
+ np.percentile(np.arange(100.), p, interpolation="midpoint")
+ assert_array_equal(p, np.linspace(0., 100., num=5).tolist())
+
def test_percentile_overwrite(self):
a = np.array([2, 3, 4, 1])
b = np.percentile(a, [50], overwrite_input=True)
diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py
index 49ad1ba5b..68b2018cd 100644
--- a/numpy/lib/tests/test_io.py
+++ b/numpy/lib/tests/test_io.py
@@ -4,9 +4,7 @@ import sys
import gzip
import os
import threading
-import shutil
-import contextlib
-from tempfile import mkstemp, mkdtemp, NamedTemporaryFile
+from tempfile import mkstemp, NamedTemporaryFile
import time
import warnings
import gc
@@ -24,13 +22,7 @@ from numpy.ma.testutils import (
assert_raises, assert_raises_regex, run_module_suite
)
from numpy.testing import assert_warns, assert_, build_err_msg
-
-
-@contextlib.contextmanager
-def tempdir(change_dir=False):
- tmpdir = mkdtemp()
- yield tmpdir
- shutil.rmtree(tmpdir)
+from numpy.testing.utils import tempdir
class TextIO(BytesIO):
@@ -202,7 +194,7 @@ class TestSavezLoad(RoundtripTest, TestCase):
def test_big_arrays(self):
L = (1 << 31) + 100000
a = np.empty(L, dtype=np.uint8)
- with tempdir() as tmpdir:
+ with tempdir(prefix="numpy_test_big_arrays_") as tmpdir:
tmp = os.path.join(tmpdir, "file.npz")
np.savez(tmp, a=a)
del a
@@ -224,6 +216,17 @@ class TestSavezLoad(RoundtripTest, TestCase):
l = np.load(c)
assert_equal(a, l['file_a'])
assert_equal(b, l['file_b'])
+
+ def test_BagObj(self):
+ a = np.array([[1, 2], [3, 4]], float)
+ b = np.array([[1 + 2j, 2 + 7j], [3 - 6j, 4 + 12j]], complex)
+ c = BytesIO()
+ np.savez(c, file_a=a, file_b=b)
+ c.seek(0)
+ l = np.load(c)
+ assert_equal(sorted(dir(l.f)), ['file_a','file_b'])
+ assert_equal(a, l.f.file_a)
+ assert_equal(b, l.f.file_b)
def test_savez_filename_clashes(self):
# Test that issue #852 is fixed
@@ -311,7 +314,7 @@ class TestSavezLoad(RoundtripTest, TestCase):
# Check that zipfile owns file and can close it.
# This needs to pass a file name to load for the
# test.
- with tempdir() as tmpdir:
+ with tempdir(prefix="numpy_test_closing_zipfile_after_load_") as tmpdir:
fd, tmp = mkstemp(suffix='.npz', dir=tmpdir)
os.close(fd)
np.savez(tmp, lab='place holder')
@@ -1093,6 +1096,21 @@ M 33 21.99
control = np.array([2009., 23., 46],)
assert_equal(test, control)
+ def test_dtype_with_converters_and_usecols(self):
+ dstr = "1,5,-1,1:1\n2,8,-1,1:n\n3,3,-2,m:n\n"
+ dmap = {'1:1':0, '1:n':1, 'm:1':2, 'm:n':3}
+ dtyp = [('E1','i4'),('E2','i4'),('E3','i2'),('N', 'i1')]
+ conv = {0: int, 1: int, 2: int, 3: lambda r: dmap[r.decode()]}
+ test = np.recfromcsv(TextIO(dstr,), dtype=dtyp, delimiter=',',
+ names=None, converters=conv)
+ control = np.rec.array([[1,5,-1,0], [2,8,-1,1], [3,3,-2,3]], dtype=dtyp)
+ assert_equal(test, control)
+ dtyp = [('E1','i4'),('E2','i4'),('N', 'i1')]
+ test = np.recfromcsv(TextIO(dstr,), dtype=dtyp, delimiter=',',
+ usecols=(0,1,3), names=None, converters=conv)
+ control = np.rec.array([[1,5,0], [2,8,1], [3,3,3]], dtype=dtyp)
+ assert_equal(test, control)
+
def test_dtype_with_object(self):
"Test using an explicit dtype with an object"
from datetime import date
@@ -1308,6 +1326,16 @@ M 33 21.99
ctrl = np.array([(0, 3), (4, -999)], dtype=[(_, int) for _ in "ac"])
assert_equal(test, ctrl)
+ data2 = "1,2,*,4\n5,*,7,8\n"
+ test = np.genfromtxt(TextIO(data2), delimiter=',', dtype=int,
+ missing_values="*", filling_values=0)
+ ctrl = np.array([[1, 2, 0, 4], [5, 0, 7, 8]])
+ assert_equal(test, ctrl)
+ test = np.genfromtxt(TextIO(data2), delimiter=',', dtype=int,
+ missing_values="*", filling_values=-1)
+ ctrl = np.array([[1, 2, -1, 4], [5, -1, 7, 8]])
+ assert_equal(test, ctrl)
+
def test_withmissing_float(self):
data = TextIO('A,B\n0,1.5\n2,-999.00')
test = np.mafromtxt(data, dtype=None, delimiter=',',
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py
index 3da6b5149..35ae86c20 100644
--- a/numpy/lib/tests/test_nanfunctions.py
+++ b/numpy/lib/tests/test_nanfunctions.py
@@ -645,6 +645,22 @@ class TestNanFunctions_Median(TestCase):
assert_raises(IndexError, np.nanmedian, d, axis=(0, 4))
assert_raises(ValueError, np.nanmedian, d, axis=(1, 1))
+ def test_float_special(self):
+ with warnings.catch_warnings(record=True):
+ warnings.simplefilter('ignore', RuntimeWarning)
+ a = np.array([[np.inf, np.nan], [np.nan, np.nan]])
+ assert_equal(np.nanmedian(a, axis=0), [np.inf, np.nan])
+ assert_equal(np.nanmedian(a, axis=1), [np.inf, np.nan])
+ assert_equal(np.nanmedian(a), np.inf)
+
+ # minimum fill value check
+ a = np.array([[np.nan, np.nan, np.inf], [np.nan, np.nan, np.inf]])
+ assert_equal(np.nanmedian(a, axis=1), np.inf)
+
+ # no mask path
+ a = np.array([[np.inf, np.inf], [np.inf, np.inf]])
+ assert_equal(np.nanmedian(a, axis=1), np.inf)
+
class TestNanFunctions_Percentile(TestCase):
diff --git a/numpy/lib/tests/test_packbits.py b/numpy/lib/tests/test_packbits.py
new file mode 100644
index 000000000..186e8960d
--- /dev/null
+++ b/numpy/lib/tests/test_packbits.py
@@ -0,0 +1,26 @@
+import numpy as np
+
+from numpy.testing import assert_array_equal, assert_equal, assert_raises
+
+
+def test_packbits():
+ # Copied from the docstring.
+ a = [[[1, 0, 1], [0, 1, 0]],
+ [[1, 1, 0], [0, 0, 1]]]
+ for dtype in [np.bool, np.uint8, np.int]:
+ arr = np.array(a, dtype=dtype)
+ b = np.packbits(arr, axis=-1)
+ assert_equal(b.dtype, np.uint8)
+ assert_array_equal(b, np.array([[[160], [64]], [[192], [32]]]))
+
+ assert_raises(TypeError, np.packbits, np.array(a, dtype=float))
+
+
+def test_unpackbits():
+ # Copied from the docstring.
+ a = np.array([[2], [7], [23]], dtype=np.uint8)
+ b = np.unpackbits(a, axis=1)
+ assert_equal(b.dtype, np.uint8)
+ assert_array_equal(b, np.array([[0, 0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 0, 1, 1, 1]]))
diff --git a/numpy/lib/tests/test_polynomial.py b/numpy/lib/tests/test_polynomial.py
index 02faa0283..5c15941e6 100644
--- a/numpy/lib/tests/test_polynomial.py
+++ b/numpy/lib/tests/test_polynomial.py
@@ -153,6 +153,9 @@ class TestDocs(TestCase):
assert_(p2[3] == Decimal("1.333333333333333333333333333"))
assert_(p2[2] == Decimal('1.5'))
assert_(np.issubdtype(p2.coeffs.dtype, np.object_))
+ p = np.poly([Decimal(1), Decimal(2)])
+ assert_equal(np.poly([Decimal(1), Decimal(2)]),
+ [1, Decimal(-3), Decimal(2)])
def test_complex(self):
p = np.poly1d([3j, 2j, 1j])
@@ -173,5 +176,13 @@ class TestDocs(TestCase):
except ValueError:
pass
+ def test_poly_int_overflow(self):
+ """
+ Regression test for gh-5096.
+ """
+ v = np.arange(1, 21)
+ assert_almost_equal(np.poly(v), np.poly(np.diag(v)))
+
+
if __name__ == "__main__":
run_module_suite()
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py
index cd0973300..bc7e30ca4 100644
--- a/numpy/lib/tests/test_stride_tricks.py
+++ b/numpy/lib/tests/test_stride_tricks.py
@@ -3,7 +3,7 @@ from __future__ import division, absolute_import, print_function
import numpy as np
from numpy.testing import (
run_module_suite, assert_equal, assert_array_equal,
- assert_raises
+ assert_raises, assert_
)
from numpy.lib.stride_tricks import as_strided, broadcast_arrays
@@ -234,5 +234,49 @@ def test_as_strided():
assert_array_equal(a_view, expected)
+class VerySimpleSubClass(np.ndarray):
+ def __new__(cls, *args, **kwargs):
+ kwargs['subok'] = True
+ return np.array(*args, **kwargs).view(cls)
+
+
+class SimpleSubClass(VerySimpleSubClass):
+ def __new__(cls, *args, **kwargs):
+ kwargs['subok'] = True
+ self = np.array(*args, **kwargs).view(cls)
+ self.info = 'simple'
+ return self
+
+ def __array_finalize__(self, obj):
+ self.info = getattr(obj, 'info', '') + ' finalized'
+
+
+def test_subclasses():
+ # test that subclass is preserved only if subok=True
+ a = VerySimpleSubClass([1, 2, 3, 4])
+ assert_(type(a) is VerySimpleSubClass)
+ a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,))
+ assert_(type(a_view) is np.ndarray)
+ a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True)
+ assert_(type(a_view) is VerySimpleSubClass)
+ # test that if a subclass has __array_finalize__, it is used
+ a = SimpleSubClass([1, 2, 3, 4])
+ a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True)
+ assert_(type(a_view) is SimpleSubClass)
+ assert_(a_view.info == 'simple finalized')
+
+ # similar tests for broadcast_arrays
+ b = np.arange(len(a)).reshape(-1, 1)
+ a_view, b_view = broadcast_arrays(a, b)
+ assert_(type(a_view) is np.ndarray)
+ assert_(type(b_view) is np.ndarray)
+ assert_(a_view.shape == b_view.shape)
+ a_view, b_view = broadcast_arrays(a, b, subok=True)
+ assert_(type(a_view) is SimpleSubClass)
+ assert_(a_view.info == 'simple finalized')
+ assert_(type(b_view) is np.ndarray)
+ assert_(a_view.shape == b_view.shape)
+
+
if __name__ == "__main__":
run_module_suite()
diff --git a/numpy/lib/tests/test_twodim_base.py b/numpy/lib/tests/test_twodim_base.py
index e9dbef70f..739061a5d 100644
--- a/numpy/lib/tests/test_twodim_base.py
+++ b/numpy/lib/tests/test_twodim_base.py
@@ -311,6 +311,40 @@ def test_tril_triu_ndim3():
yield assert_equal, a_triu_observed.dtype, a.dtype
yield assert_equal, a_tril_observed.dtype, a.dtype
+def test_tril_triu_with_inf():
+ # Issue 4859
+ arr = np.array([[1, 1, np.inf],
+ [1, 1, 1],
+ [np.inf, 1, 1]])
+ out_tril = np.array([[1, 0, 0],
+ [1, 1, 0],
+ [np.inf, 1, 1]])
+ out_triu = out_tril.T
+ assert_array_equal(np.triu(arr), out_triu)
+ assert_array_equal(np.tril(arr), out_tril)
+
+
+def test_tril_triu_dtype():
+ # Issue 4916
+ # tril and triu should return the same dtype as input
+ for c in np.typecodes['All']:
+ if c == 'V':
+ continue
+ arr = np.zeros((3, 3), dtype=c)
+ assert_equal(np.triu(arr).dtype, arr.dtype)
+ assert_equal(np.tril(arr).dtype, arr.dtype)
+
+ # check special cases
+ arr = np.array([['2001-01-01T12:00', '2002-02-03T13:56'],
+ ['2004-01-01T12:00', '2003-01-03T13:45']],
+ dtype='datetime64')
+ assert_equal(np.triu(arr).dtype, arr.dtype)
+ assert_equal(np.tril(arr).dtype, arr.dtype)
+
+ arr = np.zeros((3,3), dtype='f4,f4')
+ assert_equal(np.triu(arr).dtype, arr.dtype)
+ assert_equal(np.tril(arr).dtype, arr.dtype)
+
def test_mask_indices():
# simple test without offset