diff options
Diffstat (limited to 'numpy/fft/tests/test_fftpack.py')
-rw-r--r-- | numpy/fft/tests/test_fftpack.py | 39 |
1 files changed, 38 insertions, 1 deletions
diff --git a/numpy/fft/tests/test_fftpack.py b/numpy/fft/tests/test_fftpack.py index 50702ab3c..13db3e682 100644 --- a/numpy/fft/tests/test_fftpack.py +++ b/numpy/fft/tests/test_fftpack.py @@ -1,7 +1,8 @@ from __future__ import division, absolute_import, print_function import numpy as np -from numpy.testing import TestCase, run_module_suite, assert_array_almost_equal +from numpy.testing import TestCase, run_module_suite, assert_array_equal, assert_array_almost_equal +import threading, Queue def fft1(x): L = len(x) @@ -23,5 +24,41 @@ class TestFFT1D(TestCase): assert_array_almost_equal(fft1(x), np.fft.fft(x)) +class TestFFTThreadSafe(TestCase): + threads = 32 + input_shape = (1000, 1000) + + def _test_mtsame(self, func, *args): + def worker(args, q): + q.put(func(*args)) + + q = Queue.Queue() + expected = func(*args) + + # Spin off a bunch of threads to call the same function simultaneously + for i in xrange(self.threads): + threading.Thread(target=worker, args=(args, q)).start() + + # Make sure all threads returned the correct value + for i in xrange(self.threads): + assert_array_equal(q.get(), expected, 'Function returned wrong value in multithreaded context') + + def test_fft(self): + a = np.ones(self.input_shape) * 1+0j + self._test_mtsame(np.fft.fft, a) + + def test_ifft(self): + a = np.ones(self.input_shape) * 1+0j + self._test_mtsame(np.fft.ifft, a) + + def test_rfft(self): + a = np.ones(self.input_shape) + self._test_mtsame(np.fft.rfft, a) + + def test_irfft(self): + a = np.ones(self.input_shape) * 1+0j + self._test_mtsame(np.fft.irfft, a) + + if __name__ == "__main__": run_module_suite() |