diff options
Diffstat (limited to 'numpy/fft')
-rw-r--r-- | numpy/fft/fftpack.py | 31 | ||||
-rw-r--r-- | numpy/fft/fftpack_litemodule.c | 24 | ||||
-rw-r--r-- | numpy/fft/tests/test_fftpack.py | 48 |
3 files changed, 82 insertions, 21 deletions
diff --git a/numpy/fft/fftpack.py b/numpy/fft/fftpack.py index 9a169b04c..706fcdd2f 100644 --- a/numpy/fft/fftpack.py +++ b/numpy/fft/fftpack.py @@ -53,10 +53,12 @@ def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti, raise ValueError("Invalid number of FFT data points (%d) specified." % n) try: - wsave = fft_cache[n] - except(KeyError): + # Thread-safety note: We rely on list.pop() here to atomically + # retrieve-and-remove a wsave from the cache. This ensures that no + # other thread can get the same wsave while we're using it. + wsave = fft_cache.setdefault(n, []).pop() + except (IndexError): wsave = init_function(n) - fft_cache[n] = wsave if a.shape[axis] != n: s = list(a.shape) @@ -77,6 +79,12 @@ def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti, r = work_function(a, wsave) if axis != -1: r = swapaxes(r, axis, -1) + + # As soon as we put wsave back into the cache, another thread could pick it + # up and start using it, so we must not do this until after we're + # completely done using it ourselves. + fft_cache[n].append(wsave) + return r @@ -296,7 +304,7 @@ def rfft(a, n=None, axis=-1): conjugates of the corresponding positive-frequency terms, and the negative-frequency terms are therefore redundant. This function does not compute the negative frequency terms, and the length of the transformed - axis of the output is therefore ``n//2+1``. + axis of the output is therefore ``n//2 + 1``. When ``A = rfft(a)`` and fs is the sampling frequency, ``A[0]`` contains the zero-frequency term 0*fs, which is real due to Hermitian symmetry. @@ -821,11 +829,16 @@ def fft2(a, s=None, axes=(-2, -1)): -------- >>> a = np.mgrid[:5, :5][0] >>> np.fft.fft2(a) - array([[ 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], - [ 5.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], - [ 10.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], - [ 15.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], - [ 20.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j]]) + array([[ 50.0 +0.j , 0.0 +0.j , 0.0 +0.j , + 0.0 +0.j , 0.0 +0.j ], + [-12.5+17.20477401j, 0.0 +0.j , 0.0 +0.j , + 0.0 +0.j , 0.0 +0.j ], + [-12.5 +4.0614962j , 0.0 +0.j , 0.0 +0.j , + 0.0 +0.j , 0.0 +0.j ], + [-12.5 -4.0614962j , 0.0 +0.j , 0.0 +0.j , + 0.0 +0.j , 0.0 +0.j ], + [-12.5-17.20477401j, 0.0 +0.j , 0.0 +0.j , + 0.0 +0.j , 0.0 +0.j ]]) """ diff --git a/numpy/fft/fftpack_litemodule.c b/numpy/fft/fftpack_litemodule.c index 6f6a6c9f3..95da3194f 100644 --- a/numpy/fft/fftpack_litemodule.c +++ b/numpy/fft/fftpack_litemodule.c @@ -44,14 +44,14 @@ fftpack_cfftf(PyObject *NPY_UNUSED(self), PyObject *args) nrepeats = PyArray_SIZE(data)/npts; dptr = (double *)PyArray_DATA(data); - NPY_SIGINT_ON; Py_BEGIN_ALLOW_THREADS; + NPY_SIGINT_ON; for (i = 0; i < nrepeats; i++) { cfftf(npts, dptr, wsave); dptr += npts*2; } - Py_END_ALLOW_THREADS; NPY_SIGINT_OFF; + Py_END_ALLOW_THREADS; PyArray_Free(op2, (char *)wsave); return (PyObject *)data; @@ -97,14 +97,14 @@ fftpack_cfftb(PyObject *NPY_UNUSED(self), PyObject *args) nrepeats = PyArray_SIZE(data)/npts; dptr = (double *)PyArray_DATA(data); - NPY_SIGINT_ON; Py_BEGIN_ALLOW_THREADS; + NPY_SIGINT_ON; for (i = 0; i < nrepeats; i++) { cfftb(npts, dptr, wsave); dptr += npts*2; } - Py_END_ALLOW_THREADS; NPY_SIGINT_OFF; + Py_END_ALLOW_THREADS; PyArray_Free(op2, (char *)wsave); return (PyObject *)data; @@ -134,11 +134,11 @@ fftpack_cffti(PyObject *NPY_UNUSED(self), PyObject *args) return NULL; } - NPY_SIGINT_ON; Py_BEGIN_ALLOW_THREADS; + NPY_SIGINT_ON; cffti(n, (double *)PyArray_DATA((PyArrayObject*)op)); - Py_END_ALLOW_THREADS; NPY_SIGINT_OFF; + Py_END_ALLOW_THREADS; return (PyObject *)op; } @@ -188,8 +188,8 @@ fftpack_rfftf(PyObject *NPY_UNUSED(self), PyObject *args) dptr = (double *)PyArray_DATA(data); - NPY_SIGINT_ON; Py_BEGIN_ALLOW_THREADS; + NPY_SIGINT_ON; for (i = 0; i < nrepeats; i++) { memcpy((char *)(rptr+1), dptr, npts*sizeof(double)); rfftf(npts, rptr+1, wsave); @@ -198,8 +198,8 @@ fftpack_rfftf(PyObject *NPY_UNUSED(self), PyObject *args) rptr += rstep; dptr += npts; } - Py_END_ALLOW_THREADS; NPY_SIGINT_OFF; + Py_END_ALLOW_THREADS; PyArray_Free(op2, (char *)wsave); Py_DECREF(data); return (PyObject *)ret; @@ -252,8 +252,8 @@ fftpack_rfftb(PyObject *NPY_UNUSED(self), PyObject *args) rptr = (double *)PyArray_DATA(ret); dptr = (double *)PyArray_DATA(data); - NPY_SIGINT_ON; Py_BEGIN_ALLOW_THREADS; + NPY_SIGINT_ON; for (i = 0; i < nrepeats; i++) { memcpy((char *)(rptr + 1), (dptr + 2), (npts - 1)*sizeof(double)); rptr[0] = dptr[0]; @@ -261,8 +261,8 @@ fftpack_rfftb(PyObject *NPY_UNUSED(self), PyObject *args) rptr += npts; dptr += npts*2; } - Py_END_ALLOW_THREADS; NPY_SIGINT_OFF; + Py_END_ALLOW_THREADS; PyArray_Free(op2, (char *)wsave); Py_DECREF(data); return (PyObject *)ret; @@ -294,11 +294,11 @@ fftpack_rffti(PyObject *NPY_UNUSED(self), PyObject *args) if (op == NULL) { return NULL; } - NPY_SIGINT_ON; Py_BEGIN_ALLOW_THREADS; + NPY_SIGINT_ON; rffti(n, (double *)PyArray_DATA((PyArrayObject*)op)); - Py_END_ALLOW_THREADS; NPY_SIGINT_OFF; + Py_END_ALLOW_THREADS; return (PyObject *)op; } diff --git a/numpy/fft/tests/test_fftpack.py b/numpy/fft/tests/test_fftpack.py index 50702ab3c..ac892c83b 100644 --- a/numpy/fft/tests/test_fftpack.py +++ b/numpy/fft/tests/test_fftpack.py @@ -2,6 +2,14 @@ from __future__ import division, absolute_import, print_function import numpy as np from numpy.testing import TestCase, run_module_suite, assert_array_almost_equal +from numpy.testing import assert_array_equal +import threading +import sys +if sys.version_info[0] >= 3: + import queue +else: + import Queue as queue + def fft1(x): L = len(x) @@ -9,6 +17,7 @@ def fft1(x): phase = np.arange(L).reshape(-1, 1) * phase return np.sum(x*np.exp(phase), axis=1) + class TestFFTShift(TestCase): def test_fft_n(self): @@ -23,5 +32,44 @@ class TestFFT1D(TestCase): assert_array_almost_equal(fft1(x), np.fft.fft(x)) +class TestFFTThreadSafe(TestCase): + threads = 16 + input_shape = (800, 200) + + def _test_mtsame(self, func, *args): + def worker(args, q): + q.put(func(*args)) + + q = queue.Queue() + expected = func(*args) + + # Spin off a bunch of threads to call the same function simultaneously + t = [threading.Thread(target=worker, args=(args, q)) + for i in range(self.threads)] + [x.start() for x in t] + + # Make sure all threads returned the correct value + for i in range(self.threads): + assert_array_equal(q.get(timeout=5), expected, + 'Function returned wrong value in multithreaded context') + [x.join() for x in t] + + def test_fft(self): + a = np.ones(self.input_shape) * 1+0j + self._test_mtsame(np.fft.fft, a) + + def test_ifft(self): + a = np.ones(self.input_shape) * 1+0j + self._test_mtsame(np.fft.ifft, a) + + def test_rfft(self): + a = np.ones(self.input_shape) + self._test_mtsame(np.fft.rfft, a) + + def test_irfft(self): + a = np.ones(self.input_shape) * 1+0j + self._test_mtsame(np.fft.irfft, a) + + if __name__ == "__main__": run_module_suite() |