diff options
-rw-r--r-- | numpy/core/fromnumeric.py | 2 | ||||
-rw-r--r-- | numpy/core/src/multiarray/common.c | 14 | ||||
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 40 | ||||
-rw-r--r-- | numpy/core/src/multiarray/lowlevel_strided_loops.c.src | 11 | ||||
-rw-r--r-- | numpy/core/tests/test_datetime.py | 2 | ||||
-rw-r--r-- | numpy/core/tests/test_indexing.py | 22 | ||||
-rw-r--r-- | numpy/lib/tests/test_twodim_base.py | 44 | ||||
-rw-r--r-- | numpy/lib/twodim_base.py | 62 | ||||
-rw-r--r-- | numpy/testing/tests/test_utils.py | 68 | ||||
-rw-r--r-- | numpy/testing/utils.py | 28 |
10 files changed, 232 insertions, 61 deletions
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py index 728c95294..3de81305d 100644 --- a/numpy/core/fromnumeric.py +++ b/numpy/core/fromnumeric.py @@ -584,7 +584,7 @@ def partition(a, kth, axis=-1, kind='introselect', order=None): The various selection algorithms are characterized by their average speed, worst case performance, work space size, and whether they are stable. A stable sort keeps items with the same key in the same relative order. The - three available algorithms have the following properties: + available algorithms have the following properties: ================= ======= ============= ============ ======= kind speed worst case work space stable diff --git a/numpy/core/src/multiarray/common.c b/numpy/core/src/multiarray/common.c index 0e8a21394..1729d50e2 100644 --- a/numpy/core/src/multiarray/common.c +++ b/numpy/core/src/multiarray/common.c @@ -650,14 +650,12 @@ _IsAligned(PyArrayObject *ap) { unsigned int i; npy_uintp aligned; - const unsigned int alignment = PyArray_DESCR(ap)->alignment; - - /* The special casing for STRING and VOID types was removed - * in accordance with http://projects.scipy.org/numpy/ticket/1227 - * It used to be that IsAligned always returned True for these - * types, which is indeed the case when they are created using - * PyArray_DescrConverter(), but not necessarily when using - * PyArray_DescrAlignConverter(). */ + npy_uintp alignment = PyArray_DESCR(ap)->alignment; + + /* alignment 1 types should have a efficient alignment for copy loops */ + if (PyArray_ISFLEXIBLE(ap) || PyArray_ISSTRING(ap)) { + alignment = 16; + } if (alignment == 1) { return 1; diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index e4126109e..d6e0980c6 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -2235,41 +2235,51 @@ PyArray_Compress(PyArrayObject *self, PyObject *condition, int axis, } /* - * count number of nonzero bytes in 16 byte block + * count number of nonzero bytes in 48 byte block * w must be aligned to 8 bytes * * even though it uses 64 bit types its faster than the bytewise sum on 32 bit * but a 32 bit type version would make it even faster on these platforms */ -static NPY_INLINE int -count_nonzero_bytes_128(const npy_uint64 * w) +static NPY_INLINE npy_intp +count_nonzero_bytes_384(const npy_uint64 * w) { const npy_uint64 w1 = w[0]; const npy_uint64 w2 = w[1]; + const npy_uint64 w3 = w[2]; + const npy_uint64 w4 = w[3]; + const npy_uint64 w5 = w[4]; + const npy_uint64 w6 = w[5]; + npy_intp r; + + /* + * last part of sideways add popcount, first three bisections can be + * skipped as we are dealing with bytes. + * multiplication equivalent to (x + (x>>8) + (x>>16) + (x>>24)) & 0xFF + * multiplication overflow well defined for unsigned types. + * w1 + w2 guaranteed to not overflow as we only have 0 and 1 data. + */ + r = ((w1 + w2 + w3 + w4 + w5 + w6) * 0x0101010101010101ULL) >> 56ULL; /* * bytes not exclusively 0 or 1, sum them individually. * should only happen if one does weird stuff with views or external * buffers. + * Doing this after the optimistic computation allows saving registers and + * better pipelining */ - if (NPY_UNLIKELY(((w1 | w2) & 0xFEFEFEFEFEFEFEFEULL) != 0)) { + if (NPY_UNLIKELY( + ((w1 | w2 | w3 | w4 | w5 | w6) & 0xFEFEFEFEFEFEFEFEULL) != 0)) { /* reload from pointer to avoid a unnecessary stack spill with gcc */ const char * c = (const char *)w; npy_uintp i, count = 0; - for (i = 0; i < 16; i++) { + for (i = 0; i < 48; i++) { count += (c[i] != 0); } return count; } - /* - * last part of sideways add popcount, first three bisections can be - * skipped as we are dealing with bytes. - * multiplication equivalent to (x + (x>>8) + (x>>16) + (x>>24)) & 0xFF - * multiplication overflow well defined for unsigned types. - * w1 + w2 guaranteed to not overflow as we only have 0 and 1 data. - */ - return ((w1 + w2) * 0x0101010101010101ULL) >> 56ULL; + return r; } /* @@ -2311,9 +2321,9 @@ count_boolean_trues(int ndim, char *data, npy_intp *ashape, npy_intp *astrides) const char *e = data + shape[0]; if (NPY_CPU_HAVE_UNALIGNED_ACCESS || npy_is_aligned(d, sizeof(npy_uint64))) { - npy_uintp stride = 2 * sizeof(npy_uint64); + npy_uintp stride = 6 * sizeof(npy_uint64); for (; d < e - (shape[0] % stride); d += stride) { - count += count_nonzero_bytes_128((const npy_uint64 *)d); + count += count_nonzero_bytes_384((const npy_uint64 *)d); } } for (; d < e; ++d) { diff --git a/numpy/core/src/multiarray/lowlevel_strided_loops.c.src b/numpy/core/src/multiarray/lowlevel_strided_loops.c.src index e3d0c4b88..02920014b 100644 --- a/numpy/core/src/multiarray/lowlevel_strided_loops.c.src +++ b/numpy/core/src/multiarray/lowlevel_strided_loops.c.src @@ -1429,6 +1429,7 @@ mapiter_trivial_@name@(PyArrayObject *self, PyArrayObject *ind, default: #endif while (itersize--) { + assert(npy_is_aligned(ind_ptr, _ALIGN(npy_intp))); indval = *((npy_intp*)ind_ptr); #if @isget@ if (check_and_adjust_index(&indval, fancy_dim, 1, _save) < 0 ) { @@ -1443,6 +1444,8 @@ mapiter_trivial_@name@(PyArrayObject *self, PyArrayObject *ind, #if @isget@ #if @elsize@ + assert(npy_is_aligned(result_ptr, _ALIGN(@copytype@))); + assert(npy_is_aligned(self_ptr, _ALIGN(@copytype@))); *(@copytype@ *)result_ptr = *(@copytype@ *)self_ptr; #else copyswap(result_ptr, self_ptr, 0, self); @@ -1450,6 +1453,8 @@ mapiter_trivial_@name@(PyArrayObject *self, PyArrayObject *ind, #else /* !@isget@ */ #if @elsize@ + assert(npy_is_aligned(result_ptr, _ALIGN(@copytype@))); + assert(npy_is_aligned(self_ptr, _ALIGN(@copytype@))); *(@copytype@ *)self_ptr = *(@copytype@ *)result_ptr; #else copyswap(self_ptr, result_ptr, 0, self); @@ -1567,6 +1572,8 @@ mapiter_@name@(PyArrayMapIterObject *mit) while (count--) { self_ptr = baseoffset; for (i=0; i < @numiter@; i++) { + assert(npy_is_aligned(outer_ptrs[i], + _ALIGN(npy_intp))); indval = *((npy_intp*)outer_ptrs[i]); #if @isget@ && @one_iter@ @@ -1587,12 +1594,16 @@ mapiter_@name@(PyArrayMapIterObject *mit) #if @isget@ #if @elsize@ + assert(npy_is_aligned(outer_ptrs[i], _ALIGN(@copytype@))); + assert(npy_is_aligned(self_ptr, _ALIGN(@copytype@))); *(@copytype@ *)(outer_ptrs[i]) = *(@copytype@ *)self_ptr; #else copyswap(outer_ptrs[i], self_ptr, 0, array); #endif #else /* !@isget@ */ #if @elsize@ + assert(npy_is_aligned(outer_ptrs[i], _ALIGN(@copytype@))); + assert(npy_is_aligned(self_ptr, _ALIGN(@copytype@))); *(@copytype@ *)self_ptr = *(@copytype@ *)(outer_ptrs[i]); #else copyswap(self_ptr, outer_ptrs[i], 0, array); diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py index bfc7237a4..bf0ba6807 100644 --- a/numpy/core/tests/test_datetime.py +++ b/numpy/core/tests/test_datetime.py @@ -595,7 +595,7 @@ class TestDateTime(TestCase): def test_cast_overflow(self): # gh-4486 def cast(): - numpy.datetime64("1970-01-01 00:00:00.000000000000000").astype("<M8[D]") + numpy.datetime64("1971-01-01 00:00:00.000000000000000").astype("<M8[D]") assert_raises(OverflowError, cast) def cast2(): numpy.datetime64("2014").astype("<M8[fs]") diff --git a/numpy/core/tests/test_indexing.py b/numpy/core/tests/test_indexing.py index 736210722..f09940af7 100644 --- a/numpy/core/tests/test_indexing.py +++ b/numpy/core/tests/test_indexing.py @@ -336,6 +336,28 @@ class TestIndexing(TestCase): assert_equal(sys.getrefcount(np.dtype(np.intp)), refcount) + def test_unaligned(self): + v = (np.zeros(64, dtype=np.int8) + ord('a'))[1:-7] + d = v.view(np.dtype("S8")) + # unaligned source + x = (np.zeros(16, dtype=np.int8) + ord('a'))[1:-7] + x = x.view(np.dtype("S8")) + x[...] = np.array("b" * 8, dtype="S") + b = np.arange(d.size) + #trivial + assert_equal(d[b], d) + d[b] = x + # nontrivial + # unaligned index array + b = np.zeros(d.size + 1).view(np.int8)[1:-(np.intp(0).itemsize - 1)] + b = b.view(np.intp)[:d.size] + b[...] = np.arange(d.size) + assert_equal(d[b.astype(np.int16)], d) + d[b.astype(np.int16)] = x + # boolean + d[b % 2 == 0] + d[b % 2 == 0] = x[::2] + class TestFieldIndexing(TestCase): def test_scalar_return_type(self): diff --git a/numpy/lib/tests/test_twodim_base.py b/numpy/lib/tests/test_twodim_base.py index 022c45bd0..9e81cfe4b 100644 --- a/numpy/lib/tests/test_twodim_base.py +++ b/numpy/lib/tests/test_twodim_base.py @@ -286,6 +286,7 @@ def test_tril_triu_ndim2(): yield assert_equal, b.dtype, a.dtype yield assert_equal, c.dtype, a.dtype + def test_tril_triu_ndim3(): for dtype in np.typecodes['AllFloat'] + np.typecodes['AllInteger']: a = np.array([ @@ -324,16 +325,21 @@ def test_mask_indices(): def test_tril_indices(): # indices without and with offset il1 = tril_indices(4) - il2 = tril_indices(4, 2) + il2 = tril_indices(4, k=2) + il3 = tril_indices(4, m=5) + il4 = tril_indices(4, k=2, m=5) a = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) + b = np.arange(1, 21).reshape(4, 5) # indexing: yield (assert_array_equal, a[il1], array([1, 5, 6, 9, 10, 11, 13, 14, 15, 16])) + yield (assert_array_equal, b[il3], + array([1, 6, 7, 11, 12, 13, 16, 17, 18, 19])) # And for assigning values: a[il1] = -1 @@ -342,7 +348,12 @@ def test_tril_indices(): [-1, -1, 7, 8], [-1, -1, -1, 12], [-1, -1, -1, -1]])) - + b[il3] = -1 + yield (assert_array_equal, b, + array([[-1, 2, 3, 4, 5], + [-1, -1, 8, 9, 10], + [-1, -1, -1, 14, 15], + [-1, -1, -1, -1, 20]])) # These cover almost the whole array (two diagonals right of the main one): a[il2] = -10 yield (assert_array_equal, a, @@ -350,21 +361,32 @@ def test_tril_indices(): [-10, -10, -10, -10], [-10, -10, -10, -10], [-10, -10, -10, -10]])) + b[il4] = -10 + yield (assert_array_equal, b, + array([[-10, -10, -10, 4, 5], + [-10, -10, -10, -10, 10], + [-10, -10, -10, -10, -10], + [-10, -10, -10, -10, -10]])) class TestTriuIndices(object): def test_triu_indices(self): iu1 = triu_indices(4) - iu2 = triu_indices(4, 2) + iu2 = triu_indices(4, k=2) + iu3 = triu_indices(4, m=5) + iu4 = triu_indices(4, k=2, m=5) a = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) + b = np.arange(1, 21).reshape(4, 5) # Both for indexing: yield (assert_array_equal, a[iu1], array([1, 2, 3, 4, 6, 7, 8, 11, 12, 16])) + yield (assert_array_equal, b[iu3], + array([1, 2, 3, 4, 5, 7, 8, 9, 10, 13, 14, 15, 19, 20])) # And for assigning values: a[iu1] = -1 @@ -373,6 +395,12 @@ class TestTriuIndices(object): [5, -1, -1, -1], [9, 10, -1, -1], [13, 14, 15, -1]])) + b[iu3] = -1 + yield (assert_array_equal, b, + array([[-1, -1, -1, -1, -1], + [ 6, -1, -1, -1, -1], + [11, 12, -1, -1, -1], + [16, 17, 18, -1, -1]])) # These cover almost the whole array (two diagonals right of the # main one): @@ -382,20 +410,26 @@ class TestTriuIndices(object): [5, -1, -1, -10], [9, 10, -1, -1], [13, 14, 15, -1]])) + b[iu4] = -10 + yield (assert_array_equal, b, + array([[-1, -1, -10, -10, -10], + [6, -1, -1, -10, -10], + [11, 12, -1, -1, -10], + [16, 17, 18, -1, -1]])) class TestTrilIndicesFrom(object): def test_exceptions(self): assert_raises(ValueError, tril_indices_from, np.ones((2,))) assert_raises(ValueError, tril_indices_from, np.ones((2, 2, 2))) - assert_raises(ValueError, tril_indices_from, np.ones((2, 3))) + # assert_raises(ValueError, tril_indices_from, np.ones((2, 3))) class TestTriuIndicesFrom(object): def test_exceptions(self): assert_raises(ValueError, triu_indices_from, np.ones((2,))) assert_raises(ValueError, triu_indices_from, np.ones((2, 2, 2))) - assert_raises(ValueError, triu_indices_from, np.ones((2, 3))) + # assert_raises(ValueError, triu_indices_from, np.ones((2, 3))) class TestVander(object): diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py index d168e0fca..5a0c0e7ee 100644 --- a/numpy/lib/twodim_base.py +++ b/numpy/lib/twodim_base.py @@ -11,10 +11,11 @@ __all__ = ['diag', 'diagflat', 'eye', 'fliplr', 'flipud', 'rot90', 'tri', from numpy.core.numeric import ( asanyarray, subtract, arange, zeros, greater_equal, multiply, ones, - asarray, where, + asarray, where, dtype as np_dtype, less ) + def fliplr(m): """ Flip array in the left/right direction. @@ -372,6 +373,7 @@ def tri(N, M=None, k=0, dtype=float): dtype : dtype, optional Data type of the returned array. The default is float. + Returns ------- tri : ndarray of shape (N, M) @@ -393,8 +395,14 @@ def tri(N, M=None, k=0, dtype=float): """ if M is None: M = N - m = greater_equal(subtract.outer(arange(N), arange(M)), -k) - return m.astype(dtype) + + m = greater_equal.outer(arange(N), arange(-k, M-k)) + + # Avoid making a copy if the requested type is already bool + if np_dtype(dtype) != np_dtype(bool): + m = m.astype(dtype) + + return m def tril(m, k=0): @@ -430,8 +438,7 @@ def tril(m, k=0): """ m = asanyarray(m) - out = multiply(tri(m.shape[-2], m.shape[-1], k=k, dtype=m.dtype), m) - return out + return multiply(tri(*m.shape[-2:], k=k, dtype=bool), m, dtype=m.dtype) def triu(m, k=0): @@ -457,8 +464,7 @@ def triu(m, k=0): """ m = asanyarray(m) - out = multiply((1 - tri(m.shape[-2], m.shape[-1], k - 1, dtype=m.dtype)), m) - return out + return multiply(~tri(*m.shape[-2:], k=k-1, dtype=bool), m, dtype=m.dtype) # Originally borrowed from John Hunter and matplotlib @@ -757,17 +763,24 @@ def mask_indices(n, mask_func, k=0): return where(a != 0) -def tril_indices(n, k=0): +def tril_indices(n, k=0, m=None): """ - Return the indices for the lower-triangle of an (n, n) array. + Return the indices for the lower-triangle of an (n, m) array. Parameters ---------- n : int - The row dimension of the square arrays for which the returned + The row dimension of the arrays for which the returned indices will be valid. k : int, optional Diagonal offset (see `tril` for details). + m : int, optional + .. versionadded:: 1.9.0 + + The column dimension of the arrays for which the returned + arrays will be valid. + By default `m` is taken equal to `n`. + Returns ------- @@ -827,7 +840,7 @@ def tril_indices(n, k=0): [-10, -10, -10, -10]]) """ - return mask_indices(n, tril, k) + return where(tri(n, m, k=k, dtype=bool)) def tril_indices_from(arr, k=0): @@ -853,14 +866,14 @@ def tril_indices_from(arr, k=0): .. versionadded:: 1.4.0 """ - if not (arr.ndim == 2 and arr.shape[0] == arr.shape[1]): - raise ValueError("input array must be 2-d and square") - return tril_indices(arr.shape[0], k) + if arr.ndim != 2: + raise ValueError("input array must be 2-d") + return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1]) -def triu_indices(n, k=0): +def triu_indices(n, k=0, m=None): """ - Return the indices for the upper-triangle of an (n, n) array. + Return the indices for the upper-triangle of an (n, m) array. Parameters ---------- @@ -869,6 +882,13 @@ def triu_indices(n, k=0): be valid. k : int, optional Diagonal offset (see `triu` for details). + m : int, optional + .. versionadded:: 1.9.0 + + The column dimension of the arrays for which the returned + arrays will be valid. + By default `m` is taken equal to `n`. + Returns ------- @@ -930,12 +950,12 @@ def triu_indices(n, k=0): [ 12, 13, 14, -1]]) """ - return mask_indices(n, triu, k) + return where(~tri(n, m, k=k-1, dtype=bool)) def triu_indices_from(arr, k=0): """ - Return the indices for the upper-triangle of a (N, N) array. + Return the indices for the upper-triangle of arr. See `triu_indices` for full details. @@ -960,6 +980,6 @@ def triu_indices_from(arr, k=0): .. versionadded:: 1.4.0 """ - if not (arr.ndim == 2 and arr.shape[0] == arr.shape[1]): - raise ValueError("input array must be 2-d and square") - return triu_indices(arr.shape[0], k) + if arr.ndim != 2: + raise ValueError("input array must be 2-d") + return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1]) diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 5956a4294..aa0a2669f 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -134,6 +134,49 @@ class TestArrayEqual(_GenericTest, unittest.TestCase): self._test_not_equal(c, b) +class TestBuildErrorMessage(unittest.TestCase): + def test_build_err_msg_defaults(self): + x = np.array([1.00001, 2.00002, 3.00003]) + y = np.array([1.00002, 2.00003, 3.00004]) + err_msg = 'There is a mismatch' + + a = build_err_msg([x, y], err_msg) + b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ ' + '1.00001, 2.00002, 3.00003])\n DESIRED: array([ 1.00002, ' + '2.00003, 3.00004])') + self.assertEqual(a, b) + + def test_build_err_msg_no_verbose(self): + x = np.array([1.00001, 2.00002, 3.00003]) + y = np.array([1.00002, 2.00003, 3.00004]) + err_msg = 'There is a mismatch' + + a = build_err_msg([x, y], err_msg, verbose=False) + b = '\nItems are not equal: There is a mismatch' + self.assertEqual(a, b) + + def test_build_err_msg_custom_names(self): + x = np.array([1.00001, 2.00002, 3.00003]) + y = np.array([1.00002, 2.00003, 3.00004]) + err_msg = 'There is a mismatch' + + a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR')) + b = ('\nItems are not equal: There is a mismatch\n FOO: array([ ' + '1.00001, 2.00002, 3.00003])\n BAR: array([ 1.00002, 2.00003, ' + '3.00004])') + self.assertEqual(a, b) + + def test_build_err_msg_custom_precision(self): + x = np.array([1.000000001, 2.00002, 3.00003]) + y = np.array([1.000000002, 2.00003, 3.00004]) + err_msg = 'There is a mismatch' + + a = build_err_msg([x, y], err_msg, precision=10) + b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ ' + '1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array([ ' + '1.000000002, 2.00003 , 3.00004 ])') + self.assertEqual(a, b) + class TestEqual(TestArrayEqual): def setUp(self): self._assert_func = assert_equal @@ -239,6 +282,31 @@ class TestAlmostEqual(_GenericTest, unittest.TestCase): self._test_not_equal(x, y) self._test_not_equal(x, z) + def test_error_message(self): + """Check the message is formatted correctly for the decimal value""" + x = np.array([1.00000000001, 2.00000000002, 3.00003]) + y = np.array([1.00000000002, 2.00000000003, 3.00004]) + + # test with a different amount of decimal digits + # note that we only check for the formatting of the arrays themselves + b = ('x: array([ 1.00000000001, 2.00000000002, 3.00003 ' + ' ])\n y: array([ 1.00000000002, 2.00000000003, 3.00004 ])') + try: + self._assert_func(x, y, decimal=12) + except AssertionError as e: + # remove anything that's not the array string + self.assertEqual(str(e).split('%)\n ')[1], b) + + # with the default value of decimal digits, only the 3rd element differs + # note that we only check for the formatting of the arrays themselves + b = ('x: array([ 1. , 2. , 3.00003])\n y: array([ 1. , ' + '2. , 3.00004])') + try: + self._assert_func(x, y) + except AssertionError as e: + # remove anything that's not the array string + self.assertEqual(str(e).split('%)\n ')[1], b) + class TestApproxEqual(unittest.TestCase): def setUp(self): self._assert_func = assert_approx_equal diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 6dee4917b..70357d835 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -9,8 +9,9 @@ import sys import re import operator import warnings +from functools import partial from .nosetester import import_nose -from numpy.core import float32, empty, arange +from numpy.core import float32, empty, arange, array_repr, ndarray if sys.version_info[0] >= 3: from io import StringIO @@ -190,8 +191,7 @@ if os.name=='nt' and sys.version[:3] > '2.3': win32pdh.PDH_FMT_LONG, None) def build_err_msg(arrays, err_msg, header='Items are not equal:', - verbose=True, - names=('ACTUAL', 'DESIRED')): + verbose=True, names=('ACTUAL', 'DESIRED'), precision=8): msg = ['\n' + header] if err_msg: if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header): @@ -200,8 +200,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:', msg.append(err_msg) if verbose: for i, a in enumerate(arrays): + + if isinstance(a, ndarray): + # precision argument is only needed if the objects are ndarrays + r_func = partial(array_repr, precision=precision) + else: + r_func = repr + try: - r = repr(a) + r = r_func(a) except: r = '[repr failed]' if r.count('\n') > 3: @@ -575,7 +582,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): raise AssertionError(msg) def assert_array_compare(comparison, x, y, err_msg='', verbose=True, - header=''): + header='', precision=6): from numpy.core import array, isnan, isinf, any, all, inf x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) @@ -592,7 +599,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, msg = build_err_msg([x, y], err_msg + '\nx and y %s location mismatch:' \ % (hasval), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) raise AssertionError(msg) try: @@ -603,7 +610,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, + '\n(shapes %s, %s mismatch)' % (x.shape, y.shape), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) if not cond : raise AssertionError(msg) @@ -648,7 +655,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, err_msg + '\n(mismatch %s%%)' % (match,), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) if not cond : raise AssertionError(msg) except ValueError as e: @@ -657,7 +664,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header) msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) raise ValueError(msg) def assert_array_equal(x, y, err_msg='', verbose=True): @@ -825,7 +832,8 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): return around(z, decimal) <= 10.0**(-decimal) assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, - header=('Arrays are not almost equal to %d decimals' % decimal)) + header=('Arrays are not almost equal to %d decimals' % decimal), + precision=decimal) def assert_array_less(x, y, err_msg='', verbose=True): |