summaryrefslogtreecommitdiff
path: root/sqlparse
diff options
context:
space:
mode:
Diffstat (limited to 'sqlparse')
-rw-r--r--sqlparse/filters.py47
-rw-r--r--sqlparse/utils.py46
2 files changed, 81 insertions, 12 deletions
diff --git a/sqlparse/filters.py b/sqlparse/filters.py
index 291a613..08c1f69 100644
--- a/sqlparse/filters.py
+++ b/sqlparse/filters.py
@@ -4,11 +4,11 @@ import re
from os.path import abspath, join
-from sqlparse import sql
-from sqlparse import tokens as T
+from sqlparse import sql, tokens as T
from sqlparse.engine import FilterStack
from sqlparse.tokens import (Comment, Comparison, Keyword, Name, Punctuation,
String, Whitespace)
+from sqlparse.utils import memoize_generator
# --------------------------
@@ -92,12 +92,17 @@ def StripWhitespace(stream):
class IncludeStatement:
"""Filter that enable a INCLUDE statement"""
- def __init__(self, dirpath=".", maxRecursive=10):
+ def __init__(self, dirpath=".", maxrecursive=10, raiseexceptions=False):
+ if maxrecursive <= 0:
+ raise ValueError('Max recursion limit reached')
+
self.dirpath = abspath(dirpath)
- self.maxRecursive = maxRecursive
+ self.maxRecursive = maxrecursive
+ self.raiseexceptions = raiseexceptions
self.detected = False
+ @memoize_generator
def process(self, stack, stream):
# Run over all tokens in the stream
for token_type, value in stream:
@@ -110,30 +115,48 @@ class IncludeStatement:
elif self.detected:
# Omit whitespaces
if token_type in Whitespace:
- pass
-
- # Get path of file to include
- path = None
+ continue
+ # Found file path to include
if token_type in String.Symbol:
# if token_type in tokens.String.Symbol:
+
+ # Get path of file to include
path = join(self.dirpath, value[1:-1])
- # Include file if path was found
- if path:
try:
f = open(path)
raw_sql = f.read()
f.close()
+
+ # There was a problem loading the include file
except IOError, err:
+ # Raise the exception to the interpreter
+ if self.raiseexceptions:
+ raise
+
+ # Put the exception as a comment on the SQL code
yield Comment, u'-- IOError: %s\n' % err
else:
# Create new FilterStack to parse readed file
# and add all its tokens to the main stack recursively
- # [ToDo] Add maximum recursive iteration value
+ try:
+ filtr = IncludeStatement(self.dirpath,
+ self.maxRecursive - 1,
+ self.raiseexceptions)
+
+ # Max recursion limit reached
+ except ValueError, err:
+ # Raise the exception to the interpreter
+ if self.raiseexceptions:
+ raise
+
+ # Put the exception as a comment on the SQL code
+ yield Comment, u'-- ValueError: %s\n' % err
+
stack = FilterStack()
- stack.preprocess.append(IncludeStatement(self.dirpath))
+ stack.preprocess.append(filtr)
for tv in stack.run(raw_sql):
yield tv
diff --git a/sqlparse/utils.py b/sqlparse/utils.py
new file mode 100644
index 0000000..fd6651a
--- /dev/null
+++ b/sqlparse/utils.py
@@ -0,0 +1,46 @@
+'''
+Created on 17/05/2012
+
+@author: piranna
+'''
+
+
+def memoize_generator(func):
+ """Memoize decorator for generators
+
+ Store `func` results in a cache according to their arguments as 'memoize'
+ does but instead this works on decorators instead of regular functions.
+ Obviusly, this is only useful if the generator will always return the same
+ values for each specific parameters...
+ """
+ cache = {}
+
+ def wrapped_func(*args, **kwargs):
+ params = (args, kwargs)
+
+ # Look if cached
+ try:
+ cached = cache[params]
+
+ # Not cached, exec and store it
+ except KeyError:
+ # Reset the cache if we have too much cached entries and start over
+ # In the future would be better to use an OrderedDict and drop the
+ # Least Recent Used entries
+ if len(cache) >= 10:
+ cache.clear()
+
+ cached = []
+
+ for item in func(*args, **kwargs):
+ cached.append(item)
+ yield item
+
+ cache[params] = cached
+
+ # Cached, yield its items
+ else:
+ for item in cached:
+ yield item
+
+ return wrapped_func