diff options
-rw-r--r-- | numpy/f2py/lib/statements.py | 217 | ||||
-rw-r--r-- | numpy/f2py/lib/test_parser.py | 78 |
2 files changed, 240 insertions, 55 deletions
diff --git a/numpy/f2py/lib/statements.py b/numpy/f2py/lib/statements.py index 346678654..990fb1de7 100644 --- a/numpy/f2py/lib/statements.py +++ b/numpy/f2py/lib/statements.py @@ -8,11 +8,17 @@ from base_classes import Statement is_name = re.compile(r'\w+\Z').match -def split_comma(line, item): +def split_comma(line, item = None, comma=','): + items = [] + if item is None: + for s in line.split(comma): + s = s.strip() + if not s: continue + items.append(s) + return items newitem = item.copy(line, True) apply_map = newitem.apply_map - items = [] - for s in newitem.get_line().split(','): + for s in newitem.get_line().split(comma): s = apply_map(s).strip() if not s: continue items.append(s) @@ -37,14 +43,20 @@ class StatementWithNamelist(Statement): """ def process_item(self): assert not self.item.has_map() - clsname = self.__class__.__name__.lower() + if hasattr(self,'stmtname'): + clsname = self.stmtname + else: + clsname = self.__class__.__name__.lower() line = self.item.get_line()[len(clsname):].lstrip() if line.startswith('::'): line = line[2:].lstrip() self.items = [s.strip() for s in line.split(',')] return def __str__(self): - clsname = self.__class__.__name__.upper() + if hasattr(self,'stmtname'): + clsname = self.stmtname.upper() + else: + clsname = self.__class__.__name__.upper() s = ', '.join(self.items) if s: s = ' ' + s @@ -1174,28 +1186,68 @@ class Forall(Statement): def process_item(self): line = self.item.get_line()[6:].lstrip() i = line.index(')') - self.specs = line[1:i].strip() + + line0 = line[1:i] line = line[i+1:].lstrip() - stmt = GeneralAssignment(self, self.item.copy(line)) + stmt = GeneralAssignment(self, self.item.copy(line, True)) if stmt.isvalid: self.content = [stmt] else: self.isvalid = False + return + + specs = [] + mask = '' + for l in split_comma(line0,self.item): + j = l.find('=') + if j==-1: + assert not mask,`mask,l` + mask = l + continue + assert j!=-1,`l` + index = l[:j].rstrip() + it = self.item.copy(l[j+1:].lstrip()) + l = it.get_line() + k = l.split(':') + if len(k)==3: + s1, s2, s3 = map(it.apply_map, + [k[0].strip(),k[1].strip(),k[2].strip()]) + else: + assert len(k)==2,`k` + s1, s2 = map(it.apply_map, + [k[0].strip(),k[1].strip()]) + s3 = '1' + specs.append((index,s1,s2,s3)) + + self.specs = specs + self.mask = mask return + def __str__(self): tab = self.get_indent_tab() - return tab + 'FORALL (%s) %s' % (self.specs, str(self.content[0]).lstrip()) + l = [] + for index,s1,s2,s3 in self.specs: + s = '%s = %s : %s' % (index,s1,s2) + if s3!='1': + s += ' : %s' % (s3) + l.append(s) + s = ', '.join(l) + if self.mask: + s += ', ' + self.mask + return tab + 'FORALL (%s) %s' % \ + (s, str(self.content[0]).lstrip()) ForallStmt = Forall class SpecificBinding(Statement): """ - PROCEDURE [ (<interface-name>) ] [ [ , <binding-attr-list> ] :: ] <binding-name> [ => <procedure-name> ] + PROCEDURE [ ( <interface-name> ) ] [ [ , <binding-attr-list> ] :: ] <binding-name> [ => <procedure-name> ] <binding-attr> = PASS [ ( <arg-name> ) ] | NOPASS | NON_OVERRIDABLE | DEFERRED | <access-spec> + <access-spec> = PUBLIC | PRIVATE """ match = re.compile(r'procedure\b',re.I).match def process_item(self): @@ -1206,46 +1258,69 @@ class SpecificBinding(Statement): line = line[i+1:].lstrip() else: name = '' - self.interface_name = name + self.iname = name if line.startswith(','): line = line[1:].lstrip() i = line.find('::') if i != -1: - attrs = line[:i].rstrip() + attrs = split_comma(line[:i], self.item) line = line[i+2:].lstrip() else: - attrs = '' - self.attrs = attrs - self.rest = line + attrs = [] + attrs1 = [] + for attr in attrs: + if is_name(attr): + attr = attr.upper() + else: + i = attr.find('(') + assert i!=-1 and attr.endswith(')'),`attr` + attr = '%s (%s)' % (attr[:i].rstrip().upper(), attr[i+1:-1].strip()) + attrs1.append(attr) + self.attrs = attrs1 + i = line.find('=') + if i==-1: + self.name = line + self.bname = '' + else: + self.name = line[:i].rstrip() + self.bname = line[i+1:].lstrip()[1:].lstrip() return def __str__(self): tab = self.get_indent_tab() s = 'PROCEDURE ' - if self.interface_name: - s += ' (' + self.interface_name + ')' + if self.iname: + s += '(' + self.iname + ') ' if self.attrs: - s += ' , ' + self.attrs + ' :: ' - return tab + s + rest + s += ', ' + ', '.join(self.attrs) + ' :: ' + if self.bname: + s += '%s => %s' % (self.name, self.bname) + else: + s += self.name + return tab + s class GenericBinding(Statement): """ GENERIC [ , <access-spec> ] :: <generic-spec> => <binding-name-list> """ - match = re.compile(r'generic\b.*::.*=.*\Z', re.I).match + match = re.compile(r'generic\b.*::.*=\>.*\Z', re.I).match def process_item(self): line = self.item.get_line()[7:].lstrip() if line.startswith(','): line = line[1:].lstrip() i = line.index('::') - self.specs = line[:i].lstrip() - self.rest = line[i+2:].lstrip() + self.aspec = line[:i].rstrip().upper() + line = line[i+2:].lstrip() + i = line.index('=>') + self.spec = self.item.apply_map(line[:i].rstrip()) + self.items = split_comma(line[i+2:].lstrip()) return + def __str__(self): tab = self.get_indent_tab() s = 'GENERIC' - if self.specs: - s += ', '+self.specs - s += ' :: ' + self.rest + if self.aspec: + s += ', '+self.aspec + s += ' :: ' + self.spec + ' => ' + ', '.join(self.items) return tab + s @@ -1253,6 +1328,7 @@ class FinalBinding(StatementWithNamelist): """ FINAL [ :: ] <final-subroutine-name-list> """ + stmtname = 'final' match = re.compile(r'final\b', re.I).match class Allocatable(Statement): @@ -1264,14 +1340,10 @@ class Allocatable(Statement): line = self.item.get_line()[11:].lstrip() if line.startswith('::'): line = line[2:].lstrip() - items = [] - for s in line.split(','): - s = s.strip() - items.append(s) - self.items = items + self.items = split_comma(line, self.item) return def __str__(self): - return self.get_tab_indent() + 'ALLOCATABLE ' + ', '.join(self.items) + return self.get_indent_tab() + 'ALLOCATABLE ' + ', '.join(self.items) class Asynchronous(StatementWithNamelist): """ @@ -1285,12 +1357,30 @@ class Bind(Statement): <language-binding-spec> = BIND ( C [ , NAME = <scalar-char-initialization-expr> ] ) <bind-entity> = <entity-name> | / <common-block-name> / """ - match = re.compile(r'bind\s*\(.*\)\Z',re.I).match + match = re.compile(r'bind\s*\(.*\)',re.I).match def process_item(self): - self.value = self.item.get_line()[4].lstrip()[1:-1].strip() + line = self.item.get_line()[4:].lstrip() + specs = [] + for spec in specs_split_comma(line[1:line.index(')')].strip(), self.item): + if is_name(spec): + specs.append(spec.upper()) + else: + specs.append(spec) + self.specs = specs + line = line[line.index(')')+1:].lstrip() + if line.startswith('::'): + line = line[2:].lstrip() + items = [] + for item in split_comma(line, self.item): + if item.startswith('/'): + assert item.endswith('/'),`item` + item = '/ ' + item[1:-1].strip() + ' /' + items.append(item) + self.items = items return def __str__(self): - return self.get_indent_tab() + 'BIND (%s)' % (self.value) + return self.get_indent_tab() + 'BIND (%s) %s' %\ + (', '.join(self.specs), ', '.join(self.items)) # IF construct statements @@ -1303,22 +1393,25 @@ class Else(Statement): def process_item(self): item = self.item self.name = item.get_line()[4:].strip() - if self.name and not self.name==self.parent.name: + parent_name = getattr(self.parent,'name','') + if self.name and self.name!=parent_name: message = self.reader.format_message(\ 'WARNING', 'expected if-construct-name %r but got %r, skipping.'\ - % (self.parent.name, self.name), + % (parent_name, self.name), item.span[0],item.span[1]) print >> sys.stderr, message self.isvalid = False return def __str__(self): - return self.get_indent_tab(deindent=True) + 'ELSE ' + self.name + if self.name: + return self.get_indent_tab(deindent=True) + 'ELSE ' + self.name + return self.get_indent_tab(deindent=True) + 'ELSE' class ElseIf(Statement): """ - ELSE IF ( <scalar-logical-expr> ) THEN [<if-construct-name>] + ELSE IF ( <scalar-logical-expr> ) THEN [ <if-construct-name> ] """ match = re.compile(r'else\s*if\s*\(.*\)\s*then\s*\w*\s*\Z',re.I).match @@ -1327,21 +1420,25 @@ class ElseIf(Statement): line = item.get_line()[4:].lstrip()[2:].lstrip() i = line.find(')') assert line[0]=='(' - self.expr = line[1:i] + self.expr = item.apply_map(line[1:i]) self.name = line[i+1:].lstrip()[4:].strip() - if self.name and not self.name==self.parent.name: + parent_name = getattr(self.parent,'name','') + if self.name and self.name!=parent_name: message = self.reader.format_message(\ 'WARNING', 'expected if-construct-name %r but got %r, skipping.'\ - % (self.parent.name, self.name), + % (parent_name, self.name), item.span[0],item.span[1]) - print >> sys.stderr, message + self.show_message(message) self.isvalid = False return def __str__(self): - return self.get_indent_tab(deindent=True) + 'ELSE IF (%s) THEN %s' \ - % (self.expr, self.name) + s = '' + if self.name: + s = ' ' + self.name + return self.get_indent_tab(deindent=True) + 'ELSE IF (%s) THEN%s' \ + % (self.expr, s) # SelectCase construct statements @@ -1357,27 +1454,49 @@ class Case(Statement): """ match = re.compile(r'case\b\s*(\(.*\)|DEFAULT)\s*\w*\Z',re.I).match def process_item(self): - assert self.parent.__class__.__name__=='Select',`self.parent.__class__` + #assert self.parent.__class__.__name__=='Select',`self.parent.__class__` line = self.item.get_line()[4:].lstrip() if line.startswith('('): i = line.find(')') - self.ranges = line[1:i].strip() + items = split_comma(line[1:i].strip(), self.item) line = line[i+1:].lstrip() else: assert line.startswith('default'),`line` - self.ranges = '' + items = [] line = line[7:].lstrip() + for i in range(len(items)): + it = self.item.copy(items[i]) + rl = [] + for r in it.get_line().split(':'): + rl.append(it.apply_map(r.strip())) + items[i] = rl + self.items = items self.name = line - if self.name and not self.name==self.parent.name: + parent_name = getattr(self.parent, 'name', '') + if self.name and self.name!=parent_name: message = self.reader.format_message(\ 'WARNING', 'expected case-construct-name %r but got %r, skipping.'\ - % (self.parent.name, self.name), + % (parent_name, self.name), self.item.span[0],self.item.span[1]) - print >> sys.stderr, message + self.show_message(message) self.isvalid = False return + def __str__(self): + tab = self.get_indent_tab() + s = 'CASE' + if self.items: + l = [] + for item in self.items: + l.append((' : '.join(item)).strip()) + s += ' ( %s )' % (', '.join(l)) + else: + s += ' DEFAULT' + if self.name: + s += ' ' + self.name + return s + # Where construct statements class Where(Statement): diff --git a/numpy/f2py/lib/test_parser.py b/numpy/f2py/lib/test_parser.py index 1765f50ad..1bd8476e2 100644 --- a/numpy/f2py/lib/test_parser.py +++ b/numpy/f2py/lib/test_parser.py @@ -3,18 +3,15 @@ from numpy.testing import * from block_statements import * from readfortran import Line, FortranStringReader -def toLine(line, label=''): + +def parse(cls, line, label=''): if label: line = label + ' : ' + line reader = FortranStringReader(line, True, False) - return reader.next() - -def parse(cls, line, label=''): - item = toLine(line, label=label) + item = reader.next() if not cls.match(item.get_line()): raise ValueError, '%r does not match %s pattern' % (line, cls.__name__) stmt = cls(item, item) - if stmt.isvalid: return str(stmt) raise ValueError, 'parsing %r with %s pattern failed' % (line, cls.__name__) @@ -306,5 +303,74 @@ class test_Statements(NumpyTestCase): assert_equal(parse(Import,'import::a'),'IMPORT a') assert_equal(parse(Import,'import a , b'),'IMPORT a, b') + def check_forall(self): + assert_equal(parse(ForallStmt,'forall (i = 1:n(k,:) : 2) a(i) = i*i*b(i)'), + 'FORALL (i = 1 : n(k,:) : 2) a(i) = i*i*b(i)') + assert_equal(parse(ForallStmt,'forall (i=1:n,j=2:3) a(i) = b(i,i)'), + 'FORALL (i = 1 : n, j = 2 : 3) a(i) = b(i,i)') + assert_equal(parse(ForallStmt,'forall (i=1:n,j=2:3, 1+a(1,2)) a(i) = b(i,i)'), + 'FORALL (i = 1 : n, j = 2 : 3, 1+a(1,2)) a(i) = b(i,i)') + + def check_specificbinding(self): + assert_equal(parse(SpecificBinding,'procedure a'),'PROCEDURE a') + assert_equal(parse(SpecificBinding,'procedure :: a'),'PROCEDURE a') + assert_equal(parse(SpecificBinding,'procedure , NOPASS :: a'),'PROCEDURE , NOPASS :: a') + assert_equal(parse(SpecificBinding,'procedure , public, pass(x ) :: a'),'PROCEDURE , PUBLIC, PASS (x) :: a') + assert_equal(parse(SpecificBinding,'procedure(n) a'),'PROCEDURE (n) a') + assert_equal(parse(SpecificBinding,'procedure(n),pass :: a'), + 'PROCEDURE (n) , PASS :: a') + assert_equal(parse(SpecificBinding,'procedure(n) :: a'), + 'PROCEDURE (n) a') + assert_equal(parse(SpecificBinding,'procedure a= >b'),'PROCEDURE a => b') + assert_equal(parse(SpecificBinding,'procedure(n),pass :: a =>c'), + 'PROCEDURE (n) , PASS :: a => c') + + def check_genericbinding(self): + assert_equal(parse(GenericBinding,'generic :: a=>b'),'GENERIC :: a => b') + assert_equal(parse(GenericBinding,'generic, public :: a=>b'),'GENERIC, PUBLIC :: a => b') + assert_equal(parse(GenericBinding,'generic, public :: a(1,2)=>b ,c'), + 'GENERIC, PUBLIC :: a(1,2) => b, c') + + def check_finalbinding(self): + assert_equal(parse(FinalBinding,'final a'),'FINAL a') + assert_equal(parse(FinalBinding,'final::a'),'FINAL a') + assert_equal(parse(FinalBinding,'final a , b'),'FINAL a, b') + + def check_allocatable(self): + assert_equal(parse(Allocatable,'allocatable a'),'ALLOCATABLE a') + assert_equal(parse(Allocatable,'allocatable :: a'),'ALLOCATABLE a') + assert_equal(parse(Allocatable,'allocatable a (1,2)'),'ALLOCATABLE a (1,2)') + assert_equal(parse(Allocatable,'allocatable a (1,2) ,b'),'ALLOCATABLE a (1,2), b') + + def check_asynchronous(self): + assert_equal(parse(Asynchronous,'asynchronous a'),'ASYNCHRONOUS a') + assert_equal(parse(Asynchronous,'asynchronous::a'),'ASYNCHRONOUS a') + assert_equal(parse(Asynchronous,'asynchronous a , b'),'ASYNCHRONOUS a, b') + + def check_bind(self): + assert_equal(parse(Bind,'bind(c) a'),'BIND (C) a') + assert_equal(parse(Bind,'bind(c) :: a'),'BIND (C) a') + assert_equal(parse(Bind,'bind(c) a ,b'),'BIND (C) a, b') + assert_equal(parse(Bind,'bind(c) /a/'),'BIND (C) / a /') + assert_equal(parse(Bind,'bind(c) /a/ ,b'),'BIND (C) / a /, b') + assert_equal(parse(Bind,'bind(c,name="hey") a'),'BIND (C, NAME = "hey") a') + + def check_else(self): + assert_equal(parse(Else,'else'),'ELSE') + assert_equal(parse(ElseIf,'else if (a) then'),'ELSE IF (a) THEN') + assert_equal(parse(ElseIf,'else if (a.eq.b(1,2)) then'), + 'ELSE IF (a.eq.b(1,2)) THEN') + + def check_case(self): + assert_equal(parse(Case,'case (1)'),'CASE ( 1 )') + assert_equal(parse(Case,'case (1:)'),'CASE ( 1 : )') + assert_equal(parse(Case,'case (:1)'),'CASE ( : 1 )') + assert_equal(parse(Case,'case (1:2)'),'CASE ( 1 : 2 )') + assert_equal(parse(Case,'case (a(1,2))'),'CASE ( a(1,2) )') + assert_equal(parse(Case,'case ("ab")'),'CASE ( "ab" )') + assert_equal(parse(Case,'case default'),'CASE DEFAULT') + assert_equal(parse(Case,'case (1:2 ,3:4)'),'CASE ( 1 : 2, 3 : 4 )') + assert_equal(parse(Case,'case (a(1,:):)'),'CASE ( a(1,:) : )') + if __name__ == "__main__": NumpyTest().run() |