summaryrefslogtreecommitdiff
path: root/sqlparse/engine
diff options
context:
space:
mode:
authorVictor Uriarte <victor.m.uriarte@intel.com>2016-06-02 22:54:25 -0700
committerVictor Uriarte <victor.m.uriarte@intel.com>2016-06-04 15:06:04 -0700
commit2b8ede11388e81e0f6dc871a45c5327eaf456e44 (patch)
treea441a56c583c923e059977846a7239e14a632fbc /sqlparse/engine
parentf0a6af57b7a5c116528db73643b26f934477d350 (diff)
downloadsqlparse-2b8ede11388e81e0f6dc871a45c5327eaf456e44.tar.gz
Refactor statement filter
Diffstat (limited to 'sqlparse/engine')
-rw-r--r--sqlparse/engine/filter.py65
1 files changed, 27 insertions, 38 deletions
diff --git a/sqlparse/engine/filter.py b/sqlparse/engine/filter.py
index c7b3bf8..ea2033a 100644
--- a/sqlparse/engine/filter.py
+++ b/sqlparse/engine/filter.py
@@ -5,8 +5,7 @@
# This module is part of python-sqlparse and is released under
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
-from sqlparse.sql import Statement, Token
-from sqlparse import tokens as T
+from sqlparse import sql, tokens as T
class StatementFilter(object):
@@ -22,11 +21,14 @@ class StatementFilter(object):
self._is_create = False
self._begin_depth = 0
+ self.consume_ws = False
+ self.tokens = []
+ self.level = 0
+
def _change_splitlevel(self, ttype, value):
"""Get the new split level (increase, decrease or remain equal)"""
# PostgreSQL
- if ttype == T.Name.Builtin \
- and value.startswith('$') and value.endswith('$'):
+ if ttype == T.Name.Builtin and value[0] == '$' and value[-1] == '$':
# 2nd dbldollar found. $quote$ completed
# decrease level
@@ -54,6 +56,12 @@ class StatementFilter(object):
# returning
unified = value.upper()
+ # three keywords begin with CREATE, but only one of them is DDL
+ # DDL Create though can contain more words such as "or replace"
+ if ttype is T.Keyword.DDL and unified.startswith('CREATE'):
+ self._is_create = True
+ return 0
+
# can have nested declare inside of being...
if unified == 'DECLARE' and self._is_create and self._begin_depth == 0:
self._in_declare = True
@@ -61,14 +69,11 @@ class StatementFilter(object):
if unified == 'BEGIN':
self._begin_depth += 1
- if self._in_declare or self._is_create:
+ if self._is_create:
# FIXME(andi): This makes no sense.
return 1
return 0
- if unified in ('END IF', 'END FOR', 'END WHILE'):
- return -1
-
# Should this respect a preceeding BEGIN?
# In CASE ... WHEN ... END this results in a split level -1.
# Would having multiple CASE WHEN END and a Assigment Operator
@@ -77,25 +82,19 @@ class StatementFilter(object):
self._begin_depth = max(0, self._begin_depth - 1)
return -1
- # three keywords begin with CREATE, but only one of them is DDL
- # DDL Create though can contain more words such as "or replace"
- if ttype is T.Keyword.DDL and unified.startswith('CREATE'):
- self._is_create = True
- return 0
-
- if unified in ('IF', 'FOR', 'WHILE') \
- and self._is_create and self._begin_depth > 0:
+ if (unified in ('IF', 'FOR', 'WHILE') and
+ self._is_create and self._begin_depth > 0):
return 1
+ if unified in ('END IF', 'END FOR', 'END WHILE'):
+ return -1
+
# Default
return 0
def process(self, stream):
"""Process the stream"""
- consume_ws = False
- splitlevel = 0
- stmt = None
- stmt_tokens = []
+ EOS_TTYPE = T.Whitespace, T.Comment.Single
# Run over all stream tokens
for ttype, value in stream:
@@ -103,32 +102,22 @@ class StatementFilter(object):
# It will count newline token as a non whitespace. In this context
# whitespace ignores newlines.
# why don't multi line comments also count?
- if consume_ws and ttype not in (T.Whitespace, T.Comment.Single):
- stmt.tokens = stmt_tokens
- yield stmt
+ if self.consume_ws and ttype not in EOS_TTYPE:
+ yield sql.Statement(self.tokens)
# Reset filter and prepare to process next statement
self._reset()
- consume_ws = False
- splitlevel = 0
- stmt = None
-
- # Create a new statement if we are not currently in one of them
- if stmt is None:
- stmt = Statement()
- stmt_tokens = []
# Change current split level (increase, decrease or remain equal)
- splitlevel += self._change_splitlevel(ttype, value)
+ self.level += self._change_splitlevel(ttype, value)
# Append the token to the current statement
- stmt_tokens.append(Token(ttype, value))
+ self.tokens.append(sql.Token(ttype, value))
# Check if we get the end of a statement
- if splitlevel <= 0 and ttype is T.Punctuation and value == ';':
- consume_ws = True
+ if self.level <= 0 and ttype is T.Punctuation and value == ';':
+ self.consume_ws = True
# Yield pending statement (if any)
- if stmt is not None:
- stmt.tokens = stmt_tokens
- yield stmt
+ if self.tokens:
+ yield sql.Statement(self.tokens)