summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-06-09 09:06:19 -0600
committerCharles Harris <charlesr.harris@gmail.com>2016-06-09 09:06:19 -0600
commitd4a39b0f0e7d5b999b834b86f4beababe23e77f1 (patch)
treeaceded64acc079e83bb8036c163bd054db006a9d
parentd69c1470ed09f18293b9a9dec1809e14b5b9b779 (diff)
parent820d527fe0bc6414ad19b10eea941a7fc0749757 (diff)
downloadnumpy-d4a39b0f0e7d5b999b834b86f4beababe23e77f1.tar.gz
Merge pull request #7712 from krischer/fft-cache-race-condition
BUG: Fix race condition with new FFT cache
-rw-r--r--numpy/fft/fftpack.py15
-rw-r--r--numpy/fft/helper.py93
-rw-r--r--numpy/fft/tests/test_helper.py110
3 files changed, 124 insertions, 94 deletions
diff --git a/numpy/fft/fftpack.py b/numpy/fft/fftpack.py
index 78cf214d2..8dc3eccbc 100644
--- a/numpy/fft/fftpack.py
+++ b/numpy/fft/fftpack.py
@@ -55,12 +55,13 @@ def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti,
raise ValueError("Invalid number of FFT data points (%d) specified."
% n)
- try:
- # Thread-safety note: We rely on list.pop() here to atomically
- # retrieve-and-remove a wsave from the cache. This ensures that no
- # other thread can get the same wsave while we're using it.
- wsave = fft_cache.setdefault(n, []).pop()
- except (IndexError):
+ # We have to ensure that only a single thread can access a wsave array
+ # at any given time. Thus we remove it from the cache and insert it
+ # again after it has been used. Multiple threads might create multiple
+ # copies of the wsave array. This is intentional and a limitation of
+ # the current C code.
+ wsave = fft_cache.pop_twiddle_factors(n)
+ if wsave is None:
wsave = init_function(n)
if a.shape[axis] != n:
@@ -86,7 +87,7 @@ def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti,
# As soon as we put wsave back into the cache, another thread could pick it
# up and start using it, so we must not do this until after we're
# completely done using it ourselves.
- fft_cache[n].append(wsave)
+ fft_cache.put_twiddle_factors(n, wsave)
return r
diff --git a/numpy/fft/helper.py b/numpy/fft/helper.py
index 5d51c1a24..0832bc5a4 100644
--- a/numpy/fft/helper.py
+++ b/numpy/fft/helper.py
@@ -4,7 +4,8 @@ Discrete Fourier Transforms - helper.py
"""
from __future__ import division, absolute_import, print_function
-from collections import OrderedDict
+import collections
+import threading
from numpy.compat import integer_types
from numpy.core import (
@@ -228,7 +229,7 @@ def rfftfreq(n, d=1.0):
class _FFTCache(object):
"""
- Cache for the FFT init functions as an LRU (least recently used) cache.
+ Cache for the FFT twiddle factors as an LRU (least recently used) cache.
Parameters
----------
@@ -250,37 +251,73 @@ class _FFTCache(object):
def __init__(self, max_size_in_mb, max_item_count):
self._max_size_in_bytes = max_size_in_mb * 1024 ** 2
self._max_item_count = max_item_count
- # Much simpler than inheriting from it and having to work around
- # recursive behaviour.
- self._dict = OrderedDict()
-
- def setdefault(self, key, value):
- return self._dict.setdefault(key, value)
-
- def __getitem__(self, key):
- # pop + add to move it to the end.
- value = self._dict.pop(key)
- self._dict[key] = value
- self._prune_dict()
- return value
-
- def __setitem__(self, key, value):
- # Just setting is it not enough to move it to the end if it already
- # exists.
- try:
- del self._dict[key]
- except:
- pass
- self._dict[key] = value
- self._prune_dict()
-
- def _prune_dict(self):
+ self._dict = collections.OrderedDict()
+ self._lock = threading.Lock()
+
+ def put_twiddle_factors(self, n, factors):
+ """
+ Store twiddle factors for an FFT of length n in the cache.
+
+ Putting multiple twiddle factors for a certain n will store it multiple
+ times.
+
+ Parameters
+ ----------
+ n : int
+ Data length for the FFT.
+ factors : ndarray
+ The actual twiddle values.
+ """
+ with self._lock:
+ # Pop + later add to move it to the end for LRU behavior.
+ # Internally everything is stored in a dictionary whose values are
+ # lists.
+ try:
+ value = self._dict.pop(n)
+ except KeyError:
+ value = []
+ value.append(factors)
+ self._dict[n] = value
+ self._prune_cache()
+
+ def pop_twiddle_factors(self, n):
+ """
+ Pop twiddle factors for an FFT of length n from the cache.
+
+ Will return None if the requested twiddle factors are not available in
+ the cache.
+
+ Parameters
+ ----------
+ n : int
+ Data length for the FFT.
+
+ Returns
+ -------
+ out : ndarray or None
+ The retrieved twiddle factors if available, else None.
+ """
+ with self._lock:
+ if n not in self._dict or not self._dict[n]:
+ return None
+ # Pop + later add to move it to the end for LRU behavior.
+ all_values = self._dict.pop(n)
+ value = all_values.pop()
+ # Only put pack if there are still some arrays left in the list.
+ if all_values:
+ self._dict[n] = all_values
+ return value
+
+ def _prune_cache(self):
# Always keep at least one item.
while len(self._dict) > 1 and (
len(self._dict) > self._max_item_count or self._check_size()):
self._dict.popitem(last=False)
def _check_size(self):
- item_sizes = [_i[0].nbytes for _i in self._dict.values() if _i]
+ item_sizes = [sum(_j.nbytes for _j in _i)
+ for _i in self._dict.values() if _i]
+ if not item_sizes:
+ return False
max_size = max(self._max_size_in_bytes, 1.5 * max(item_sizes))
return sum(item_sizes) > max_size
diff --git a/numpy/fft/tests/test_helper.py b/numpy/fft/tests/test_helper.py
index 9fd0e496d..cb85755d2 100644
--- a/numpy/fft/tests/test_helper.py
+++ b/numpy/fft/tests/test_helper.py
@@ -79,86 +79,78 @@ class TestFFTCache(TestCase):
def test_basic_behaviour(self):
c = _FFTCache(max_size_in_mb=1, max_item_count=4)
- # Setting
- c[1] = [np.ones(2, dtype=np.float32)]
- c[2] = [np.zeros(2, dtype=np.float32)]
- # Getting
- assert_array_almost_equal(c[1][0], np.ones(2, dtype=np.float32))
- assert_array_almost_equal(c[2][0], np.zeros(2, dtype=np.float32))
- # Setdefault
- c.setdefault(1, [np.array([1, 2], dtype=np.float32)])
- assert_array_almost_equal(c[1][0], np.ones(2, dtype=np.float32))
- c.setdefault(3, [np.array([1, 2], dtype=np.float32)])
- assert_array_almost_equal(c[3][0], np.array([1, 2], dtype=np.float32))
-
- self.assertEqual(len(c._dict), 3)
+
+ # Put
+ c.put_twiddle_factors(1, np.ones(2, dtype=np.float32))
+ c.put_twiddle_factors(2, np.zeros(2, dtype=np.float32))
+
+ # Get
+ assert_array_almost_equal(c.pop_twiddle_factors(1),
+ np.ones(2, dtype=np.float32))
+ assert_array_almost_equal(c.pop_twiddle_factors(2),
+ np.zeros(2, dtype=np.float32))
+
+ # Nothing should be left.
+ self.assertEqual(len(c._dict), 0)
+
+ # Now put everything in twice so it can be retrieved once and each will
+ # still have one item left.
+ for _ in range(2):
+ c.put_twiddle_factors(1, np.ones(2, dtype=np.float32))
+ c.put_twiddle_factors(2, np.zeros(2, dtype=np.float32))
+ assert_array_almost_equal(c.pop_twiddle_factors(1),
+ np.ones(2, dtype=np.float32))
+ assert_array_almost_equal(c.pop_twiddle_factors(2),
+ np.zeros(2, dtype=np.float32))
+ self.assertEqual(len(c._dict), 2)
def test_automatic_pruning(self):
- # Thats around 2600 single precision samples.
+ # That's around 2600 single precision samples.
c = _FFTCache(max_size_in_mb=0.01, max_item_count=4)
- c[1] = [np.ones(200, dtype=np.float32)]
- c[2] = [np.ones(200, dtype=np.float32)]
- # Don't raise errors.
- c[1], c[2], c[1], c[2]
+ c.put_twiddle_factors(1, np.ones(200, dtype=np.float32))
+ c.put_twiddle_factors(2, np.ones(200, dtype=np.float32))
+ self.assertEqual(list(c._dict.keys()), [1, 2])
# This is larger than the limit but should still be kept.
- c[3] = [np.ones(3000, dtype=np.float32)]
- # Should exist.
- c[1], c[2], c[3]
+ c.put_twiddle_factors(3, np.ones(3000, dtype=np.float32))
+ self.assertEqual(list(c._dict.keys()), [1, 2, 3])
# Add one more.
- c[4] = [np.ones(3000, dtype=np.float32)]
-
+ c.put_twiddle_factors(4, np.ones(3000, dtype=np.float32))
# The other three should no longer exist.
- with self.assertRaises(KeyError):
- c[1]
- with self.assertRaises(KeyError):
- c[2]
- with self.assertRaises(KeyError):
- c[3]
+ self.assertEqual(list(c._dict.keys()), [4])
# Now test the max item count pruning.
c = _FFTCache(max_size_in_mb=0.01, max_item_count=2)
- c[1] = [np.empty(2)]
- c[2] = [np.empty(2)]
+ c.put_twiddle_factors(2, np.empty(2))
+ c.put_twiddle_factors(1, np.empty(2))
# Can still be accessed.
- c[2], c[1]
-
- c[3] = [np.empty(2)]
+ self.assertEqual(list(c._dict.keys()), [2, 1])
+ c.put_twiddle_factors(3, np.empty(2))
# 1 and 3 can still be accessed - c[2] has been touched least recently
# and is thus evicted.
- c[1], c[3]
-
- with self.assertRaises(KeyError):
- c[2]
-
- c[1], c[3]
+ self.assertEqual(list(c._dict.keys()), [1, 3])
# One last test. We will add a single large item that is slightly
# bigger then the cache size. Some small items can still be added.
c = _FFTCache(max_size_in_mb=0.01, max_item_count=5)
- c[1] = [np.ones(3000, dtype=np.float32)]
- c[1]
- c[2] = [np.ones(2, dtype=np.float32)]
- c[3] = [np.ones(2, dtype=np.float32)]
- c[4] = [np.ones(2, dtype=np.float32)]
- c[1], c[2], c[3], c[4]
-
- # One more big item.
- c[5] = [np.ones(3000, dtype=np.float32)]
-
- # c[1] no longer in the cache. Rest still in the cache.
- c[2], c[3], c[4], c[5]
- with self.assertRaises(KeyError):
- c[1]
+ c.put_twiddle_factors(1, np.ones(3000, dtype=np.float32))
+ c.put_twiddle_factors(2, np.ones(2, dtype=np.float32))
+ c.put_twiddle_factors(3, np.ones(2, dtype=np.float32))
+ c.put_twiddle_factors(4, np.ones(2, dtype=np.float32))
+ self.assertEqual(list(c._dict.keys()), [1, 2, 3, 4])
+
+ # One more big item. This time it is 6 smaller ones but they are
+ # counted as one big item.
+ for _ in range(6):
+ c.put_twiddle_factors(5, np.ones(500, dtype=np.float32))
+ # '1' no longer in the cache. Rest still in the cache.
+ self.assertEqual(list(c._dict.keys()), [2, 3, 4, 5])
# Another big item - should now be the only item in the cache.
- c[6] = [np.ones(4000, dtype=np.float32)]
- for _i in range(1, 6):
- with self.assertRaises(KeyError):
- c[_i]
- c[6]
+ c.put_twiddle_factors(6, np.ones(4000, dtype=np.float32))
+ self.assertEqual(list(c._dict.keys()), [6])
if __name__ == "__main__":