summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/numeric.py8
-rw-r--r--numpy/core/tests/test_maskna.py9
2 files changed, 13 insertions, 4 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 7a879719a..e01f24f0d 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1893,7 +1893,7 @@ def _maketup(descr, val):
res = [_maketup(fields[name][0],val) for name in dt.names]
return tuple(res)
-def identity(n, dtype=None):
+def identity(n, dtype=None, maskna=False):
"""
Return the identity array.
@@ -1906,6 +1906,8 @@ def identity(n, dtype=None):
Number of rows (and columns) in `n` x `n` output.
dtype : data-type, optional
Data-type of the output. Defaults to ``float``.
+ maskna : bool, optional
+ If this is true, the returned array will have an NA mask.
Returns
-------
@@ -1921,8 +1923,8 @@ def identity(n, dtype=None):
[ 0., 0., 1.]])
"""
- a = zeros((n,n), dtype=dtype)
- a.flat[::n+1] = 1
+ a = zeros((n,n), dtype=dtype, maskna=maskna)
+ a.diagonal()[...] = 1
return a
def allclose(a, b, rtol=1.e-5, atol=1.e-8):
diff --git a/numpy/core/tests/test_maskna.py b/numpy/core/tests/test_maskna.py
index 77ab22b29..5c4cb5264 100644
--- a/numpy/core/tests/test_maskna.py
+++ b/numpy/core/tests/test_maskna.py
@@ -1338,12 +1338,14 @@ def test_array_maskna_linspace_logspace():
assert_(not a.flags.maskna)
assert_(b.flags.maskna)
-def test_array_maskna_eye():
+def test_array_maskna_eye_identity():
# np.eye
# By default there should be no NA mask
a = np.eye(3)
assert_(not a.flags.maskna)
+ a = np.identity(3)
+ assert_(not a.flags.maskna)
a = np.eye(3, maskna=True)
assert_(a.flags.maskna)
@@ -1355,5 +1357,10 @@ def test_array_maskna_eye():
assert_(a.flags.ownmaskna)
assert_equal(a, np.eye(3, k=2))
+ a = np.identity(3, maskna=True)
+ assert_(a.flags.maskna)
+ assert_(a.flags.ownmaskna)
+ assert_equal(a, np.identity(3))
+
if __name__ == "__main__":
run_module_suite()