diff options
Diffstat (limited to 'numpy/fft/helper.py')
-rw-r--r-- | numpy/fft/helper.py | 93 |
1 files changed, 65 insertions, 28 deletions
diff --git a/numpy/fft/helper.py b/numpy/fft/helper.py index 5d51c1a24..0832bc5a4 100644 --- a/numpy/fft/helper.py +++ b/numpy/fft/helper.py @@ -4,7 +4,8 @@ Discrete Fourier Transforms - helper.py """ from __future__ import division, absolute_import, print_function -from collections import OrderedDict +import collections +import threading from numpy.compat import integer_types from numpy.core import ( @@ -228,7 +229,7 @@ def rfftfreq(n, d=1.0): class _FFTCache(object): """ - Cache for the FFT init functions as an LRU (least recently used) cache. + Cache for the FFT twiddle factors as an LRU (least recently used) cache. Parameters ---------- @@ -250,37 +251,73 @@ class _FFTCache(object): def __init__(self, max_size_in_mb, max_item_count): self._max_size_in_bytes = max_size_in_mb * 1024 ** 2 self._max_item_count = max_item_count - # Much simpler than inheriting from it and having to work around - # recursive behaviour. - self._dict = OrderedDict() - - def setdefault(self, key, value): - return self._dict.setdefault(key, value) - - def __getitem__(self, key): - # pop + add to move it to the end. - value = self._dict.pop(key) - self._dict[key] = value - self._prune_dict() - return value - - def __setitem__(self, key, value): - # Just setting is it not enough to move it to the end if it already - # exists. - try: - del self._dict[key] - except: - pass - self._dict[key] = value - self._prune_dict() - - def _prune_dict(self): + self._dict = collections.OrderedDict() + self._lock = threading.Lock() + + def put_twiddle_factors(self, n, factors): + """ + Store twiddle factors for an FFT of length n in the cache. + + Putting multiple twiddle factors for a certain n will store it multiple + times. + + Parameters + ---------- + n : int + Data length for the FFT. + factors : ndarray + The actual twiddle values. + """ + with self._lock: + # Pop + later add to move it to the end for LRU behavior. + # Internally everything is stored in a dictionary whose values are + # lists. + try: + value = self._dict.pop(n) + except KeyError: + value = [] + value.append(factors) + self._dict[n] = value + self._prune_cache() + + def pop_twiddle_factors(self, n): + """ + Pop twiddle factors for an FFT of length n from the cache. + + Will return None if the requested twiddle factors are not available in + the cache. + + Parameters + ---------- + n : int + Data length for the FFT. + + Returns + ------- + out : ndarray or None + The retrieved twiddle factors if available, else None. + """ + with self._lock: + if n not in self._dict or not self._dict[n]: + return None + # Pop + later add to move it to the end for LRU behavior. + all_values = self._dict.pop(n) + value = all_values.pop() + # Only put pack if there are still some arrays left in the list. + if all_values: + self._dict[n] = all_values + return value + + def _prune_cache(self): # Always keep at least one item. while len(self._dict) > 1 and ( len(self._dict) > self._max_item_count or self._check_size()): self._dict.popitem(last=False) def _check_size(self): - item_sizes = [_i[0].nbytes for _i in self._dict.values() if _i] + item_sizes = [sum(_j.nbytes for _j in _i) + for _i in self._dict.values() if _i] + if not item_sizes: + return False max_size = max(self._max_size_in_bytes, 1.5 * max(item_sizes)) return sum(item_sizes) > max_size |