summaryrefslogtreecommitdiff
path: root/numpy/f2py
diff options
context:
space:
mode:
authorPearu Peterson <pearu.peterson@gmail.com>2008-05-19 22:10:36 +0000
committerPearu Peterson <pearu.peterson@gmail.com>2008-05-19 22:10:36 +0000
commitebecbb1f048e77eea50a99f13e069e4cc81bdd88 (patch)
treeae46b457197fbfb3d88a27a55a9a7012f75d4334 /numpy/f2py
parent399147b0b20e32bbb0af62db9bc1fe1100770063 (diff)
downloadnumpy-ebecbb1f048e77eea50a99f13e069e4cc81bdd88.tar.gz
f2py: Allow expressions that contain max/min calls, be used as dimension specifications. Defined macros min/max that are needed when --lower is used. Typical usage case: real a(min(1,n)).
Diffstat (limited to 'numpy/f2py')
-rw-r--r--numpy/f2py/cfuncs.py2
-rwxr-xr-xnumpy/f2py/crackfortran.py14
2 files changed, 14 insertions, 2 deletions
diff --git a/numpy/f2py/cfuncs.py b/numpy/f2py/cfuncs.py
index 8c6275ae2..a1e8ebff3 100644
--- a/numpy/f2py/cfuncs.py
+++ b/numpy/f2py/cfuncs.py
@@ -228,6 +228,8 @@ cppmacros['PRINTPYOBJERR']="""\
\tfprintf(stderr,\"\\n\");
"""
cppmacros['MINMAX']="""\
+#define max(a,b) ((a > b) ? (a) : (b))
+#define min(a,b) ((a < b) ? (a) : (b))
#ifndef MAX
#define MAX(a,b) ((a > b) ? (a) : (b))
#endif
diff --git a/numpy/f2py/crackfortran.py b/numpy/f2py/crackfortran.py
index 8679b239c..1b3effc84 100755
--- a/numpy/f2py/crackfortran.py
+++ b/numpy/f2py/crackfortran.py
@@ -208,6 +208,7 @@ for n in ['int','double','float','char','short','long','void','case','while',
'struct','static','register','new','break','do','goto','switch',
'continue','else','inline','extern','delete','const','auto',
'len','rank','shape','index','slen','size','_i',
+ 'max', 'min',
'flen','fshape',
'string','complex_double','float_double','stdin','stderr','stdout',
'type','default']:
@@ -1732,7 +1733,7 @@ def myeval(e,g=None,l=None):
r = eval(e,g,l)
if type(r) in [type(0),type(0.0)]:
return r
- raise ValueError,'r=%r' % (r)
+ raise ValueError('r=%r' % (r))
getlincoef_re_1 = re.compile(r'\A\b\w+\b\Z',re.I)
def getlincoef(e,xset): # e = a*x+b ; x in xset
@@ -1745,6 +1746,9 @@ def getlincoef(e,xset): # e = a*x+b ; x in xset
len_e = len(e)
for x in xset:
if len(x)>len_e: continue
+ if re.search(r'\w\s*\([^)]*\b'+x+r'\b', e):
+ # skip function calls having x as an argument, e.g max(1, x)
+ continue
re_1 = re.compile(r'(?P<before>.*?)\b'+x+r'\b(?P<after>.*)',re.I)
m = re_1.match(e)
if m:
@@ -1764,7 +1768,13 @@ def getlincoef(e,xset): # e = a*x+b ; x in xset
ee = '%s(%s)%s'%(m1.group('before'),0.5,m1.group('after'))
m1 = re_1.match(ee)
c = myeval(ee,{},{})
- if (a*0.5+b==c):
+ # computing another point to be sure that expression is linear
+ m1 = re_1.match(e)
+ while m1:
+ ee = '%s(%s)%s'%(m1.group('before'),1.5,m1.group('after'))
+ m1 = re_1.match(ee)
+ c2 = myeval(ee,{},{})
+ if (a*0.5+b==c and a*1.5+b==c2):
return a,b,x
except: pass
break