summaryrefslogtreecommitdiff
path: root/numpy/numarray
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-08-10 11:55:33 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-08-10 11:55:33 +0000
commitb772c977e5d4d71c78919ef941858ad438ee4986 (patch)
tree2fbce2c189b29ec2f1deb33e820e70ba3ad56ff8 /numpy/numarray
parent4b1569e2208baf36a5ebd0de0877946bd86b2a38 (diff)
downloadnumpy-b772c977e5d4d71c78919ef941858ad438ee4986.tar.gz
Update C-API to add features needed for numarray compatibility. Output argument added for several functions and clipmode argument added for a few others.
Diffstat (limited to 'numpy/numarray')
-rw-r--r--numpy/numarray/__init__.py7
-rw-r--r--numpy/numarray/compat.py6
-rw-r--r--numpy/numarray/functions.py463
-rw-r--r--numpy/numarray/numerictypes.py1
-rw-r--r--numpy/numarray/session.py349
-rw-r--r--numpy/numarray/ufuncs.py5
6 files changed, 767 insertions, 64 deletions
diff --git a/numpy/numarray/__init__.py b/numpy/numarray/__init__.py
index e48401433..6a902ddee 100644
--- a/numpy/numarray/__init__.py
+++ b/numpy/numarray/__init__.py
@@ -2,21 +2,24 @@ from util import *
from numerictypes import *
from functions import *
from ufuncs import *
+from session import *
import util
import numerictypes
import functions
import ufuncs
import compat
+import session
-__all__ = util.__all__
+__all__ = ['session', 'numerictypes']
+__all__ += util.__all__
__all__ += numerictypes.__all__
__all__ += functions.__all__
__all__ += ufuncs.__all__
__all__ += compat.__all__
+__all__ += session.__all__
del util
-del numerictypes
del functions
del ufuncs
del compat
diff --git a/numpy/numarray/compat.py b/numpy/numarray/compat.py
index c9fd3ef86..e0d13a7c2 100644
--- a/numpy/numarray/compat.py
+++ b/numpy/numarray/compat.py
@@ -1,6 +1,4 @@
-__all__ = ['NewAxis']
+__all__ = ['NewAxis', 'ArrayType']
-from numpy import newaxis
-
-NewAxis = newaxis
+from numpy import newaxis as NewAxis, ndarray as ArrayType
diff --git a/numpy/numarray/functions.py b/numpy/numarray/functions.py
index d549f938f..1f358385c 100644
--- a/numpy/numarray/functions.py
+++ b/numpy/numarray/functions.py
@@ -1,71 +1,80 @@
# missing Numarray defined names (in from numarray import *)
-##__all__ = ['ArrayType', 'CLIP', 'ClassicUnpickler', 'Complex32_fromtype',
-## 'Complex64_fromtype', 'ComplexArray', 'EarlyEOFError', 'Error',
-## 'FileSeekWarning', 'MAX_ALIGN', 'MAX_INT_SIZE', 'MAX_LINE_WIDTH',
-## 'MathDomainError', 'NDArray', 'NewArray', 'NewAxis', 'NumArray',
-## 'NumError', 'NumOverflowError', 'PRECISION', 'Py2NumType',
+##__all__ = ['ClassicUnpickler', 'Complex32_fromtype',
+## 'Complex64_fromtype', 'ComplexArray', 'Error',
+## 'MAX_ALIGN', 'MAX_INT_SIZE', 'MAX_LINE_WIDTH',
+## 'NDArray', 'NewArray', 'NumArray',
+## 'NumError', 'PRECISION', 'Py2NumType',
## 'PyINT_TYPES', 'PyLevel2Type', 'PyNUMERIC_TYPES', 'PyREAL_TYPES',
-## 'RAISE', 'SLOPPY', 'STRICT', 'SUPPRESS_SMALL', 'SizeMismatchError',
-## 'SizeMismatchWarning', 'SuitableBuffer', 'USING_BLAS',
-## 'UnderflowError', 'UsesOpPriority', 'WARN', 'WRAP', 'all',
-## 'allclose', 'alltrue', 'and_', 'any', 'arange', 'argmax',
-## 'argmin', 'argsort', 'around', 'array2list', 'array_equal',
-## 'array_equiv', 'array_repr', 'array_str', 'arrayprint',
-## 'arrayrange', 'average', 'choose', 'clip',
-## 'codegenerator', 'compress', 'concatenate', 'conjugate',
-## 'copy', 'copy_reg', 'diagonal', 'divide_remainder',
-## 'dotblas', 'e', 'explicit_type', 'flush_caches', 'fromfile',
-## 'fromfunction', 'fromlist', 'fromstring', 'generic',
-## 'genericCoercions', 'genericPromotionExclusions', 'genericTypeRank',
-## 'getShape', 'getTypeObject', 'handleError', 'identity', 'indices',
-## 'info', 'innerproduct', 'inputarray', 'isBigEndian',
-## 'kroneckerproduct', 'lexsort', 'libnumarray', 'libnumeric',
-## 'load', 'make_ufuncs', 'math', 'memory',
-## 'numarrayall', 'numarraycore', 'numerictypes', 'numinclude',
-## 'operator', 'os', 'outerproduct', 'pi', 'put', 'putmask',
-## 'pythonTypeMap', 'pythonTypeRank', 'rank', 'repeat',
-## 'reshape', 'resize', 'round', 'safethread', 'save', 'scalarTypeMap',
-## 'scalarTypes', 'searchsorted', 'session', 'shape', 'sign', 'size',
-## 'sometrue', 'sort', 'swapaxes', 'sys', 'take', 'tcode',
-## 'tensormultiply', 'tname', 'trace', 'transpose', 'typeDict',
-## 'typecode', 'typecodes', 'typeconv', 'types', 'ufunc',
-## 'ufuncFactory', 'value', 'ieeemask', 'cumproduct', 'cumsum',
-## 'nonzero']
-
+## 'SUPPRESS_SMALL',
+## 'SuitableBuffer', 'USING_BLAS',
+## 'UsesOpPriority',
+## 'codegenerator', 'generic', 'libnumarray', 'libnumeric',
+## 'make_ufuncs', 'memory',
+## 'numarrayall', 'numarraycore', 'numinclude', 'safethread',
+## 'typecode', 'typecodes', 'typeconv', 'ufunc', 'ufuncFactory',
+## 'ieeemask']
__all__ = ['asarray', 'ones', 'zeros', 'array', 'where']
__all__ += ['vdot', 'dot', 'matrixmultiply', 'ravel', 'indices',
- 'arange', 'concatenate']
+ 'arange', 'concatenate', 'all', 'allclose', 'alltrue', 'and_',
+ 'any', 'argmax', 'argmin', 'argsort', 'around', 'array_equal',
+ 'array_equiv', 'arrayrange', 'array_str', 'array_repr',
+ 'array2list', 'average', 'choose', 'CLIP', 'RAISE', 'WRAP',
+ 'clip', 'compress', 'concatenate', 'copy', 'copy_reg',
+ 'diagonal', 'divide_remainder', 'e', 'explicit_type', 'pi',
+ 'flush_caches', 'fromfile', 'os', 'sys', 'STRICT',
+ 'SLOPPY', 'WARN', 'EarlyEOFError', 'SizeMismatchError',
+ 'SizeMismatchWarning', 'FileSeekWarning', 'fromstring',
+ 'fromfunction', 'fromlist', 'getShape', 'getTypeObject',
+ 'identity', 'indices', 'info', 'innerproduct', 'inputarray',
+ 'isBigEndian', 'kroneckerproduct', 'lexsort', 'math',
+ 'operator', 'outerproduct', 'put', 'putmask', 'rank',
+ 'repeat', 'reshape', 'resize', 'round', 'searchsorted',
+ 'shape', 'size', 'sometrue', 'sort', 'swapaxes', 'take',
+ 'tcode', 'tname', 'tensormultiply', 'trace', 'transpose',
+ 'types', 'value', 'cumsum', 'cumproduct', 'nonzero'
+ ]
-from numpy import dot as matrixmultiply, dot, vdot, ravel, concatenate
+import copy, copy_reg, types
+import os, sys, math, operator
-def array(sequence=None, typecode=None, copy=1, savespace=0,
- type=None, shape=None, dtype=None):
- dtype = type2dtype(typecode, type, dtype)
- if sequence is None:
- if shape is None:
- return None
- if dtype is None:
- dtype = 'l'
- return N.empty(shape, dtype)
- arr = N.array(sequence, dtype, copy=copy)
- if shape is not None:
- arr.shape = shape
- return arr
+from numpy import dot as matrixmultiply, dot, vdot, ravel, concatenate, all,\
+ allclose, any, argmax, argmin, around, argsort, array_equal, array_equiv,\
+ array_str, array_repr, average, CLIP, RAISE, WRAP, clip, concatenate, \
+ diagonal, e, pi, fromfunction, indices, inner as innerproduct, nonzero, \
+ outer as outerproduct, kron as kroneckerproduct, lexsort, putmask, rank, \
+ resize, searchsorted, shape, size, sort, swapaxes, trace, transpose
+import numpy as N
-def asarray(seq, type=None, typecode=None, dtype=None):
- if seq is None:
- return None
- dtype = type2dtype(typecode, type, dtype)
- return N.array(seq, dtype, copy=0)
+from numerictypes import typefrom
+
+isBigEndian = sys.byteorder != 'little'
+value = tcode = 'f'
+tname = 'Float32'
+# If dtype is not None, then it is used
+# If type is not None, then it is used
+# If typecode is not None then it is used
+# If use_default is True, then the default
+# data-type is returned if all are None
+def type2dtype(typecode, type, dtype, use_default=True):
+ if dtype is None:
+ if type is None:
+ if use_default or typecode is not None:
+ dtype = N.dtype(typecode)
+ else:
+ dtype = N.dtype(type)
+ if use_default and dtype is None:
+ dtype = N.dtype(None)
+ return dtype
+
def ones(shape, type=None, typecode=None, dtype=None):
- dtype = type2dtype(typecode, type, dtype)
+ dtype = type2dtype(typecode, type, dtype, 1)
return N.ones(shape, dtype)
def zeros(shape, type=None, typecode=None, dtype=None):
- dtype = type2dtype(typecode, type, dtype)
+ dtype = type2dtype(typecode, type, dtype, 1)
return N.zeros(shape, dtype)
def where(condition, x=None, y=None, out=None):
@@ -83,5 +92,349 @@ def indices(shape, type=None):
def arange(a1, a2=None, stride=1, type=None, shape=None,
typecode=None, dtype=None):
- dtype = type2dtype(typecode, type, dtype)
+ dtype = type2dtype(typecode, type, dtype, 0)
return N.arange(a1, a2, stride, dtype)
+
+arrayrange = arange
+
+def alltrue(x, axis=0):
+ return N.alltrue(x, axis)
+
+def and_(a, b):
+ """Same as a & b
+ """
+ return a & b
+
+def divide_remainder(a, b):
+ a, b = asarray(a), asarray(b)
+ return (a/b,a%b)
+
+def around(array, digits=0, output=None):
+ ret = N.around(array, digits, output)
+ if output is None:
+ return ret
+ return
+
+def array2list(arr):
+ return arr.tolist()
+
+
+def choose(selector, population, outarr=None, clipmode=RAISE):
+ a = N.asarray(selector)
+ ret = a.choose(population, out=outarr, mode=clipmode)
+ if outarr is None:
+ return ret
+ return
+
+def compress(condition, a, axis=0):
+ return N.compress(condition, a, axis)
+
+# only returns a view
+def explicit_type(a):
+ x = a.view()
+ return x
+
+# stub
+def flush_caches():
+ pass
+
+
+class EarlyEOFError(Exception):
+ "Raised in fromfile() if EOF unexpectedly occurs."
+ pass
+
+class SizeMismatchError(Exception):
+ "Raised in fromfile() if file size does not match shape."
+ pass
+
+class SizeMismatchWarning(Warning):
+ "Issued in fromfile() if file size does not match shape."
+ pass
+
+class FileSeekWarning(Warning):
+ "Issued in fromfile() if there is unused data and seek() fails"
+ pass
+
+
+STRICT, SLOPPY, WARN = range(3)
+
+_BLOCKSIZE=1024
+
+# taken and adapted directly from numarray
+def fromfile(infile, type=None, shape=None, sizing=STRICT,
+ typecode=None, dtype=None):
+ if isinstance(infile, (str, unicode)):
+ infile = open(infile, 'rb')
+ dtype = type2dtype(typecode, type, dtype, True)
+ if shape is None:
+ shape = (-1,)
+ if not isinstance(shape, tuple):
+ shape = (shape,)
+
+ if (list(shape).count(-1)>1):
+ raise ValueError("At most one unspecified dimension in shape")
+
+ if -1 not in shape:
+ if sizing != STRICT:
+ raise ValueError("sizing must be STRICT if size complete")
+ arr = N.empty(shape, dtype)
+ bytesleft=arr.nbytes
+ bytesread=0
+ while(bytesleft > _BLOCKSIZE):
+ data = infile.read(_BLOCKSIZE)
+ if len(data) != _BLOCKSIZE:
+ raise EarlyEOFError("Unexpected EOF reading data for size complete array")
+ arr.data[bytesread:bytesread+_BLOCKSIZE]=data
+ bytesread += _BLOCKSIZE
+ bytesleft -= _BLOCKSIZE
+ if bytesleft > 0:
+ data = infile.read(bytesleft)
+ if len(data) != bytesleft:
+ raise EarlyEOFError("Unexpected EOF reading data for size complete array")
+ arr.data[bytesread:bytesread+bytesleft]=data
+ return arr
+
+
+ ##shape is incompletely specified
+ ##read until EOF
+ ##implementation 1: naively use memory blocks
+ ##problematic because memory allocation can be double what is
+ ##necessary (!)
+
+ ##the most common case, namely reading in data from an unchanging
+ ##file whose size may be determined before allocation, should be
+ ##quick -- only one allocation will be needed.
+
+ recsize = dtype.itemsize * N.product([i for i in shape if i != -1])
+ blocksize = max(_BLOCKSIZE/recsize, 1)*recsize
+
+ ##try to estimate file size
+ try:
+ curpos=infile.tell()
+ infile.seek(0,2)
+ endpos=infile.tell()
+ infile.seek(curpos)
+ except (AttributeError, IOError):
+ initsize=blocksize
+ else:
+ initsize=max(1,(endpos-curpos)/recsize)*recsize
+
+ buf = N.newbuffer(initsize)
+
+ bytesread=0
+ while 1:
+ data=infile.read(blocksize)
+ if len(data) != blocksize: ##eof
+ break
+ ##do we have space?
+ if len(buf) < bytesread+blocksize:
+ buf=_resizebuf(buf,len(buf)+blocksize)
+ ## or rather a=resizebuf(a,2*len(a)) ?
+ assert len(buf) >= bytesread+blocksize
+ buf[bytesread:bytesread+blocksize]=data
+ bytesread += blocksize
+
+ if len(data) % recsize != 0:
+ if sizing == STRICT:
+ raise SizeMismatchError("Filesize does not match specified shape")
+ if sizing == WARN:
+ _warnings.warn("Filesize does not match specified shape",
+ SizeMismatchWarning)
+ try:
+ infile.seek(-(len(data) % recsize),1)
+ except AttributeError:
+ _warnings.warn("Could not rewind (no seek support)",
+ FileSeekWarning)
+ except IOError:
+ _warnings.warn("Could not rewind (IOError in seek)",
+ FileSeekWarning)
+ datasize = (len(data)/recsize) * recsize
+ if len(buf) != bytesread+datasize:
+ buf=_resizebuf(buf,bytesread+datasize)
+ buf[bytesread:bytesread+datasize]=data[:datasize]
+ ##deduce shape from len(buf)
+ shape = list(shape)
+ uidx = shape.index(-1)
+ shape[uidx]=len(buf) / recsize
+
+ a = N.ndarray(shape=shape, dtype=type, buffer=buf)
+ if a.dtype.char == '?':
+ N.not_equal(a, 0, a)
+ return a
+
+def fromstring(datastring, type=None, shape=None, typecode=None, dtype=None):
+ dtype = type2dtype(typecode, type, dtype, True)
+ if shape is None:
+ count = -1
+ else:
+ count = N.product(shape)*dtype.itemsize
+ res = N.fromstring(datastring, count=count)
+ if shape is not None:
+ res.shape = shape
+ return res
+
+
+# check_overflow is ignored
+def fromlist(seq, type=None, shape=None, check_overflow=0, typecode=None, dtype=None):
+ dtype = type2dtype(typecode, type, dtype, False)
+ return N.array(seq, dtype)
+
+def array(sequence=None, typecode=None, copy=1, savespace=0,
+ type=None, shape=None, dtype=None):
+ dtype = type2dtype(typecode, type, dtype, 0)
+ if sequence is None:
+ if shape is None:
+ return None
+ if dtype is None:
+ dtype = 'l'
+ return N.empty(shape, dtype)
+ if isinstance(sequence, file):
+ return fromfile(sequence, dtype=dtype, shape=shape)
+ if isinstance(sequence, str):
+ return fromstring(sequence, dtype=dtype, shape=shape)
+ if isinstance(sequence, buffer):
+ arr = N.frombuffer(sequence, dtype=dtype)
+ else:
+ arr = N.array(sequence, dtype, copy=copy)
+ if shape is not None:
+ arr.shape = shape
+ return arr
+
+def asarray(seq, type=None, typecode=None, dtype=None):
+ if isinstance(seq, N.ndarray) and type is None and \
+ typecode is None and dtype is None:
+ return seq
+ return array(seq, type=type, typecode=typecode, copy=0, dtype=dtype)
+
+inputarray = asarray
+
+
+def getTypeObject(sequence, type):
+ if type is not None:
+ return type
+ try:
+ return typefrom(N.array(sequence))
+ except:
+ raise TypeError("Can't determine a reasonable type from sequence")
+
+def getShape(shape, *args):
+ try:
+ if shape is () and not args:
+ return ()
+ if len(args) > 0:
+ shape = (shape, ) + args
+ else:
+ shape = tuple(shape)
+ dummy = N.array(shape)
+ if not issubclass(dummy.dtype.type, N.integer):
+ raise TypeError
+ if len(dummy) > N.MAXDIMS:
+ raise TypeError
+ except:
+ raise TypeError("Shape must be a sequence of integers")
+ return shape
+
+
+def identity(n, type=None, typecode=None, dtype=None):
+ dtype = type2dtype(typecode, type, dtype, True)
+ return N.identity(n, dtype)
+
+def info(obj):
+ print "class: ", type(obj)
+ print "shape: ", obj.shape
+ print "strides: ", obj.strides
+ print "byteoffset: 0"
+ print "bytestride: ", obj.strides[0]
+ print "itemsize: ", obj.itemsize
+ print "aligned: ", obj.flags.isaligned
+ print "contiguous: ", obj.flags.contiguous
+ print "buffer: ", obj.data
+ print "data pointer:", obj._as_paramater_, "(DEBUG ONLY)"
+ print "byteorder: ",
+ endian = obj.dtype.byteorder
+ if endian in ['|','=']:
+ print sys.byteorder
+ elif endian == '>':
+ print "big"
+ else:
+ print "little"
+ print "byteswap: ", not obj.dtype.isnative
+ print "type: ", typefrom(obj)
+
+#clipmode is ignored if axis is not 0 and array is not 1d
+def put(array, indices, values, axis=0, clipmode=RAISE):
+ if not isinstance(array, N.ndarray):
+ raise TypeError("put only works on subclass of ndarray")
+ work = asarray(array)
+ if axis == 0:
+ if array.ndim == 1:
+ work.put(indices, values, clipmode)
+ else:
+ work[indices] = values
+ elif isinstance(axis, (int, long, N.integer)):
+ work = work.swapaxes(0, axis)
+ work[indices] = values
+ work = work.swapaxes(0, axis)
+ else:
+ def_axes = range(work.ndim)
+ for x in axis:
+ def_axes.remove(x)
+ axis = list(axis)+def_axes
+ work = work.transpose(axis)
+ work[indices] = values
+ work = work.transpose(axis)
+
+def repeat(array, repeats, axis=0):
+ return N.repeat(array, repeats, axis)
+
+
+def reshape(array, shape, *args):
+ if len(args) > 0:
+ shape = (shape,) + args
+ return N.reshape(array, shape)
+
+
+import warnings as _warnings
+def round(*args, **keys):
+ _warnings.warn("round() is deprecated. Switch to around()",
+ DeprecationWarning)
+ return around(*args, **keys)
+
+def sometrue(array, axis=0):
+ return N.sometrue(array, axis)
+
+#clipmode is ignored if axis is not an integer
+def take(array, indices, axis=0, outarr=None, clipmode=RAISE):
+ array = N.asarray(array)
+ if isinstance(axis, (int, long, N.integer)):
+ res = array.take(indices, axis, outarr, clipmode)
+ if outarr is None:
+ return res
+ return
+ else:
+ def_axes = range(array.ndim)
+ for x in axis:
+ def_axes.remove(x)
+ axis = list(axis) + def_axes
+ work = array.transpose(axis)
+ res = work[indices]
+ if outarr is None:
+ return res
+ out[...] = res
+ return
+
+def tensormultiply(a1, a2):
+ a1, a2 = N.asarray(a1), N.asarray(a2)
+ if (a1.shape[-1] != a2.shape[0]):
+ raise ValueError("Unmatched dimensions")
+ shape = a1.shape[:-1] + a2.shape[1:]
+ return N.reshape(dot(N.reshape(a1, (-1, a1.shape[-1])),
+ N.reshape(a2, (a2.shape[0],-1))),
+ shape)
+
+def cumsum(a1, axis=0, out=None, type=None, dim=0):
+ return N.asarray(a1).cumsum(axis,dtype=type,out=out)
+
+def cumproduct(a1, axis=0, out=None, type=None, dim=0):
+ return N.asarray(a1).cumprod(axis,dtype=type,out=out)
+
diff --git a/numpy/numarray/numerictypes.py b/numpy/numarray/numerictypes.py
index 549fb8ca2..eadaaeee3 100644
--- a/numpy/numarray/numerictypes.py
+++ b/numpy/numarray/numerictypes.py
@@ -530,7 +530,6 @@ _scipy_dtypechar_inverse = {}
for key,value in _scipy_dtypechar.items():
_scipy_dtypechar_inverse[value] = key
-
def typefrom(obj):
return _scipy_dtypechar_inverse[obj.dtype.char]
diff --git a/numpy/numarray/session.py b/numpy/numarray/session.py
new file mode 100644
index 000000000..25520bd41
--- /dev/null
+++ b/numpy/numarray/session.py
@@ -0,0 +1,349 @@
+""" This module contains a "session saver" which saves the state of a
+NumPy session to a file. At a later time, a different Python
+process can be started and the saved session can be restored using
+load().
+
+The session saver relies on the Python pickle protocol to save and
+restore objects. Objects which are not themselves picklable (e.g.
+modules) can sometimes be saved by "proxy", particularly when they
+are global constants of some kind. If it's not known that proxying
+will work, a warning is issued at save time. If a proxy fails to
+reload properly (e.g. because it's not a global constant), a warning
+is issued at reload time and that name is bound to a _ProxyFailure
+instance which tries to identify what should have been restored.
+
+First, some unfortunate (probably unnecessary) concessions to doctest
+to keep the test run free of warnings.
+
+>>> del _PROXY_ALLOWED
+>>> del copy
+>>> del __builtins__
+
+By default, save() stores every variable in the caller's namespace:
+
+>>> import numpy as na
+>>> a = na.arange(10)
+>>> save()
+
+Alternately, save() can be passed a comma seperated string of variables:
+
+>>> save("a,na")
+
+Alternately, save() can be passed a dictionary, typically one you already
+have lying around somewhere rather than created inline as shown here:
+
+>>> save(dictionary={"a":a,"na":na})
+
+If both variables and a dictionary are specified, the variables to be
+saved are taken from the dictionary.
+
+>>> save(variables="a,na",dictionary={"a":a,"na":na})
+
+Remove names from the session namespace
+
+>>> del a, na
+
+By default, load() restores every variable/object in the session file
+to the caller's namespace.
+
+>>> load()
+
+load() can be passed a comma seperated string of variables to be
+restored from the session file to the caller's namespace:
+
+>>> load("a,na")
+
+load() can also be passed a dictionary to *restore to*:
+
+>>> d = {}
+>>> load(dictionary=d)
+
+load can be passed both a list variables of variables to restore and a
+dictionary to restore to:
+
+>>> load(variables="a,na", dictionary=d)
+
+>>> na.all(a == na.arange(10))
+1
+>>> na.__name__
+'numpy'
+
+NOTE: session saving is faked for modules using module proxy objects.
+Saved modules are re-imported at load time but any "state" in the module
+which is not restored by a simple import is lost.
+
+"""
+
+__all__ = ['load', 'save']
+
+import copy
+import sys
+import pickle
+
+SAVEFILE="session.dat"
+VERBOSE = False # global import-time override
+
+def _foo(): pass
+
+_PROXY_ALLOWED = (type(sys), # module
+ type(_foo), # function
+ type(None)) # None
+
+def _update_proxy_types():
+ """Suppress warnings for known un-picklables with working proxies."""
+ pass
+
+def _unknown(_type):
+ """returns True iff _type isn't known as OK to proxy"""
+ return (_type is not None) and (_type not in _PROXY_ALLOWED)
+
+# caller() from the following article with one extra f_back added.
+# from http://www.python.org/search/hypermail/python-1994q1/0506.html
+# SUBJECT: import ( how to put a symbol into caller's namespace )
+# SENDER: Steven D. Majewski (sdm7g@elvis.med.virginia.edu)
+# DATE: Thu, 24 Mar 1994 15:38:53 -0500
+
+def _caller():
+ """caller() returns the frame object of the function's caller."""
+ try:
+ 1 + '' # make an error happen
+ except: # and return the caller's caller's frame
+ return sys.exc_traceback.tb_frame.f_back.f_back.f_back
+
+def _callers_globals():
+ """callers_globals() returns the global dictionary of the caller."""
+ frame = _caller()
+ return frame.f_globals
+
+def _callers_modules():
+ """returns a list containing the names of all the modules in the caller's
+ global namespace."""
+ g = _callers_globals()
+ mods = []
+ for k,v in g.items():
+ if type(v) == type(sys):
+ mods.append(getattr(v,"__name__"))
+ return mods
+
+def _errout(*args):
+ for a in args:
+ print >>sys.stderr, a,
+ print >>sys.stderr
+
+def _verbose(*args):
+ if VERBOSE:
+ _errout(*args)
+
+class _ProxyingFailure:
+ """Object which is bound to a variable for a proxy pickle which failed to reload"""
+ def __init__(self, module, name, type=None):
+ self.module = module
+ self.name = name
+ self.type = type
+ def __repr__(self):
+ return "ProxyingFailure('%s','%s','%s')" % (self.module, self.name, self.type)
+
+class _ModuleProxy(object):
+ """Proxy object which fakes pickling a module"""
+ def __new__(_type, name, save=False):
+ if save:
+ _verbose("proxying module", name)
+ self = object.__new__(_type)
+ self.name = name
+ else:
+ _verbose("loading module proxy", name)
+ try:
+ self = _loadmodule(name)
+ except ImportError:
+ _errout("warning: module", name,"import failed.")
+ return self
+
+ def __getnewargs__(self):
+ return (self.name,)
+
+ def __getstate__(self):
+ return False
+
+def _loadmodule(module):
+ if not sys.modules.has_key(module):
+ modules = module.split(".")
+ s = ""
+ for i in range(len(modules)):
+ s = ".".join(modules[:i+1])
+ exec "import " + s
+ return sys.modules[module]
+
+class _ObjectProxy(object):
+ """Proxy object which fakes pickling an arbitrary object. Only global
+ constants can really be proxied."""
+ def __new__(_type, module, name, _type2, save=False):
+ if save:
+ if _unknown(_type2):
+ _errout("warning: proxying object", module + "." + name,
+ "of type", _type2, "because it wouldn't pickle...",
+ "it may not reload later.")
+ else:
+ _verbose("proxying object", module, name)
+ self = object.__new__(_type)
+ self.module, self.name, self.type = module, name, str(_type2)
+ else:
+ _verbose("loading object proxy", module, name)
+ try:
+ m = _loadmodule(module)
+ except (ImportError, KeyError):
+ _errout("warning: loading object proxy", module + "." + name,
+ "module import failed.")
+ return _ProxyingFailure(module,name,_type2)
+ try:
+ self = getattr(m, name)
+ except AttributeError:
+ _errout("warning: object proxy", module + "." + name,
+ "wouldn't reload from", m)
+ return _ProxyingFailure(module,name,_type2)
+ return self
+
+ def __getnewargs__(self):
+ return (self.module, self.name, self.type)
+
+ def __getstate__(self):
+ return False
+
+
+class _SaveSession(object):
+ """Tag object which marks the end of a save session and holds the
+ saved session variable names as a list of strings in the same
+ order as the session pickles."""
+ def __new__(_type, keys, save=False):
+ if save:
+ _verbose("saving session", keys)
+ else:
+ _verbose("loading session", keys)
+ self = object.__new__(_type)
+ self.keys = keys
+ return self
+
+ def __getnewargs__(self):
+ return (self.keys,)
+
+ def __getstate__(self):
+ return False
+
+class ObjectNotFound(RuntimeError):
+ pass
+
+def _locate(modules, object):
+ for mname in modules:
+ m = sys.modules[mname]
+ if m:
+ for k,v in m.__dict__.items():
+ if v is object:
+ return m.__name__, k
+ else:
+ raise ObjectNotFound(k)
+
+def save(variables=None, file=SAVEFILE, dictionary=None, verbose=False):
+
+ """saves variables from a numpy session to a file. Variables
+ which won't pickle are "proxied" if possible.
+
+ 'variables' a string of comma seperated variables: e.g. "a,b,c"
+ Defaults to dictionary.keys().
+
+ 'file' a filename or file object for the session file.
+
+ 'dictionary' the dictionary in which to look up the variables.
+ Defaults to the caller's globals()
+
+ 'verbose' print additional debug output when True.
+ """
+
+ global VERBOSE
+ VERBOSE = verbose
+
+ _update_proxy_types()
+
+ if isinstance(file, str):
+ file = open(file, "wb")
+
+ if dictionary is None:
+ dictionary = _callers_globals()
+
+ if variables is None:
+ keys = dictionary.keys()
+ else:
+ keys = variables.split(",")
+
+ source_modules = _callers_modules() + sys.modules.keys()
+
+ p = pickle.Pickler(file, protocol=2)
+
+ _verbose("variables:",keys)
+ for k in keys:
+ v = dictionary[k]
+ _verbose("saving", k, type(v))
+ try: # Try to write an ordinary pickle
+ p.dump(v)
+ _verbose("pickled", k)
+ except (pickle.PicklingError, TypeError, SystemError):
+ # Use proxies for stuff that won't pickle
+ if isinstance(v, type(sys)): # module
+ proxy = _ModuleProxy(v.__name__, save=True)
+ else:
+ try:
+ module, name = _locate(source_modules, v)
+ except ObjectNotFound:
+ _errout("warning: couldn't find object",k,
+ "in any module... skipping.")
+ continue
+ else:
+ proxy = _ObjectProxy(module, name, type(v), save=True)
+ p.dump(proxy)
+ o = _SaveSession(keys, save=True)
+ p.dump(o)
+ file.close()
+
+def load(variables=None, file=SAVEFILE, dictionary=None, verbose=False):
+
+ """load a numpy session from a file and store the specified
+ 'variables' into 'dictionary'.
+
+ 'variables' a string of comma seperated variables: e.g. "a,b,c"
+ Defaults to dictionary.keys().
+
+ 'file' a filename or file object for the session file.
+
+ 'dictionary' the dictionary in which to look up the variables.
+ Defaults to the caller's globals()
+
+ 'verbose' print additional debug output when True.
+ """
+
+ global VERBOSE
+ VERBOSE = verbose
+
+ if isinstance(file, str):
+ file = open(file, "rb")
+ if dictionary is None:
+ dictionary = _callers_globals()
+ values = []
+ p = pickle.Unpickler(file)
+ while 1:
+ o = p.load()
+ if isinstance(o, _SaveSession):
+ session = dict(zip(o.keys, values))
+ _verbose("updating dictionary with session variables.")
+ if variables is None:
+ keys = session.keys()
+ else:
+ keys = variables.split(",")
+ for k in keys:
+ dictionary[k] = session[k]
+ return None
+ else:
+ _verbose("unpickled object", str(o))
+ values.append(o)
+
+def test():
+ import doctest, numpy.numarray.session
+ return doctest.testmod(numpy.numarray.session)
+
diff --git a/numpy/numarray/ufuncs.py b/numpy/numarray/ufuncs.py
index 685b76587..3fb5671ce 100644
--- a/numpy/numarray/ufuncs.py
+++ b/numpy/numarray/ufuncs.py
@@ -8,7 +8,8 @@ __all__ = ['abs', 'absolute', 'add', 'arccos', 'arccosh', 'arcsin', 'arcsinh',
'logical_or', 'logical_xor', 'lshift', 'maximum', 'minimum',
'minus', 'multiply', 'negative', 'not_equal',
'power', 'product', 'remainder', 'rshift', 'sin', 'sinh', 'sqrt',
- 'subtract', 'sum', 'tan', 'tanh', 'true_divide']
+ 'subtract', 'sum', 'tan', 'tanh', 'true_divide',
+ 'conjugate', 'sign']
from numpy import absolute as abs, absolute, add, arccos, arccosh, arcsin, \
arcsinh, arctan, arctan2, arctanh, bitwise_and, invert as bitwise_not, \
@@ -18,4 +19,4 @@ from numpy import absolute as abs, absolute, add, arccos, arccosh, arcsin, \
logical_not, logical_or, logical_xor, left_shift as lshift, \
maximum, minimum, negative as minus, multiply, negative, \
not_equal, power, product, remainder, right_shift as rshift, sin, \
- sinh, sqrt, subtract, sum, tan, tanh, true_divide
+ sinh, sqrt, subtract, sum, tan, tanh, true_divide, conjugate, sign