summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2019-01-10 10:01:05 -0700
committerGitHub <noreply@github.com>2019-01-10 10:01:05 -0700
commit83a16723876994120f13537103f70aadb5bf7475 (patch)
treedaa80af2473718abe6e679f0ab64253277a7a163 /numpy
parent49a24687c76d7fc99cf82eb1912776cf5501feb5 (diff)
parent4bba58a862d10c3e3ffd26441854e3aaf24ca6fc (diff)
downloadnumpy-83a16723876994120f13537103f70aadb5bf7475.tar.gz
Merge pull request #12702 from rth/fft-3d-fortran-test
TST: Check FFT results for C/Fortran ordered and non contigous input
Diffstat (limited to 'numpy')
-rw-r--r--numpy/fft/tests/test_pocketfft.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/numpy/fft/tests/test_pocketfft.py b/numpy/fft/tests/test_pocketfft.py
index 1029294a1..08f80076a 100644
--- a/numpy/fft/tests/test_pocketfft.py
+++ b/numpy/fft/tests/test_pocketfft.py
@@ -166,6 +166,44 @@ class TestFFT1D(object):
assert_array_almost_equal(np.fft.ifft(np.fft.fft(x)), x)
assert_array_almost_equal(np.fft.irfft(np.fft.rfft(x)), x)
+
+@pytest.mark.parametrize(
+ "dtype",
+ [np.float32, np.float64, np.complex64, np.complex128])
+@pytest.mark.parametrize("order", ["F", 'non-contiguous'])
+@pytest.mark.parametrize(
+ "fft",
+ [np.fft.fft, np.fft.fft2, np.fft.fftn,
+ np.fft.ifft, np.fft.ifft2, np.fft.ifftn])
+def test_fft_with_order(dtype, order, fft):
+ # Check that FFT/IFFT produces identical results for C, Fortran and
+ # non contiguous arrays
+ rng = np.random.RandomState(42)
+ X = rng.rand(8, 7, 13).astype(dtype, copy=False)
+ if order == 'F':
+ Y = np.asfortranarray(X)
+ else:
+ # Make a non contiguous array
+ Y = X[::-1]
+ X = np.ascontiguousarray(X[::-1])
+
+ if fft.__name__.endswith('fft'):
+ for axis in range(3):
+ X_res = fft(X, axis=axis)
+ Y_res = fft(Y, axis=axis)
+ assert_array_almost_equal(X_res, Y_res)
+ elif fft.__name__.endswith(('fft2', 'fftn')):
+ axes = [(0, 1), (1, 2), (0, 2)]
+ if fft.__name__.endswith('fftn'):
+ axes.extend([(0,), (1,), (2,), None])
+ for ax in axes:
+ X_res = fft(X, axes=ax)
+ Y_res = fft(Y, axes=ax)
+ assert_array_almost_equal(X_res, Y_res)
+ else:
+ raise ValueError
+
+
class TestFFTThreadSafe(object):
threads = 16
input_shape = (800, 200)