summaryrefslogtreecommitdiff
path: root/twodim_base.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2007-12-15 18:54:52 +0000
committerTravis Oliphant <oliphant@enthought.com>2007-12-15 18:54:52 +0000
commite76b5fa6896c09257181675bbf4cf47789d32927 (patch)
tree7174e22c68fc47df61e745ee18625ee9f4f5b88c /twodim_base.py
parent02ee35a7e1c722a1cdac8f3a60fe9ef7aa079a37 (diff)
downloadnumpy-e76b5fa6896c09257181675bbf4cf47789d32927.tar.gz
Create a branch for io work in NumPy
Diffstat (limited to 'twodim_base.py')
-rw-r--r--twodim_base.py184
1 files changed, 184 insertions, 0 deletions
diff --git a/twodim_base.py b/twodim_base.py
new file mode 100644
index 000000000..4852097f3
--- /dev/null
+++ b/twodim_base.py
@@ -0,0 +1,184 @@
+""" Basic functions for manipulating 2d arrays
+
+"""
+
+__all__ = ['diag','diagflat','eye','fliplr','flipud','rot90','tri','triu',
+ 'tril','vander','histogram2d']
+
+from numpy.core.numeric import asanyarray, equal, subtract, arange, \
+ zeros, arange, greater_equal, multiply, ones, asarray
+
+def fliplr(m):
+ """ returns an array m with the rows preserved and columns flipped
+ in the left/right direction. Works on the first two dimensions of m.
+ """
+ m = asanyarray(m)
+ if m.ndim < 2:
+ raise ValueError, "Input must be >= 2-d."
+ return m[:, ::-1]
+
+def flipud(m):
+ """ returns an array with the columns preserved and rows flipped in
+ the up/down direction. Works on the first dimension of m.
+ """
+ m = asanyarray(m)
+ if m.ndim < 1:
+ raise ValueError, "Input must be >= 1-d."
+ return m[::-1,...]
+
+def rot90(m, k=1):
+ """ returns the array found by rotating m by k*90
+ degrees in the counterclockwise direction. Works on the first two
+ dimensions of m.
+ """
+ m = asanyarray(m)
+ if m.ndim < 2:
+ raise ValueError, "Input must >= 2-d."
+ k = k % 4
+ if k == 0: return m
+ elif k == 1: return fliplr(m).swapaxes(0,1)
+ elif k == 2: return fliplr(flipud(m))
+ else: return fliplr(m.swapaxes(0,1)) # k==3
+
+def eye(N, M=None, k=0, dtype=float):
+ """ eye returns a N-by-M 2-d array where the k-th diagonal is all ones,
+ and everything else is zeros.
+ """
+ if M is None: M = N
+ m = equal(subtract.outer(arange(N), arange(M)),-k)
+ if m.dtype != dtype:
+ m = m.astype(dtype)
+ return m
+
+def diag(v, k=0):
+ """ returns a copy of the the k-th diagonal if v is a 2-d array
+ or returns a 2-d array with v as the k-th diagonal if v is a
+ 1-d array.
+ """
+ v = asarray(v)
+ s = v.shape
+ if len(s)==1:
+ n = s[0]+abs(k)
+ res = zeros((n,n), v.dtype)
+ if (k>=0):
+ i = arange(0,n-k)
+ fi = i+k+i*n
+ else:
+ i = arange(0,n+k)
+ fi = i+(i-k)*n
+ res.flat[fi] = v
+ return res
+ elif len(s)==2:
+ N1,N2 = s
+ if k >= 0:
+ M = min(N1,N2-k)
+ i = arange(0,M)
+ fi = i+k+i*N2
+ else:
+ M = min(N1+k,N2)
+ i = arange(0,M)
+ fi = i + (i-k)*N2
+ return v.flat[fi]
+ else:
+ raise ValueError, "Input must be 1- or 2-d."
+
+def diagflat(v,k=0):
+ try:
+ wrap = v.__array_wrap__
+ except AttributeError:
+ wrap = None
+ v = asarray(v).ravel()
+ s = len(v)
+ n = s + abs(k)
+ res = zeros((n,n), v.dtype)
+ if (k>=0):
+ i = arange(0,n-k)
+ fi = i+k+i*n
+ else:
+ i = arange(0,n+k)
+ fi = i+(i-k)*n
+ res.flat[fi] = v
+ if not wrap:
+ return res
+ return wrap(res)
+
+def tri(N, M=None, k=0, dtype=float):
+ """ returns a N-by-M array where all the diagonals starting from
+ lower left corner up to the k-th are all ones.
+ """
+ if M is None: M = N
+ m = greater_equal(subtract.outer(arange(N), arange(M)),-k)
+ return m.astype(dtype)
+
+def tril(m, k=0):
+ """ returns the elements on and below the k-th diagonal of m. k=0 is the
+ main diagonal, k > 0 is above and k < 0 is below the main diagonal.
+ """
+ m = asanyarray(m)
+ out = multiply(tri(m.shape[0], m.shape[1], k=k, dtype=int),m)
+ return out
+
+def triu(m, k=0):
+ """ returns the elements on and above the k-th diagonal of m. k=0 is the
+ main diagonal, k > 0 is above and k < 0 is below the main diagonal.
+ """
+ m = asanyarray(m)
+ out = multiply((1-tri(m.shape[0], m.shape[1], k-1, int)),m)
+ return out
+
+# borrowed from John Hunter and matplotlib
+def vander(x, N=None):
+ """
+ X = vander(x,N=None)
+
+ The Vandermonde matrix of vector x. The i-th column of X is the
+ the i-th power of x. N is the maximum power to compute; if N is
+ None it defaults to len(x).
+
+ """
+ x = asarray(x)
+ if N is None: N=len(x)
+ X = ones( (len(x),N), x.dtype)
+ for i in range(N-1):
+ X[:,i] = x**(N-i-1)
+ return X
+
+
+def histogram2d(x,y, bins=10, range=None, normed=False, weights=None):
+ """histogram2d(x,y, bins=10, range=None, normed=False) -> H, xedges, yedges
+
+ Compute the 2D histogram from samples x,y.
+
+ :Parameters:
+ - `x,y` : Sample arrays (1D).
+ - `bins` : Number of bins -or- [nbin x, nbin y] -or-
+ [bin edges] -or- [x bin edges, y bin edges].
+ - `range` : A sequence of lower and upper bin edges (default: [min, max]).
+ - `normed` : Boolean, if False, return the number of samples in each bin,
+ if True, returns the density.
+ - `weights` : An array of weights. The weights are normed only if normed
+ is True. Should weights.sum() not equal N, the total bin count \
+ will not be equal to the number of samples.
+
+ :Return:
+ - `hist` : Histogram array.
+ - `xedges, yedges` : Arrays defining the bin edges.
+
+ Example:
+ >>> x = random.randn(100,2)
+ >>> hist2d, xedges, yedges = histogram2d(x, bins = (6, 7))
+
+ :SeeAlso: histogramdd
+ """
+ from numpy import histogramdd
+
+ try:
+ N = len(bins)
+ except TypeError:
+ N = 1
+
+ if N != 1 and N != 2:
+ xedges = yedges = asarray(bins, float)
+ bins = [xedges, yedges]
+ hist, edges = histogramdd([x,y], bins, range, normed, weights)
+ return hist, edges[0], edges[1]