summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorPearu Peterson <pearu.peterson@gmail.com>2022-01-04 09:22:45 +0200
committerGitHub <noreply@github.com>2022-01-04 09:22:45 +0200
commit20f972ce3693d7700afe51898089613ebe4b3ee5 (patch)
treed920706021c35f898b6cb5d057e7de834fe0dad2 /numpy
parent79137960472c1f723b983feff4663c3b60e64eb0 (diff)
downloadnumpy-20f972ce3693d7700afe51898089613ebe4b3ee5.tar.gz
BUG: Fix array dimensions solver for multidimensional arguments in f2py (#20721)
* BUG: Fix array dimensions solver for multidimensional arguments in f2py. See #20709
Diffstat (limited to 'numpy')
-rwxr-xr-xnumpy/f2py/crackfortran.py5
-rw-r--r--numpy/f2py/tests/test_crackfortran.py8
2 files changed, 9 insertions, 4 deletions
diff --git a/numpy/f2py/crackfortran.py b/numpy/f2py/crackfortran.py
index b02eb68b7..824d87e4c 100755
--- a/numpy/f2py/crackfortran.py
+++ b/numpy/f2py/crackfortran.py
@@ -2591,7 +2591,10 @@ def analyzevars(block):
if dsize.contains(s):
try:
a, b = dsize.linear_solve(s)
- solve_v = lambda s: (s - b) / a
+
+ def solve_v(s, a=a, b=b):
+ return (s - b) / a
+
all_symbols = set(a.symbols())
all_symbols.update(b.symbols())
except RuntimeError as msg:
diff --git a/numpy/f2py/tests/test_crackfortran.py b/numpy/f2py/tests/test_crackfortran.py
index 0b47264ad..fb47eb31d 100644
--- a/numpy/f2py/tests/test_crackfortran.py
+++ b/numpy/f2py/tests/test_crackfortran.py
@@ -147,17 +147,19 @@ class TestDimSpec(util.F2PyTest):
""")
linear_dimspecs = [
- "n", "2*n", "2:n", "n/2", "5 - n/2", "3*n:20", "n*(n+1):n*(n+5)"
+ "n", "2*n", "2:n", "n/2", "5 - n/2", "3*n:20", "n*(n+1):n*(n+5)",
+ "2*n, n"
]
nonlinear_dimspecs = ["2*n:3*n*n+2*n"]
all_dimspecs = linear_dimspecs + nonlinear_dimspecs
code = ""
for count, dimspec in enumerate(all_dimspecs):
+ lst = [(d.split(":")[0] if ":" in d else "1") for d in dimspec.split(',')]
code += code_template.format(
count=count,
dimspec=dimspec,
- first=dimspec.split(":")[0] if ":" in dimspec else "1",
+ first=", ".join(lst),
)
@pytest.mark.parametrize("dimspec", all_dimspecs)
@@ -168,7 +170,7 @@ class TestDimSpec(util.F2PyTest):
for n in [1, 2, 3, 4, 5]:
sz, a = get_arr_size(n)
- assert len(a) == sz
+ assert a.size == sz
@pytest.mark.parametrize("dimspec", all_dimspecs)
def test_inv_array_size(self, dimspec):