diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-01-05 01:14:35 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-01-05 01:14:35 +0000 |
commit | 2f01cc8b3c368242224f7ff63e1e5343cf890e9c (patch) | |
tree | a39e9c290cea3277ffbc8bb4420aec4374879dcf /numpy/_import_tools.py | |
parent | f115fdf57332da17fd43202a1512b02902a77048 (diff) | |
download | numpy-2f01cc8b3c368242224f7ff63e1e5343cf890e9c.tar.gz |
Move package loader to _import_tools.py
Diffstat (limited to 'numpy/_import_tools.py')
-rw-r--r-- | numpy/_import_tools.py | 265 |
1 files changed, 265 insertions, 0 deletions
diff --git a/numpy/_import_tools.py b/numpy/_import_tools.py index f2feab88a..3139c1d63 100644 --- a/numpy/_import_tools.py +++ b/numpy/_import_tools.py @@ -146,3 +146,268 @@ class PackageImport: print >> sys.stderr, msg return self._format_titles(titles) + +class PackageLoader: + def __init__(self): + """ Manages loading NumPy packages. + """ + + self.parent_frame = frame = sys._getframe(1) + self.parent_name = eval('__name__',frame.f_globals,frame.f_locals) + self.parent_path = eval('__path__',frame.f_globals,frame.f_locals) + if not frame.f_locals.has_key('__all__'): + exec('__all__ = []',frame.f_globals,frame.f_locals) + self.parent_export_names = eval('__all__',frame.f_globals,frame.f_locals) + + self.info_modules = None + self.imported_packages = [] + self.verbose = None + + def _get_info_files(self, package_dir, parent_path, parent_package=None): + """ Return list of (package name,info.py file) from parent_path subdirectories. + """ + from glob import glob + files = glob(os.path.join(parent_path,package_dir,'info.py')) + for info_file in glob(os.path.join(parent_path,package_dir,'info.pyc')): + if info_file[:-1] not in files: + files.append(info_file) + info_files = [] + for info_file in files: + package_name = os.path.dirname(info_file[len(parent_path)+1:])\ + .replace(os.sep,'.') + if parent_package: + package_name = parent_package + '.' + package_name + info_files.append((package_name,info_file)) + info_files.extend(self._get_info_files('*', + os.path.dirname(info_file), + package_name)) + return info_files + + def _init_info_modules(self, packages=None): + """Initialize info_modules = {<package_name>: <package info.py module>}. + """ + import imp + info_files = [] + if packages is None: + for path in self.parent_path: + info_files.extend(self._get_info_files('*',path)) + else: + for package_name in packages: + package_dir = os.path.join(*package_name.split('.')) + for path in self.parent_path: + names_files = self._get_info_files(package_dir, path) + if names_files: + info_files.extend(names_files) + break + else: + self.warn('Package %r does not have info.py file. Ignoring.'\ + % package_name) + + info_modules = self.info_modules + for package_name,info_file in info_files: + if info_modules.has_key(package_name): + continue + fullname = self.parent_name +'.'+ package_name + if info_file[-1]=='c': + filedescriptor = ('.pyc','rb',2) + else: + filedescriptor = ('.py','U',1) + + try: + info_module = imp.load_module(fullname+'.info', + open(info_file,filedescriptor[1]), + info_file, + filedescriptor) + except Exception,msg: + self.error(msg) + info_module = None + + if info_module is None or getattr(info_module,'ignore',False): + info_modules.pop(package_name,None) + else: + self._init_info_modules(getattr(info_module,'depends',[])) + info_modules[package_name] = info_module + + return + + def _get_sorted_names(self): + """ Return package names sorted in the order as they should be + imported due to dependence relations between packages. + """ + + depend_dict = {} + for name,info_module in self.info_modules.items(): + depend_dict[name] = getattr(info_module,'depends',[]) + package_names = [] + + for name in depend_dict.keys(): + if not depend_dict[name]: + package_names.append(name) + del depend_dict[name] + + while depend_dict: + for name, lst in depend_dict.items(): + new_lst = [n for n in lst if depend_dict.has_key(n)] + if not new_lst: + package_names.append(name) + del depend_dict[name] + else: + depend_dict[name] = new_lst + + return package_names + + def __call__(self,*packages, **options): + """Load one or more packages into numpy's top-level namespace. + + Usage: + + This function is intended to shorten the need to import many of numpy's + submodules constantly with statements such as + + import numpy.linalg, numpy.fft, numpy.etc... + + Instead, you can say: + + import numpy + numpy.pkgload('linalg','fft',...) + + or + + numpy.pkgload() + + to load all of them in one call. + + If a name which doesn't exist in numpy's namespace is + given, an exception [[WHAT? ImportError, probably?]] is raised. + [NotImplemented] + + Inputs: + + - the names (one or more strings) of all the numpy modules one wishes to + load into the top-level namespace. + + Optional keyword inputs: + + - verbose - integer specifying verbosity level [default: 0]. + - force - when True, force reloading loaded packages [default: False]. + - postpone - when True, don't load packages [default: False] + + If no input arguments are given, then all of numpy's subpackages are + imported. + + + Outputs: + + The function returns a tuple with all the names of the modules which + were actually imported. [NotImplemented] + + """ + frame = self.parent_frame + self.info_modules = {} + if options.get('force',False): + self.imported_packages = [] + self.verbose = verbose = options.get('verbose',False) + postpone = options.get('postpone',False) + + self._init_info_modules(packages or None) + + self.log('Imports to %r namespace\n----------------------------'\ + % self.parent_name) + + for package_name in self._get_sorted_names(): + if package_name in self.imported_packages: + continue + info_module = self.info_modules[package_name] + global_symbols = getattr(info_module,'global_symbols',[]) + if postpone and not global_symbols: + self.log('__all__.append(%r)' % (package_name)) + if '.' not in package_name: + self.parent_export_names.append(package_name) + continue + + old_object = frame.f_locals.get(package_name,None) + + cmdstr = 'import '+package_name + if self._execcmd(cmdstr): + continue + self.imported_packages.append(package_name) + + if verbose!=-1: + new_object = frame.f_locals.get(package_name) + if old_object is not None and old_object is not new_object: + self.warn('Overwriting %s=%s (was %s)' \ + % (package_name,self._obj2str(new_object), + self._obj2str(old_object))) + + if '.' not in package_name: + self.parent_export_names.append(package_name) + + for symbol in global_symbols: + if symbol=='*': + symbols = eval('getattr(%s,"__all__",None)'\ + % (package_name), + frame.f_globals,frame.f_locals) + if symbols is None: + symbols = eval('dir(%s)' % (package_name), + frame.f_globals,frame.f_locals) + symbols = filter(lambda s:not s.startswith('_'),symbols) + else: + symbols = [symbol] + + if verbose!=-1: + old_objects = {} + for s in symbols: + if frame.f_locals.has_key(s): + old_objects[s] = frame.f_locals[s] + + cmdstr = 'from '+package_name+' import '+symbol + if self._execcmd(cmdstr): + continue + + if verbose!=-1: + for s,old_object in old_objects.items(): + new_object = frame.f_locals[s] + if new_object is not old_object: + self.warn('Overwriting %s=%s (was %s)' \ + % (s,self._obj2repr(new_object), + self._obj2repr(old_object))) + + if symbol=='*': + self.parent_export_names.extend(symbols) + else: + self.parent_export_names.append(symbol) + + return + + def _execcmd(self,cmdstr): + """ Execute command in parent_frame.""" + frame = self.parent_frame + try: + exec (cmdstr, frame.f_globals,frame.f_locals) + except Exception,msg: + self.error('%s -> failed: %s' % (cmdstr,msg)) + return True + else: + self.log('%s -> success' % (cmdstr)) + return + + def _obj2repr(self,obj): + """ Return repr(obj) with""" + module = getattr(obj,'__module__',None) + file = getattr(obj,'__file__',None) + if module is not None: + return repr(obj) + ' from ' + module + if file is not None: + return repr(obj) + ' from ' + file + return repr(obj) + + def log(self,mess): + if self.verbose>1: + print >> sys.stderr, str(mess) + def warn(self,mess): + if self.verbose>=0: + print >> sys.stderr, str(mess) + def error(self,mess): + if self.verbose!=-1: + print >> sys.stderr, str(mess) + |