summaryrefslogtreecommitdiff
path: root/sqlparse
diff options
context:
space:
mode:
Diffstat (limited to 'sqlparse')
-rw-r--r--sqlparse/filters.py158
1 files changed, 155 insertions, 3 deletions
diff --git a/sqlparse/filters.py b/sqlparse/filters.py
index 813be99..ce2fb80 100644
--- a/sqlparse/filters.py
+++ b/sqlparse/filters.py
@@ -2,8 +2,12 @@
import re
-from sqlparse import tokens as T
+from os.path import abspath, join
+
from sqlparse import sql
+from sqlparse import tokens as T
+from sqlparse.engine import FilterStack
+from sqlparse.tokens import Comment, Keyword, Name, Punctuation, String, Whitespace
class Filter(object):
@@ -52,6 +56,83 @@ class IdentifierCaseFilter(_CaseFilter):
yield ttype, value
+class Get_Comments(Filter):
+ """Get the comments from a stack"""
+ def process(self, stack, stream):
+ for token_type, value in stream:
+ if token_type in Comment:
+ yield token_type, value
+
+
+class StripComments(Filter):
+ """Strip the comments from a stack"""
+ def process(self, stack, stream):
+ for token_type, value in stream:
+ if token_type not in Comment:
+ yield token_type, value
+
+
+class IncludeStatement(Filter):
+ """Filter that enable a INCLUDE statement"""
+
+ def __init__(self, dirpath=".", maxRecursive=10):
+ self.dirpath = abspath(dirpath)
+ self.maxRecursive = maxRecursive
+
+ self.detected = False
+
+
+ def process(self, stack, stream):
+ # Run over all tokens in the stream
+ for token_type, value in stream:
+ # INCLUDE statement found, set detected mode
+ if token_type in Name and value.upper() == 'INCLUDE':
+ self.detected = True
+ continue
+
+ # INCLUDE statement was found, parse it
+ elif self.detected:
+ # Omit whitespaces
+ if token_type in Whitespace:
+ pass
+
+ # Get path of file to include
+ path = None
+
+ if token_type in String.Symbol:
+# if token_type in tokens.String.Symbol:
+ path = join(self.dirpath, value[1:-1])
+
+ # Include file if path was found
+ if path:
+ try:
+ with open(path) as f:
+ sql = f.read()
+
+ except IOError, err:
+ logging.error(err)
+ yield Comment, u'-- IOError: %s\n' % err
+
+ else:
+ # Create new FilterStack to parse readed file
+ # and add all its tokens to the main stack recursively
+ # [ToDo] Add maximum recursive iteration value
+ stack = FilterStack()
+ stack.preprocess.append(IncludeStatement(self.dirpath))
+
+ for tv in stack.run(sql):
+ yield tv
+
+ # Set normal mode
+ self.detected = False
+
+ # Don't include any token while in detected mode
+ continue
+
+ # Normal token
+ yield token_type, value
+
+
# ----------------------
# statement process
@@ -150,9 +231,9 @@ class ReindentFilter(Filter):
t = tlist.token_next_match(i, T.Keyword, split_words,
regex=True)
if t and t.value.upper() == 'BETWEEN':
- t = _next_token(tlist.token_index(t)+1)
+ t = _next_token(tlist.token_index(t) + 1)
if t and t.value.upper() == 'AND':
- t = _next_token(tlist.token_index(t)+1)
+ t = _next_token(tlist.token_index(t) + 1)
return t
idx = 0
@@ -316,6 +397,56 @@ class RightMarginFilter(Filter):
group.tokens = self._process(stack, group, group.tokens)
+class ColumnsSelect(Filter):
+ """Get the columns names of a SELECT query"""
+ def process(self, stack, stream):
+ mode = 0
+ oldValue = ""
+ parenthesis = 0
+
+ for token_type, value in stream:
+ # Ignore comments
+ if token_type in Comment:
+ continue
+
+ # We have not detected a SELECT statement
+ if mode == 0:
+ if token_type in Keyword and value == 'SELECT':
+ mode = 1
+
+ # We have detected a SELECT statement
+ elif mode == 1:
+ if value == 'FROM':
+ if oldValue:
+ yield Name, oldValue
+
+ mode = 3 # Columns have been checked
+
+ elif value == 'AS':
+ oldValue = ""
+ mode = 2
+
+ elif token_type == Punctuation and value == ',' and not parenthesis:
+ if oldValue:
+ yield Name, oldValue
+ oldValue = ""
+
+ elif token_type not in Whitespace:
+ if value == '(':
+ parenthesis += 1
+ elif value == ')':
+ parenthesis -= 1
+
+ oldValue += value
+
+ # We are processing an AS keyword
+ elif mode == 2:
+ # We check also for Keywords because a bug in SQLParse
+ if token_type == Name or token_type == Keyword:
+ yield Name, value
+ mode = 1
+
+
# ---------------------------
# postprocess
@@ -422,3 +553,24 @@ class OutputPHPFilter(Filter):
varname = self.varname
stmt.tokens = tuple(self._process(stmt.tokens, varname))
return stmt
+
+
+class Limit(Filter):
+ """Get the LIMIT of a query.
+
+ If not defined, return -1 (SQL specification for no LIMIT query)
+ """
+ def process(self, stack, stream):
+ index = 7
+ stream = list(stream)
+ stream.reverse()
+
+ # Run over all tokens in the stream from the end
+ for token_type, value in stream:
+ index -= 1
+
+# if index and token_type in Keyword:
+ if index and token_type in Keyword and value == 'LIMIT':
+ return stream[4 - index][1]
+
+ return -1 \ No newline at end of file