summaryrefslogtreecommitdiff
path: root/numpy/lib/twodim_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/twodim_base.py')
-rw-r--r--numpy/lib/twodim_base.py62
1 files changed, 41 insertions, 21 deletions
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py
index d168e0fca..5a0c0e7ee 100644
--- a/numpy/lib/twodim_base.py
+++ b/numpy/lib/twodim_base.py
@@ -11,10 +11,11 @@ __all__ = ['diag', 'diagflat', 'eye', 'fliplr', 'flipud', 'rot90', 'tri',
from numpy.core.numeric import (
asanyarray, subtract, arange, zeros, greater_equal, multiply, ones,
- asarray, where,
+ asarray, where, dtype as np_dtype, less
)
+
def fliplr(m):
"""
Flip array in the left/right direction.
@@ -372,6 +373,7 @@ def tri(N, M=None, k=0, dtype=float):
dtype : dtype, optional
Data type of the returned array. The default is float.
+
Returns
-------
tri : ndarray of shape (N, M)
@@ -393,8 +395,14 @@ def tri(N, M=None, k=0, dtype=float):
"""
if M is None:
M = N
- m = greater_equal(subtract.outer(arange(N), arange(M)), -k)
- return m.astype(dtype)
+
+ m = greater_equal.outer(arange(N), arange(-k, M-k))
+
+ # Avoid making a copy if the requested type is already bool
+ if np_dtype(dtype) != np_dtype(bool):
+ m = m.astype(dtype)
+
+ return m
def tril(m, k=0):
@@ -430,8 +438,7 @@ def tril(m, k=0):
"""
m = asanyarray(m)
- out = multiply(tri(m.shape[-2], m.shape[-1], k=k, dtype=m.dtype), m)
- return out
+ return multiply(tri(*m.shape[-2:], k=k, dtype=bool), m, dtype=m.dtype)
def triu(m, k=0):
@@ -457,8 +464,7 @@ def triu(m, k=0):
"""
m = asanyarray(m)
- out = multiply((1 - tri(m.shape[-2], m.shape[-1], k - 1, dtype=m.dtype)), m)
- return out
+ return multiply(~tri(*m.shape[-2:], k=k-1, dtype=bool), m, dtype=m.dtype)
# Originally borrowed from John Hunter and matplotlib
@@ -757,17 +763,24 @@ def mask_indices(n, mask_func, k=0):
return where(a != 0)
-def tril_indices(n, k=0):
+def tril_indices(n, k=0, m=None):
"""
- Return the indices for the lower-triangle of an (n, n) array.
+ Return the indices for the lower-triangle of an (n, m) array.
Parameters
----------
n : int
- The row dimension of the square arrays for which the returned
+ The row dimension of the arrays for which the returned
indices will be valid.
k : int, optional
Diagonal offset (see `tril` for details).
+ m : int, optional
+ .. versionadded:: 1.9.0
+
+ The column dimension of the arrays for which the returned
+ arrays will be valid.
+ By default `m` is taken equal to `n`.
+
Returns
-------
@@ -827,7 +840,7 @@ def tril_indices(n, k=0):
[-10, -10, -10, -10]])
"""
- return mask_indices(n, tril, k)
+ return where(tri(n, m, k=k, dtype=bool))
def tril_indices_from(arr, k=0):
@@ -853,14 +866,14 @@ def tril_indices_from(arr, k=0):
.. versionadded:: 1.4.0
"""
- if not (arr.ndim == 2 and arr.shape[0] == arr.shape[1]):
- raise ValueError("input array must be 2-d and square")
- return tril_indices(arr.shape[0], k)
+ if arr.ndim != 2:
+ raise ValueError("input array must be 2-d")
+ return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1])
-def triu_indices(n, k=0):
+def triu_indices(n, k=0, m=None):
"""
- Return the indices for the upper-triangle of an (n, n) array.
+ Return the indices for the upper-triangle of an (n, m) array.
Parameters
----------
@@ -869,6 +882,13 @@ def triu_indices(n, k=0):
be valid.
k : int, optional
Diagonal offset (see `triu` for details).
+ m : int, optional
+ .. versionadded:: 1.9.0
+
+ The column dimension of the arrays for which the returned
+ arrays will be valid.
+ By default `m` is taken equal to `n`.
+
Returns
-------
@@ -930,12 +950,12 @@ def triu_indices(n, k=0):
[ 12, 13, 14, -1]])
"""
- return mask_indices(n, triu, k)
+ return where(~tri(n, m, k=k-1, dtype=bool))
def triu_indices_from(arr, k=0):
"""
- Return the indices for the upper-triangle of a (N, N) array.
+ Return the indices for the upper-triangle of arr.
See `triu_indices` for full details.
@@ -960,6 +980,6 @@ def triu_indices_from(arr, k=0):
.. versionadded:: 1.4.0
"""
- if not (arr.ndim == 2 and arr.shape[0] == arr.shape[1]):
- raise ValueError("input array must be 2-d and square")
- return triu_indices(arr.shape[0], k)
+ if arr.ndim != 2:
+ raise ValueError("input array must be 2-d")
+ return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1])