summaryrefslogtreecommitdiff
path: root/numpy/fft/helper.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/fft/helper.py')
-rw-r--r--numpy/fft/helper.py93
1 files changed, 65 insertions, 28 deletions
diff --git a/numpy/fft/helper.py b/numpy/fft/helper.py
index 5d51c1a24..0832bc5a4 100644
--- a/numpy/fft/helper.py
+++ b/numpy/fft/helper.py
@@ -4,7 +4,8 @@ Discrete Fourier Transforms - helper.py
"""
from __future__ import division, absolute_import, print_function
-from collections import OrderedDict
+import collections
+import threading
from numpy.compat import integer_types
from numpy.core import (
@@ -228,7 +229,7 @@ def rfftfreq(n, d=1.0):
class _FFTCache(object):
"""
- Cache for the FFT init functions as an LRU (least recently used) cache.
+ Cache for the FFT twiddle factors as an LRU (least recently used) cache.
Parameters
----------
@@ -250,37 +251,73 @@ class _FFTCache(object):
def __init__(self, max_size_in_mb, max_item_count):
self._max_size_in_bytes = max_size_in_mb * 1024 ** 2
self._max_item_count = max_item_count
- # Much simpler than inheriting from it and having to work around
- # recursive behaviour.
- self._dict = OrderedDict()
-
- def setdefault(self, key, value):
- return self._dict.setdefault(key, value)
-
- def __getitem__(self, key):
- # pop + add to move it to the end.
- value = self._dict.pop(key)
- self._dict[key] = value
- self._prune_dict()
- return value
-
- def __setitem__(self, key, value):
- # Just setting is it not enough to move it to the end if it already
- # exists.
- try:
- del self._dict[key]
- except:
- pass
- self._dict[key] = value
- self._prune_dict()
-
- def _prune_dict(self):
+ self._dict = collections.OrderedDict()
+ self._lock = threading.Lock()
+
+ def put_twiddle_factors(self, n, factors):
+ """
+ Store twiddle factors for an FFT of length n in the cache.
+
+ Putting multiple twiddle factors for a certain n will store it multiple
+ times.
+
+ Parameters
+ ----------
+ n : int
+ Data length for the FFT.
+ factors : ndarray
+ The actual twiddle values.
+ """
+ with self._lock:
+ # Pop + later add to move it to the end for LRU behavior.
+ # Internally everything is stored in a dictionary whose values are
+ # lists.
+ try:
+ value = self._dict.pop(n)
+ except KeyError:
+ value = []
+ value.append(factors)
+ self._dict[n] = value
+ self._prune_cache()
+
+ def pop_twiddle_factors(self, n):
+ """
+ Pop twiddle factors for an FFT of length n from the cache.
+
+ Will return None if the requested twiddle factors are not available in
+ the cache.
+
+ Parameters
+ ----------
+ n : int
+ Data length for the FFT.
+
+ Returns
+ -------
+ out : ndarray or None
+ The retrieved twiddle factors if available, else None.
+ """
+ with self._lock:
+ if n not in self._dict or not self._dict[n]:
+ return None
+ # Pop + later add to move it to the end for LRU behavior.
+ all_values = self._dict.pop(n)
+ value = all_values.pop()
+ # Only put pack if there are still some arrays left in the list.
+ if all_values:
+ self._dict[n] = all_values
+ return value
+
+ def _prune_cache(self):
# Always keep at least one item.
while len(self._dict) > 1 and (
len(self._dict) > self._max_item_count or self._check_size()):
self._dict.popitem(last=False)
def _check_size(self):
- item_sizes = [_i[0].nbytes for _i in self._dict.values() if _i]
+ item_sizes = [sum(_j.nbytes for _j in _i)
+ for _i in self._dict.values() if _i]
+ if not item_sizes:
+ return False
max_size = max(self._max_size_in_bytes, 1.5 * max(item_sizes))
return sum(item_sizes) > max_size