diff options
| author | Pauli Virtanen <pav@iki.fi> | 2019-12-01 15:19:00 +0200 |
|---|---|---|
| committer | Pauli Virtanen <pav@iki.fi> | 2019-12-01 16:37:50 +0200 |
| commit | d57739d3152c366a43f0d17694e2ea8d5db142d7 (patch) | |
| tree | 90e92e5540ea49e1ff5e68577ef78dd4baa1a2bf | |
| parent | a0e6571670d1af81f277df36e0d4ce6191add043 (diff) | |
| download | numpy-d57739d3152c366a43f0d17694e2ea8d5db142d7.tar.gz | |
TST: linalg: add smoke test for 64-bit blas
| -rw-r--r-- | numpy/linalg/tests/test_linalg.py | 16 | ||||
| -rw-r--r-- | numpy/testing/_private/utils.py | 4 |
2 files changed, 18 insertions, 2 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 173e81e9c..e1590f1e7 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -20,8 +20,9 @@ from numpy.linalg.linalg import _multi_dot_matrix_chain_order from numpy.testing import ( assert_, assert_equal, assert_raises, assert_array_equal, assert_almost_equal, assert_allclose, suppress_warnings, - assert_raises_regex, + assert_raises_regex, HAS_LAPACK64, ) +from numpy.testing._private.utils import requires_memory def consistent_subclass(out, in_): @@ -2002,3 +2003,16 @@ def test_unsupported_commontype(): arr = np.array([[1, -2], [2, 5]], dtype='float16') with assert_raises_regex(TypeError, "unsupported in linalg"): linalg.cholesky(arr) + + +@pytest.mark.slow +@pytest.mark.xfail(not HAS_LAPACK64, run=False, + reason="Numpy not compiled with 64-bit BLAS/LAPACK") +@requires_memory(16e9) +def test_blas64_dot(): + n = 2**32 + a = np.zeros([1, n], dtype=np.float32) + b = np.ones([1, 1], dtype=np.float32) + a[0,-1] = 1 + c = np.dot(b, a) + assert_equal(c[0,-1], 1) diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 409ed142f..4642cc0f8 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -21,6 +21,7 @@ import pprint from numpy.core import( intp, float32, empty, arange, array_repr, ndarray, isnat, array) +import numpy.__config__ if sys.version_info[0] >= 3: from io import StringIO @@ -39,7 +40,7 @@ __all__ = [ 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY', 'HAS_REFCOUNT', 'suppress_warnings', 'assert_array_compare', '_assert_valid_refcount', '_gen_alignment_data', 'assert_no_gc_cycles', - 'break_cycles', + 'break_cycles', 'HAS_LAPACK64' ] @@ -53,6 +54,7 @@ verbose = 0 IS_PYPY = platform.python_implementation() == 'PyPy' HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None +HAS_LAPACK64 = hasattr(numpy.__config__, 'lapack64__opt_info') def import_nose(): |
