diff options
Diffstat (limited to 'numpy')
| -rwxr-xr-x | numpy/f2py/crackfortran.py | 7 | ||||
| -rw-r--r-- | numpy/f2py/symbolic.py | 141 | ||||
| -rw-r--r-- | numpy/f2py/tests/test_symbolic.py | 38 |
3 files changed, 157 insertions, 29 deletions
diff --git a/numpy/f2py/crackfortran.py b/numpy/f2py/crackfortran.py index 2adca9fba..306665720 100755 --- a/numpy/f2py/crackfortran.py +++ b/numpy/f2py/crackfortran.py @@ -2701,11 +2701,6 @@ def analyzevars(block): 'required' if is_required else 'optional') if v_attr: vars[v]['attrspec'] = v_attr - else: - # n is output or hidden argument, hence it - # will depend on all variables in d - n_deps.extend(coeffs_and_deps) - if coeffs_and_deps is not None: # extend v dependencies with ones specified in attrspec for v, (solver, deps) in coeffs_and_deps.items(): @@ -2716,6 +2711,8 @@ def analyzevars(block): v_deps.extend(aa[7:-1].split(',')) if v_deps: vars[v]['depend'] = list(set(v_deps)) + if n not in v_deps: + n_deps.append(v) elif isstring(vars[n]): if 'charselector' in vars[n]: if '*' in vars[n]['charselector']: diff --git a/numpy/f2py/symbolic.py b/numpy/f2py/symbolic.py index 0a2c3ba04..b747a75f9 100644 --- a/numpy/f2py/symbolic.py +++ b/numpy/f2py/symbolic.py @@ -15,7 +15,6 @@ References: # # TODO: support logical constants (Op.BOOLEAN) # TODO: support logical operators (.AND., ...) -# TODO: support relational operators (<, >, ..., .LT., ...) # TODO: support defined operators (.MYOP., ...) # __all__ = ['Expr'] @@ -50,12 +49,43 @@ class Op(Enum): APPLY = 200 INDEXING = 210 CONCAT = 220 + RELATIONAL = 300 TERMS = 1000 FACTORS = 2000 REF = 3000 DEREF = 3001 +class RelOp(Enum): + """ + Used in Op.RELATIONAL expression to specify the function part. + """ + EQ = 1 + NE = 2 + LT = 3 + LE = 4 + GT = 5 + GE = 6 + + @classmethod + def fromstring(cls, s, language=Language.C): + if language is Language.Fortran: + return {'.eq.': RelOp.EQ, '.ne.': RelOp.NE, + '.lt.': RelOp.LT, '.le.': RelOp.LE, + '.gt.': RelOp.GT, '.ge.': RelOp.GE}[s.lower()] + return {'==': RelOp.EQ, '!=': RelOp.NE, '<': RelOp.LT, + '<=': RelOp.LE, '>': RelOp.GT, '>=': RelOp.GE}[s] + + def tostring(self, language=Language.C): + if language is Language.Fortran: + return {RelOp.EQ: '.eq.', RelOp.NE: '.ne.', + RelOp.LT: '.lt.', RelOp.LE: '.le.', + RelOp.GT: '.gt.', RelOp.GE: '.ge.'}[self] + return {RelOp.EQ: '==', RelOp.NE: '!=', + RelOp.LT: '<', RelOp.LE: '<=', + RelOp.GT: '>', RelOp.GE: '>='}[self] + + class ArithOp(Enum): """ Used in Op.APPLY expression to specify the function part. @@ -77,12 +107,19 @@ class Precedence(Enum): """ Used as Expr.tostring precedence argument. """ - NONE = 0 - TUPLE = 1 - SUM = 2 + ATOM = 0 + POWER = 1 + UNARY = 2 PRODUCT = 3 - POWER = 4 - ATOM = 5 + SUM = 4 + LT = 6 + EQ = 7 + LAND = 11 + LOR = 12 + TERNARY = 13 + ASSIGN = 14 + TUPLE = 15 + NONE = 100 integer_types = (int,) @@ -178,6 +215,9 @@ class Expr: elif op in (Op.REF, Op.DEREF): # data is Expr instance assert isinstance(data, Expr) + elif op is Op.RELATIONAL: + # data is (<relop>, <left>, <right>) + assert isinstance(data, tuple) and len(data) == 3 else: raise NotImplementedError( f'unknown op or missing sanity check: {op}') @@ -341,19 +381,32 @@ class Expr: language=language) for a in self.data] if language is Language.C: - return f'({cond} ? {expr1} : {expr2})' - if language is Language.Python: - return f'({expr1} if {cond} else {expr2})' - if language is Language.Fortran: - return f'merge({expr1}, {expr2}, {cond})' - raise NotImplementedError(f'tostring for {self.op} and {language}') + r = f'({cond} ? {expr1} : {expr2})' + elif language is Language.Python: + r = f'({expr1} if {cond} else {expr2})' + elif language is Language.Fortran: + r = f'merge({expr1}, {expr2}, {cond})' + else: + raise NotImplementedError( + f'tostring for {self.op} and {language}') + precedence = Precedence.ATOM elif self.op is Op.REF: - return '&' + self.data.tostring(language=language) + r = '&' + self.data.tostring(Precedence.UNARY, language=language) + precedence = Precedence.UNARY elif self.op is Op.DEREF: - return '*' + self.data.tostring(language=language) + r = '*' + self.data.tostring(Precedence.UNARY, language=language) + precedence = Precedence.UNARY + elif self.op is Op.RELATIONAL: + rop, left, right = self.data + precedence = (Precedence.EQ if rop in (RelOp.EQ, RelOp.NE) + else Precedence.LT) + left = left.tostring(precedence, language=language) + right = right.tostring(precedence, language=language) + rop = rop.tostring(language=language) + r = f'{left} {rop} {right}' else: raise NotImplementedError(f'tostring for op {self.op}') - if parent_precedence.value > precedence.value: + if parent_precedence.value < precedence.value: # If parent precedence is higher than operand precedence, # operand will be enclosed in parenthesis. return '(' + r + ')' @@ -590,7 +643,11 @@ class Expr: return normalize(Expr(self.op, operands)) if self.op in (Op.REF, Op.DEREF): return normalize(Expr(self.op, self.data.substitute(symbols_map))) - + if self.op is Op.RELATIONAL: + rop, left, right = self.data + left = left.substitute(symbols_map) + right = right.substitute(symbols_map) + return normalize(Expr(self.op, (rop, left, right))) raise NotImplementedError(f'substitute method for {self.op}: {self!r}') def traverse(self, visit, *args, **kwargs): @@ -642,6 +699,11 @@ class Expr: elif self.op in (Op.REF, Op.DEREF): return normalize(Expr(self.op, self.data.traverse(visit, *args, **kwargs))) + elif self.op is Op.RELATIONAL: + rop, left, right = self.data + left = left.traverse(visit, *args, **kwargs) + right = right.traverse(visit, *args, **kwargs) + return normalize(Expr(self.op, (rop, left, right))) raise NotImplementedError(f'traverse method for {self.op}') def contains(self, other): @@ -963,6 +1025,30 @@ def as_deref(expr): return Expr(Op.DEREF, expr) +def as_eq(left, right): + return Expr(Op.RELATIONAL, (RelOp.EQ, left, right)) + + +def as_ne(left, right): + return Expr(Op.RELATIONAL, (RelOp.NE, left, right)) + + +def as_lt(left, right): + return Expr(Op.RELATIONAL, (RelOp.LT, left, right)) + + +def as_le(left, right): + return Expr(Op.RELATIONAL, (RelOp.LE, left, right)) + + +def as_gt(left, right): + return Expr(Op.RELATIONAL, (RelOp.GT, left, right)) + + +def as_ge(left, right): + return Expr(Op.RELATIONAL, (RelOp.GE, left, right)) + + def as_terms(obj): """Return expression as TERMS expression. """ @@ -1257,23 +1343,32 @@ class _FromStringWorker: return as_complex(*self.process(operands)) raise NotImplementedError( f'parsing comma-separated list (context={context}): {r}') - return tuple(self.process(restore(r.split(',')), context)) # ternary operation m = re.match(r'\A([^?]+)[?]([^:]+)[:](.+)\Z', r) if m: assert context == 'expr', context oper, expr1, expr2 = restore(m.groups()) - if 0: - # TODO: enable this when support for boolean - # expressions is fully implemented - oper = self.process(oper) - else: - oper = as_symbol(self.finalize_string(oper)) + oper = self.process(oper) expr1 = self.process(expr1) expr2 = self.process(expr2) return as_ternary(oper, expr1, expr2) + # relational expression + if self.language is Language.Fortran: + m = re.match( + r'\A(.+)\s*[.](eq|ne|lt|le|gt|ge)[.]\s*(.+)\Z', r, re.I) + else: + m = re.match( + r'\A(.+)\s*([=][=]|[!][=]|[<][=]|[<]|[>][=]|[>])\s*(.+)\Z', r) + if m: + left, rop, right = m.groups() + if self.language is Language.Fortran: + rop = '.' + rop + '.' + left, right = self.process(restore((left, right))) + rop = RelOp.fromstring(rop, language=self.language) + return Expr(Op.RELATIONAL, (rop, left, right)) + # keyword argument m = re.match(r'\A(\w[\w\d_]*)\s*[=](.*)\Z', r) if m: diff --git a/numpy/f2py/tests/test_symbolic.py b/numpy/f2py/tests/test_symbolic.py index 61f0d4b3a..52cabac53 100644 --- a/numpy/f2py/tests/test_symbolic.py +++ b/numpy/f2py/tests/test_symbolic.py @@ -5,7 +5,7 @@ from numpy.f2py.symbolic import ( as_terms, as_factors, eliminate_quotes, insert_quotes, fromstring, as_expr, as_apply, as_numer_denom, as_ternary, as_ref, as_deref, - normalize + normalize, as_eq, as_ne, as_lt, as_gt, as_le, as_ge ) from . import util @@ -100,6 +100,13 @@ class TestSymbolic(util.F2PyTest): assert t != u assert hash(t) is not None + e = as_eq(x, y) + f = as_lt(x, y) + assert e.op == Op.RELATIONAL + assert e == e + assert e != f + assert hash(e) is not None + def test_tostring_fortran(self): x = as_symbol('x') y = as_symbol('y') @@ -142,6 +149,12 @@ class TestSymbolic(util.F2PyTest): assert str(Expr(Op.INDEXING, ('f', x))) == 'f[x]' assert str(as_ternary(x, y, z)) == 'merge(y, z, x)' + assert str(as_eq(x, y)) == 'x .eq. y' + assert str(as_ne(x, y)) == 'x .ne. y' + assert str(as_lt(x, y)) == 'x .lt. y' + assert str(as_le(x, y)) == 'x .le. y' + assert str(as_gt(x, y)) == 'x .gt. y' + assert str(as_ge(x, y)) == 'x .ge. y' def test_tostring_c(self): language = Language.C @@ -166,6 +179,12 @@ class TestSymbolic(util.F2PyTest): language=language) == '123 + x + (x - y) / (x + y)' assert as_ternary(x, y, z).tostring(language=language) == '(x ? y : z)' + assert as_eq(x, y).tostring(language=language) == 'x == y' + assert as_ne(x, y).tostring(language=language) == 'x != y' + assert as_lt(x, y).tostring(language=language) == 'x < y' + assert as_le(x, y).tostring(language=language) == 'x <= y' + assert as_gt(x, y).tostring(language=language) == 'x > y' + assert as_ge(x, y).tostring(language=language) == 'x >= y' def test_operations(self): x = as_symbol('x') @@ -240,6 +259,8 @@ class TestSymbolic(util.F2PyTest): assert as_ternary(x, y, z).substitute( {x: y + z}) == as_ternary(y + z, y, z) + assert as_eq(x, y).substitute( + {x: y + z}) == as_eq(y + z, y) def test_fromstring(self): @@ -319,6 +340,20 @@ class TestSymbolic(util.F2PyTest): assert fromstring('*x * *y') == as_deref(x) * as_deref(y) assert fromstring('*x**y') == as_deref(x) * as_deref(y) + assert fromstring('x == y') == as_eq(x, y) + assert fromstring('x != y') == as_ne(x, y) + assert fromstring('x < y') == as_lt(x, y) + assert fromstring('x > y') == as_gt(x, y) + assert fromstring('x <= y') == as_le(x, y) + assert fromstring('x >= y') == as_ge(x, y) + + assert fromstring('x .eq. y', language=Language.Fortran) == as_eq(x, y) + assert fromstring('x .ne. y', language=Language.Fortran) == as_ne(x, y) + assert fromstring('x .lt. y', language=Language.Fortran) == as_lt(x, y) + assert fromstring('x .gt. y', language=Language.Fortran) == as_gt(x, y) + assert fromstring('x .le. y', language=Language.Fortran) == as_le(x, y) + assert fromstring('x .ge. y', language=Language.Fortran) == as_ge(x, y) + def test_traverse(self): x = as_symbol('x') y = as_symbol('y') @@ -340,6 +375,7 @@ class TestSymbolic(util.F2PyTest): assert (x + y + z).traverse(replace_visit) == (2 * z + y) assert (x + f(y, x - z)).traverse( replace_visit) == (z + f(y, as_number(0))) + assert as_eq(x, y).traverse(replace_visit) == as_eq(z, y) # Use traverse to collect symbols, method 1 function_symbols = set() |
