diff options
| author | Andi Albrecht <albrecht.andi@googlemail.com> | 2012-05-15 22:30:18 -0700 |
|---|---|---|
| committer | Andi Albrecht <albrecht.andi@googlemail.com> | 2012-05-15 22:30:18 -0700 |
| commit | 7442f145233db0d23e1a7d74cf20ce316b890f97 (patch) | |
| tree | 99effac8466f2ab884ef6fbf557f4fdf9ad87fc4 /sqlparse | |
| parent | 210dce4cdf97441fc9a4e9bfe6c96b8f34612e5b (diff) | |
| parent | c45e0c62373407544c089ada7240c99da79ff29e (diff) | |
| download | sqlparse-7442f145233db0d23e1a7d74cf20ce316b890f97.tar.gz | |
Merge pull request #67 from piranna/master
Some optimizations
Diffstat (limited to 'sqlparse')
| -rw-r--r-- | sqlparse/engine/filter.py | 13 | ||||
| -rw-r--r-- | sqlparse/filters.py | 146 | ||||
| -rw-r--r-- | sqlparse/sql.py | 54 |
3 files changed, 105 insertions, 108 deletions
diff --git a/sqlparse/engine/filter.py b/sqlparse/engine/filter.py index 9ea9703..9af2f99 100644 --- a/sqlparse/engine/filter.py +++ b/sqlparse/engine/filter.py @@ -4,21 +4,10 @@ from sqlparse.sql import Statement, Token from sqlparse import tokens as T -class TokenFilter(object): - - def __init__(self, **options): - self.options = options - - def process(self, stack, stream): - """Process token stream.""" - raise NotImplementedError - - -class StatementFilter(TokenFilter): +class StatementFilter: "Filter that split stream at individual statements" def __init__(self): - TokenFilter.__init__(self) self._in_declare = False self._in_dbldollar = False self._is_create = False diff --git a/sqlparse/filters.py b/sqlparse/filters.py index 99ef80c..291a613 100644 --- a/sqlparse/filters.py +++ b/sqlparse/filters.py @@ -11,22 +11,10 @@ from sqlparse.tokens import (Comment, Comparison, Keyword, Name, Punctuation, String, Whitespace) -class Filter(object): - - def process(self, *args): - raise NotImplementedError - - -class TokenFilter(Filter): - - def process(self, stack, stream): - raise NotImplementedError - - # -------------------------- # token process -class _CaseFilter(TokenFilter): +class _CaseFilter: ttype = None @@ -57,7 +45,7 @@ class IdentifierCaseFilter(_CaseFilter): yield ttype, value -class GetComments(Filter): +class GetComments: """Get the comments from a stack""" def process(self, stack, stream): for token_type, value in stream: @@ -65,7 +53,7 @@ class GetComments(Filter): yield token_type, value -class StripComments(Filter): +class StripComments: """Strip the comments from a stack""" def process(self, stack, stream): for token_type, value in stream: @@ -101,7 +89,7 @@ def StripWhitespace(stream): last_type = token_type -class IncludeStatement(Filter): +class IncludeStatement: """Filter that enable a INCLUDE statement""" def __init__(self, dirpath=".", maxRecursive=10): @@ -163,7 +151,7 @@ class IncludeStatement(Filter): # ---------------------- # statement process -class StripCommentsFilter(Filter): +class StripCommentsFilter: def _get_next_comment(self, tlist): # TODO(andi) Comment types should be unified, see related issue38 @@ -194,7 +182,7 @@ class StripCommentsFilter(Filter): self._process(stmt) -class StripWhitespaceFilter(Filter): +class StripWhitespaceFilter: def _stripws(self, tlist): func_name = '_stripws_%s' % tlist.__class__.__name__.lower() @@ -226,7 +214,7 @@ class StripWhitespaceFilter(Filter): stmt.tokens.pop(-1) -class ReindentFilter(Filter): +class ReindentFilter: def __init__(self, width=2, char=' ', line_width=None): self.width = width @@ -391,7 +379,7 @@ class ReindentFilter(Filter): # FIXME: Doesn't work ;) -class RightMarginFilter(Filter): +class RightMarginFilter: keep_together = ( # sql.TypeCast, sql.Identifier, sql.Alias, @@ -429,7 +417,7 @@ class RightMarginFilter(Filter): group.tokens = self._process(stack, group, group.tokens) -class ColumnsSelect(Filter): +class ColumnsSelect: """Get the columns names of a SELECT query""" def process(self, stack, stream): mode = 0 @@ -483,7 +471,7 @@ class ColumnsSelect(Filter): # --------------------------- # postprocess -class SerializerUnicode(Filter): +class SerializerUnicode: def process(self, stack, stmt): raw = unicode(stmt) @@ -503,14 +491,32 @@ def Tokens2Unicode(stream): return result -class OutputPythonFilter(Filter): +class OutputFilter: + varname_prefix = '' def __init__(self, varname='sql'): - self.varname = varname - self.cnt = 0 + self.varname = self.varname_prefix + varname + self.count = 0 + + def _process(self, stream, varname, has_nl): + raise NotImplementedError - def _process(self, stream, varname, count, has_nl): - if count > 1: + def process(self, stack, stmt): + self.count += 1 + if self.count > 1: + varname = '%s%d' % (self.varname, self.count) + else: + varname = self.varname + + has_nl = len(unicode(stmt).strip().splitlines()) > 1 + stmt.tokens = self._process(stmt.tokens, varname, has_nl) + return stmt + + +class OutputPythonFilter(OutputFilter): + def _process(self, stream, varname, has_nl): + # SQL query asignation to varname + if self.count > 1: yield sql.Token(T.Whitespace, '\n') yield sql.Token(T.Name, varname) yield sql.Token(T.Whitespace, ' ') @@ -519,85 +525,87 @@ class OutputPythonFilter(Filter): if has_nl: yield sql.Token(T.Operator, '(') yield sql.Token(T.Text, "'") - cnt = 0 + + # Print the tokens on the quote for token in stream: - cnt += 1 + # Token is a new line separator if token.is_whitespace() and '\n' in token.value: - if cnt == 1: - continue - after_lb = token.value.split('\n', 1)[1] + # Close quote and add a new line yield sql.Token(T.Text, " '") yield sql.Token(T.Whitespace, '\n') - for i in range(len(varname) + 4): - yield sql.Token(T.Whitespace, ' ') + + # Quote header on secondary lines + yield sql.Token(T.Whitespace, ' ' * (len(varname) + 4)) yield sql.Token(T.Text, "'") - if after_lb: # it's the indendation + + # Indentation + after_lb = token.value.split('\n', 1)[1] + if after_lb: yield sql.Token(T.Whitespace, after_lb) continue - elif token.value and "'" in token.value: + + # Token has escape chars + elif "'" in token.value: token.value = token.value.replace("'", "\\'") - yield sql.Token(T.Text, token.value or '') + + # Put the token + yield sql.Token(T.Text, token.value) + + # Close quote yield sql.Token(T.Text, "'") if has_nl: yield sql.Token(T.Operator, ')') - def process(self, stack, stmt): - self.cnt += 1 - if self.cnt > 1: - varname = '%s%d' % (self.varname, self.cnt) - else: - varname = self.varname - has_nl = len(unicode(stmt).strip().splitlines()) > 1 - stmt.tokens = self._process(stmt.tokens, varname, self.cnt, has_nl) - return stmt +class OutputPHPFilter(OutputFilter): + varname_prefix = '$' -class OutputPHPFilter(Filter): - - def __init__(self, varname='sql'): - self.varname = '$%s' % varname - self.count = 0 - - def _process(self, stream, varname): + def _process(self, stream, varname, has_nl): + # SQL query asignation to varname (quote header) if self.count > 1: yield sql.Token(T.Whitespace, '\n') yield sql.Token(T.Name, varname) yield sql.Token(T.Whitespace, ' ') + if has_nl: + yield sql.Token(T.Whitespace, ' ') yield sql.Token(T.Operator, '=') yield sql.Token(T.Whitespace, ' ') yield sql.Token(T.Text, '"') + + # Print the tokens on the quote for token in stream: + # Token is a new line separator if token.is_whitespace() and '\n' in token.value: - after_lb = token.value.split('\n', 1)[1] - yield sql.Token(T.Text, ' "') - yield sql.Token(T.Operator, ';') + # Close quote and add a new line + yield sql.Token(T.Text, ' ";') yield sql.Token(T.Whitespace, '\n') + + # Quote header on secondary lines yield sql.Token(T.Name, varname) yield sql.Token(T.Whitespace, ' ') - yield sql.Token(T.Punctuation, '.') - yield sql.Token(T.Operator, '=') + yield sql.Token(T.Operator, '.=') yield sql.Token(T.Whitespace, ' ') yield sql.Token(T.Text, '"') + + # Indentation + after_lb = token.value.split('\n', 1)[1] if after_lb: - yield sql.Token(T.Text, after_lb) + yield sql.Token(T.Whitespace, after_lb) continue + + # Token has escape chars elif '"' in token.value: token.value = token.value.replace('"', '\\"') + + # Put the token yield sql.Token(T.Text, token.value) + + # Close quote yield sql.Token(T.Text, '"') yield sql.Token(T.Punctuation, ';') - def process(self, stack, stmt): - self.count += 1 - if self.count > 1: - varname = '%s%d' % (self.varname, self.count) - else: - varname = self.varname - stmt.tokens = tuple(self._process(stmt.tokens, varname)) - return stmt - -class Limit(Filter): +class Limit: """Get the LIMIT of a query. If not defined, return -1 (SQL specification for no LIMIT query) diff --git a/sqlparse/sql.py b/sqlparse/sql.py index 31fa34d..05e078d 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -38,7 +38,7 @@ class Token(object): def to_unicode(self): """Returns a unicode representation of this object. - + @deprecated: please use __unicode__() """ return unicode(self) @@ -49,10 +49,8 @@ class Token(object): def _get_repr_value(self): raw = unicode(self) if len(raw) > 7: - short = raw[:6] + u'...' - else: - short = raw - return re.sub('\s+', ' ', short) + raw = raw[:6] + u'...' + return re.sub('\s+', ' ', raw) def flatten(self): """Resolve subgroups.""" @@ -73,31 +71,33 @@ class Token(object): type_matched = self.ttype is ttype if not type_matched or values is None: return type_matched + if regex: if isinstance(values, basestring): values = set([values]) + if self.ttype is T.Keyword: - values = set([re.compile(v, re.IGNORECASE) for v in values]) + values = set(re.compile(v, re.IGNORECASE) for v in values) else: - values = set([re.compile(v) for v in values]) + values = set(re.compile(v) for v in values) + for pattern in values: if pattern.search(self.value): return True return False - else: - if isinstance(values, basestring): - if self.is_keyword: - return values.upper() == self.normalized - else: - return values == self.value + + if isinstance(values, basestring): if self.is_keyword: - for v in values: - if v.upper() == self.normalized: - return True - return False - else: - print len(values) - return self.value in values + return values.upper() == self.normalized + return values == self.value + + if self.is_keyword: + for v in values: + if v.upper() == self.normalized: + return True + return False + + return self.value in values def is_group(self): """Returns ``True`` if this object has children.""" @@ -271,7 +271,7 @@ class TokenList(Token): if not isinstance(idx, int): idx = self.token_index(idx) - while idx != 0: + while idx: idx -= 1 if self.tokens[idx].is_whitespace() and skip_ws: continue @@ -379,12 +379,12 @@ class TokenList(Token): dot = self.token_next_match(0, T.Punctuation, '.') if dot is None: return self.token_next_by_type(0, T.Name).value - else: - next_ = self.token_next_by_type(self.token_index(dot), - (T.Name, T.Wildcard)) - if next_ is None: # invalid identifier, e.g. "a." - return None - return next_.value + + next_ = self.token_next_by_type(self.token_index(dot), + (T.Name, T.Wildcard)) + if next_ is None: # invalid identifier, e.g. "a." + return None + return next_.value class Statement(TokenList): |
