diff options
| author | Jesús Leganés Combarro "Piranna" <piranna@gmail.com> | 2012-04-28 15:03:33 +0200 |
|---|---|---|
| committer | Jesús Leganés Combarro "Piranna" <piranna@gmail.com> | 2012-04-28 15:03:33 +0200 |
| commit | a8ded6465322b3b6f9b6eb7299172268b4e4bd40 (patch) | |
| tree | abe6ffe44cfb82a02dc73655bdd9db240c0b5200 /sqlparse | |
| parent | 28f9c777545bb18fd3141568e2a25de685c3c30f (diff) | |
| download | sqlparse-a8ded6465322b3b6f9b6eb7299172268b4e4bd40.tar.gz | |
Put common code from Python and PHP output filters in OutputFilter
Diffstat (limited to 'sqlparse')
| -rw-r--r-- | sqlparse/filters.py | 70 |
1 files changed, 35 insertions, 35 deletions
diff --git a/sqlparse/filters.py b/sqlparse/filters.py index 99ef80c..8f4b3dd 100644 --- a/sqlparse/filters.py +++ b/sqlparse/filters.py @@ -503,14 +503,28 @@ def Tokens2Unicode(stream): return result -class OutputPythonFilter(Filter): - +class OutputFilter(Filter): def __init__(self, varname='sql'): - self.varname = varname - self.cnt = 0 + self.varname = self.varname_prefix + varname + self.count = 0 - def _process(self, stream, varname, count, has_nl): - if count > 1: + def process(self, stack, stmt): + self.count += 1 + if self.count > 1: + varname = '%s%d' % (self.varname, self.count) + else: + varname = self.varname + + has_nl = len(unicode(stmt).strip().splitlines()) > 1 + stmt.tokens = self._process(stmt.tokens, varname, has_nl) + return stmt + + +class OutputPythonFilter(OutputFilter): + varname_prefix = '' + + def _process(self, stream, varname, has_nl): + if self.count > 1: yield sql.Token(T.Whitespace, '\n') yield sql.Token(T.Name, varname) yield sql.Token(T.Whitespace, ' ') @@ -519,6 +533,7 @@ class OutputPythonFilter(Filter): if has_nl: yield sql.Token(T.Operator, '(') yield sql.Token(T.Text, "'") + cnt = 0 for token in stream: cnt += 1 @@ -528,74 +543,59 @@ class OutputPythonFilter(Filter): after_lb = token.value.split('\n', 1)[1] yield sql.Token(T.Text, " '") yield sql.Token(T.Whitespace, '\n') - for i in range(len(varname) + 4): - yield sql.Token(T.Whitespace, ' ') + + yield sql.Token(T.Whitespace, ' ' * (len(varname) + 4)) yield sql.Token(T.Text, "'") if after_lb: # it's the indendation yield sql.Token(T.Whitespace, after_lb) continue + elif token.value and "'" in token.value: token.value = token.value.replace("'", "\\'") yield sql.Token(T.Text, token.value or '') + yield sql.Token(T.Text, "'") if has_nl: yield sql.Token(T.Operator, ')') - def process(self, stack, stmt): - self.cnt += 1 - if self.cnt > 1: - varname = '%s%d' % (self.varname, self.cnt) - else: - varname = self.varname - has_nl = len(unicode(stmt).strip().splitlines()) > 1 - stmt.tokens = self._process(stmt.tokens, varname, self.cnt, has_nl) - return stmt - - -class OutputPHPFilter(Filter): - def __init__(self, varname='sql'): - self.varname = '$%s' % varname - self.count = 0 +class OutputPHPFilter(OutputFilter): + varname_prefix = '$' - def _process(self, stream, varname): + def _process(self, stream, varname, has_nl): if self.count > 1: yield sql.Token(T.Whitespace, '\n') yield sql.Token(T.Name, varname) yield sql.Token(T.Whitespace, ' ') + if has_nl: + yield sql.Token(T.Whitespace, ' ') yield sql.Token(T.Operator, '=') yield sql.Token(T.Whitespace, ' ') yield sql.Token(T.Text, '"') + for token in stream: if token.is_whitespace() and '\n' in token.value: after_lb = token.value.split('\n', 1)[1] yield sql.Token(T.Text, ' "') yield sql.Token(T.Operator, ';') yield sql.Token(T.Whitespace, '\n') + yield sql.Token(T.Name, varname) yield sql.Token(T.Whitespace, ' ') - yield sql.Token(T.Punctuation, '.') - yield sql.Token(T.Operator, '=') + yield sql.Token(T.Operator, '.=') yield sql.Token(T.Whitespace, ' ') yield sql.Token(T.Text, '"') if after_lb: yield sql.Token(T.Text, after_lb) continue + elif '"' in token.value: token.value = token.value.replace('"', '\\"') yield sql.Token(T.Text, token.value) + yield sql.Token(T.Text, '"') yield sql.Token(T.Punctuation, ';') - def process(self, stack, stmt): - self.count += 1 - if self.count > 1: - varname = '%s%d' % (self.varname, self.count) - else: - varname = self.varname - stmt.tokens = tuple(self._process(stmt.tokens, varname)) - return stmt - class Limit(Filter): """Get the LIMIT of a query. |
