summaryrefslogtreecommitdiff
path: root/numpy/fft/tests/test_fftpack.py
diff options
context:
space:
mode:
authorAlex Stewart <foogod@gmail.com>2014-05-06 15:54:05 -0700
committerAlex Stewart <foogod@gmail.com>2014-05-06 15:54:05 -0700
commite4adab82ec64122dc953a40663e6d71aa7af0fee (patch)
tree940c11612fbcc381eab5a059e05816f8741ae6c4 /numpy/fft/tests/test_fftpack.py
parent5286f5e949214c6e6a7485109ea8ee861fe337d7 (diff)
downloadnumpy-e4adab82ec64122dc953a40663e6d71aa7af0fee.tar.gz
Make TestFFTThreadSafe tests Py3-compatible
Diffstat (limited to 'numpy/fft/tests/test_fftpack.py')
-rw-r--r--numpy/fft/tests/test_fftpack.py23
1 files changed, 16 insertions, 7 deletions
diff --git a/numpy/fft/tests/test_fftpack.py b/numpy/fft/tests/test_fftpack.py
index 13db3e682..ba19e5594 100644
--- a/numpy/fft/tests/test_fftpack.py
+++ b/numpy/fft/tests/test_fftpack.py
@@ -1,8 +1,15 @@
from __future__ import division, absolute_import, print_function
import numpy as np
-from numpy.testing import TestCase, run_module_suite, assert_array_equal, assert_array_almost_equal
-import threading, Queue
+from numpy.testing import TestCase, run_module_suite, assert_array_almost_equal
+from numpy.testing import assert_array_equal
+import threading
+import sys
+if sys.version_info[0] >= 3:
+ import queue
+else:
+ import Queue as queue
+
def fft1(x):
L = len(x)
@@ -10,6 +17,7 @@ def fft1(x):
phase = np.arange(L).reshape(-1, 1) * phase
return np.sum(x*np.exp(phase), axis=1)
+
class TestFFTShift(TestCase):
def test_fft_n(self):
@@ -25,23 +33,24 @@ class TestFFT1D(TestCase):
class TestFFTThreadSafe(TestCase):
- threads = 32
+ threads = 16
input_shape = (1000, 1000)
def _test_mtsame(self, func, *args):
def worker(args, q):
q.put(func(*args))
- q = Queue.Queue()
+ q = queue.Queue()
expected = func(*args)
# Spin off a bunch of threads to call the same function simultaneously
- for i in xrange(self.threads):
+ for i in range(self.threads):
threading.Thread(target=worker, args=(args, q)).start()
# Make sure all threads returned the correct value
- for i in xrange(self.threads):
- assert_array_equal(q.get(), expected, 'Function returned wrong value in multithreaded context')
+ for i in range(self.threads):
+ assert_array_equal(q.get(timeout=5), expected,
+ 'Function returned wrong value in multithreaded context')
def test_fft(self):
a = np.ones(self.input_shape) * 1+0j