summaryrefslogtreecommitdiff
path: root/numpy/lib/twodim_base.py
diff options
context:
space:
mode:
authorIllviljan <14371165+Illviljan@users.noreply.github.com>2021-01-11 22:31:20 +0100
committerGitHub <noreply@github.com>2021-01-11 22:31:20 +0100
commitb038c35ff03d70780711be1116db53b1c72b224c (patch)
tree4d45f6fec2f97492806177b42e41633967f54de6 /numpy/lib/twodim_base.py
parenta190258d4e90f2a17a9469e5dd9fb5f4b045aa90 (diff)
downloadnumpy-b038c35ff03d70780711be1116db53b1c72b224c.tar.gz
faster tril_indices
nonzero is slow. Use np.indices and np.broadcast_to to speed it up.
Diffstat (limited to 'numpy/lib/twodim_base.py')
-rw-r--r--numpy/lib/twodim_base.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py
index 2b4cbdfbb..8e008ba71 100644
--- a/numpy/lib/twodim_base.py
+++ b/numpy/lib/twodim_base.py
@@ -894,7 +894,10 @@ def tril_indices(n, k=0, m=None):
[-10, -10, -10, -10]])
"""
- return nonzero(tri(n, m, k=k, dtype=bool))
+ tri = np.tri(n, m=m, k=k, dtype=bool)
+
+ return tuple(np.broadcast_to(inds, tri.shape)[tri]
+ for inds in np.indices(tri.shape, sparse=True))
def _trilu_indices_form_dispatcher(arr, k=None):
@@ -1010,7 +1013,10 @@ def triu_indices(n, k=0, m=None):
[ 12, 13, 14, -1]])
"""
- return nonzero(~tri(n, m, k=k-1, dtype=bool))
+ tri = ~np.tri(n, m, k=k - 1, dtype=bool)
+
+ return tuple(np.broadcast_to(inds, tri.shape)[tri]
+ for inds in np.indices(tri.shape, sparse=True))
@array_function_dispatch(_trilu_indices_form_dispatcher)