summaryrefslogtreecommitdiff
path: root/sqlparse
diff options
context:
space:
mode:
authorJesús Leganés Combarro "Piranna" <piranna@gmail.com>2012-04-28 15:03:33 +0200
committerJesús Leganés Combarro "Piranna" <piranna@gmail.com>2012-04-28 15:03:33 +0200
commita8ded6465322b3b6f9b6eb7299172268b4e4bd40 (patch)
treeabe6ffe44cfb82a02dc73655bdd9db240c0b5200 /sqlparse
parent28f9c777545bb18fd3141568e2a25de685c3c30f (diff)
downloadsqlparse-a8ded6465322b3b6f9b6eb7299172268b4e4bd40.tar.gz
Put common code from Python and PHP output filters in OutputFilter
Diffstat (limited to 'sqlparse')
-rw-r--r--sqlparse/filters.py70
1 files changed, 35 insertions, 35 deletions
diff --git a/sqlparse/filters.py b/sqlparse/filters.py
index 99ef80c..8f4b3dd 100644
--- a/sqlparse/filters.py
+++ b/sqlparse/filters.py
@@ -503,14 +503,28 @@ def Tokens2Unicode(stream):
return result
-class OutputPythonFilter(Filter):
-
+class OutputFilter(Filter):
def __init__(self, varname='sql'):
- self.varname = varname
- self.cnt = 0
+ self.varname = self.varname_prefix + varname
+ self.count = 0
- def _process(self, stream, varname, count, has_nl):
- if count > 1:
+ def process(self, stack, stmt):
+ self.count += 1
+ if self.count > 1:
+ varname = '%s%d' % (self.varname, self.count)
+ else:
+ varname = self.varname
+
+ has_nl = len(unicode(stmt).strip().splitlines()) > 1
+ stmt.tokens = self._process(stmt.tokens, varname, has_nl)
+ return stmt
+
+
+class OutputPythonFilter(OutputFilter):
+ varname_prefix = ''
+
+ def _process(self, stream, varname, has_nl):
+ if self.count > 1:
yield sql.Token(T.Whitespace, '\n')
yield sql.Token(T.Name, varname)
yield sql.Token(T.Whitespace, ' ')
@@ -519,6 +533,7 @@ class OutputPythonFilter(Filter):
if has_nl:
yield sql.Token(T.Operator, '(')
yield sql.Token(T.Text, "'")
+
cnt = 0
for token in stream:
cnt += 1
@@ -528,74 +543,59 @@ class OutputPythonFilter(Filter):
after_lb = token.value.split('\n', 1)[1]
yield sql.Token(T.Text, " '")
yield sql.Token(T.Whitespace, '\n')
- for i in range(len(varname) + 4):
- yield sql.Token(T.Whitespace, ' ')
+
+ yield sql.Token(T.Whitespace, ' ' * (len(varname) + 4))
yield sql.Token(T.Text, "'")
if after_lb: # it's the indendation
yield sql.Token(T.Whitespace, after_lb)
continue
+
elif token.value and "'" in token.value:
token.value = token.value.replace("'", "\\'")
yield sql.Token(T.Text, token.value or '')
+
yield sql.Token(T.Text, "'")
if has_nl:
yield sql.Token(T.Operator, ')')
- def process(self, stack, stmt):
- self.cnt += 1
- if self.cnt > 1:
- varname = '%s%d' % (self.varname, self.cnt)
- else:
- varname = self.varname
- has_nl = len(unicode(stmt).strip().splitlines()) > 1
- stmt.tokens = self._process(stmt.tokens, varname, self.cnt, has_nl)
- return stmt
-
-
-class OutputPHPFilter(Filter):
- def __init__(self, varname='sql'):
- self.varname = '$%s' % varname
- self.count = 0
+class OutputPHPFilter(OutputFilter):
+ varname_prefix = '$'
- def _process(self, stream, varname):
+ def _process(self, stream, varname, has_nl):
if self.count > 1:
yield sql.Token(T.Whitespace, '\n')
yield sql.Token(T.Name, varname)
yield sql.Token(T.Whitespace, ' ')
+ if has_nl:
+ yield sql.Token(T.Whitespace, ' ')
yield sql.Token(T.Operator, '=')
yield sql.Token(T.Whitespace, ' ')
yield sql.Token(T.Text, '"')
+
for token in stream:
if token.is_whitespace() and '\n' in token.value:
after_lb = token.value.split('\n', 1)[1]
yield sql.Token(T.Text, ' "')
yield sql.Token(T.Operator, ';')
yield sql.Token(T.Whitespace, '\n')
+
yield sql.Token(T.Name, varname)
yield sql.Token(T.Whitespace, ' ')
- yield sql.Token(T.Punctuation, '.')
- yield sql.Token(T.Operator, '=')
+ yield sql.Token(T.Operator, '.=')
yield sql.Token(T.Whitespace, ' ')
yield sql.Token(T.Text, '"')
if after_lb:
yield sql.Token(T.Text, after_lb)
continue
+
elif '"' in token.value:
token.value = token.value.replace('"', '\\"')
yield sql.Token(T.Text, token.value)
+
yield sql.Token(T.Text, '"')
yield sql.Token(T.Punctuation, ';')
- def process(self, stack, stmt):
- self.count += 1
- if self.count > 1:
- varname = '%s%d' % (self.varname, self.count)
- else:
- varname = self.varname
- stmt.tokens = tuple(self._process(stmt.tokens, varname))
- return stmt
-
class Limit(Filter):
"""Get the LIMIT of a query.