diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/tests/test_twodim_base.py | 8 | ||||
-rw-r--r-- | numpy/lib/twodim_base.py | 12 |
2 files changed, 16 insertions, 4 deletions
diff --git a/numpy/lib/tests/test_twodim_base.py b/numpy/lib/tests/test_twodim_base.py index c1c5a1615..141f508fd 100644 --- a/numpy/lib/tests/test_twodim_base.py +++ b/numpy/lib/tests/test_twodim_base.py @@ -44,6 +44,12 @@ class TestEye: assert_equal(eye(3) == 1, eye(3, dtype=bool)) + def test_uint64(self): + # Regression test for gh-9982 + assert_equal(eye(np.uint64(2), dtype=int), array([[1, 0], [0, 1]])) + assert_equal(eye(np.uint64(2), M=np.uint64(4), k=np.uint64(1)), + array([[0, 1, 0, 0], [0, 0, 1, 0]])) + def test_diag(self): assert_equal(eye(4, k=1), array([[0, 1, 0, 0], @@ -382,7 +388,7 @@ def test_tril_triu_dtype(): assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) - arr = np.zeros((3,3), dtype='f4,f4') + arr = np.zeros((3, 3), dtype='f4,f4') assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py index 3e5ad31ff..3d47abbfb 100644 --- a/numpy/lib/twodim_base.py +++ b/numpy/lib/twodim_base.py @@ -2,6 +2,7 @@ """ import functools +import operator from numpy.core.numeric import ( asanyarray, arange, zeros, greater_equal, multiply, ones, @@ -214,6 +215,11 @@ def eye(N, M=None, k=0, dtype=float, order='C', *, like=None): m = zeros((N, M), dtype=dtype, order=order) if k >= M: return m + # Ensure M and k are integers, so we don't get any surprise casting + # results in the expressions `M-k` and `M+1` used below. This avoids + # a problem with inputs with type (for example) np.uint64. + M = operator.index(M) + k = operator.index(k) if k >= 0: i = k else: @@ -494,8 +500,8 @@ def triu(m, k=0): Upper triangle of an array. Return a copy of an array with the elements below the `k`-th diagonal - zeroed. For arrays with ``ndim`` exceeding 2, `triu` will apply to the final - two axes. + zeroed. For arrays with ``ndim`` exceeding 2, `triu` will apply to the + final two axes. Please refer to the documentation for `tril` for further details. @@ -804,7 +810,7 @@ def histogram2d(x, y, bins=10, range=None, normed=None, weights=None, >>> plt.show() """ from numpy import histogramdd - + if len(x) != len(y): raise ValueError('x and y must have the same length.') |