diff options
author | Alex Stewart <foogod@gmail.com> | 2014-05-06 15:54:05 -0700 |
---|---|---|
committer | Alex Stewart <foogod@gmail.com> | 2014-05-06 15:54:05 -0700 |
commit | e4adab82ec64122dc953a40663e6d71aa7af0fee (patch) | |
tree | 940c11612fbcc381eab5a059e05816f8741ae6c4 /numpy/fft/tests/test_fftpack.py | |
parent | 5286f5e949214c6e6a7485109ea8ee861fe337d7 (diff) | |
download | numpy-e4adab82ec64122dc953a40663e6d71aa7af0fee.tar.gz |
Make TestFFTThreadSafe tests Py3-compatible
Diffstat (limited to 'numpy/fft/tests/test_fftpack.py')
-rw-r--r-- | numpy/fft/tests/test_fftpack.py | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/numpy/fft/tests/test_fftpack.py b/numpy/fft/tests/test_fftpack.py index 13db3e682..ba19e5594 100644 --- a/numpy/fft/tests/test_fftpack.py +++ b/numpy/fft/tests/test_fftpack.py @@ -1,8 +1,15 @@ from __future__ import division, absolute_import, print_function import numpy as np -from numpy.testing import TestCase, run_module_suite, assert_array_equal, assert_array_almost_equal -import threading, Queue +from numpy.testing import TestCase, run_module_suite, assert_array_almost_equal +from numpy.testing import assert_array_equal +import threading +import sys +if sys.version_info[0] >= 3: + import queue +else: + import Queue as queue + def fft1(x): L = len(x) @@ -10,6 +17,7 @@ def fft1(x): phase = np.arange(L).reshape(-1, 1) * phase return np.sum(x*np.exp(phase), axis=1) + class TestFFTShift(TestCase): def test_fft_n(self): @@ -25,23 +33,24 @@ class TestFFT1D(TestCase): class TestFFTThreadSafe(TestCase): - threads = 32 + threads = 16 input_shape = (1000, 1000) def _test_mtsame(self, func, *args): def worker(args, q): q.put(func(*args)) - q = Queue.Queue() + q = queue.Queue() expected = func(*args) # Spin off a bunch of threads to call the same function simultaneously - for i in xrange(self.threads): + for i in range(self.threads): threading.Thread(target=worker, args=(args, q)).start() # Make sure all threads returned the correct value - for i in xrange(self.threads): - assert_array_equal(q.get(), expected, 'Function returned wrong value in multithreaded context') + for i in range(self.threads): + assert_array_equal(q.get(timeout=5), expected, + 'Function returned wrong value in multithreaded context') def test_fft(self): a = np.ones(self.input_shape) * 1+0j |