summaryrefslogtreecommitdiff
path: root/numpy/fft/tests
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/fft/tests')
-rw-r--r--numpy/fft/tests/test_fftpack.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/numpy/fft/tests/test_fftpack.py b/numpy/fft/tests/test_fftpack.py
index 50702ab3c..ac892c83b 100644
--- a/numpy/fft/tests/test_fftpack.py
+++ b/numpy/fft/tests/test_fftpack.py
@@ -2,6 +2,14 @@ from __future__ import division, absolute_import, print_function
import numpy as np
from numpy.testing import TestCase, run_module_suite, assert_array_almost_equal
+from numpy.testing import assert_array_equal
+import threading
+import sys
+if sys.version_info[0] >= 3:
+ import queue
+else:
+ import Queue as queue
+
def fft1(x):
L = len(x)
@@ -9,6 +17,7 @@ def fft1(x):
phase = np.arange(L).reshape(-1, 1) * phase
return np.sum(x*np.exp(phase), axis=1)
+
class TestFFTShift(TestCase):
def test_fft_n(self):
@@ -23,5 +32,44 @@ class TestFFT1D(TestCase):
assert_array_almost_equal(fft1(x), np.fft.fft(x))
+class TestFFTThreadSafe(TestCase):
+ threads = 16
+ input_shape = (800, 200)
+
+ def _test_mtsame(self, func, *args):
+ def worker(args, q):
+ q.put(func(*args))
+
+ q = queue.Queue()
+ expected = func(*args)
+
+ # Spin off a bunch of threads to call the same function simultaneously
+ t = [threading.Thread(target=worker, args=(args, q))
+ for i in range(self.threads)]
+ [x.start() for x in t]
+
+ # Make sure all threads returned the correct value
+ for i in range(self.threads):
+ assert_array_equal(q.get(timeout=5), expected,
+ 'Function returned wrong value in multithreaded context')
+ [x.join() for x in t]
+
+ def test_fft(self):
+ a = np.ones(self.input_shape) * 1+0j
+ self._test_mtsame(np.fft.fft, a)
+
+ def test_ifft(self):
+ a = np.ones(self.input_shape) * 1+0j
+ self._test_mtsame(np.fft.ifft, a)
+
+ def test_rfft(self):
+ a = np.ones(self.input_shape)
+ self._test_mtsame(np.fft.rfft, a)
+
+ def test_irfft(self):
+ a = np.ones(self.input_shape) * 1+0j
+ self._test_mtsame(np.fft.irfft, a)
+
+
if __name__ == "__main__":
run_module_suite()