diff options
author | Travis Oliphant <oliphant@enthought.com> | 2005-09-14 22:28:28 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2005-09-14 22:28:28 +0000 |
commit | 61b48697e440f76b2337c790ec5ca763cd55200b (patch) | |
tree | da64ece2ba0b6b97deb51c36ca320c64102e9baa /scipy/base/numeric.py | |
parent | 575d373479c63a42bc4a729a058da31a74e75d3e (diff) | |
download | numpy-61b48697e440f76b2337c790ec5ca763cd55200b.tar.gz |
Moving things to live under scipy
Diffstat (limited to 'scipy/base/numeric.py')
-rw-r--r-- | scipy/base/numeric.py | 291 |
1 files changed, 291 insertions, 0 deletions
diff --git a/scipy/base/numeric.py b/scipy/base/numeric.py new file mode 100644 index 000000000..05e9f1355 --- /dev/null +++ b/scipy/base/numeric.py @@ -0,0 +1,291 @@ + +import sys +import multiarray +from umath import * +from numerictypes import * + +#import _numpy # for freeze dependency resolution (at least on Mac) + +import types, math + +newaxis = None + +arange = multiarray.arange +array = multiarray.array +zeros = multiarray.zeros +empty = multiarray.empty +fromstring = multiarray.fromstring +fromfile = multiarray.fromfile +frombuffer = multiarray.frombuffer +where = multiarray.where +concatenate = multiarray.concatenate +#def where(condition, x=None, y=None): +# """where(condition,x,y) is shaped like condition and has elements of x and +# y where condition is respectively true or false. +# """ +# if (x is None) or (y is None): +# return nonzero(condition) +# return choose(not_equal(condition, 0), (y, x)) + +def asarray(a, dtype=None): + """asarray(a,dtype=None) returns a as a NumPy array. Unlike array(), + no copy is performed if a is already an array. + """ + return array(a, dtype, copy=0) + +_mode_from_name_dict = {'v': 0, + 's' : 1, + 'f' : 2} + +def _mode_from_name(mode): + if isinstance(mode, type("")): + return _mode_from_name_dict[mode.lower()[0]] + return mode + +def correlate(a,v,mode='valid'): + mode = _mode_from_name(mode) + return multiarray.correlate(a,v,mode) + + +def convolve(a,v,mode='full'): + """Returns the discrete, linear convolution of 1-D + sequences a and v; mode can be 0 (valid), 1 (same), or 2 (full) + to specify size of the resulting sequence. + """ + if (len(v) > len(a)): + a, v = v, a + mode = _mode_from_name(mode) + return correlate(a,asarray(v)[::-1],mode) + +ndarray = multiarray.ndarray +ufunc = type(sin) + +inner = multiarray.inner +dot = multiarray.dot + +def outer(a,b): + """outer(a,b) returns the outer product of two vectors. + result(i,j) = a(i)*b(j) when a and b are vectors + Will accept any arguments that can be made into vectors. + """ + a = asarray(a) + b = asarray(b) + return a.ravel()[:,newaxis]*b.ravel()[newaxis,:] + +def vdot(a, b): + """Returns the dot product of 2 vectors (or anything that can be made into + a vector). NB: this is not the same as `dot`, as it takes the conjugate + of its first argument if complex and always returns a scalar.""" + return dot(asarray(a).ravel().conj(), asarray(b).ravel()) + +# try to import blas optimized dot if available +try: + # importing this changes the dot function for basic 4 types + # to blas-optimized versions. + from blasdot import dot, vdot, inner +except ImportError: + pass + +def _move_axis_to_0(a, axis): + if axis == 0: + return a + n = a.ndim + if axis < 0: + axis += n + axes = range(1, axis+1) + [0,] + range(axis+1, n) + return a.transpose(axes) + +def cross(a, b, axisa=-1, axisb=-1, axisc=-1): + """Return the cross product of two (arrays of) vectors. + + The cross product is performed over the last axis of a and b by default, + and can handle axes with dimensions 2 and 3. For a dimension of 2, + the z-component of the equivalent three-dimensional cross product is + returned. + """ + a = _move_axis_to_0(asarray(a), axisa) + b = _move_axis_to_0(asarray(b), axisb) + msg = "incompatible dimensions for cross product\n"\ + "(dimension must be 2 or 3)" + if (a.shape[0] not in [2,3]) or (b.shape[0] not in [2,3]): + raise ValueError(msg) + if a.shape[0] == 2: + if (b.shape[0] == 2): + cp = a[0]*b[1] - a[1]*b[0] + if cp.ndim == 0: + return cp + else: + return cp.swapaxes(0,axisc) + else: + x = a[1]*b[2] + y = -a[0]*b[2] + z = a[0]*b[1] - a[1]*b[0] + elif a.shape[0] == 3: + if (b.shape[0] == 3): + x = a[1]*b[2] - a[2]*b[1] + y = a[2]*b[0] - a[0]*b[2] + z = a[0]*b[1] - a[1]*b[0] + else: + x = -a[2]*b[1] + y = a[2]*b[0] + z = a[0]*b[1] - a[1]*b[0] + cp = array([x,y,z]) + if cp.ndim == 1: + return cp + else: + return cp.swapaxes(0,axisc) + + +#Use numarray's printing function +from arrayprint import array2string, get_printoptions, set_printoptions + +_typelessdata = [int, float, complex] +if issubclass(intc, pyint): + _typelessdata.append(intc) + +def array_repr(arr, max_line_width=None, precision=None, suppress_small=None): + if arr.size > 0 or arr.shape==(0,): + lst = array2string(arr, max_line_width, precision, suppress_small, + ', ', "array(") + else: # show zero-length shape unless it is (0,) + lst = "[], shape=%s" % (repr(arr.shape),) + typeless = arr.dtype in _typelessdata + + if arr.__class__ is not ndarray: + cName= arr.__class__.__name__ + else: + cName = "array" + if typeless and arr.size: + return cName + "(%s)" % lst + else: + typename=arr.dtype.__name__[:-8] + return cName + "(%s, dtype=%s)" % (lst, typename) + +def array_str(a, max_line_width = None, precision = None, suppress_small = None): + return array2string(a, max_line_width, precision, suppress_small, ' ', "") + +set_string_function = multiarray.set_string_function +set_string_function(array_str, 0) +set_string_function(array_repr, 1) + +little_endian = (sys.byteorder == 'little') + +def indices(dimensions, dtype=intp): + """indices(dimensions,dtype=intp) returns an array representing a grid + of indices with row-only, and column-only variation. + """ + tmp = ones(dimensions, dtype) + lst = [] + for i in range(len(dimensions)): + lst.append( add.accumulate(tmp, i, )-1 ) + return array(lst) + +def fromfunction(function, dimensions): + """fromfunction(function, dimensions) returns an array constructed by + calling function on a tuple of number grids. The function should + accept as many arguments as there are dimensions which is a list of + numbers indicating the length of the desired output for each axis. + """ + args = indices(dimensions) + return function(*args) + + +from cPickle import load, loads +_cload = load +_file = file + +def load(file): + if isinstance(file, type("")): + file = _file(file,"rb") + return _cload(file) + +# These are all essentially abbreviations +# These might wind up in a special abbreviations module + +def ones(shape, dtype=intp, fortran=0): + """ones(shape, dtype=intp) returns an array of the given + dimensions which is initialized to all ones. + """ + a=zeros(shape, dtype, fortran) + a+=1 + ### a[...]=1 -- slower + return a + +def identity(n,dtype=intp): + """identity(n) returns the identity matrix of shape n x n. + """ + a = array([1]+n*[0],dtype=dtype) + b = empty((n,n),dtype=dtype) + b.flat = a + return b + +def allclose (a, b, rtol=1.e-5, atol=1.e-8): + """ allclose(a,b,rtol=1.e-5,atol=1.e-8) + Returns true if all components of a and b are equal + subject to given tolerances. + The relative error rtol must be positive and << 1.0 + The absolute error atol comes into play for those elements + of y that are very small or zero; it says how small x must be also. + """ + x = array(a, copy=0) + y = array(b, copy=0) + d = less(absolute(x-y), atol + rtol * absolute(y)) + return alltrue(ravel(d)) + + +# Now a method.... +##def setflags(arr, write=None, swap=None, uic=None, align=None): +## if not isinstance(arr, ndarray): +## raise ValueError, "first argument must be an array" +## sdict = {} +## if write is not None: +## sdict['WRITEABLE'] = not not write +## if swap is not None: +## sdict['NOTSWAPPED'] = not swap +## if uic is not None: +## if (uic): +## raise ValueError, "Can only set UPDATEIFCOPY flag to False" +## sdict['UPDATEIFCOPY'] = False +## if align is not None: +## sdict['ALIGNED'] = not not align +## arr.flags = sdict + +_errdict = {"ignore":ERR_IGNORE, + "warn":ERR_WARN, + "raise":ERR_RAISE, + "call":ERR_CALL} + +_errdict_rev = {} +for key in _errdict.keys(): + _errdict_rev[_errdict[key]] = key + +def seterr(divide="ignore", over="ignore", under="ignore", invalid="ignore"): + maskvalue = (_errdict[divide] << SHIFT_DIVIDEBYZERO) + \ + (_errdict[over] << SHIFT_OVERFLOW ) + \ + (_errdict[under] << SHIFT_UNDERFLOW) + \ + (_errdict[invalid] << SHIFT_INVALID) + frame = sys._getframe().f_back + frame.f_locals[UFUNC_ERRMASK_NAME] = maskvalue + return + +def geterr(): + frame = sys._getframe().f_back + try: + maskvalue = frame.f_locals[UFUNC_ERRMASK_NAME] + except KeyError: + maskvalue = ERR_DEFAULT + + mask = 3 + res = {} + val = (maskvalue >> SHIFT_DIVIDEBYZERO) & mask + res['divide'] = _errdict_rev[val] + val = (maskvalue >> SHIFT_OVERFLOW) & mask + res['over'] = _errdict_rev[val] + val = (maskvalue >> SHIFT_UNDERFLOW) & mask + res['under'] = _errdict_rev[val] + val = (maskvalue >> SHIFT_INVALID) & mask + res['invalid'] = _errdict_rev[val] + return res + + + |