diff options
| author | Victor Uriarte <victor.m.uriarte@intel.com> | 2016-06-02 22:54:25 -0700 |
|---|---|---|
| committer | Victor Uriarte <victor.m.uriarte@intel.com> | 2016-06-04 15:06:04 -0700 |
| commit | 2b8ede11388e81e0f6dc871a45c5327eaf456e44 (patch) | |
| tree | a441a56c583c923e059977846a7239e14a632fbc /sqlparse/engine | |
| parent | f0a6af57b7a5c116528db73643b26f934477d350 (diff) | |
| download | sqlparse-2b8ede11388e81e0f6dc871a45c5327eaf456e44.tar.gz | |
Refactor statement filter
Diffstat (limited to 'sqlparse/engine')
| -rw-r--r-- | sqlparse/engine/filter.py | 65 |
1 files changed, 27 insertions, 38 deletions
diff --git a/sqlparse/engine/filter.py b/sqlparse/engine/filter.py index c7b3bf8..ea2033a 100644 --- a/sqlparse/engine/filter.py +++ b/sqlparse/engine/filter.py @@ -5,8 +5,7 @@ # This module is part of python-sqlparse and is released under # the BSD License: http://www.opensource.org/licenses/bsd-license.php -from sqlparse.sql import Statement, Token -from sqlparse import tokens as T +from sqlparse import sql, tokens as T class StatementFilter(object): @@ -22,11 +21,14 @@ class StatementFilter(object): self._is_create = False self._begin_depth = 0 + self.consume_ws = False + self.tokens = [] + self.level = 0 + def _change_splitlevel(self, ttype, value): """Get the new split level (increase, decrease or remain equal)""" # PostgreSQL - if ttype == T.Name.Builtin \ - and value.startswith('$') and value.endswith('$'): + if ttype == T.Name.Builtin and value[0] == '$' and value[-1] == '$': # 2nd dbldollar found. $quote$ completed # decrease level @@ -54,6 +56,12 @@ class StatementFilter(object): # returning unified = value.upper() + # three keywords begin with CREATE, but only one of them is DDL + # DDL Create though can contain more words such as "or replace" + if ttype is T.Keyword.DDL and unified.startswith('CREATE'): + self._is_create = True + return 0 + # can have nested declare inside of being... if unified == 'DECLARE' and self._is_create and self._begin_depth == 0: self._in_declare = True @@ -61,14 +69,11 @@ class StatementFilter(object): if unified == 'BEGIN': self._begin_depth += 1 - if self._in_declare or self._is_create: + if self._is_create: # FIXME(andi): This makes no sense. return 1 return 0 - if unified in ('END IF', 'END FOR', 'END WHILE'): - return -1 - # Should this respect a preceeding BEGIN? # In CASE ... WHEN ... END this results in a split level -1. # Would having multiple CASE WHEN END and a Assigment Operator @@ -77,25 +82,19 @@ class StatementFilter(object): self._begin_depth = max(0, self._begin_depth - 1) return -1 - # three keywords begin with CREATE, but only one of them is DDL - # DDL Create though can contain more words such as "or replace" - if ttype is T.Keyword.DDL and unified.startswith('CREATE'): - self._is_create = True - return 0 - - if unified in ('IF', 'FOR', 'WHILE') \ - and self._is_create and self._begin_depth > 0: + if (unified in ('IF', 'FOR', 'WHILE') and + self._is_create and self._begin_depth > 0): return 1 + if unified in ('END IF', 'END FOR', 'END WHILE'): + return -1 + # Default return 0 def process(self, stream): """Process the stream""" - consume_ws = False - splitlevel = 0 - stmt = None - stmt_tokens = [] + EOS_TTYPE = T.Whitespace, T.Comment.Single # Run over all stream tokens for ttype, value in stream: @@ -103,32 +102,22 @@ class StatementFilter(object): # It will count newline token as a non whitespace. In this context # whitespace ignores newlines. # why don't multi line comments also count? - if consume_ws and ttype not in (T.Whitespace, T.Comment.Single): - stmt.tokens = stmt_tokens - yield stmt + if self.consume_ws and ttype not in EOS_TTYPE: + yield sql.Statement(self.tokens) # Reset filter and prepare to process next statement self._reset() - consume_ws = False - splitlevel = 0 - stmt = None - - # Create a new statement if we are not currently in one of them - if stmt is None: - stmt = Statement() - stmt_tokens = [] # Change current split level (increase, decrease or remain equal) - splitlevel += self._change_splitlevel(ttype, value) + self.level += self._change_splitlevel(ttype, value) # Append the token to the current statement - stmt_tokens.append(Token(ttype, value)) + self.tokens.append(sql.Token(ttype, value)) # Check if we get the end of a statement - if splitlevel <= 0 and ttype is T.Punctuation and value == ';': - consume_ws = True + if self.level <= 0 and ttype is T.Punctuation and value == ';': + self.consume_ws = True # Yield pending statement (if any) - if stmt is not None: - stmt.tokens = stmt_tokens - yield stmt + if self.tokens: + yield sql.Statement(self.tokens) |
