summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xnumpy/f2py/crackfortran.py4
-rw-r--r--numpy/f2py/symbolic.py449
-rw-r--r--numpy/f2py/tests/test_symbolic.py67
3 files changed, 348 insertions, 172 deletions
diff --git a/numpy/f2py/crackfortran.py b/numpy/f2py/crackfortran.py
index 8cb1f73d2..2adca9fba 100755
--- a/numpy/f2py/crackfortran.py
+++ b/numpy/f2py/crackfortran.py
@@ -2609,8 +2609,8 @@ def analyzevars(block):
# not defined in block['vars']. Here we assume
# these correspond to Fortran/C intrinsic
# functions or that are defined by other
- # means. We'll let the compiler to validate
- # the definiteness of such symbols.
+ # means. We'll let the compiler validate the
+ # definiteness of such symbols.
dimension_exprs[d] = solver_and_deps
vars[n]['dimension'].append(d)
diff --git a/numpy/f2py/symbolic.py b/numpy/f2py/symbolic.py
index 56faa309c..0a2c3ba04 100644
--- a/numpy/f2py/symbolic.py
+++ b/numpy/f2py/symbolic.py
@@ -46,11 +46,14 @@ class Op(Enum):
STRING = 20
ARRAY = 30
SYMBOL = 40
+ TERNARY = 100
APPLY = 200
INDEXING = 210
CONCAT = 220
TERMS = 1000
FACTORS = 2000
+ REF = 3000
+ DEREF = 3001
class ArithOp(Enum):
@@ -114,10 +117,10 @@ class Expr:
"""
@staticmethod
- def parse(s):
+ def parse(s, language=Language.C):
"""Parse a Fortran expression to a Expr.
"""
- return fromstring(s)
+ return fromstring(s, language=language)
def __init__(self, op, data):
assert isinstance(op, Op)
@@ -169,6 +172,12 @@ class Expr:
assert isinstance(data, tuple) and len(data) == 2
# function is any hashable object
assert hash(data[0]) is not None
+ elif op is Op.TERNARY:
+ # data is (<cond>, <expr1>, <expr2>)
+ assert isinstance(data, tuple) and len(data) == 3
+ elif op in (Op.REF, Op.DEREF):
+ # data is Expr instance
+ assert isinstance(data, Expr)
else:
raise NotImplementedError(
f'unknown op or missing sanity check: {op}')
@@ -327,6 +336,21 @@ class Expr:
for arg in self.data]
r = " // ".join(args)
precedence = Precedence.PRODUCT
+ elif self.op is Op.TERNARY:
+ cond, expr1, expr2 = [a.tostring(Precedence.TUPLE,
+ language=language)
+ for a in self.data]
+ if language is Language.C:
+ return f'({cond} ? {expr1} : {expr2})'
+ if language is Language.Python:
+ return f'({expr1} if {cond} else {expr2})'
+ if language is Language.Fortran:
+ return f'merge({expr1}, {expr2}, {cond})'
+ raise NotImplementedError(f'tostring for {self.op} and {language}')
+ elif self.op is Op.REF:
+ return '&' + self.data.tostring(language=language)
+ elif self.op is Op.DEREF:
+ return '*' + self.data.tostring(language=language)
else:
raise NotImplementedError(f'tostring for op {self.op}')
if parent_precedence.value > precedence.value:
@@ -561,6 +585,12 @@ class Expr:
func = func.substitute(symbols_map)
args = tuple(a.substitute(symbols_map) for a in self.data[1:])
return normalize(Expr(self.op, (func,) + args))
+ if self.op is Op.TERNARY:
+ operands = tuple(a.substitute(symbols_map) for a in self.data)
+ return normalize(Expr(self.op, operands))
+ if self.op in (Op.REF, Op.DEREF):
+ return normalize(Expr(self.op, self.data.substitute(symbols_map)))
+
raise NotImplementedError(f'substitute method for {self.op}: {self!r}')
def traverse(self, visit, *args, **kwargs):
@@ -579,7 +609,7 @@ class Expr:
if self.op in (Op.INTEGER, Op.REAL, Op.STRING, Op.SYMBOL):
return self
- elif self.op in (Op.COMPLEX, Op.ARRAY, Op.CONCAT):
+ elif self.op in (Op.COMPLEX, Op.ARRAY, Op.CONCAT, Op.TERNARY):
return normalize(Expr(self.op, tuple(
item.traverse(visit, *args, **kwargs)
for item in self.data)))
@@ -609,6 +639,9 @@ class Expr:
indices = tuple(index.traverse(visit, *args, **kwargs)
for index in self.data[1:])
return normalize(Expr(self.op, (obj,) + indices))
+ elif self.op in (Op.REF, Op.DEREF):
+ return normalize(Expr(self.op,
+ self.data.traverse(visit, *args, **kwargs)))
raise NotImplementedError(f'traverse method for {self.op}')
def contains(self, other):
@@ -813,6 +846,13 @@ def normalize(obj):
if len(lst) == 1:
return lst[0]
return Expr(Op.CONCAT, tuple(lst))
+
+ if obj.op is Op.TERNARY:
+ cond, expr1, expr2 = map(normalize, obj.data)
+ if cond.op is Op.INTEGER:
+ return expr1 if cond.data[0] else expr2
+ return Expr(Op.TERNARY, (cond, expr1, expr2))
+
return obj
@@ -905,6 +945,24 @@ def as_apply(func, *args, **kwargs):
dict((k, as_expr(v)) for k, v in kwargs.items())))
+def as_ternary(cond, expr1, expr2):
+ """Return object as TERNARY expression (cond?expr1:expr2).
+ """
+ return Expr(Op.TERNARY, (cond, expr1, expr2))
+
+
+def as_ref(expr):
+ """Return object as referencing expression.
+ """
+ return Expr(Op.REF, expr)
+
+
+def as_deref(expr):
+ """Return object as dereferencing expression.
+ """
+ return Expr(Op.DEREF, expr)
+
+
def as_terms(obj):
"""Return expression as TERMS expression.
"""
@@ -967,7 +1025,8 @@ def as_numer_denom(obj):
"""
if isinstance(obj, Expr):
obj = normalize(obj)
- if obj.op in (Op.INTEGER, Op.REAL, Op.COMPLEX, Op.SYMBOL, Op.INDEXING):
+ if obj.op in (Op.INTEGER, Op.REAL, Op.COMPLEX, Op.SYMBOL,
+ Op.INDEXING, Op.TERNARY):
return obj, as_number(1)
elif obj.op is Op.APPLY:
if obj.data[0] is ArithOp.DIV and not obj.data[2]:
@@ -1017,8 +1076,8 @@ def _counter():
COUNTER = _counter()
-def replace_quotes(s):
- """Replace quoted substrings of input.
+def eliminate_quotes(s):
+ """Replace quoted substrings of input string.
Return a new string and a mapping of replacements.
"""
@@ -1030,15 +1089,31 @@ def replace_quotes(s):
# remove trailing underscore
kind = kind[:-1]
p = {"'": "SINGLE", '"': "DOUBLE"}[value[0]]
- k = f'@__f2py_QUOTES_{p}_{COUNTER.__next__()}@'
- d[as_symbol(k)] = as_string(value, kind or 1)
+ k = f'{kind}@__f2py_QUOTES_{p}_{COUNTER.__next__()}@'
+ d[k] = value
return k
- return (re.sub(r'({kind}_|)({single_quoted}|{double_quoted})'.format(
+ new_s = re.sub(r'({kind}_|)({single_quoted}|{double_quoted})'.format(
kind=r'\w[\w\d_]*',
single_quoted=r"('([^'\\]|(\\.))*')",
double_quoted=r'("([^"\\]|(\\.))*")'),
- repl, s), d)
+ repl, s)
+
+ assert '"' not in new_s
+ assert "'" not in new_s
+
+ return new_s, d
+
+
+def insert_quotes(s, d):
+ """Inverse of eliminate_quotes.
+ """
+ for k, v in d.items():
+ kind = k[:k.find('@')]
+ if kind:
+ kind += '_'
+ s = s.replace(k, kind + v)
+ return s
def replace_parenthesis(s):
@@ -1084,21 +1159,32 @@ def replace_parenthesis(s):
return s[:i] + k + r, d
-def fromstring(s):
- """Create an expression from string.
+def _get_parenthesis_kind(s):
+ assert s.startswith('@__f2py_PARENTHESIS_'), s
+ return s.split('_')[4]
- This is a "lazy" parser, that is, only arithmetic operations are
- resolved, non-arithmetic operations are treated as symbols.
+
+def unreplace_parenthesis(s, d):
+ """Inverse of replace_parenthesis.
"""
- # We replace string literal constants that may include parenthesis
- # that may confuse the arithmetic expression parser.
- s, quotes_map = replace_quotes(s)
+ for k, v in d.items():
+ p = _get_parenthesis_kind(k)
+ left = dict(ROUND='(', SQUARE='[', CURLY='{', ROUNDDIV='(/')[p]
+ right = dict(ROUND=')', SQUARE=']', CURLY='}', ROUNDDIV='/)')[p]
+ s = s.replace(k, left + v + right)
+ return s
- # Apply arithmetic expression parser.
- r = _fromstring_worker(s)
- # Restore string constants.
- return r.substitute(quotes_map)
+def fromstring(s, language=Language.C):
+ """Create an expression from a string.
+
+ This is a "lazy" parser, that is, only arithmetic operations are
+ resolved, non-arithmetic operations are treated as symbols.
+ """
+ r = _FromStringWorker(language=language).parse(s)
+ if isinstance(r, Expr):
+ return r
+ raise ValueError(f'failed to parse `{s}` to Expr instance: got `{r}`')
class _Pair:
@@ -1120,151 +1206,210 @@ class _Pair:
return f'{type(self).__name__}({self.left}, {self.right})'
-def _fromstring_worker(s, dummy=None):
- # Internal method used by fromstring to convert a string of
- # arithmetic expressions to an expression tree.
+class _FromStringWorker:
- # Replace subexpression enclosed in parenthesis with dummy symbols
- # and recursively parse subexpression into symbol-expression map:
- r, raw_symbols_map = replace_parenthesis(s)
- r = r.strip()
+ def __init__(self, language=Language.C):
+ self.original = None
+ self.quotes_map = None
+ self.language = language
- symbols_map = dict([(as_symbol(k), _fromstring_worker(v))
- for k, v in raw_symbols_map.items()])
+ def finalize_string(self, s):
+ return insert_quotes(s, self.quotes_map)
- if ',' in r:
- # comma-separate tuple of expressions
- return tuple(_fromstring_worker(r_).substitute(symbols_map)
- for r_ in r.split(','))
+ def parse(self, inp):
+ self.original = inp
+ unquoted, self.quotes_map = eliminate_quotes(inp)
+ return self.process(unquoted)
- m = re.match(r'\A(\w[\w\d_]*)\s*[=](.*)\Z', r)
- if m:
- # keyword argument
- keyname, value = m.groups()
- return _Pair(keyname,
- _fromstring_worker(value).substitute(symbols_map))
-
- operands = re.split(r'((?<!\d[edED])[+-])', r)
- if len(operands) > 1:
- # Expression is an arithmetic sum
- result = _fromstring_worker(operands[0] or '0')
- for op, operand in zip(operands[1::2], operands[2::2]):
- op = op.strip()
- if op == '+':
- result += _fromstring_worker(operand)
- else:
- assert op == '-'
- result -= _fromstring_worker(operand)
- return result.substitute(symbols_map)
-
- operands = re.split(r'//', r)
- if len(operands) > 1:
- # Expression is string concatenate operation
- return Expr(Op.CONCAT,
- tuple(_fromstring_worker(operand).substitute(symbols_map)
- for operand in operands))
-
- operands = re.split(r'([*]|/)', r.replace('**', '@__f2py_DOUBLE_STAR@'))
- operands = [operand.replace('@__f2py_DOUBLE_STAR@', '**')
- for operand in operands]
- if len(operands) > 1:
- # Expression is an arithmetic product
- result = _fromstring_worker(operands[0])
- for op, operand in zip(operands[1::2], operands[2::2]):
- op = op.strip()
- if op == '*':
- result *= _fromstring_worker(operand)
+ def process(self, s, context='expr'):
+ """Parse string within the given context.
+
+ The context may define the result in case of ambiguous
+ expressions. For instance, consider expressions `f(x, y)` and
+ `(x, y) + (a, b)` where `f` is a function and pair `(x, y)`
+ denotes complex number. Specifying context as "args" or
+ "expr", the subexpression `(x, y)` will be parse to an
+ argument list or to a complex number, respectively.
+ """
+ if isinstance(s, (list, tuple)):
+ return type(s)(self.process(s_, context) for s_ in s)
+
+ assert isinstance(s, str), (type(s), s)
+
+ # replace subexpressions in parenthesis with f2py @-names
+ r, raw_symbols_map = replace_parenthesis(s)
+ r = r.strip()
+
+ def restore(r):
+ # restores subexpressions marked with f2py @-names
+ if isinstance(r, (list, tuple)):
+ return type(r)(map(restore, r))
+ return unreplace_parenthesis(r, raw_symbols_map)
+
+ # comma-separated tuple
+ if ',' in r:
+ operands = restore(r.split(','))
+ if context == 'args':
+ return tuple(self.process(operands))
+ if context == 'expr':
+ if len(operands) == 2:
+ # complex number literal
+ return as_complex(*self.process(operands))
+ raise NotImplementedError(
+ f'parsing comma-separated list (context={context}): {r}')
+ return tuple(self.process(restore(r.split(',')), context))
+
+ # ternary operation
+ m = re.match(r'\A([^?]+)[?]([^:]+)[:](.+)\Z', r)
+ if m:
+ assert context == 'expr', context
+ oper, expr1, expr2 = restore(m.groups())
+ if 0:
+ # TODO: enable this when support for boolean
+ # expressions is fully implemented
+ oper = self.process(oper)
else:
- assert op == '/'
- result /= _fromstring_worker(operand)
- return result.substitute(symbols_map)
-
- operands = list(reversed(re.split(r'[*][*]', r)))
- if len(operands) > 1:
- # Expression is an arithmetic exponentiation
- result = _fromstring_worker(operands[0])
- for operand in operands[1:]:
- result = _fromstring_worker(operand) ** result
- return result
-
- m = re.match(r'\A({digit_string})({kind}|)\Z'.format(
- digit_string=r'\d+',
- kind=r'_(\d+|\w[\w\d_]*)'), r)
- if m:
- # Expression is int-literal-constant
- value, _, kind = m.groups()
- if kind and kind.isdigit():
- kind = int(kind)
- return as_integer(int(value), kind or 4)
-
- m = re.match(r'\A({significant}({exponent}|)|\d+{exponent})({kind}|)\Z'
- .format(
- significant=r'[.]\d+|\d+[.]\d*',
- exponent=r'[edED][+-]?\d+',
- kind=r'_(\d+|\w[\w\d_]*)'), r)
- if m:
- # Expression is real-literal-constant
- value, _, _, kind = m.groups()
- if kind and kind.isdigit():
- kind = int(kind)
- value = value.lower()
- if 'd' in value:
- return as_real(float(value.replace('d', 'e')), kind or 8)
- return as_real(float(value), kind or 4)
-
- m = re.match(r'\A(\w[\w\d_]*)_(@__f2py_QUOTES_(\w+)_\d+@)\Z', r)
- if m:
+ oper = as_symbol(self.finalize_string(oper))
+ expr1 = self.process(expr1)
+ expr2 = self.process(expr2)
+ return as_ternary(oper, expr1, expr2)
+
+ # keyword argument
+ m = re.match(r'\A(\w[\w\d_]*)\s*[=](.*)\Z', r)
+ if m:
+ keyname, value = m.groups()
+ value = restore(value)
+ return _Pair(keyname, self.process(value))
+
+ # addition/subtraction operations
+ operands = re.split(r'((?<!\d[edED])[+-])', r)
+ if len(operands) > 1:
+ result = self.process(restore(operands[0] or '0'))
+ for op, operand in zip(operands[1::2], operands[2::2]):
+ operand = self.process(restore(operand))
+ op = op.strip()
+ if op == '+':
+ result += operand
+ else:
+ assert op == '-'
+ result -= operand
+ return result
+
+ # string concatenate operation
+ if self.language is Language.Fortran and '//' in r:
+ operands = restore(r.split('//'))
+ return Expr(Op.CONCAT,
+ tuple(self.process(operands)))
+
+ # multiplication/division operations
+ operands = re.split(r'(?<=[@\w\d_])\s*([*]|/)',
+ (r if self.language is Language.C
+ else r.replace('**', '@__f2py_DOUBLE_STAR@')))
+ if len(operands) > 1:
+ operands = restore(operands)
+ if self.language is not Language.C:
+ operands = [operand.replace('@__f2py_DOUBLE_STAR@', '**')
+ for operand in operands]
+ # Expression is an arithmetic product
+ result = self.process(operands[0])
+ for op, operand in zip(operands[1::2], operands[2::2]):
+ operand = self.process(operand)
+ op = op.strip()
+ if op == '*':
+ result *= operand
+ else:
+ assert op == '/'
+ result /= operand
+ return result
+
+ # referencing/dereferencing
+ if r.startswith('*') or r.startswith('&'):
+ op = {'*': Op.DEREF, '&': Op.REF}[r[0]]
+ operand = self.process(restore(r[1:]))
+ return Expr(op, operand)
+
+ # exponentiation operations
+ if self.language is not Language.C and '**' in r:
+ operands = list(reversed(restore(r.split('**'))))
+ result = self.process(operands[0])
+ for operand in operands[1:]:
+ operand = self.process(operand)
+ result = operand ** result
+ return result
+
+ # int-literal-constant
+ m = re.match(r'\A({digit_string})({kind}|)\Z'.format(
+ digit_string=r'\d+',
+ kind=r'_(\d+|\w[\w\d_]*)'), r)
+ if m:
+ value, _, kind = m.groups()
+ if kind and kind.isdigit():
+ kind = int(kind)
+ return as_integer(int(value), kind or 4)
+
+ # real-literal-constant
+ m = re.match(r'\A({significant}({exponent}|)|\d+{exponent})({kind}|)\Z'
+ .format(
+ significant=r'[.]\d+|\d+[.]\d*',
+ exponent=r'[edED][+-]?\d+',
+ kind=r'_(\d+|\w[\w\d_]*)'), r)
+ if m:
+ value, _, _, kind = m.groups()
+ if kind and kind.isdigit():
+ kind = int(kind)
+ value = value.lower()
+ if 'd' in value:
+ return as_real(float(value.replace('d', 'e')), kind or 8)
+ return as_real(float(value), kind or 4)
+
# string-literal-constant with kind parameter specification
- kind, name, _ = m.groups()
- return as_string(name, kind)
-
- m = re.match(r'\A(\w[\w\d+]*)\s*(@__f2py_PARENTHESIS_(\w+)_\d+@)\Z', r)
- if m:
- # apply or indexing expression
- name, args, paren = m.groups()
- target = as_symbol(name).substitute(symbols_map)
- if args in raw_symbols_map:
- args = symbols_map[as_symbol(args)]
+ if r in self.quotes_map:
+ kind = r[:r.find('@')]
+ return as_string(self.quotes_map[r], kind or 1)
+
+ # array constructor or literal complex constant or
+ # parenthesized expression
+ if r in raw_symbols_map:
+ paren = _get_parenthesis_kind(r)
+ items = self.process(restore(raw_symbols_map[r]),
+ 'expr' if paren == 'ROUND' else 'args')
+ if paren == 'ROUND':
+ if isinstance(items, Expr):
+ return items
+ if paren in ['ROUNDDIV', 'SQUARE']:
+ # Expression is a array constructor
+ if isinstance(items, Expr):
+ items = (items,)
+ return as_array(items)
+
+ # function call/indexing
+ m = re.match(r'\A(.+)\s*(@__f2py_PARENTHESIS_(ROUND|SQUARE)_\d+@)\Z',
+ r)
+ if m:
+ target, args, paren = m.groups()
+ target = self.process(restore(target))
+ args = self.process(restore(args)[1:-1], 'args')
if not isinstance(args, tuple):
args = args,
if paren == 'ROUND':
- # Expression is a function call or a Fortran indexing
- # operation
kwargs = dict((a.left, a.right) for a in args
if isinstance(a, _Pair))
args = tuple(a for a in args if not isinstance(a, _Pair))
+ # Warning: this could also be Fortran indexing operation..
return as_apply(target, *args, **kwargs)
- if paren == 'SQUARE':
- # Expression is a C indexing operation (e.g. used in
- # .pyf files)
+ else:
+ # Expression is a C/Python indexing operation
+ # (e.g. used in .pyf files)
+ assert paren == 'SQUARE'
return target[args]
- # A non-Fortran/C expression?
- assert 0, (s, r)
-
- m = re.match(r'\A(@__f2py_PARENTHESIS_(\w+)_\d+@)\Z', r)
- if m:
- # array constructor or literal complex constant
- items, paren = m.groups()
- if items in raw_symbols_map:
- items = symbols_map[as_symbol(items)]
- if (paren == 'ROUND' and isinstance(items,
- tuple) and len(items) == 2):
- # Expression is a literal complex constant
- return as_complex(items[0], items[1])
- if paren in ['ROUNDDIV', 'SQUARE']:
- # Expression is a array constructor
- return as_array(items)
- assert 0, (r, items, paren)
- m = re.match(r'\A\w[\w\d_]*\Z', r)
- if m:
- # Fortran standard conforming name
- return as_symbol(r)
+ # Fortran standard conforming identifier
+ m = re.match(r'\A\w[\w\d_]*\Z', r)
+ if m:
+ return as_symbol(r)
- m = re.match(r'\A@\w[\w\d_]*@\Z', r)
- if m:
- # f2py special dummy name
+ # fall-back to symbol
+ r = self.finalize_string(restore(r))
+ ewarn(
+ f'fromstring: treating {r!r} as symbol (original={self.original})')
return as_symbol(r)
-
- ewarn(f'fromstring: treating {r!r} as symbol')
- return as_symbol(r)
diff --git a/numpy/f2py/tests/test_symbolic.py b/numpy/f2py/tests/test_symbolic.py
index 27244c995..61f0d4b3a 100644
--- a/numpy/f2py/tests/test_symbolic.py
+++ b/numpy/f2py/tests/test_symbolic.py
@@ -2,34 +2,35 @@ from numpy.testing import assert_raises
from numpy.f2py.symbolic import (
Expr, Op, ArithOp, Language,
as_symbol, as_number, as_string, as_array, as_complex,
- as_terms, as_factors, fromstring, replace_quotes, as_expr, as_apply,
- as_numer_denom
+ as_terms, as_factors, eliminate_quotes, insert_quotes,
+ fromstring, as_expr, as_apply,
+ as_numer_denom, as_ternary, as_ref, as_deref,
+ normalize
)
from . import util
class TestSymbolic(util.F2PyTest):
- def test_replace_quotes(self):
-
+ def test_eliminate_quotes(self):
def worker(s):
- r, d = replace_quotes(s)
- assert '"' not in r, r
- assert "'" not in r, r
- # undo replacements:
- s1 = r
- for k, v in d.items():
- s1 = s1.replace(k.data, v.data[0])
+ r, d = eliminate_quotes(s)
+ s1 = insert_quotes(r, d)
assert s1 == s
- worker('"1234" // "ABCD"')
- worker('"1234" // \'ABCD\'')
- worker('"1\\"2\'AB\'34"')
- worker("'1\\'2\"AB\"34'")
+ for kind in ['', 'mykind_']:
+ worker(kind + '"1234" // "ABCD"')
+ worker(kind + '"1234" // ' + kind + '"ABCD"')
+ worker(kind + '"1234" // \'ABCD\'')
+ worker(kind + '"1234" // ' + kind + '\'ABCD\'')
+ worker(kind + '"1\\"2\'AB\'34"')
+ worker('a = ' + kind + "'1\\'2\"AB\"34'")
def test_sanity(self):
x = as_symbol('x')
y = as_symbol('y')
+ z = as_symbol('z')
+
assert x.op == Op.SYMBOL
assert repr(x) == "Expr(Op.SYMBOL, 'x')"
assert x == x
@@ -92,9 +93,17 @@ class TestSymbolic(util.F2PyTest):
assert w != v
assert hash(v) is not None
+ t = as_ternary(x, y, z)
+ u = as_ternary(x, z, y)
+ assert t.op == Op.TERNARY
+ assert t == t
+ assert t != u
+ assert hash(t) is not None
+
def test_tostring_fortran(self):
x = as_symbol('x')
y = as_symbol('y')
+ z = as_symbol('z')
n = as_number(123)
m = as_number(456)
a = as_array((n, m))
@@ -132,10 +141,13 @@ class TestSymbolic(util.F2PyTest):
assert str(Expr(Op.APPLY, ('f', (x, y), {}))) == 'f(x, y)'
assert str(Expr(Op.INDEXING, ('f', x))) == 'f[x]'
+ assert str(as_ternary(x, y, z)) == 'merge(y, z, x)'
+
def test_tostring_c(self):
language = Language.C
x = as_symbol('x')
y = as_symbol('y')
+ z = as_symbol('z')
n = as_number(123)
assert Expr(Op.FACTORS, {x: 2}).tostring(language=language) == 'x * x'
@@ -153,6 +165,8 @@ class TestSymbolic(util.F2PyTest):
assert (x + (x - y) / (x + y) + n).tostring(
language=language) == '123 + x + (x - y) / (x + y)'
+ assert as_ternary(x, y, z).tostring(language=language) == '(x ? y : z)'
+
def test_operations(self):
x = as_symbol('x')
y = as_symbol('y')
@@ -224,6 +238,9 @@ class TestSymbolic(util.F2PyTest):
assert x.substitute({x: y + z}) == y + z
assert a.substitute({x: y + z}) == as_array((y + z, y))
+ assert as_ternary(x, y, z).substitute(
+ {x: y + z}) == as_ternary(y + z, y, z)
+
def test_fromstring(self):
x = as_symbol('x')
@@ -242,16 +259,20 @@ class TestSymbolic(util.F2PyTest):
assert fromstring('x * y') == x * y
assert fromstring('x * 2') == x * 2
assert fromstring('x / y') == x / y
- assert fromstring('x ** 2') == x ** 2
- assert fromstring('x ** 2 ** 3') == x ** 2 ** 3
+ assert fromstring('x ** 2',
+ language=Language.Python) == x ** 2
+ assert fromstring('x ** 2 ** 3',
+ language=Language.Python) == x ** 2 ** 3
assert fromstring('(x + y) * z') == (x + y) * z
assert fromstring('f(x)') == f(x)
assert fromstring('f(x,y)') == f(x, y)
assert fromstring('f[x]') == f[x]
+ assert fromstring('f[x][y]') == f[x][y]
assert fromstring('"ABC"') == s
- assert fromstring('"ABC" // "123" ') == s // t
+ assert normalize(fromstring('"ABC" // "123" ',
+ language=Language.Fortran)) == s // t
assert fromstring('f("ABC")') == f(s)
assert fromstring('MYSTRKIND_"ABC"') == as_string('"ABC"', 'MYSTRKIND')
@@ -288,6 +309,16 @@ class TestSymbolic(util.F2PyTest):
age=as_number(50),
shape=as_array((as_number(34), as_number(23)))))
+ assert fromstring('x?y:z') == as_ternary(x, y, z)
+
+ assert fromstring('*x') == as_deref(x)
+ assert fromstring('**x') == as_deref(as_deref(x))
+ assert fromstring('&x') == as_ref(x)
+ assert fromstring('(*x) * (*y)') == as_deref(x) * as_deref(y)
+ assert fromstring('(*x) * *y') == as_deref(x) * as_deref(y)
+ assert fromstring('*x * *y') == as_deref(x) * as_deref(y)
+ assert fromstring('*x**y') == as_deref(x) * as_deref(y)
+
def test_traverse(self):
x = as_symbol('x')
y = as_symbol('y')