diff options
Diffstat (limited to 'numpy/fft/fftpack.py')
-rw-r--r-- | numpy/fft/fftpack.py | 31 |
1 files changed, 22 insertions, 9 deletions
diff --git a/numpy/fft/fftpack.py b/numpy/fft/fftpack.py index 9a169b04c..706fcdd2f 100644 --- a/numpy/fft/fftpack.py +++ b/numpy/fft/fftpack.py @@ -53,10 +53,12 @@ def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti, raise ValueError("Invalid number of FFT data points (%d) specified." % n) try: - wsave = fft_cache[n] - except(KeyError): + # Thread-safety note: We rely on list.pop() here to atomically + # retrieve-and-remove a wsave from the cache. This ensures that no + # other thread can get the same wsave while we're using it. + wsave = fft_cache.setdefault(n, []).pop() + except (IndexError): wsave = init_function(n) - fft_cache[n] = wsave if a.shape[axis] != n: s = list(a.shape) @@ -77,6 +79,12 @@ def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti, r = work_function(a, wsave) if axis != -1: r = swapaxes(r, axis, -1) + + # As soon as we put wsave back into the cache, another thread could pick it + # up and start using it, so we must not do this until after we're + # completely done using it ourselves. + fft_cache[n].append(wsave) + return r @@ -296,7 +304,7 @@ def rfft(a, n=None, axis=-1): conjugates of the corresponding positive-frequency terms, and the negative-frequency terms are therefore redundant. This function does not compute the negative frequency terms, and the length of the transformed - axis of the output is therefore ``n//2+1``. + axis of the output is therefore ``n//2 + 1``. When ``A = rfft(a)`` and fs is the sampling frequency, ``A[0]`` contains the zero-frequency term 0*fs, which is real due to Hermitian symmetry. @@ -821,11 +829,16 @@ def fft2(a, s=None, axes=(-2, -1)): -------- >>> a = np.mgrid[:5, :5][0] >>> np.fft.fft2(a) - array([[ 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], - [ 5.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], - [ 10.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], - [ 15.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], - [ 20.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j]]) + array([[ 50.0 +0.j , 0.0 +0.j , 0.0 +0.j , + 0.0 +0.j , 0.0 +0.j ], + [-12.5+17.20477401j, 0.0 +0.j , 0.0 +0.j , + 0.0 +0.j , 0.0 +0.j ], + [-12.5 +4.0614962j , 0.0 +0.j , 0.0 +0.j , + 0.0 +0.j , 0.0 +0.j ], + [-12.5 -4.0614962j , 0.0 +0.j , 0.0 +0.j , + 0.0 +0.j , 0.0 +0.j ], + [-12.5-17.20477401j, 0.0 +0.j , 0.0 +0.j , + 0.0 +0.j , 0.0 +0.j ]]) """ |