summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/f2py/lib/statements.py217
-rw-r--r--numpy/f2py/lib/test_parser.py78
2 files changed, 240 insertions, 55 deletions
diff --git a/numpy/f2py/lib/statements.py b/numpy/f2py/lib/statements.py
index 346678654..990fb1de7 100644
--- a/numpy/f2py/lib/statements.py
+++ b/numpy/f2py/lib/statements.py
@@ -8,11 +8,17 @@ from base_classes import Statement
is_name = re.compile(r'\w+\Z').match
-def split_comma(line, item):
+def split_comma(line, item = None, comma=','):
+ items = []
+ if item is None:
+ for s in line.split(comma):
+ s = s.strip()
+ if not s: continue
+ items.append(s)
+ return items
newitem = item.copy(line, True)
apply_map = newitem.apply_map
- items = []
- for s in newitem.get_line().split(','):
+ for s in newitem.get_line().split(comma):
s = apply_map(s).strip()
if not s: continue
items.append(s)
@@ -37,14 +43,20 @@ class StatementWithNamelist(Statement):
"""
def process_item(self):
assert not self.item.has_map()
- clsname = self.__class__.__name__.lower()
+ if hasattr(self,'stmtname'):
+ clsname = self.stmtname
+ else:
+ clsname = self.__class__.__name__.lower()
line = self.item.get_line()[len(clsname):].lstrip()
if line.startswith('::'):
line = line[2:].lstrip()
self.items = [s.strip() for s in line.split(',')]
return
def __str__(self):
- clsname = self.__class__.__name__.upper()
+ if hasattr(self,'stmtname'):
+ clsname = self.stmtname.upper()
+ else:
+ clsname = self.__class__.__name__.upper()
s = ', '.join(self.items)
if s:
s = ' ' + s
@@ -1174,28 +1186,68 @@ class Forall(Statement):
def process_item(self):
line = self.item.get_line()[6:].lstrip()
i = line.index(')')
- self.specs = line[1:i].strip()
+
+ line0 = line[1:i]
line = line[i+1:].lstrip()
- stmt = GeneralAssignment(self, self.item.copy(line))
+ stmt = GeneralAssignment(self, self.item.copy(line, True))
if stmt.isvalid:
self.content = [stmt]
else:
self.isvalid = False
+ return
+
+ specs = []
+ mask = ''
+ for l in split_comma(line0,self.item):
+ j = l.find('=')
+ if j==-1:
+ assert not mask,`mask,l`
+ mask = l
+ continue
+ assert j!=-1,`l`
+ index = l[:j].rstrip()
+ it = self.item.copy(l[j+1:].lstrip())
+ l = it.get_line()
+ k = l.split(':')
+ if len(k)==3:
+ s1, s2, s3 = map(it.apply_map,
+ [k[0].strip(),k[1].strip(),k[2].strip()])
+ else:
+ assert len(k)==2,`k`
+ s1, s2 = map(it.apply_map,
+ [k[0].strip(),k[1].strip()])
+ s3 = '1'
+ specs.append((index,s1,s2,s3))
+
+ self.specs = specs
+ self.mask = mask
return
+
def __str__(self):
tab = self.get_indent_tab()
- return tab + 'FORALL (%s) %s' % (self.specs, str(self.content[0]).lstrip())
+ l = []
+ for index,s1,s2,s3 in self.specs:
+ s = '%s = %s : %s' % (index,s1,s2)
+ if s3!='1':
+ s += ' : %s' % (s3)
+ l.append(s)
+ s = ', '.join(l)
+ if self.mask:
+ s += ', ' + self.mask
+ return tab + 'FORALL (%s) %s' % \
+ (s, str(self.content[0]).lstrip())
ForallStmt = Forall
class SpecificBinding(Statement):
"""
- PROCEDURE [ (<interface-name>) ] [ [ , <binding-attr-list> ] :: ] <binding-name> [ => <procedure-name> ]
+ PROCEDURE [ ( <interface-name> ) ] [ [ , <binding-attr-list> ] :: ] <binding-name> [ => <procedure-name> ]
<binding-attr> = PASS [ ( <arg-name> ) ]
| NOPASS
| NON_OVERRIDABLE
| DEFERRED
| <access-spec>
+ <access-spec> = PUBLIC | PRIVATE
"""
match = re.compile(r'procedure\b',re.I).match
def process_item(self):
@@ -1206,46 +1258,69 @@ class SpecificBinding(Statement):
line = line[i+1:].lstrip()
else:
name = ''
- self.interface_name = name
+ self.iname = name
if line.startswith(','):
line = line[1:].lstrip()
i = line.find('::')
if i != -1:
- attrs = line[:i].rstrip()
+ attrs = split_comma(line[:i], self.item)
line = line[i+2:].lstrip()
else:
- attrs = ''
- self.attrs = attrs
- self.rest = line
+ attrs = []
+ attrs1 = []
+ for attr in attrs:
+ if is_name(attr):
+ attr = attr.upper()
+ else:
+ i = attr.find('(')
+ assert i!=-1 and attr.endswith(')'),`attr`
+ attr = '%s (%s)' % (attr[:i].rstrip().upper(), attr[i+1:-1].strip())
+ attrs1.append(attr)
+ self.attrs = attrs1
+ i = line.find('=')
+ if i==-1:
+ self.name = line
+ self.bname = ''
+ else:
+ self.name = line[:i].rstrip()
+ self.bname = line[i+1:].lstrip()[1:].lstrip()
return
def __str__(self):
tab = self.get_indent_tab()
s = 'PROCEDURE '
- if self.interface_name:
- s += ' (' + self.interface_name + ')'
+ if self.iname:
+ s += '(' + self.iname + ') '
if self.attrs:
- s += ' , ' + self.attrs + ' :: '
- return tab + s + rest
+ s += ', ' + ', '.join(self.attrs) + ' :: '
+ if self.bname:
+ s += '%s => %s' % (self.name, self.bname)
+ else:
+ s += self.name
+ return tab + s
class GenericBinding(Statement):
"""
GENERIC [ , <access-spec> ] :: <generic-spec> => <binding-name-list>
"""
- match = re.compile(r'generic\b.*::.*=.*\Z', re.I).match
+ match = re.compile(r'generic\b.*::.*=\>.*\Z', re.I).match
def process_item(self):
line = self.item.get_line()[7:].lstrip()
if line.startswith(','):
line = line[1:].lstrip()
i = line.index('::')
- self.specs = line[:i].lstrip()
- self.rest = line[i+2:].lstrip()
+ self.aspec = line[:i].rstrip().upper()
+ line = line[i+2:].lstrip()
+ i = line.index('=>')
+ self.spec = self.item.apply_map(line[:i].rstrip())
+ self.items = split_comma(line[i+2:].lstrip())
return
+
def __str__(self):
tab = self.get_indent_tab()
s = 'GENERIC'
- if self.specs:
- s += ', '+self.specs
- s += ' :: ' + self.rest
+ if self.aspec:
+ s += ', '+self.aspec
+ s += ' :: ' + self.spec + ' => ' + ', '.join(self.items)
return tab + s
@@ -1253,6 +1328,7 @@ class FinalBinding(StatementWithNamelist):
"""
FINAL [ :: ] <final-subroutine-name-list>
"""
+ stmtname = 'final'
match = re.compile(r'final\b', re.I).match
class Allocatable(Statement):
@@ -1264,14 +1340,10 @@ class Allocatable(Statement):
line = self.item.get_line()[11:].lstrip()
if line.startswith('::'):
line = line[2:].lstrip()
- items = []
- for s in line.split(','):
- s = s.strip()
- items.append(s)
- self.items = items
+ self.items = split_comma(line, self.item)
return
def __str__(self):
- return self.get_tab_indent() + 'ALLOCATABLE ' + ', '.join(self.items)
+ return self.get_indent_tab() + 'ALLOCATABLE ' + ', '.join(self.items)
class Asynchronous(StatementWithNamelist):
"""
@@ -1285,12 +1357,30 @@ class Bind(Statement):
<language-binding-spec> = BIND ( C [ , NAME = <scalar-char-initialization-expr> ] )
<bind-entity> = <entity-name> | / <common-block-name> /
"""
- match = re.compile(r'bind\s*\(.*\)\Z',re.I).match
+ match = re.compile(r'bind\s*\(.*\)',re.I).match
def process_item(self):
- self.value = self.item.get_line()[4].lstrip()[1:-1].strip()
+ line = self.item.get_line()[4:].lstrip()
+ specs = []
+ for spec in specs_split_comma(line[1:line.index(')')].strip(), self.item):
+ if is_name(spec):
+ specs.append(spec.upper())
+ else:
+ specs.append(spec)
+ self.specs = specs
+ line = line[line.index(')')+1:].lstrip()
+ if line.startswith('::'):
+ line = line[2:].lstrip()
+ items = []
+ for item in split_comma(line, self.item):
+ if item.startswith('/'):
+ assert item.endswith('/'),`item`
+ item = '/ ' + item[1:-1].strip() + ' /'
+ items.append(item)
+ self.items = items
return
def __str__(self):
- return self.get_indent_tab() + 'BIND (%s)' % (self.value)
+ return self.get_indent_tab() + 'BIND (%s) %s' %\
+ (', '.join(self.specs), ', '.join(self.items))
# IF construct statements
@@ -1303,22 +1393,25 @@ class Else(Statement):
def process_item(self):
item = self.item
self.name = item.get_line()[4:].strip()
- if self.name and not self.name==self.parent.name:
+ parent_name = getattr(self.parent,'name','')
+ if self.name and self.name!=parent_name:
message = self.reader.format_message(\
'WARNING',
'expected if-construct-name %r but got %r, skipping.'\
- % (self.parent.name, self.name),
+ % (parent_name, self.name),
item.span[0],item.span[1])
print >> sys.stderr, message
self.isvalid = False
return
def __str__(self):
- return self.get_indent_tab(deindent=True) + 'ELSE ' + self.name
+ if self.name:
+ return self.get_indent_tab(deindent=True) + 'ELSE ' + self.name
+ return self.get_indent_tab(deindent=True) + 'ELSE'
class ElseIf(Statement):
"""
- ELSE IF ( <scalar-logical-expr> ) THEN [<if-construct-name>]
+ ELSE IF ( <scalar-logical-expr> ) THEN [ <if-construct-name> ]
"""
match = re.compile(r'else\s*if\s*\(.*\)\s*then\s*\w*\s*\Z',re.I).match
@@ -1327,21 +1420,25 @@ class ElseIf(Statement):
line = item.get_line()[4:].lstrip()[2:].lstrip()
i = line.find(')')
assert line[0]=='('
- self.expr = line[1:i]
+ self.expr = item.apply_map(line[1:i])
self.name = line[i+1:].lstrip()[4:].strip()
- if self.name and not self.name==self.parent.name:
+ parent_name = getattr(self.parent,'name','')
+ if self.name and self.name!=parent_name:
message = self.reader.format_message(\
'WARNING',
'expected if-construct-name %r but got %r, skipping.'\
- % (self.parent.name, self.name),
+ % (parent_name, self.name),
item.span[0],item.span[1])
- print >> sys.stderr, message
+ self.show_message(message)
self.isvalid = False
return
def __str__(self):
- return self.get_indent_tab(deindent=True) + 'ELSE IF (%s) THEN %s' \
- % (self.expr, self.name)
+ s = ''
+ if self.name:
+ s = ' ' + self.name
+ return self.get_indent_tab(deindent=True) + 'ELSE IF (%s) THEN%s' \
+ % (self.expr, s)
# SelectCase construct statements
@@ -1357,27 +1454,49 @@ class Case(Statement):
"""
match = re.compile(r'case\b\s*(\(.*\)|DEFAULT)\s*\w*\Z',re.I).match
def process_item(self):
- assert self.parent.__class__.__name__=='Select',`self.parent.__class__`
+ #assert self.parent.__class__.__name__=='Select',`self.parent.__class__`
line = self.item.get_line()[4:].lstrip()
if line.startswith('('):
i = line.find(')')
- self.ranges = line[1:i].strip()
+ items = split_comma(line[1:i].strip(), self.item)
line = line[i+1:].lstrip()
else:
assert line.startswith('default'),`line`
- self.ranges = ''
+ items = []
line = line[7:].lstrip()
+ for i in range(len(items)):
+ it = self.item.copy(items[i])
+ rl = []
+ for r in it.get_line().split(':'):
+ rl.append(it.apply_map(r.strip()))
+ items[i] = rl
+ self.items = items
self.name = line
- if self.name and not self.name==self.parent.name:
+ parent_name = getattr(self.parent, 'name', '')
+ if self.name and self.name!=parent_name:
message = self.reader.format_message(\
'WARNING',
'expected case-construct-name %r but got %r, skipping.'\
- % (self.parent.name, self.name),
+ % (parent_name, self.name),
self.item.span[0],self.item.span[1])
- print >> sys.stderr, message
+ self.show_message(message)
self.isvalid = False
return
+ def __str__(self):
+ tab = self.get_indent_tab()
+ s = 'CASE'
+ if self.items:
+ l = []
+ for item in self.items:
+ l.append((' : '.join(item)).strip())
+ s += ' ( %s )' % (', '.join(l))
+ else:
+ s += ' DEFAULT'
+ if self.name:
+ s += ' ' + self.name
+ return s
+
# Where construct statements
class Where(Statement):
diff --git a/numpy/f2py/lib/test_parser.py b/numpy/f2py/lib/test_parser.py
index 1765f50ad..1bd8476e2 100644
--- a/numpy/f2py/lib/test_parser.py
+++ b/numpy/f2py/lib/test_parser.py
@@ -3,18 +3,15 @@ from numpy.testing import *
from block_statements import *
from readfortran import Line, FortranStringReader
-def toLine(line, label=''):
+
+def parse(cls, line, label=''):
if label:
line = label + ' : ' + line
reader = FortranStringReader(line, True, False)
- return reader.next()
-
-def parse(cls, line, label=''):
- item = toLine(line, label=label)
+ item = reader.next()
if not cls.match(item.get_line()):
raise ValueError, '%r does not match %s pattern' % (line, cls.__name__)
stmt = cls(item, item)
-
if stmt.isvalid:
return str(stmt)
raise ValueError, 'parsing %r with %s pattern failed' % (line, cls.__name__)
@@ -306,5 +303,74 @@ class test_Statements(NumpyTestCase):
assert_equal(parse(Import,'import::a'),'IMPORT a')
assert_equal(parse(Import,'import a , b'),'IMPORT a, b')
+ def check_forall(self):
+ assert_equal(parse(ForallStmt,'forall (i = 1:n(k,:) : 2) a(i) = i*i*b(i)'),
+ 'FORALL (i = 1 : n(k,:) : 2) a(i) = i*i*b(i)')
+ assert_equal(parse(ForallStmt,'forall (i=1:n,j=2:3) a(i) = b(i,i)'),
+ 'FORALL (i = 1 : n, j = 2 : 3) a(i) = b(i,i)')
+ assert_equal(parse(ForallStmt,'forall (i=1:n,j=2:3, 1+a(1,2)) a(i) = b(i,i)'),
+ 'FORALL (i = 1 : n, j = 2 : 3, 1+a(1,2)) a(i) = b(i,i)')
+
+ def check_specificbinding(self):
+ assert_equal(parse(SpecificBinding,'procedure a'),'PROCEDURE a')
+ assert_equal(parse(SpecificBinding,'procedure :: a'),'PROCEDURE a')
+ assert_equal(parse(SpecificBinding,'procedure , NOPASS :: a'),'PROCEDURE , NOPASS :: a')
+ assert_equal(parse(SpecificBinding,'procedure , public, pass(x ) :: a'),'PROCEDURE , PUBLIC, PASS (x) :: a')
+ assert_equal(parse(SpecificBinding,'procedure(n) a'),'PROCEDURE (n) a')
+ assert_equal(parse(SpecificBinding,'procedure(n),pass :: a'),
+ 'PROCEDURE (n) , PASS :: a')
+ assert_equal(parse(SpecificBinding,'procedure(n) :: a'),
+ 'PROCEDURE (n) a')
+ assert_equal(parse(SpecificBinding,'procedure a= >b'),'PROCEDURE a => b')
+ assert_equal(parse(SpecificBinding,'procedure(n),pass :: a =>c'),
+ 'PROCEDURE (n) , PASS :: a => c')
+
+ def check_genericbinding(self):
+ assert_equal(parse(GenericBinding,'generic :: a=>b'),'GENERIC :: a => b')
+ assert_equal(parse(GenericBinding,'generic, public :: a=>b'),'GENERIC, PUBLIC :: a => b')
+ assert_equal(parse(GenericBinding,'generic, public :: a(1,2)=>b ,c'),
+ 'GENERIC, PUBLIC :: a(1,2) => b, c')
+
+ def check_finalbinding(self):
+ assert_equal(parse(FinalBinding,'final a'),'FINAL a')
+ assert_equal(parse(FinalBinding,'final::a'),'FINAL a')
+ assert_equal(parse(FinalBinding,'final a , b'),'FINAL a, b')
+
+ def check_allocatable(self):
+ assert_equal(parse(Allocatable,'allocatable a'),'ALLOCATABLE a')
+ assert_equal(parse(Allocatable,'allocatable :: a'),'ALLOCATABLE a')
+ assert_equal(parse(Allocatable,'allocatable a (1,2)'),'ALLOCATABLE a (1,2)')
+ assert_equal(parse(Allocatable,'allocatable a (1,2) ,b'),'ALLOCATABLE a (1,2), b')
+
+ def check_asynchronous(self):
+ assert_equal(parse(Asynchronous,'asynchronous a'),'ASYNCHRONOUS a')
+ assert_equal(parse(Asynchronous,'asynchronous::a'),'ASYNCHRONOUS a')
+ assert_equal(parse(Asynchronous,'asynchronous a , b'),'ASYNCHRONOUS a, b')
+
+ def check_bind(self):
+ assert_equal(parse(Bind,'bind(c) a'),'BIND (C) a')
+ assert_equal(parse(Bind,'bind(c) :: a'),'BIND (C) a')
+ assert_equal(parse(Bind,'bind(c) a ,b'),'BIND (C) a, b')
+ assert_equal(parse(Bind,'bind(c) /a/'),'BIND (C) / a /')
+ assert_equal(parse(Bind,'bind(c) /a/ ,b'),'BIND (C) / a /, b')
+ assert_equal(parse(Bind,'bind(c,name="hey") a'),'BIND (C, NAME = "hey") a')
+
+ def check_else(self):
+ assert_equal(parse(Else,'else'),'ELSE')
+ assert_equal(parse(ElseIf,'else if (a) then'),'ELSE IF (a) THEN')
+ assert_equal(parse(ElseIf,'else if (a.eq.b(1,2)) then'),
+ 'ELSE IF (a.eq.b(1,2)) THEN')
+
+ def check_case(self):
+ assert_equal(parse(Case,'case (1)'),'CASE ( 1 )')
+ assert_equal(parse(Case,'case (1:)'),'CASE ( 1 : )')
+ assert_equal(parse(Case,'case (:1)'),'CASE ( : 1 )')
+ assert_equal(parse(Case,'case (1:2)'),'CASE ( 1 : 2 )')
+ assert_equal(parse(Case,'case (a(1,2))'),'CASE ( a(1,2) )')
+ assert_equal(parse(Case,'case ("ab")'),'CASE ( "ab" )')
+ assert_equal(parse(Case,'case default'),'CASE DEFAULT')
+ assert_equal(parse(Case,'case (1:2 ,3:4)'),'CASE ( 1 : 2, 3 : 4 )')
+ assert_equal(parse(Case,'case (a(1,:):)'),'CASE ( a(1,:) : )')
+
if __name__ == "__main__":
NumpyTest().run()