diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-08-29 18:29:53 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-08-29 18:29:53 +0000 |
commit | f21c08b6b4eed501a54bcb4522877ce716726e75 (patch) | |
tree | 5a0f2dcfb85b668c1769eacb22e97e5026911cd6 /numpy/oldnumeric | |
parent | 63326185c04be2795a7a33028d427e9c8c31c900 (diff) | |
download | numpy-f21c08b6b4eed501a54bcb4522877ce716726e75.tar.gz |
Add a module/script to fix the default axis issue for code already converted to NumPy
Diffstat (limited to 'numpy/oldnumeric')
-rw-r--r-- | numpy/oldnumeric/fix_default_axis.py | 292 |
1 files changed, 292 insertions, 0 deletions
diff --git a/numpy/oldnumeric/fix_default_axis.py b/numpy/oldnumeric/fix_default_axis.py new file mode 100644 index 000000000..d9e1c9473 --- /dev/null +++ b/numpy/oldnumeric/fix_default_axis.py @@ -0,0 +1,292 @@ +""" +This module adds the default axis argument to code which did not specify it +for the functions where the default was changed in NumPy. + +The functions changed are + +add -1 ( all second argument) +====== +nansum +nanmax +nanmin +nanargmax +nanargmin +argmax +argmin +compress 3 + + +add 0 +====== +take 3 +repeat 3 +sum # might cause problems with builtin. +product +sometrue +alltrue +cumsum +cumproduct +average +ptp +cumprod +prod +std +mean +""" +__all__ = ['convertfile', 'convertall', 'converttree', + 'convertfile2','convertall2', 'converttree2'] + +import sys +import os +import re +import glob + + +_args3 = ['compress', 'take', 'repeat'] +_funcm1 = ['nansum', 'nanmax', 'nanmin', 'nanargmax', 'nanargmin', + 'argmax', 'argmin', 'compress'] +_func0 = ['take', 'repeat', 'sum', 'product', 'sometrue', 'alltrue', + 'cumsum', 'cumproduct', 'average', 'ptp', 'cumprod', 'prod', + 'std', 'mean'] + +_all = _func0 + _funcm1 +func_re = {} + +for name in _all: + _astr = r"""%s\s*[(]"""%name + func_re[name] = re.compile(_astr) + + +import string +disallowed = '_' + string.uppercase + string.lowercase + string.digits + +def _add_axis(fstr, name, repl): + alter = 0 + if name in _args3: + allowed_comma = 1 + else: + allowed_comma = 0 + newcode = "" + last = 0 + for obj in func_re[name].finditer(fstr): + nochange = 0 + start, end = obj.span() + if fstr[start-1] in disallowed: + continue + if fstr[start-1] == '.' \ + and fstr[start-6:start-1] != 'numpy' \ + and fstr[start-2:start-1] != 'N' \ + and fstr[start-9:start-1] != 'numarray' \ + and fstr[start-8:start-1] != 'numerix' \ + and fstr[start-8:start-1] != 'Numeric': + continue + if fstr[start-1] in ['\t',' ']: + k = start-2 + while fstr[k] in ['\t',' ']: + k -= 1 + if fstr[k-2:k+1] == 'def' or \ + fstr[k-4:k+1] == 'class': + continue + k = end + stack = 1 + ncommas = 0 + N = len(fstr) + while stack: + if k>=N: + nochange =1 + break + if fstr[k] == ')': + stack -= 1 + elif fstr[k] == '(': + stack += 1 + elif stack == 1 and fstr[k] == ',': + ncommas += 1 + if ncommas > allowed_comma: + nochange = 1 + break + k += 1 + if nochange: + continue + alter += 1 + newcode = "%s%s,%s)" % (newcode, fstr[last:k-1], repl) + last = k + if not alter: + newcode = fstr + else: + newcode = "%s%s" % (newcode, fstr[last:]) + return newcode, alter + +def _import_change(fstr, names): + # Four possibilities + # 1.) import numpy with subsequent use of numpy.<name> + # change this to import numpy.oldnumeric as numpy + # 2.) import numpy as XXXX with subsequent use of + # XXXX.<name> ==> import numpy.oldnumeric as XXXX + # 3.) from numpy import * + # with subsequent use of one of the names + # 4.) from numpy import ..., <name>, ... (could span multiple + # lines. ==> remove all names from list and + # add from numpy.oldnumeric import <name> + + num = 0 + # case 1 + importstr = "import numpy" + ind = fstr.find(importstr) + if (ind > 0): + found = 0 + for name in names: + ind2 = fstr.find("numpy.%s" % name, ind) + if (ind2 > 0): + found = 1 + break + if found: + fstr = "%s%s%s" % (fstr[:ind], "import numpy.oldnumeric as numpy", + fstr[ind+len(importstr):]) + num += 1 + + # case 2 + importre = re.compile("""import numpy as ([A-Za-z0-9_]+)""") + modules = importre.findall(fstr) + if len(modules) > 0: + for module in modules: + found = 0 + for name in names: + ind2 = fstr.find("%s.%s" % (module, name)) + if (ind2 > 0): + found = 1 + break + if found: + importstr = "import numpy as %s" % module + ind = fstr.find(importstr) + fstr = "%s%s%s" % (fstr[:ind], + "import numpy.oldnumeric as %s" % module, + fstr[ind+len(importstr):]) + num += 1 + + # case 3 + importstr = "from numpy import *" + ind = fstr.find(importstr) + if (ind > 0): + found = 0 + for name in names: + ind2 = fstr.find(name, ind) + if (ind2 > 0) and fstr[ind2-1] not in disallowed: + found = 1 + break + if found: + fstr = "%s%s%s" % (fstr[:ind], + "from numpy.oldnumeric import *", + fstr[ind+len(importstr):]) + num += 1 + + # case 4 + ind = 0 + importstr = "from numpy import" + N = len(importstr) + while 1: + ind = fstr.find(importstr, ind) + if (ind < 0): + break + ind += N + ptr = ind+1 + stack = 1 + while stack: + if fstr[ptr] == '\\': + stack += 1 + elif fstr[ptr] == '\n': + stack -= 1 + ptr += 1 + substr = fstr[ind:ptr] + found = 0 + substr = substr.replace('\n',' ') + substr = substr.replace('\\','') + importnames = [x.strip() for x in substr.split(',')] + # determine if any of names are in importnames + addnames = [] + for name in names: + if name in importnames: + importnames.remove(name) + addnames.append(name) + if len(addnames) > 0: + fstr = "%s%s\n%s\n%s" % \ + (fstr[:ind], + "from numpy import %s" % \ + ", ".join(importnames), + "from numpy.oldnumeric import %s" % \ + ", ".join(addnames), + fstr[ptr:]) + num += 1 + + return fstr, num + +def add_axis(fstr, import_change=False): + total = 0 + if not import_change: + for name in _funcm1: + fstr, num = _add_axis(fstr, name, 'axis=-1') + total += num + for name in _func0: + fstr, num = _add_axis(fstr, name, 'axis=0') + total += num + return fstr, total + else: + fstr, num = _import_change(fstr, _funcm1+_func0) + return fstr, num + + +def makenewfile(name, filestr): + fid = file(name, 'w') + fid.write(filestr) + fid.close() + +def getfile(name): + fid = file(name) + filestr = fid.read() + fid.close() + return filestr + +def copyfile(name, fstr): + base, ext = os.path.splitext(name) + makenewfile(base+'.orig', fstr) + return + +def convertfile(filename, import_change=False): + """Convert the filename given from using Numeric to using NumPy + + Copies the file to filename.orig and then over-writes the file + with the updated code + """ + filestr = getfile(filename) + newstr, total = add_axis(filestr, import_change) + if total > 0: + print "Changing ", filename + copyfile(filename, filestr) + makenewfile(filename, newstr) + sys.stdout.flush() + +def fromargs(args): + filename = args[1] + convertfile(filename) + +def convertall(direc=os.path.curdir, import_change=False): + """Convert all .py files in the directory given + + For each file, a backup of <usesnumeric>.py is made as + <usesnumeric>.py.orig. A new file named <usesnumeric>.py + is then written with the updated code. + """ + files = glob.glob(os.path.join(direc,'*.py')) + for afile in files: + convertfile(afile, import_change) + +def _func(arg, dirname, fnames): + convertall(dirname, import_change=arg) + +def converttree(direc=os.path.curdir, import_change=False): + """Convert all .py files in the tree given + + """ + os.path.walk(direc, _func, import_change) + +if __name__ == '__main__': + fromargs(sys.argv) |