diff options
-rw-r--r-- | numpy/core/numeric.py | 8 | ||||
-rw-r--r-- | numpy/core/tests/test_maskna.py | 9 |
2 files changed, 13 insertions, 4 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 7a879719a..e01f24f0d 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1893,7 +1893,7 @@ def _maketup(descr, val): res = [_maketup(fields[name][0],val) for name in dt.names] return tuple(res) -def identity(n, dtype=None): +def identity(n, dtype=None, maskna=False): """ Return the identity array. @@ -1906,6 +1906,8 @@ def identity(n, dtype=None): Number of rows (and columns) in `n` x `n` output. dtype : data-type, optional Data-type of the output. Defaults to ``float``. + maskna : bool, optional + If this is true, the returned array will have an NA mask. Returns ------- @@ -1921,8 +1923,8 @@ def identity(n, dtype=None): [ 0., 0., 1.]]) """ - a = zeros((n,n), dtype=dtype) - a.flat[::n+1] = 1 + a = zeros((n,n), dtype=dtype, maskna=maskna) + a.diagonal()[...] = 1 return a def allclose(a, b, rtol=1.e-5, atol=1.e-8): diff --git a/numpy/core/tests/test_maskna.py b/numpy/core/tests/test_maskna.py index 77ab22b29..5c4cb5264 100644 --- a/numpy/core/tests/test_maskna.py +++ b/numpy/core/tests/test_maskna.py @@ -1338,12 +1338,14 @@ def test_array_maskna_linspace_logspace(): assert_(not a.flags.maskna) assert_(b.flags.maskna) -def test_array_maskna_eye(): +def test_array_maskna_eye_identity(): # np.eye # By default there should be no NA mask a = np.eye(3) assert_(not a.flags.maskna) + a = np.identity(3) + assert_(not a.flags.maskna) a = np.eye(3, maskna=True) assert_(a.flags.maskna) @@ -1355,5 +1357,10 @@ def test_array_maskna_eye(): assert_(a.flags.ownmaskna) assert_equal(a, np.eye(3, k=2)) + a = np.identity(3, maskna=True) + assert_(a.flags.maskna) + assert_(a.flags.ownmaskna) + assert_equal(a, np.identity(3)) + if __name__ == "__main__": run_module_suite() |