summaryrefslogtreecommitdiff
path: root/sqlparse
diff options
context:
space:
mode:
Diffstat (limited to 'sqlparse')
-rw-r--r--sqlparse/engine/filter.py13
-rw-r--r--sqlparse/filters.py146
-rw-r--r--sqlparse/sql.py54
3 files changed, 105 insertions, 108 deletions
diff --git a/sqlparse/engine/filter.py b/sqlparse/engine/filter.py
index 9ea9703..9af2f99 100644
--- a/sqlparse/engine/filter.py
+++ b/sqlparse/engine/filter.py
@@ -4,21 +4,10 @@ from sqlparse.sql import Statement, Token
from sqlparse import tokens as T
-class TokenFilter(object):
-
- def __init__(self, **options):
- self.options = options
-
- def process(self, stack, stream):
- """Process token stream."""
- raise NotImplementedError
-
-
-class StatementFilter(TokenFilter):
+class StatementFilter:
"Filter that split stream at individual statements"
def __init__(self):
- TokenFilter.__init__(self)
self._in_declare = False
self._in_dbldollar = False
self._is_create = False
diff --git a/sqlparse/filters.py b/sqlparse/filters.py
index 99ef80c..291a613 100644
--- a/sqlparse/filters.py
+++ b/sqlparse/filters.py
@@ -11,22 +11,10 @@ from sqlparse.tokens import (Comment, Comparison, Keyword, Name, Punctuation,
String, Whitespace)
-class Filter(object):
-
- def process(self, *args):
- raise NotImplementedError
-
-
-class TokenFilter(Filter):
-
- def process(self, stack, stream):
- raise NotImplementedError
-
-
# --------------------------
# token process
-class _CaseFilter(TokenFilter):
+class _CaseFilter:
ttype = None
@@ -57,7 +45,7 @@ class IdentifierCaseFilter(_CaseFilter):
yield ttype, value
-class GetComments(Filter):
+class GetComments:
"""Get the comments from a stack"""
def process(self, stack, stream):
for token_type, value in stream:
@@ -65,7 +53,7 @@ class GetComments(Filter):
yield token_type, value
-class StripComments(Filter):
+class StripComments:
"""Strip the comments from a stack"""
def process(self, stack, stream):
for token_type, value in stream:
@@ -101,7 +89,7 @@ def StripWhitespace(stream):
last_type = token_type
-class IncludeStatement(Filter):
+class IncludeStatement:
"""Filter that enable a INCLUDE statement"""
def __init__(self, dirpath=".", maxRecursive=10):
@@ -163,7 +151,7 @@ class IncludeStatement(Filter):
# ----------------------
# statement process
-class StripCommentsFilter(Filter):
+class StripCommentsFilter:
def _get_next_comment(self, tlist):
# TODO(andi) Comment types should be unified, see related issue38
@@ -194,7 +182,7 @@ class StripCommentsFilter(Filter):
self._process(stmt)
-class StripWhitespaceFilter(Filter):
+class StripWhitespaceFilter:
def _stripws(self, tlist):
func_name = '_stripws_%s' % tlist.__class__.__name__.lower()
@@ -226,7 +214,7 @@ class StripWhitespaceFilter(Filter):
stmt.tokens.pop(-1)
-class ReindentFilter(Filter):
+class ReindentFilter:
def __init__(self, width=2, char=' ', line_width=None):
self.width = width
@@ -391,7 +379,7 @@ class ReindentFilter(Filter):
# FIXME: Doesn't work ;)
-class RightMarginFilter(Filter):
+class RightMarginFilter:
keep_together = (
# sql.TypeCast, sql.Identifier, sql.Alias,
@@ -429,7 +417,7 @@ class RightMarginFilter(Filter):
group.tokens = self._process(stack, group, group.tokens)
-class ColumnsSelect(Filter):
+class ColumnsSelect:
"""Get the columns names of a SELECT query"""
def process(self, stack, stream):
mode = 0
@@ -483,7 +471,7 @@ class ColumnsSelect(Filter):
# ---------------------------
# postprocess
-class SerializerUnicode(Filter):
+class SerializerUnicode:
def process(self, stack, stmt):
raw = unicode(stmt)
@@ -503,14 +491,32 @@ def Tokens2Unicode(stream):
return result
-class OutputPythonFilter(Filter):
+class OutputFilter:
+ varname_prefix = ''
def __init__(self, varname='sql'):
- self.varname = varname
- self.cnt = 0
+ self.varname = self.varname_prefix + varname
+ self.count = 0
+
+ def _process(self, stream, varname, has_nl):
+ raise NotImplementedError
- def _process(self, stream, varname, count, has_nl):
- if count > 1:
+ def process(self, stack, stmt):
+ self.count += 1
+ if self.count > 1:
+ varname = '%s%d' % (self.varname, self.count)
+ else:
+ varname = self.varname
+
+ has_nl = len(unicode(stmt).strip().splitlines()) > 1
+ stmt.tokens = self._process(stmt.tokens, varname, has_nl)
+ return stmt
+
+
+class OutputPythonFilter(OutputFilter):
+ def _process(self, stream, varname, has_nl):
+ # SQL query asignation to varname
+ if self.count > 1:
yield sql.Token(T.Whitespace, '\n')
yield sql.Token(T.Name, varname)
yield sql.Token(T.Whitespace, ' ')
@@ -519,85 +525,87 @@ class OutputPythonFilter(Filter):
if has_nl:
yield sql.Token(T.Operator, '(')
yield sql.Token(T.Text, "'")
- cnt = 0
+
+ # Print the tokens on the quote
for token in stream:
- cnt += 1
+ # Token is a new line separator
if token.is_whitespace() and '\n' in token.value:
- if cnt == 1:
- continue
- after_lb = token.value.split('\n', 1)[1]
+ # Close quote and add a new line
yield sql.Token(T.Text, " '")
yield sql.Token(T.Whitespace, '\n')
- for i in range(len(varname) + 4):
- yield sql.Token(T.Whitespace, ' ')
+
+ # Quote header on secondary lines
+ yield sql.Token(T.Whitespace, ' ' * (len(varname) + 4))
yield sql.Token(T.Text, "'")
- if after_lb: # it's the indendation
+
+ # Indentation
+ after_lb = token.value.split('\n', 1)[1]
+ if after_lb:
yield sql.Token(T.Whitespace, after_lb)
continue
- elif token.value and "'" in token.value:
+
+ # Token has escape chars
+ elif "'" in token.value:
token.value = token.value.replace("'", "\\'")
- yield sql.Token(T.Text, token.value or '')
+
+ # Put the token
+ yield sql.Token(T.Text, token.value)
+
+ # Close quote
yield sql.Token(T.Text, "'")
if has_nl:
yield sql.Token(T.Operator, ')')
- def process(self, stack, stmt):
- self.cnt += 1
- if self.cnt > 1:
- varname = '%s%d' % (self.varname, self.cnt)
- else:
- varname = self.varname
- has_nl = len(unicode(stmt).strip().splitlines()) > 1
- stmt.tokens = self._process(stmt.tokens, varname, self.cnt, has_nl)
- return stmt
+class OutputPHPFilter(OutputFilter):
+ varname_prefix = '$'
-class OutputPHPFilter(Filter):
-
- def __init__(self, varname='sql'):
- self.varname = '$%s' % varname
- self.count = 0
-
- def _process(self, stream, varname):
+ def _process(self, stream, varname, has_nl):
+ # SQL query asignation to varname (quote header)
if self.count > 1:
yield sql.Token(T.Whitespace, '\n')
yield sql.Token(T.Name, varname)
yield sql.Token(T.Whitespace, ' ')
+ if has_nl:
+ yield sql.Token(T.Whitespace, ' ')
yield sql.Token(T.Operator, '=')
yield sql.Token(T.Whitespace, ' ')
yield sql.Token(T.Text, '"')
+
+ # Print the tokens on the quote
for token in stream:
+ # Token is a new line separator
if token.is_whitespace() and '\n' in token.value:
- after_lb = token.value.split('\n', 1)[1]
- yield sql.Token(T.Text, ' "')
- yield sql.Token(T.Operator, ';')
+ # Close quote and add a new line
+ yield sql.Token(T.Text, ' ";')
yield sql.Token(T.Whitespace, '\n')
+
+ # Quote header on secondary lines
yield sql.Token(T.Name, varname)
yield sql.Token(T.Whitespace, ' ')
- yield sql.Token(T.Punctuation, '.')
- yield sql.Token(T.Operator, '=')
+ yield sql.Token(T.Operator, '.=')
yield sql.Token(T.Whitespace, ' ')
yield sql.Token(T.Text, '"')
+
+ # Indentation
+ after_lb = token.value.split('\n', 1)[1]
if after_lb:
- yield sql.Token(T.Text, after_lb)
+ yield sql.Token(T.Whitespace, after_lb)
continue
+
+ # Token has escape chars
elif '"' in token.value:
token.value = token.value.replace('"', '\\"')
+
+ # Put the token
yield sql.Token(T.Text, token.value)
+
+ # Close quote
yield sql.Token(T.Text, '"')
yield sql.Token(T.Punctuation, ';')
- def process(self, stack, stmt):
- self.count += 1
- if self.count > 1:
- varname = '%s%d' % (self.varname, self.count)
- else:
- varname = self.varname
- stmt.tokens = tuple(self._process(stmt.tokens, varname))
- return stmt
-
-class Limit(Filter):
+class Limit:
"""Get the LIMIT of a query.
If not defined, return -1 (SQL specification for no LIMIT query)
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index 31fa34d..05e078d 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -38,7 +38,7 @@ class Token(object):
def to_unicode(self):
"""Returns a unicode representation of this object.
-
+
@deprecated: please use __unicode__()
"""
return unicode(self)
@@ -49,10 +49,8 @@ class Token(object):
def _get_repr_value(self):
raw = unicode(self)
if len(raw) > 7:
- short = raw[:6] + u'...'
- else:
- short = raw
- return re.sub('\s+', ' ', short)
+ raw = raw[:6] + u'...'
+ return re.sub('\s+', ' ', raw)
def flatten(self):
"""Resolve subgroups."""
@@ -73,31 +71,33 @@ class Token(object):
type_matched = self.ttype is ttype
if not type_matched or values is None:
return type_matched
+
if regex:
if isinstance(values, basestring):
values = set([values])
+
if self.ttype is T.Keyword:
- values = set([re.compile(v, re.IGNORECASE) for v in values])
+ values = set(re.compile(v, re.IGNORECASE) for v in values)
else:
- values = set([re.compile(v) for v in values])
+ values = set(re.compile(v) for v in values)
+
for pattern in values:
if pattern.search(self.value):
return True
return False
- else:
- if isinstance(values, basestring):
- if self.is_keyword:
- return values.upper() == self.normalized
- else:
- return values == self.value
+
+ if isinstance(values, basestring):
if self.is_keyword:
- for v in values:
- if v.upper() == self.normalized:
- return True
- return False
- else:
- print len(values)
- return self.value in values
+ return values.upper() == self.normalized
+ return values == self.value
+
+ if self.is_keyword:
+ for v in values:
+ if v.upper() == self.normalized:
+ return True
+ return False
+
+ return self.value in values
def is_group(self):
"""Returns ``True`` if this object has children."""
@@ -271,7 +271,7 @@ class TokenList(Token):
if not isinstance(idx, int):
idx = self.token_index(idx)
- while idx != 0:
+ while idx:
idx -= 1
if self.tokens[idx].is_whitespace() and skip_ws:
continue
@@ -379,12 +379,12 @@ class TokenList(Token):
dot = self.token_next_match(0, T.Punctuation, '.')
if dot is None:
return self.token_next_by_type(0, T.Name).value
- else:
- next_ = self.token_next_by_type(self.token_index(dot),
- (T.Name, T.Wildcard))
- if next_ is None: # invalid identifier, e.g. "a."
- return None
- return next_.value
+
+ next_ = self.token_next_by_type(self.token_index(dot),
+ (T.Name, T.Wildcard))
+ if next_ is None: # invalid identifier, e.g. "a."
+ return None
+ return next_.value
class Statement(TokenList):