summaryrefslogtreecommitdiff
path: root/numpy/lib/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/utils.py')
-rw-r--r--numpy/lib/utils.py579
1 files changed, 579 insertions, 0 deletions
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py
new file mode 100644
index 000000000..048ffafc0
--- /dev/null
+++ b/numpy/lib/utils.py
@@ -0,0 +1,579 @@
+import compiler
+import os
+import sys
+import inspect
+import types
+from numpy.core.numerictypes import obj2sctype, generic
+from numpy.core.multiarray import dtype as _dtype
+from numpy.core import product, ndarray
+
+__all__ = ['issubclass_', 'get_numpy_include', 'issubsctype',
+ 'issubdtype', 'deprecate', 'deprecate_with_doc',
+ 'get_numarray_include',
+ 'get_include', 'info', 'source', 'who',
+ 'byte_bounds', 'may_share_memory', 'safe_eval']
+
+def issubclass_(arg1, arg2):
+ try:
+ return issubclass(arg1, arg2)
+ except TypeError:
+ return False
+
+def issubsctype(arg1, arg2):
+ return issubclass(obj2sctype(arg1), obj2sctype(arg2))
+
+def issubdtype(arg1, arg2):
+ if issubclass_(arg2, generic):
+ return issubclass(_dtype(arg1).type, arg2)
+ mro = _dtype(arg2).type.mro()
+ if len(mro) > 1:
+ val = mro[1]
+ else:
+ val = mro[0]
+ return issubclass(_dtype(arg1).type, val)
+
+def get_include():
+ """Return the directory in the package that contains the numpy/*.h header
+ files.
+
+ Extension modules that need to compile against numpy should use this
+ function to locate the appropriate include directory. Using distutils:
+
+ import numpy
+ Extension('extension_name', ...
+ include_dirs=[numpy.get_include()])
+ """
+ import numpy
+ if numpy.show_config is None:
+ # running from numpy source directory
+ d = os.path.join(os.path.dirname(numpy.__file__), 'core', 'include')
+ else:
+ # using installed numpy core headers
+ import numpy.core as core
+ d = os.path.join(os.path.dirname(core.__file__), 'include')
+ return d
+
+def get_numarray_include(type=None):
+ """Return the directory in the package that contains the numpy/*.h header
+ files.
+
+ Extension modules that need to compile against numpy should use this
+ function to locate the appropriate include directory. Using distutils:
+
+ import numpy
+ Extension('extension_name', ...
+ include_dirs=[numpy.get_numarray_include()])
+ """
+ from numpy.numarray import get_numarray_include_dirs
+ include_dirs = get_numarray_include_dirs()
+ if type is None:
+ return include_dirs[0]
+ else:
+ return include_dirs + [get_include()]
+
+
+if sys.version_info < (2, 4):
+ # Can't set __name__ in 2.3
+ import new
+ def _set_function_name(func, name):
+ func = new.function(func.func_code, func.func_globals,
+ name, func.func_defaults, func.func_closure)
+ return func
+else:
+ def _set_function_name(func, name):
+ func.__name__ = name
+ return func
+
+def deprecate(func, oldname=None, newname=None):
+ """Deprecate old functions.
+ Issues a DeprecationWarning, adds warning to oldname's docstring,
+ rebinds oldname.__name__ and returns new function object.
+
+ Example:
+ oldfunc = deprecate(newfunc, 'oldfunc', 'newfunc')
+
+ """
+
+ import warnings
+ if oldname is None:
+ oldname = func.func_name
+ if newname is None:
+ str1 = "%s is deprecated" % (oldname,)
+ depdoc = "%s is DEPRECATED!" % (oldname,)
+ else:
+ str1 = "%s is deprecated, use %s" % (oldname, newname),
+ depdoc = '%s is DEPRECATED! -- use %s instead' % (oldname, newname,)
+
+ def newfunc(*args,**kwds):
+ warnings.warn(str1, DeprecationWarning)
+ return func(*args, **kwds)
+
+ newfunc = _set_function_name(newfunc, oldname)
+ doc = func.__doc__
+ if doc is None:
+ doc = depdoc
+ else:
+ doc = '\n'.join([depdoc, doc])
+ newfunc.__doc__ = doc
+ try:
+ d = func.__dict__
+ except AttributeError:
+ pass
+ else:
+ newfunc.__dict__.update(d)
+ return newfunc
+
+def deprecate_with_doc(somestr):
+ """Decorator to deprecate functions and provide detailed documentation
+ with 'somestr' that is added to the functions docstring.
+
+ Example:
+ depmsg = 'function numpy.lib.foo has been merged into numpy.lib.io.foobar'
+ @deprecate_with_doc(depmsg)
+ def foo():
+ pass
+
+ """
+
+ def _decorator(func):
+ newfunc = deprecate(func)
+ newfunc.__doc__ += "\n" + somestr
+ return newfunc
+ return _decorator
+
+get_numpy_include = deprecate(get_include, 'get_numpy_include', 'get_include')
+
+
+#--------------------------------------------
+# Determine if two arrays can share memory
+#--------------------------------------------
+
+def byte_bounds(a):
+ """(low, high) are pointers to the end-points of an array
+
+ low is the first byte
+ high is just *past* the last byte
+
+ If the array is not single-segment, then it may not actually
+ use every byte between these bounds.
+
+ The array provided must conform to the Python-side of the array interface
+ """
+ ai = a.__array_interface__
+ a_data = ai['data'][0]
+ astrides = ai['strides']
+ ashape = ai['shape']
+ nd_a = len(ashape)
+ bytes_a = int(ai['typestr'][2:])
+
+ a_low = a_high = a_data
+ if astrides is None: # contiguous case
+ a_high += product(ashape, dtype=int)*bytes_a
+ else:
+ for shape, stride in zip(ashape, astrides):
+ if stride < 0:
+ a_low += (shape-1)*stride
+ else:
+ a_high += (shape-1)*stride
+ a_high += bytes_a
+ return a_low, a_high
+
+
+def may_share_memory(a, b):
+ """Determine if two arrays can share memory
+
+ The memory-bounds of a and b are computed. If they overlap then
+ this function returns True. Otherwise, it returns False.
+
+ A return of True does not necessarily mean that the two arrays
+ share any element. It just means that they *might*.
+ """
+ a_low, a_high = byte_bounds(a)
+ b_low, b_high = byte_bounds(b)
+ if b_low >= a_high or a_low >= b_high:
+ return False
+ return True
+
+#-----------------------------------------------------------------------------
+# Function for output and information on the variables used.
+#-----------------------------------------------------------------------------
+
+
+def who(vardict=None):
+ """Print the Numpy arrays in the given dictionary (or globals() if None).
+ """
+ if vardict is None:
+ frame = sys._getframe().f_back
+ vardict = frame.f_globals
+ sta = []
+ cache = {}
+ for name in vardict.keys():
+ if isinstance(vardict[name],ndarray):
+ var = vardict[name]
+ idv = id(var)
+ if idv in cache.keys():
+ namestr = name + " (%s)" % cache[idv]
+ original=0
+ else:
+ cache[idv] = name
+ namestr = name
+ original=1
+ shapestr = " x ".join(map(str, var.shape))
+ bytestr = str(var.itemsize*product(var.shape))
+ sta.append([namestr, shapestr, bytestr, var.dtype.name,
+ original])
+
+ maxname = 0
+ maxshape = 0
+ maxbyte = 0
+ totalbytes = 0
+ for k in range(len(sta)):
+ val = sta[k]
+ if maxname < len(val[0]):
+ maxname = len(val[0])
+ if maxshape < len(val[1]):
+ maxshape = len(val[1])
+ if maxbyte < len(val[2]):
+ maxbyte = len(val[2])
+ if val[4]:
+ totalbytes += int(val[2])
+
+ if len(sta) > 0:
+ sp1 = max(10,maxname)
+ sp2 = max(10,maxshape)
+ sp3 = max(10,maxbyte)
+ prval = "Name %s Shape %s Bytes %s Type" % (sp1*' ', sp2*' ', sp3*' ')
+ print prval + "\n" + "="*(len(prval)+5) + "\n"
+
+ for k in range(len(sta)):
+ val = sta[k]
+ print "%s %s %s %s %s %s %s" % (val[0], ' '*(sp1-len(val[0])+4),
+ val[1], ' '*(sp2-len(val[1])+5),
+ val[2], ' '*(sp3-len(val[2])+5),
+ val[3])
+ print "\nUpper bound on total bytes = %d" % totalbytes
+ return
+
+#-----------------------------------------------------------------------------
+
+
+# NOTE: pydoc defines a help function which works simliarly to this
+# except it uses a pager to take over the screen.
+
+# combine name and arguments and split to multiple lines of
+# width characters. End lines on a comma and begin argument list
+# indented with the rest of the arguments.
+def _split_line(name, arguments, width):
+ firstwidth = len(name)
+ k = firstwidth
+ newstr = name
+ sepstr = ", "
+ arglist = arguments.split(sepstr)
+ for argument in arglist:
+ if k == firstwidth:
+ addstr = ""
+ else:
+ addstr = sepstr
+ k = k + len(argument) + len(addstr)
+ if k > width:
+ k = firstwidth + 1 + len(argument)
+ newstr = newstr + ",\n" + " "*(firstwidth+2) + argument
+ else:
+ newstr = newstr + addstr + argument
+ return newstr
+
+_namedict = None
+_dictlist = None
+
+# Traverse all module directories underneath globals
+# to see if something is defined
+def _makenamedict(module='numpy'):
+ module = __import__(module, globals(), locals(), [])
+ thedict = {module.__name__:module.__dict__}
+ dictlist = [module.__name__]
+ totraverse = [module.__dict__]
+ while 1:
+ if len(totraverse) == 0:
+ break
+ thisdict = totraverse.pop(0)
+ for x in thisdict.keys():
+ if isinstance(thisdict[x],types.ModuleType):
+ modname = thisdict[x].__name__
+ if modname not in dictlist:
+ moddict = thisdict[x].__dict__
+ dictlist.append(modname)
+ totraverse.append(moddict)
+ thedict[modname] = moddict
+ return thedict, dictlist
+
+def info(object=None,maxwidth=76,output=sys.stdout,toplevel='numpy'):
+ """Get help information for a function, class, or module.
+
+ Example:
+ >>> from numpy import *
+ >>> info(polyval) # doctest: +SKIP
+
+ polyval(p, x)
+
+ Evaluate the polymnomial p at x.
+
+ Description:
+ If p is of length N, this function returns the value:
+ p[0]*(x**N-1) + p[1]*(x**N-2) + ... + p[N-2]*x + p[N-1]
+ """
+ global _namedict, _dictlist
+ import pydoc
+
+ if hasattr(object,'_ppimport_importer') or \
+ hasattr(object, '_ppimport_module'):
+ object = object._ppimport_module
+ elif hasattr(object, '_ppimport_attr'):
+ object = object._ppimport_attr
+
+ if object is None:
+ info(info)
+ elif isinstance(object, ndarray):
+ import numpy.numarray as nn
+ nn.info(object, output=output, numpy=1)
+ elif isinstance(object, str):
+ if _namedict is None:
+ _namedict, _dictlist = _makenamedict(toplevel)
+ numfound = 0
+ objlist = []
+ for namestr in _dictlist:
+ try:
+ obj = _namedict[namestr][object]
+ if id(obj) in objlist:
+ print >> output, "\n *** Repeat reference found in %s *** " % namestr
+ else:
+ objlist.append(id(obj))
+ print >> output, " *** Found in %s ***" % namestr
+ info(obj)
+ print >> output, "-"*maxwidth
+ numfound += 1
+ except KeyError:
+ pass
+ if numfound == 0:
+ print >> output, "Help for %s not found." % object
+ else:
+ print >> output, "\n *** Total of %d references found. ***" % numfound
+
+ elif inspect.isfunction(object):
+ name = object.func_name
+ arguments = inspect.formatargspec(*inspect.getargspec(object))
+
+ if len(name+arguments) > maxwidth:
+ argstr = _split_line(name, arguments, maxwidth)
+ else:
+ argstr = name + arguments
+
+ print >> output, " " + argstr + "\n"
+ print >> output, inspect.getdoc(object)
+
+ elif inspect.isclass(object):
+ name = object.__name__
+ arguments = "()"
+ try:
+ if hasattr(object, '__init__'):
+ arguments = inspect.formatargspec(*inspect.getargspec(object.__init__.im_func))
+ arglist = arguments.split(', ')
+ if len(arglist) > 1:
+ arglist[1] = "("+arglist[1]
+ arguments = ", ".join(arglist[1:])
+ except:
+ pass
+
+ if len(name+arguments) > maxwidth:
+ argstr = _split_line(name, arguments, maxwidth)
+ else:
+ argstr = name + arguments
+
+ print >> output, " " + argstr + "\n"
+ doc1 = inspect.getdoc(object)
+ if doc1 is None:
+ if hasattr(object,'__init__'):
+ print >> output, inspect.getdoc(object.__init__)
+ else:
+ print >> output, inspect.getdoc(object)
+
+ methods = pydoc.allmethods(object)
+ if methods != []:
+ print >> output, "\n\nMethods:\n"
+ for meth in methods:
+ if meth[0] == '_':
+ continue
+ thisobj = getattr(object, meth, None)
+ if thisobj is not None:
+ methstr, other = pydoc.splitdoc(inspect.getdoc(thisobj) or "None")
+ print >> output, " %s -- %s" % (meth, methstr)
+
+ elif type(object) is types.InstanceType: ## check for __call__ method
+ print >> output, "Instance of class: ", object.__class__.__name__
+ print >> output
+ if hasattr(object, '__call__'):
+ arguments = inspect.formatargspec(*inspect.getargspec(object.__call__.im_func))
+ arglist = arguments.split(', ')
+ if len(arglist) > 1:
+ arglist[1] = "("+arglist[1]
+ arguments = ", ".join(arglist[1:])
+ else:
+ arguments = "()"
+
+ if hasattr(object,'name'):
+ name = "%s" % object.name
+ else:
+ name = "<name>"
+ if len(name+arguments) > maxwidth:
+ argstr = _split_line(name, arguments, maxwidth)
+ else:
+ argstr = name + arguments
+
+ print >> output, " " + argstr + "\n"
+ doc = inspect.getdoc(object.__call__)
+ if doc is not None:
+ print >> output, inspect.getdoc(object.__call__)
+ print >> output, inspect.getdoc(object)
+
+ else:
+ print >> output, inspect.getdoc(object)
+
+ elif inspect.ismethod(object):
+ name = object.__name__
+ arguments = inspect.formatargspec(*inspect.getargspec(object.im_func))
+ arglist = arguments.split(', ')
+ if len(arglist) > 1:
+ arglist[1] = "("+arglist[1]
+ arguments = ", ".join(arglist[1:])
+ else:
+ arguments = "()"
+
+ if len(name+arguments) > maxwidth:
+ argstr = _split_line(name, arguments, maxwidth)
+ else:
+ argstr = name + arguments
+
+ print >> output, " " + argstr + "\n"
+ print >> output, inspect.getdoc(object)
+
+ elif hasattr(object, '__doc__'):
+ print >> output, inspect.getdoc(object)
+
+
+def source(object, output=sys.stdout):
+ """Write source for this object to output.
+ """
+ try:
+ print >> output, "In file: %s\n" % inspect.getsourcefile(object)
+ print >> output, inspect.getsource(object)
+ except:
+ print >> output, "Not available for this object."
+
+#-----------------------------------------------------------------------------
+
+# The following SafeEval class and company are adapted from Michael Spencer's
+# ASPN Python Cookbook recipe:
+# http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/364469
+# Accordingly it is mostly Copyright 2006 by Michael Spencer.
+# The recipe, like most of the other ASPN Python Cookbook recipes was made
+# available under the Python license.
+# http://www.python.org/license
+
+# It has been modified to:
+# * handle unary -/+
+# * support True/False/None
+# * raise SyntaxError instead of a custom exception.
+
+class SafeEval(object):
+
+ def visit(self, node, **kw):
+ cls = node.__class__
+ meth = getattr(self,'visit'+cls.__name__,self.default)
+ return meth(node, **kw)
+
+ def default(self, node, **kw):
+ raise SyntaxError("Unsupported source construct: %s" % node.__class__)
+
+ def visitExpression(self, node, **kw):
+ for child in node.getChildNodes():
+ return self.visit(child, **kw)
+
+ def visitConst(self, node, **kw):
+ return node.value
+
+ def visitDict(self, node,**kw):
+ return dict([(self.visit(k),self.visit(v)) for k,v in node.items])
+
+ def visitTuple(self, node, **kw):
+ return tuple([self.visit(i) for i in node.nodes])
+
+ def visitList(self, node, **kw):
+ return [self.visit(i) for i in node.nodes]
+
+ def visitUnaryAdd(self, node, **kw):
+ return +self.visit(node.getChildNodes()[0])
+
+ def visitUnarySub(self, node, **kw):
+ return -self.visit(node.getChildNodes()[0])
+
+ def visitName(self, node, **kw):
+ if node.name == 'False':
+ return False
+ elif node.name == 'True':
+ return True
+ elif node.name == 'None':
+ return None
+ else:
+ raise SyntaxError("Unknown name: %s" % node.name)
+
+def safe_eval(source):
+ """ Evaluate a string containing a Python literal expression without
+ allowing the execution of arbitrary non-literal code.
+
+ Parameters
+ ----------
+ source : str
+
+ Returns
+ -------
+ obj : object
+
+ Raises
+ ------
+ SyntaxError if the code is invalid Python expression syntax or if it
+ contains non-literal code.
+
+ Examples
+ --------
+ >>> from numpy.lib.utils import safe_eval
+ >>> safe_eval('1')
+ 1
+ >>> safe_eval('[1, 2, 3]')
+ [1, 2, 3]
+ >>> safe_eval('{"foo": ("bar", 10.0)}')
+ {'foo': ('bar', 10.0)}
+ >>> safe_eval('import os')
+ Traceback (most recent call last):
+ ...
+ SyntaxError: invalid syntax
+ >>> safe_eval('open("/home/user/.ssh/id_dsa").read()')
+ Traceback (most recent call last):
+ ...
+ SyntaxError: Unsupported source construct: compiler.ast.CallFunc
+ >>> safe_eval('dict')
+ Traceback (most recent call last):
+ ...
+ SyntaxError: Unknown name: dict
+ """
+ walker = SafeEval()
+ try:
+ ast = compiler.parse(source, "eval")
+ except SyntaxError, err:
+ raise
+ try:
+ return walker.visit(ast)
+ except SyntaxError, err:
+ raise
+
+#-----------------------------------------------------------------------------
+
+