summaryrefslogtreecommitdiff
path: root/numpy/fft
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/fft')
-rw-r--r--numpy/fft/fftpack.py31
-rw-r--r--numpy/fft/fftpack_litemodule.c24
-rw-r--r--numpy/fft/tests/test_fftpack.py48
3 files changed, 82 insertions, 21 deletions
diff --git a/numpy/fft/fftpack.py b/numpy/fft/fftpack.py
index 9a169b04c..706fcdd2f 100644
--- a/numpy/fft/fftpack.py
+++ b/numpy/fft/fftpack.py
@@ -53,10 +53,12 @@ def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti,
raise ValueError("Invalid number of FFT data points (%d) specified." % n)
try:
- wsave = fft_cache[n]
- except(KeyError):
+ # Thread-safety note: We rely on list.pop() here to atomically
+ # retrieve-and-remove a wsave from the cache. This ensures that no
+ # other thread can get the same wsave while we're using it.
+ wsave = fft_cache.setdefault(n, []).pop()
+ except (IndexError):
wsave = init_function(n)
- fft_cache[n] = wsave
if a.shape[axis] != n:
s = list(a.shape)
@@ -77,6 +79,12 @@ def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti,
r = work_function(a, wsave)
if axis != -1:
r = swapaxes(r, axis, -1)
+
+ # As soon as we put wsave back into the cache, another thread could pick it
+ # up and start using it, so we must not do this until after we're
+ # completely done using it ourselves.
+ fft_cache[n].append(wsave)
+
return r
@@ -296,7 +304,7 @@ def rfft(a, n=None, axis=-1):
conjugates of the corresponding positive-frequency terms, and the
negative-frequency terms are therefore redundant. This function does not
compute the negative frequency terms, and the length of the transformed
- axis of the output is therefore ``n//2+1``.
+ axis of the output is therefore ``n//2 + 1``.
When ``A = rfft(a)`` and fs is the sampling frequency, ``A[0]`` contains
the zero-frequency term 0*fs, which is real due to Hermitian symmetry.
@@ -821,11 +829,16 @@ def fft2(a, s=None, axes=(-2, -1)):
--------
>>> a = np.mgrid[:5, :5][0]
>>> np.fft.fft2(a)
- array([[ 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
- [ 5.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
- [ 10.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
- [ 15.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
- [ 20.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j]])
+ array([[ 50.0 +0.j , 0.0 +0.j , 0.0 +0.j ,
+ 0.0 +0.j , 0.0 +0.j ],
+ [-12.5+17.20477401j, 0.0 +0.j , 0.0 +0.j ,
+ 0.0 +0.j , 0.0 +0.j ],
+ [-12.5 +4.0614962j , 0.0 +0.j , 0.0 +0.j ,
+ 0.0 +0.j , 0.0 +0.j ],
+ [-12.5 -4.0614962j , 0.0 +0.j , 0.0 +0.j ,
+ 0.0 +0.j , 0.0 +0.j ],
+ [-12.5-17.20477401j, 0.0 +0.j , 0.0 +0.j ,
+ 0.0 +0.j , 0.0 +0.j ]])
"""
diff --git a/numpy/fft/fftpack_litemodule.c b/numpy/fft/fftpack_litemodule.c
index 6f6a6c9f3..95da3194f 100644
--- a/numpy/fft/fftpack_litemodule.c
+++ b/numpy/fft/fftpack_litemodule.c
@@ -44,14 +44,14 @@ fftpack_cfftf(PyObject *NPY_UNUSED(self), PyObject *args)
nrepeats = PyArray_SIZE(data)/npts;
dptr = (double *)PyArray_DATA(data);
- NPY_SIGINT_ON;
Py_BEGIN_ALLOW_THREADS;
+ NPY_SIGINT_ON;
for (i = 0; i < nrepeats; i++) {
cfftf(npts, dptr, wsave);
dptr += npts*2;
}
- Py_END_ALLOW_THREADS;
NPY_SIGINT_OFF;
+ Py_END_ALLOW_THREADS;
PyArray_Free(op2, (char *)wsave);
return (PyObject *)data;
@@ -97,14 +97,14 @@ fftpack_cfftb(PyObject *NPY_UNUSED(self), PyObject *args)
nrepeats = PyArray_SIZE(data)/npts;
dptr = (double *)PyArray_DATA(data);
- NPY_SIGINT_ON;
Py_BEGIN_ALLOW_THREADS;
+ NPY_SIGINT_ON;
for (i = 0; i < nrepeats; i++) {
cfftb(npts, dptr, wsave);
dptr += npts*2;
}
- Py_END_ALLOW_THREADS;
NPY_SIGINT_OFF;
+ Py_END_ALLOW_THREADS;
PyArray_Free(op2, (char *)wsave);
return (PyObject *)data;
@@ -134,11 +134,11 @@ fftpack_cffti(PyObject *NPY_UNUSED(self), PyObject *args)
return NULL;
}
- NPY_SIGINT_ON;
Py_BEGIN_ALLOW_THREADS;
+ NPY_SIGINT_ON;
cffti(n, (double *)PyArray_DATA((PyArrayObject*)op));
- Py_END_ALLOW_THREADS;
NPY_SIGINT_OFF;
+ Py_END_ALLOW_THREADS;
return (PyObject *)op;
}
@@ -188,8 +188,8 @@ fftpack_rfftf(PyObject *NPY_UNUSED(self), PyObject *args)
dptr = (double *)PyArray_DATA(data);
- NPY_SIGINT_ON;
Py_BEGIN_ALLOW_THREADS;
+ NPY_SIGINT_ON;
for (i = 0; i < nrepeats; i++) {
memcpy((char *)(rptr+1), dptr, npts*sizeof(double));
rfftf(npts, rptr+1, wsave);
@@ -198,8 +198,8 @@ fftpack_rfftf(PyObject *NPY_UNUSED(self), PyObject *args)
rptr += rstep;
dptr += npts;
}
- Py_END_ALLOW_THREADS;
NPY_SIGINT_OFF;
+ Py_END_ALLOW_THREADS;
PyArray_Free(op2, (char *)wsave);
Py_DECREF(data);
return (PyObject *)ret;
@@ -252,8 +252,8 @@ fftpack_rfftb(PyObject *NPY_UNUSED(self), PyObject *args)
rptr = (double *)PyArray_DATA(ret);
dptr = (double *)PyArray_DATA(data);
- NPY_SIGINT_ON;
Py_BEGIN_ALLOW_THREADS;
+ NPY_SIGINT_ON;
for (i = 0; i < nrepeats; i++) {
memcpy((char *)(rptr + 1), (dptr + 2), (npts - 1)*sizeof(double));
rptr[0] = dptr[0];
@@ -261,8 +261,8 @@ fftpack_rfftb(PyObject *NPY_UNUSED(self), PyObject *args)
rptr += npts;
dptr += npts*2;
}
- Py_END_ALLOW_THREADS;
NPY_SIGINT_OFF;
+ Py_END_ALLOW_THREADS;
PyArray_Free(op2, (char *)wsave);
Py_DECREF(data);
return (PyObject *)ret;
@@ -294,11 +294,11 @@ fftpack_rffti(PyObject *NPY_UNUSED(self), PyObject *args)
if (op == NULL) {
return NULL;
}
- NPY_SIGINT_ON;
Py_BEGIN_ALLOW_THREADS;
+ NPY_SIGINT_ON;
rffti(n, (double *)PyArray_DATA((PyArrayObject*)op));
- Py_END_ALLOW_THREADS;
NPY_SIGINT_OFF;
+ Py_END_ALLOW_THREADS;
return (PyObject *)op;
}
diff --git a/numpy/fft/tests/test_fftpack.py b/numpy/fft/tests/test_fftpack.py
index 50702ab3c..ac892c83b 100644
--- a/numpy/fft/tests/test_fftpack.py
+++ b/numpy/fft/tests/test_fftpack.py
@@ -2,6 +2,14 @@ from __future__ import division, absolute_import, print_function
import numpy as np
from numpy.testing import TestCase, run_module_suite, assert_array_almost_equal
+from numpy.testing import assert_array_equal
+import threading
+import sys
+if sys.version_info[0] >= 3:
+ import queue
+else:
+ import Queue as queue
+
def fft1(x):
L = len(x)
@@ -9,6 +17,7 @@ def fft1(x):
phase = np.arange(L).reshape(-1, 1) * phase
return np.sum(x*np.exp(phase), axis=1)
+
class TestFFTShift(TestCase):
def test_fft_n(self):
@@ -23,5 +32,44 @@ class TestFFT1D(TestCase):
assert_array_almost_equal(fft1(x), np.fft.fft(x))
+class TestFFTThreadSafe(TestCase):
+ threads = 16
+ input_shape = (800, 200)
+
+ def _test_mtsame(self, func, *args):
+ def worker(args, q):
+ q.put(func(*args))
+
+ q = queue.Queue()
+ expected = func(*args)
+
+ # Spin off a bunch of threads to call the same function simultaneously
+ t = [threading.Thread(target=worker, args=(args, q))
+ for i in range(self.threads)]
+ [x.start() for x in t]
+
+ # Make sure all threads returned the correct value
+ for i in range(self.threads):
+ assert_array_equal(q.get(timeout=5), expected,
+ 'Function returned wrong value in multithreaded context')
+ [x.join() for x in t]
+
+ def test_fft(self):
+ a = np.ones(self.input_shape) * 1+0j
+ self._test_mtsame(np.fft.fft, a)
+
+ def test_ifft(self):
+ a = np.ones(self.input_shape) * 1+0j
+ self._test_mtsame(np.fft.ifft, a)
+
+ def test_rfft(self):
+ a = np.ones(self.input_shape)
+ self._test_mtsame(np.fft.rfft, a)
+
+ def test_irfft(self):
+ a = np.ones(self.input_shape) * 1+0j
+ self._test_mtsame(np.fft.irfft, a)
+
+
if __name__ == "__main__":
run_module_suite()