diff options
Diffstat (limited to 'numpy/lib/twodim_base.py')
-rw-r--r-- | numpy/lib/twodim_base.py | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py index 5a0c0e7ee..20b5cdd67 100644 --- a/numpy/lib/twodim_base.py +++ b/numpy/lib/twodim_base.py @@ -11,10 +11,24 @@ __all__ = ['diag', 'diagflat', 'eye', 'fliplr', 'flipud', 'rot90', 'tri', from numpy.core.numeric import ( asanyarray, subtract, arange, zeros, greater_equal, multiply, ones, - asarray, where, dtype as np_dtype, less + asarray, where, dtype as np_dtype, less, int8, int16, int32, int64 ) +from numpy.core import iinfo +i1 = iinfo(int8) +i2 = iinfo(int16) +i4 = iinfo(int32) +def _min_int(low, high): + """ get small int that fits the range """ + if high <= i1.max and low >= i1.min: + return int8 + if high <= i2.max and low >= i2.min: + return int16 + if high <= i4.max and low >= i4.min: + return int32 + return int64 + def fliplr(m): """ @@ -396,7 +410,8 @@ def tri(N, M=None, k=0, dtype=float): if M is None: M = N - m = greater_equal.outer(arange(N), arange(-k, M-k)) + m = greater_equal.outer(arange(N, dtype=_min_int(0, N)), + arange(-k, M-k, dtype=_min_int(-k, M - k))) # Avoid making a copy if the requested type is already bool if np_dtype(dtype) != np_dtype(bool): |