diff options
| author | Vik <vmuriart@users.noreply.github.com> | 2016-06-06 06:29:25 -0700 |
|---|---|---|
| committer | Vik <vmuriart@users.noreply.github.com> | 2016-06-06 06:29:25 -0700 |
| commit | b9d81ac4fe49114f57dc33c0d635f99ff56e62f2 (patch) | |
| tree | 88642eeb84d318511191a822fd781b44e1d63df1 /sqlparse | |
| parent | c6a5e7ac2a5ecc993f4e5292ab16e6df6b84f26c (diff) | |
| parent | 5747015634a39191511de8db576f2cd0aa5eafc9 (diff) | |
| download | sqlparse-b9d81ac4fe49114f57dc33c0d635f99ff56e62f2.tar.gz | |
Merge pull request #251 from andialbrecht/filters_sql
Update Filters sql
Diffstat (limited to 'sqlparse')
| -rw-r--r-- | sqlparse/__init__.py | 7 | ||||
| -rw-r--r-- | sqlparse/compat.py | 10 | ||||
| -rw-r--r-- | sqlparse/engine/__init__.py | 48 | ||||
| -rw-r--r-- | sqlparse/engine/filter.py | 100 | ||||
| -rw-r--r-- | sqlparse/engine/grouping.py | 5 | ||||
| -rw-r--r-- | sqlparse/filters.py | 355 | ||||
| -rw-r--r-- | sqlparse/functions.py | 44 | ||||
| -rw-r--r-- | sqlparse/pipeline.py | 31 | ||||
| -rw-r--r-- | sqlparse/sql.py | 292 | ||||
| -rw-r--r-- | sqlparse/utils.py | 71 |
10 files changed, 248 insertions, 715 deletions
diff --git a/sqlparse/__init__.py b/sqlparse/__init__.py index 2943997..cb83a71 100644 --- a/sqlparse/__init__.py +++ b/sqlparse/__init__.py @@ -66,11 +66,4 @@ def split(sql, encoding=None): :returns: A list of strings. """ stack = engine.FilterStack() - stack.split_statements = True return [u(stmt).strip() for stmt in stack.run(sql, encoding)] - - -def split2(stream): - from sqlparse.engine.filter import StatementFilter - splitter = StatementFilter() - return list(splitter.process(None, stream)) diff --git a/sqlparse/compat.py b/sqlparse/compat.py index 0226a00..0defd86 100644 --- a/sqlparse/compat.py +++ b/sqlparse/compat.py @@ -25,6 +25,10 @@ if PY3: return str(s) + def unicode_compatible(cls): + return cls + + text_type = str string_types = (str,) from io import StringIO @@ -39,6 +43,12 @@ elif PY2: return unicode(s, encoding) + def unicode_compatible(cls): + cls.__unicode__ = cls.__str__ + cls.__str__ = lambda x: x.__unicode__().encode('utf-8') + return cls + + text_type = unicode string_types = (basestring,) from StringIO import StringIO diff --git a/sqlparse/engine/__init__.py b/sqlparse/engine/__init__.py index 1c2bf09..7f00c57 100644 --- a/sqlparse/engine/__init__.py +++ b/sqlparse/engine/__init__.py @@ -13,12 +13,10 @@ from sqlparse.engine.filter import StatementFilter class FilterStack(object): - def __init__(self): self.preprocess = [] self.stmtprocess = [] self.postprocess = [] - self.split_statements = False self._grouping = False def enable_grouping(self): @@ -27,42 +25,20 @@ class FilterStack(object): def run(self, sql, encoding=None): stream = lexer.tokenize(sql, encoding) # Process token stream - if self.preprocess: - for filter_ in self.preprocess: - stream = filter_.process(self, stream) - - if (self.stmtprocess or self.postprocess or - self.split_statements or self._grouping): - splitter = StatementFilter() - stream = splitter.process(self, stream) - - if self._grouping: - - def _group(stream): - for stmt in stream: - grouping.group(stmt) - yield stmt - stream = _group(stream) + for filter_ in self.preprocess: + stream = filter_.process(stream) - if self.stmtprocess: + stream = StatementFilter().process(stream) - def _run1(stream): - ret = [] - for stmt in stream: - for filter_ in self.stmtprocess: - filter_.process(self, stmt) - ret.append(stmt) - return ret - stream = _run1(stream) + # Output: Stream processed Statements + for stmt in stream: + if self._grouping: + stmt = grouping.group(stmt) - if self.postprocess: + for filter_ in self.stmtprocess: + filter_.process(stmt) - def _run2(stream): - for stmt in stream: - stmt.tokens = list(stmt.flatten()) - for filter_ in self.postprocess: - stmt = filter_.process(self, stmt) - yield stmt - stream = _run2(stream) + for filter_ in self.postprocess: + stmt = filter_.process(stmt) - return stream + yield stmt diff --git a/sqlparse/engine/filter.py b/sqlparse/engine/filter.py index 71020e7..ea2033a 100644 --- a/sqlparse/engine/filter.py +++ b/sqlparse/engine/filter.py @@ -5,113 +5,119 @@ # This module is part of python-sqlparse and is released under # the BSD License: http://www.opensource.org/licenses/bsd-license.php -from sqlparse.sql import Statement, Token -from sqlparse import tokens as T +from sqlparse import sql, tokens as T class StatementFilter(object): - "Filter that split stream at individual statements" + """Filter that split stream at individual statements""" def __init__(self): - self._in_declare = False - self._in_dbldollar = False - self._is_create = False - self._begin_depth = 0 + self._reset() def _reset(self): - "Set the filter attributes to its default values" + """Set the filter attributes to its default values""" self._in_declare = False self._in_dbldollar = False self._is_create = False self._begin_depth = 0 + self.consume_ws = False + self.tokens = [] + self.level = 0 + def _change_splitlevel(self, ttype, value): - "Get the new split level (increase, decrease or remain equal)" + """Get the new split level (increase, decrease or remain equal)""" # PostgreSQL - if ttype == T.Name.Builtin \ - and value.startswith('$') and value.endswith('$'): + if ttype == T.Name.Builtin and value[0] == '$' and value[-1] == '$': + + # 2nd dbldollar found. $quote$ completed + # decrease level if self._in_dbldollar: self._in_dbldollar = False return -1 else: self._in_dbldollar = True return 1 + + # if inside $$ everything inside is defining function character. + # Nothing inside can create a new statement elif self._in_dbldollar: return 0 # ANSI + # if normal token return + # wouldn't parenthesis increase/decrease a level? + # no, inside a paranthesis can't start new statement if ttype not in T.Keyword: return 0 + # Everything after here is ttype = T.Keyword + # Also to note, once entered an If statement you are done and basically + # returning unified = value.upper() + # three keywords begin with CREATE, but only one of them is DDL + # DDL Create though can contain more words such as "or replace" + if ttype is T.Keyword.DDL and unified.startswith('CREATE'): + self._is_create = True + return 0 + + # can have nested declare inside of being... if unified == 'DECLARE' and self._is_create and self._begin_depth == 0: self._in_declare = True return 1 if unified == 'BEGIN': self._begin_depth += 1 - if self._in_declare or self._is_create: + if self._is_create: # FIXME(andi): This makes no sense. return 1 return 0 - if unified in ('END IF', 'END FOR', 'END WHILE'): - return -1 - + # Should this respect a preceeding BEGIN? + # In CASE ... WHEN ... END this results in a split level -1. + # Would having multiple CASE WHEN END and a Assigment Operator + # cause the statement to cut off prematurely? if unified == 'END': - # Should this respect a preceeding BEGIN? - # In CASE ... WHEN ... END this results in a split level -1. self._begin_depth = max(0, self._begin_depth - 1) return -1 - if ttype is T.Keyword.DDL and unified.startswith('CREATE'): - self._is_create = True - return 0 - - if unified in ('IF', 'FOR', 'WHILE') \ - and self._is_create and self._begin_depth > 0: + if (unified in ('IF', 'FOR', 'WHILE') and + self._is_create and self._begin_depth > 0): return 1 + if unified in ('END IF', 'END FOR', 'END WHILE'): + return -1 + # Default return 0 - def process(self, stack, stream): - "Process the stream" - consume_ws = False - splitlevel = 0 - stmt = None - stmt_tokens = [] + def process(self, stream): + """Process the stream""" + EOS_TTYPE = T.Whitespace, T.Comment.Single # Run over all stream tokens for ttype, value in stream: # Yield token if we finished a statement and there's no whitespaces - if consume_ws and ttype not in (T.Whitespace, T.Comment.Single): - stmt.tokens = stmt_tokens - yield stmt + # It will count newline token as a non whitespace. In this context + # whitespace ignores newlines. + # why don't multi line comments also count? + if self.consume_ws and ttype not in EOS_TTYPE: + yield sql.Statement(self.tokens) # Reset filter and prepare to process next statement self._reset() - consume_ws = False - splitlevel = 0 - stmt = None - - # Create a new statement if we are not currently in one of them - if stmt is None: - stmt = Statement() - stmt_tokens = [] # Change current split level (increase, decrease or remain equal) - splitlevel += self._change_splitlevel(ttype, value) + self.level += self._change_splitlevel(ttype, value) # Append the token to the current statement - stmt_tokens.append(Token(ttype, value)) + self.tokens.append(sql.Token(ttype, value)) # Check if we get the end of a statement - if splitlevel <= 0 and ttype is T.Punctuation and value == ';': - consume_ws = True + if self.level <= 0 and ttype is T.Punctuation and value == ';': + self.consume_ws = True # Yield pending statement (if any) - if stmt is not None: - stmt.tokens = stmt_tokens - yield stmt + if self.tokens: + yield sql.Statement(self.tokens) diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index 0ac1cb3..c680995 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -266,7 +266,7 @@ def align_comments(tlist): token = tlist.token_next_by(i=sql.Comment, idx=token) -def group(tlist): +def group(stmt): for func in [ group_comments, group_brackets, @@ -291,4 +291,5 @@ def group(tlist): group_foreach, group_begin, ]: - func(tlist) + func(stmt) + return stmt diff --git a/sqlparse/filters.py b/sqlparse/filters.py index 1cb2f16..95ac74c 100644 --- a/sqlparse/filters.py +++ b/sqlparse/filters.py @@ -7,15 +7,8 @@ import re -from os.path import abspath, join - from sqlparse import sql, tokens as T -from sqlparse.compat import u, text_type -from sqlparse.engine import FilterStack -from sqlparse.pipeline import Pipeline -from sqlparse.tokens import (Comment, Comparison, Keyword, Name, Punctuation, - String, Whitespace) -from sqlparse.utils import memoize_generator +from sqlparse.compat import text_type from sqlparse.utils import split_unquoted_newlines @@ -23,16 +16,13 @@ from sqlparse.utils import split_unquoted_newlines # token process class _CaseFilter(object): - ttype = None def __init__(self, case=None): - if case is None: - case = 'upper' - assert case in ['lower', 'upper', 'capitalize'] + case = case or 'upper' self.convert = getattr(text_type, case) - def process(self, stack, stream): + def process(self, stream): for ttype, value in stream: if ttype in self.ttype: value = self.convert(value) @@ -44,165 +34,42 @@ class KeywordCaseFilter(_CaseFilter): class IdentifierCaseFilter(_CaseFilter): - ttype = (T.Name, T.String.Symbol) + ttype = T.Name, T.String.Symbol - def process(self, stack, stream): + def process(self, stream): for ttype, value in stream: - if ttype in self.ttype and not value.strip()[0] == '"': + if ttype in self.ttype and value.strip()[0] != '"': value = self.convert(value) yield ttype, value class TruncateStringFilter(object): - def __init__(self, width, char): - self.width = max(width, 1) - self.char = u(char) + self.width = width + self.char = char - def process(self, stack, stream): + def process(self, stream): for ttype, value in stream: - if ttype is T.Literal.String.Single: - if value[:2] == '\'\'': - inner = value[2:-2] - quote = u'\'\'' - else: - inner = value[1:-1] - quote = u'\'' - if len(inner) > self.width: - value = u''.join((quote, inner[:self.width], self.char, - quote)) - yield ttype, value - - -class GetComments(object): - """Get the comments from a stack""" - def process(self, stack, stream): - for token_type, value in stream: - if token_type in Comment: - yield token_type, value - - -class StripComments(object): - """Strip the comments from a stack""" - def process(self, stack, stream): - for token_type, value in stream: - if token_type not in Comment: - yield token_type, value - - -def StripWhitespace(stream): - "Strip the useless whitespaces from a stream leaving only the minimal ones" - last_type = None - has_space = False - ignore_group = frozenset((Comparison, Punctuation)) - - for token_type, value in stream: - # We got a previous token (not empty first ones) - if last_type: - if token_type in Whitespace: - has_space = True - continue - - # Ignore first empty spaces and dot-commas - elif token_type in (Whitespace, Whitespace.Newline, ignore_group): - continue - - # Yield a whitespace if it can't be ignored - if has_space: - if not ignore_group.intersection((last_type, token_type)): - yield Whitespace, ' ' - has_space = False - - # Yield the token and set its type for checking with the next one - yield token_type, value - last_type = token_type - - -class IncludeStatement(object): - """Filter that enable a INCLUDE statement""" - - def __init__(self, dirpath=".", maxrecursive=10, raiseexceptions=False): - if maxrecursive <= 0: - raise ValueError('Max recursion limit reached') - - self.dirpath = abspath(dirpath) - self.maxRecursive = maxrecursive - self.raiseexceptions = raiseexceptions - - self.detected = False - - @memoize_generator - def process(self, stack, stream): - # Run over all tokens in the stream - for token_type, value in stream: - # INCLUDE statement found, set detected mode - if token_type in Name and value.upper() == 'INCLUDE': - self.detected = True + if ttype != T.Literal.String.Single: + yield ttype, value continue - # INCLUDE statement was found, parse it - elif self.detected: - # Omit whitespaces - if token_type in Whitespace: - continue - - # Found file path to include - if token_type in String.Symbol: - # Get path of file to include - path = join(self.dirpath, value[1:-1]) - - try: - f = open(path) - raw_sql = f.read() - f.close() - - # There was a problem loading the include file - except IOError as err: - # Raise the exception to the interpreter - if self.raiseexceptions: - raise - - # Put the exception as a comment on the SQL code - yield Comment, u'-- IOError: %s\n' % err - - else: - # Create new FilterStack to parse readed file - # and add all its tokens to the main stack recursively - try: - filtr = IncludeStatement(self.dirpath, - self.maxRecursive - 1, - self.raiseexceptions) - - # Max recursion limit reached - except ValueError as err: - # Raise the exception to the interpreter - if self.raiseexceptions: - raise - - # Put the exception as a comment on the SQL code - yield Comment, u'-- ValueError: %s\n' % err - - stack = FilterStack() - stack.preprocess.append(filtr) - - for tv in stack.run(raw_sql): - yield tv - - # Set normal mode - self.detected = False - - # Don't include any token while in detected mode - continue + if value[:2] == "''": + inner = value[2:-2] + quote = "''" + else: + inner = value[1:-1] + quote = "'" - # Normal token - yield token_type, value + if len(inner) > self.width: + value = ''.join((quote, inner[:self.width], self.char, quote)) + yield ttype, value # ---------------------- # statement process class StripCommentsFilter(object): - def _get_next_comment(self, tlist): # TODO(andi) Comment types should be unified, see related issue38 token = tlist.token_next_by(i=sql.Comment, t=T.Comment) @@ -212,8 +79,8 @@ class StripCommentsFilter(object): token = self._get_next_comment(tlist) while token: tidx = tlist.token_index(token) - prev = tlist.token_prev(tidx, False) - next_ = tlist.token_next(tidx, False) + prev = tlist.token_prev(tidx, skip_ws=False) + next_ = tlist.token_next(tidx, skip_ws=False) # Replace by whitespace if prev and next exist and if they're not # whitespaces. This doesn't apply if prev or next is a paranthesis. if (prev is not None and next_ is not None @@ -225,13 +92,12 @@ class StripCommentsFilter(object): tlist.tokens.pop(tidx) token = self._get_next_comment(tlist) - def process(self, stack, stmt): - [self.process(stack, sgroup) for sgroup in stmt.get_sublists()] + def process(self, stmt): + [self.process(sgroup) for sgroup in stmt.get_sublists()] self._process(stmt) class StripWhitespaceFilter(object): - def _stripws(self, tlist): func_name = '_stripws_%s' % tlist.__class__.__name__.lower() func = getattr(self, func_name, self._stripws_default) @@ -253,14 +119,10 @@ class StripWhitespaceFilter(object): # Removes newlines before commas, see issue140 last_nl = None for token in tlist.tokens[:]: - if token.ttype is T.Punctuation \ - and token.value == ',' \ - and last_nl is not None: + if last_nl and token.ttype is T.Punctuation and token.value == ',': tlist.tokens.remove(last_nl) - if token.is_whitespace(): - last_nl = token - else: - last_nl = None + + last_nl = token if token.is_whitespace() else None return self._stripws_default(tlist) def _stripws_parenthesis(self, tlist): @@ -270,20 +132,14 @@ class StripWhitespaceFilter(object): tlist.tokens.pop(-2) self._stripws_default(tlist) - def process(self, stack, stmt, depth=0): - [self.process(stack, sgroup, depth + 1) - for sgroup in stmt.get_sublists()] + def process(self, stmt, depth=0): + [self.process(sgroup, depth + 1) for sgroup in stmt.get_sublists()] self._stripws(stmt) - if ( - depth == 0 - and stmt.tokens - and stmt.tokens[-1].is_whitespace() - ): + if depth == 0 and stmt.tokens and stmt.tokens[-1].is_whitespace(): stmt.tokens.pop(-1) class ReindentFilter(object): - def __init__(self, width=2, char=' ', line_width=None, wrap_after=0): self.width = width self.char = char @@ -327,8 +183,7 @@ class ReindentFilter(object): 'SET', 'BETWEEN', 'EXCEPT', 'HAVING') def _next_token(i): - t = tlist.token_next_match(i, T.Keyword, split_words, - regex=True) + t = tlist.token_next_by(m=(T.Keyword, split_words, True), idx=i) if t and t.value.upper() == 'BETWEEN': t = _next_token(tlist.token_index(t) + 1) if t and t.value.upper() == 'AND': @@ -339,13 +194,13 @@ class ReindentFilter(object): token = _next_token(idx) added = set() while token: - prev = tlist.token_prev(tlist.token_index(token), False) + prev = tlist.token_prev(token, skip_ws=False) offset = 1 if prev and prev.is_whitespace() and prev not in added: tlist.tokens.pop(tlist.token_index(prev)) offset += 1 - uprev = u(prev) - if (prev and (uprev.endswith('\n') or uprev.endswith('\r'))): + uprev = text_type(prev) + if prev and (uprev.endswith('\n') or uprev.endswith('\r')): nl = tlist.token_next(token) else: nl = self.nl() @@ -355,18 +210,17 @@ class ReindentFilter(object): token = _next_token(tlist.token_index(nl) + offset) def _split_statements(self, tlist): - idx = 0 - token = tlist.token_next_by_type(idx, (T.Keyword.DDL, T.Keyword.DML)) + token = tlist.token_next_by(t=(T.Keyword.DDL, T.Keyword.DML)) while token: - prev = tlist.token_prev(tlist.token_index(token), False) + prev = tlist.token_prev(token, skip_ws=False) if prev and prev.is_whitespace(): tlist.tokens.pop(tlist.token_index(prev)) # only break if it's not the first token if prev: nl = self.nl() tlist.insert_before(token, nl) - token = tlist.token_next_by_type(tlist.token_index(token) + 1, - (T.Keyword.DDL, T.Keyword.DML)) + token = tlist.token_next_by(t=(T.Keyword.DDL, T.Keyword.DML), + idx=token) def _process(self, tlist): func_name = '_process_%s' % tlist.__class__.__name__.lower() @@ -374,7 +228,7 @@ class ReindentFilter(object): func(tlist) def _process_where(self, tlist): - token = tlist.token_next_match(0, T.Keyword, 'WHERE') + token = tlist.token_next_by(m=(T.Keyword, 'WHERE')) try: tlist.insert_before(token, self.nl()) except ValueError: # issue121, errors in statement @@ -384,7 +238,7 @@ class ReindentFilter(object): self.indent -= 1 def _process_having(self, tlist): - token = tlist.token_next_match(0, T.Keyword, 'HAVING') + token = tlist.token_next_by(m=(T.Keyword, 'HAVING')) try: tlist.insert_before(token, self.nl()) except ValueError: # issue121, errors in statement @@ -401,7 +255,7 @@ class ReindentFilter(object): tlist.tokens.insert(0, self.nl()) indented = True num_offset = self._get_offset( - tlist.token_next_match(0, T.Punctuation, '(')) + tlist.token_next_by(m=(T.Punctuation, '('))) self.offset += num_offset self._process_default(tlist, stmts=not indented) if indented: @@ -454,7 +308,7 @@ class ReindentFilter(object): self.offset -= 5 if num_offset is not None: self.offset -= num_offset - end = tlist.token_next_match(0, T.Keyword, 'END') + end = tlist.token_next_by(m=(T.Keyword, 'END')) tlist.insert_before(end, self.nl()) self.offset -= outer_offset @@ -465,13 +319,13 @@ class ReindentFilter(object): self._split_kwds(tlist) [self._process(sgroup) for sgroup in tlist.get_sublists()] - def process(self, stack, stmt): + def process(self, stmt): if isinstance(stmt, sql.Statement): self._curr_stmt = stmt self._process(stmt) if isinstance(stmt, sql.Statement): if self._last_stmt is not None: - if u(self._last_stmt).endswith('\n'): + if text_type(self._last_stmt).endswith('\n'): nl = '\n' else: nl = '\n\n' @@ -481,9 +335,8 @@ class ReindentFilter(object): self._last_stmt = stmt -# FIXME: Doesn't work ;) +# FIXME: Doesn't work class RightMarginFilter(object): - keep_together = ( # sql.TypeCast, sql.Identifier, sql.Alias, ) @@ -492,20 +345,19 @@ class RightMarginFilter(object): self.width = width self.line = '' - def _process(self, stack, group, stream): + def _process(self, group, stream): for token in stream: if token.is_whitespace() and '\n' in token.value: if token.value.endswith('\n'): self.line = '' else: self.line = token.value.splitlines()[-1] - elif (token.is_group() - and token.__class__ not in self.keep_together): - token.tokens = self._process(stack, token, token.tokens) + elif token.is_group() and type(token) not in self.keep_together: + token.tokens = self._process(token, token.tokens) else: - val = u(token) + val = text_type(token) if len(self.line) + len(val) > self.width: - match = re.search('^ +', self.line) + match = re.search(r'^ +', self.line) if match is not None: indent = match.group() else: @@ -515,83 +367,23 @@ class RightMarginFilter(object): self.line += val yield token - def process(self, stack, group): - return - group.tokens = self._process(stack, group, group.tokens) - - -class ColumnsSelect(object): - """Get the columns names of a SELECT query""" - def process(self, stack, stream): - mode = 0 - oldValue = "" - parenthesis = 0 - - for token_type, value in stream: - # Ignore comments - if token_type in Comment: - continue - - # We have not detected a SELECT statement - if mode == 0: - if token_type in Keyword and value == 'SELECT': - mode = 1 - - # We have detected a SELECT statement - elif mode == 1: - if value == 'FROM': - if oldValue: - yield oldValue - - mode = 3 # Columns have been checked - - elif value == 'AS': - oldValue = "" - mode = 2 - - elif (token_type == Punctuation - and value == ',' and not parenthesis): - if oldValue: - yield oldValue - oldValue = "" - - elif token_type not in Whitespace: - if value == '(': - parenthesis += 1 - elif value == ')': - parenthesis -= 1 - - oldValue += value - - # We are processing an AS keyword - elif mode == 2: - # We check also for Keywords because a bug in SQLParse - if token_type == Name or token_type == Keyword: - yield value - mode = 1 + def process(self, group): + # return + # group.tokens = self._process(group, group.tokens) + raise NotImplementedError # --------------------------- # postprocess class SerializerUnicode(object): - - def process(self, stack, stmt): - raw = u(stmt) + def process(self, stmt): + raw = text_type(stmt) lines = split_unquoted_newlines(raw) res = '\n'.join(line.rstrip() for line in lines) return res -def Tokens2Unicode(stream): - result = "" - - for _, value in stream: - result += u(value) - - return result - - class OutputFilter(object): varname_prefix = '' @@ -602,14 +394,14 @@ class OutputFilter(object): def _process(self, stream, varname, has_nl): raise NotImplementedError - def process(self, stack, stmt): + def process(self, stmt): self.count += 1 if self.count > 1: varname = '%s%d' % (self.varname, self.count) else: varname = self.varname - has_nl = len(u(stmt).strip().splitlines()) > 1 + has_nl = len(text_type(stmt).strip().splitlines()) > 1 stmt.tokens = self._process(stmt.tokens, varname, has_nl) return stmt @@ -704,34 +496,3 @@ class OutputPHPFilter(OutputFilter): # Close quote yield sql.Token(T.Text, '"') yield sql.Token(T.Punctuation, ';') - - -class Limit(object): - """Get the LIMIT of a query. - - If not defined, return -1 (SQL specification for no LIMIT query) - """ - def process(self, stack, stream): - index = 7 - stream = list(stream) - stream.reverse() - - # Run over all tokens in the stream from the end - for token_type, value in stream: - index -= 1 - -# if index and token_type in Keyword: - if index and token_type in Keyword and value == 'LIMIT': - return stream[4 - index][1] - - return -1 - - -def compact(stream): - """Function that return a compacted version of the stream""" - pipe = Pipeline() - - pipe.append(StripComments()) - pipe.append(StripWhitespace) - - return pipe(stream) diff --git a/sqlparse/functions.py b/sqlparse/functions.py deleted file mode 100644 index e54457e..0000000 --- a/sqlparse/functions.py +++ /dev/null @@ -1,44 +0,0 @@ -''' -Created on 17/05/2012 - -@author: piranna - -Several utility functions to extract info from the SQL sentences -''' - -from sqlparse.filters import ColumnsSelect, Limit -from sqlparse.pipeline import Pipeline -from sqlparse.tokens import Keyword, Whitespace - - -def getlimit(stream): - """Function that return the LIMIT of a input SQL """ - pipe = Pipeline() - - pipe.append(Limit()) - - result = pipe(stream) - try: - return int(result) - except ValueError: - return result - - -def getcolumns(stream): - """Function that return the colums of a SELECT query""" - pipe = Pipeline() - - pipe.append(ColumnsSelect()) - - return pipe(stream) - - -class IsType(object): - """Functor that return is the statement is of a specific type""" - def __init__(self, type): - self.type = type - - def __call__(self, stream): - for token_type, value in stream: - if token_type not in Whitespace: - return token_type in Keyword and value == self.type diff --git a/sqlparse/pipeline.py b/sqlparse/pipeline.py deleted file mode 100644 index 34dad19..0000000 --- a/sqlparse/pipeline.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (C) 2011 Jesus Leganes "piranna", piranna@gmail.com -# -# This module is part of python-sqlparse and is released under -# the BSD License: http://www.opensource.org/licenses/bsd-license.php. - -from types import GeneratorType - - -class Pipeline(list): - """Pipeline to process filters sequentially""" - - def __call__(self, stream): - """Run the pipeline - - Return a static (non generator) version of the result - """ - - # Run the stream over all the filters on the pipeline - for filter in self: - # Functions and callable objects (objects with '__call__' method) - if callable(filter): - stream = filter(stream) - - # Normal filters (objects with 'process' method) - else: - stream = filter.process(None, stream) - - # If last filter return a generator, staticalize it inside a list - if isinstance(stream, GeneratorType): - return list(stream) - return stream diff --git a/sqlparse/sql.py b/sqlparse/sql.py index 6abc432..57bf1e7 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -6,15 +6,16 @@ # the BSD License: http://www.opensource.org/licenses/bsd-license.php """This module contains classes representing syntactical elements of SQL.""" +from __future__ import print_function import re -import sys from sqlparse import tokens as T -from sqlparse.compat import string_types, u +from sqlparse.compat import string_types, text_type, unicode_compatible from sqlparse.utils import imt, remove_quotes +@unicode_compatible class Token(object): """Base class for all other classes in this module. @@ -26,40 +27,29 @@ class Token(object): __slots__ = ('value', 'ttype', 'parent', 'normalized', 'is_keyword') def __init__(self, ttype, value): + value = text_type(value) self.value = value - if ttype in T.Keyword: - self.normalized = value.upper() - else: - self.normalized = value self.ttype = ttype - self.is_keyword = ttype in T.Keyword self.parent = None + self.is_keyword = ttype in T.Keyword + self.normalized = value.upper() if self.is_keyword else value def __str__(self): - if sys.version_info[0] == 3: - return self.value - else: - return u(self).encode('utf-8') + return self.value def __repr__(self): - short = self._get_repr_value() - if sys.version_info[0] < 3: - short = short.encode('utf-8') - return '<%s \'%s\' at 0x%07x>' % (self._get_repr_name(), - short, id(self)) - - def __unicode__(self): - """Returns a unicode representation of this object.""" - return self.value or '' + cls = self._get_repr_name() + value = self._get_repr_value() + return "<{cls} '{value}' at 0x{id:2X}>".format(id=id(self), **locals()) def _get_repr_name(self): return str(self.ttype).split('.')[-1] def _get_repr_value(self): - raw = u(self) + raw = text_type(self) if len(raw) > 7: - raw = raw[:6] + u'...' - return re.sub('\s+', ' ', raw) + raw = raw[:6] + '...' + return re.sub(r'\s+', ' ', raw) def flatten(self): """Resolve subgroups.""" @@ -81,32 +71,23 @@ class Token(object): if not type_matched or values is None: return type_matched - if regex: - if isinstance(values, string_types): - values = {values} + if isinstance(values, string_types): + values = (values,) - if self.ttype is T.Keyword: - values = set(re.compile(v, re.IGNORECASE) for v in values) - else: - values = set(re.compile(v) for v in values) + if regex: + # TODO: Add test for regex with is_keyboard = false + flag = re.IGNORECASE if self.is_keyword else 0 + values = (re.compile(v, flag) for v in values) for pattern in values: - if pattern.search(self.value): + if pattern.search(self.normalized): return True return False - if isinstance(values, string_types): - if self.is_keyword: - return values.upper() == self.normalized - return values == self.value - if self.is_keyword: - for v in values: - if v.upper() == self.normalized: - return True - return False + values = (v.upper() for v in values) - return self.value in values + return self.normalized in values def is_group(self): """Returns ``True`` if this object has children.""" @@ -114,7 +95,7 @@ class Token(object): def is_whitespace(self): """Return ``True`` if this token is a whitespace token.""" - return self.ttype and self.ttype in T.Whitespace + return self.ttype in T.Whitespace def within(self, group_cls): """Returns ``True`` if this token is within *group_cls*. @@ -143,6 +124,7 @@ class Token(object): return False +@unicode_compatible class TokenList(Token): """A group of tokens. @@ -150,45 +132,35 @@ class TokenList(Token): list of child-tokens. """ - __slots__ = ('value', 'ttype', 'tokens') + __slots__ = 'tokens' def __init__(self, tokens=None): - if tokens is None: - tokens = [] - self.tokens = tokens - super(TokenList, self).__init__(None, self.__str__()) - - def __unicode__(self): - return self._to_string() + self.tokens = tokens or [] + super(TokenList, self).__init__(None, text_type(self)) def __str__(self): - str_ = self._to_string() - if sys.version_info[0] <= 2: - str_ = str_.encode('utf-8') - return str_ - - def _to_string(self): - if sys.version_info[0] == 3: - return ''.join(x.value for x in self.flatten()) - else: - return ''.join(u(x) for x in self.flatten()) + return ''.join(token.value for token in self.flatten()) + + def __iter__(self): + return iter(self.tokens) + + def __getitem__(self, item): + return self.tokens[item] def _get_repr_name(self): - return self.__class__.__name__ + return type(self).__name__ - def _pprint_tree(self, max_depth=None, depth=0): + def _pprint_tree(self, max_depth=None, depth=0, f=None): """Pretty-print the object tree.""" - indent = ' ' * (depth * 2) + ind = ' ' * (depth * 2) for idx, token in enumerate(self.tokens): - if token.is_group(): - pre = ' +-' - else: - pre = ' | ' - print('%s%s%d %s \'%s\'' % (indent, pre, idx, - token._get_repr_name(), - token._get_repr_value())) - if (token.is_group() and (max_depth is None or depth < max_depth)): - token._pprint_tree(max_depth, depth + 1) + pre = ' +-' if token.is_group() else ' | ' + cls = token._get_repr_name() + value = token._get_repr_value() + print("{ind}{pre}{idx} {cls} '{value}'".format(**locals()), file=f) + + if token.is_group() and (max_depth is None or depth < max_depth): + token._pprint_tree(max_depth, depth + 1, f) def get_token_at_offset(self, offset): """Returns the token that is on position offset.""" @@ -205,26 +177,19 @@ class TokenList(Token): This method is recursively called for all child tokens. """ for token in self.tokens: - if isinstance(token, TokenList): + if token.is_group(): for item in token.flatten(): yield item else: yield token - # def __iter__(self): - # return self - # - # def next(self): - # for token in self.tokens: - # yield token - def is_group(self): return True def get_sublists(self): - for x in self.tokens: - if isinstance(x, TokenList): - yield x + for token in self.tokens: + if token.is_group(): + yield token @property def _groupable_tokens(self): @@ -242,7 +207,7 @@ class TokenList(Token): funcs = (funcs,) if reverse: - iterable = iter(reversed(self.tokens[end:start - 1])) + iterable = reversed(self.tokens[end:start - 1]) else: iterable = self.tokens[start:end] @@ -251,7 +216,7 @@ class TokenList(Token): if func(token): return token - def token_first(self, ignore_whitespace=True, ignore_comments=False): + def token_first(self, skip_ws=True, skip_cm=False): """Returns the first child token. If *ignore_whitespace* is ``True`` (the default), whitespace @@ -260,35 +225,15 @@ class TokenList(Token): if *ignore_comments* is ``True`` (default: ``False``), comments are ignored too. """ - funcs = lambda tk: not ((ignore_whitespace and tk.is_whitespace()) or - (ignore_comments and imt(tk, i=Comment))) + # this on is inconsistent, using Comment instead of T.Comment... + funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or + (skip_cm and imt(tk, i=Comment))) return self._token_matching(funcs) def token_next_by(self, i=None, m=None, t=None, idx=0, end=None): funcs = lambda tk: imt(tk, i, m, t) return self._token_matching(funcs, idx, end) - def token_next_by_instance(self, idx, clss, end=None): - """Returns the next token matching a class. - - *idx* is where to start searching in the list of child tokens. - *clss* is a list of classes the token should be an instance of. - - If no matching token can be found ``None`` is returned. - """ - funcs = lambda tk: imt(tk, i=clss) - return self._token_matching(funcs, idx, end) - - def token_next_by_type(self, idx, ttypes): - """Returns next matching token by it's token type.""" - funcs = lambda tk: imt(tk, t=ttypes) - return self._token_matching(funcs, idx) - - def token_next_match(self, idx, ttype, value, regex=False): - """Returns next token where it's ``match`` method returns ``True``.""" - funcs = lambda tk: imt(tk, m=(ttype, value, regex)) - return self._token_matching(funcs, idx) - def token_not_matching(self, idx, funcs): funcs = (funcs,) if not isinstance(funcs, (list, tuple)) else funcs funcs = [lambda tk: not func(tk) for func in funcs] @@ -297,7 +242,7 @@ class TokenList(Token): def token_matching(self, idx, funcs): return self._token_matching(funcs, idx) - def token_prev(self, idx, skip_ws=True): + def token_prev(self, idx, skip_ws=True, skip_cm=False): """Returns the previous token relative to *idx*. If *skip_ws* is ``True`` (the default) whitespace tokens are ignored. @@ -305,10 +250,11 @@ class TokenList(Token): """ if isinstance(idx, int): idx += 1 # alot of code usage current pre-compensates for this - funcs = lambda tk: not (tk.is_whitespace() and skip_ws) + funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or + (skip_cm and imt(tk, t=T.Comment))) return self._token_matching(funcs, idx, reverse=True) - def token_next(self, idx, skip_ws=True): + def token_next(self, idx, skip_ws=True, skip_cm=False): """Returns the next token relative to *idx*. If *skip_ws* is ``True`` (the default) whitespace tokens are ignored. @@ -316,7 +262,8 @@ class TokenList(Token): """ if isinstance(idx, int): idx += 1 # alot of code usage current pre-compensates for this - funcs = lambda tk: not (tk.is_whitespace() and skip_ws) + funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or + (skip_cm and imt(tk, t=T.Comment))) return self._token_matching(funcs, idx) def token_index(self, token, start=0): @@ -334,9 +281,9 @@ class TokenList(Token): end_idx = include_end + self.token_index(end) return self.tokens[start_idx:end_idx] - def group_tokens(self, grp_cls, tokens, ignore_ws=False, extend=False): + def group_tokens(self, grp_cls, tokens, skip_ws=False, extend=False): """Replace tokens by an instance of *grp_cls*.""" - if ignore_ws: + if skip_ws: while tokens and tokens[-1].is_whitespace(): tokens = tokens[:-1] @@ -352,7 +299,7 @@ class TokenList(Token): left.parent = self tokens = tokens[1:] left.tokens.extend(tokens) - left.value = left.__str__() + left.value = str(left) else: left = grp_cls(tokens) @@ -393,8 +340,6 @@ class TokenList(Token): if len(self.tokens) > 2 and self.token_next_by(t=T.Whitespace): return self._get_first_name(reverse=True) - return None - def get_name(self): """Returns the name of this identifier. @@ -402,32 +347,22 @@ class TokenList(Token): be considered as the name under which the object corresponding to this identifier is known within the current statement. """ - alias = self.get_alias() - if alias is not None: - return alias - return self.get_real_name() + return self.get_alias() or self.get_real_name() def get_real_name(self): """Returns the real name (object name) of this identifier.""" # a.b - dot = self.token_next_match(0, T.Punctuation, '.') - if dot is not None: - return self._get_first_name(self.token_index(dot)) - - return self._get_first_name() + dot = self.token_next_by(m=(T.Punctuation, '.')) + return self._get_first_name(dot) def get_parent_name(self): """Return name of the parent object if any. A parent object is identified by the first occuring dot. """ - dot = self.token_next_match(0, T.Punctuation, '.') - if dot is None: - return None - prev_ = self.token_prev(self.token_index(dot)) - if prev_ is None: # something must be verry wrong here.. - return None - return remove_quotes(prev_.value) + dot = self.token_next_by(m=(T.Punctuation, '.')) + prev_ = self.token_prev(dot) + return remove_quotes(prev_.value) if prev_ is not None else None def _get_first_name(self, idx=None, reverse=False, keywords=False): """Returns the name of the first token with a name""" @@ -442,19 +377,16 @@ class TokenList(Token): if keywords: types.append(T.Keyword) - for tok in tokens: - if tok.ttype in types: - return remove_quotes(tok.value) - elif isinstance(tok, Identifier) or isinstance(tok, Function): - return tok.get_name() - return None + for token in tokens: + if token.ttype in types: + return remove_quotes(token.value) + elif isinstance(token, (Identifier, Function)): + return token.get_name() class Statement(TokenList): """Represents a SQL statement.""" - __slots__ = ('value', 'ttype', 'tokens') - def get_type(self): """Returns the type of a statement. @@ -465,7 +397,7 @@ class Statement(TokenList): Whitespaces and comments at the beginning of the statement are ignored. """ - first_token = self.token_first(ignore_comments=True) + first_token = self.token_first(skip_cm=True) if first_token is None: # An "empty" statement that either has not tokens at all # or only whitespace tokens. @@ -478,16 +410,14 @@ class Statement(TokenList): # The WITH keyword should be followed by either an Identifier or # an IdentifierList containing the CTE definitions; the actual # DML keyword (e.g. SELECT, INSERT) will follow next. - idents = self.token_next( - self.token_index(first_token), skip_ws=True) - if isinstance(idents, (Identifier, IdentifierList)): - dml_keyword = self.token_next( - self.token_index(idents), skip_ws=True) + token = self.token_next(first_token, skip_ws=True) + if isinstance(token, (Identifier, IdentifierList)): + dml_keyword = self.token_next(token, skip_ws=True) + if dml_keyword.ttype == T.Keyword.DML: return dml_keyword.normalized - # Hmm, probably invalid syntax, so return unknown. - return 'UNKNOWN' + # Hmm, probably invalid syntax, so return unknown. return 'UNKNOWN' @@ -499,33 +429,33 @@ class Identifier(TokenList): def is_wildcard(self): """Return ``True`` if this identifier contains a wildcard.""" - token = self.token_next_by_type(0, T.Wildcard) + token = self.token_next_by(t=T.Wildcard) return token is not None def get_typecast(self): """Returns the typecast or ``None`` of this object as a string.""" - marker = self.token_next_match(0, T.Punctuation, '::') + marker = self.token_next_by(m=(T.Punctuation, '::')) if marker is None: return None - next_ = self.token_next(self.token_index(marker), False) + next_ = self.token_next(marker, False) if next_ is None: return None - return u(next_) + return next_.value def get_ordering(self): """Returns the ordering or ``None`` as uppercase string.""" - ordering = self.token_next_by_type(0, T.Keyword.Order) + ordering = self.token_next_by(t=T.Keyword.Order) if ordering is None: return None - return ordering.value.upper() + return ordering.normalized def get_array_indices(self): """Returns an iterator of index token lists""" - for tok in self.tokens: - if isinstance(tok, SquareBrackets): + for token in self.tokens: + if isinstance(token, SquareBrackets): # Use [1:-1] index to discard the square brackets - yield tok.tokens[1:-1] + yield token.tokens[1:-1] class IdentifierList(TokenList): @@ -536,15 +466,15 @@ class IdentifierList(TokenList): Whitespaces and punctuations are not included in this generator. """ - for x in self.tokens: - if not x.is_whitespace() and not x.match(T.Punctuation, ','): - yield x + for token in self.tokens: + if not (token.is_whitespace() or token.match(T.Punctuation, ',')): + yield token class Parenthesis(TokenList): """Tokens between parenthesis.""" - M_OPEN = (T.Punctuation, '(') - M_CLOSE = (T.Punctuation, ')') + M_OPEN = T.Punctuation, '(' + M_CLOSE = T.Punctuation, ')' @property def _groupable_tokens(self): @@ -553,8 +483,8 @@ class Parenthesis(TokenList): class SquareBrackets(TokenList): """Tokens between square brackets""" - M_OPEN = (T.Punctuation, '[') - M_CLOSE = (T.Punctuation, ']') + M_OPEN = T.Punctuation, '[' + M_CLOSE = T.Punctuation, ']' @property def _groupable_tokens(self): @@ -567,14 +497,14 @@ class Assignment(TokenList): class If(TokenList): """An 'if' clause with possible 'else if' or 'else' parts.""" - M_OPEN = (T.Keyword, 'IF') - M_CLOSE = (T.Keyword, 'END IF') + M_OPEN = T.Keyword, 'IF' + M_CLOSE = T.Keyword, 'END IF' class For(TokenList): """A 'FOR' loop.""" - M_OPEN = (T.Keyword, ('FOR', 'FOREACH')) - M_CLOSE = (T.Keyword, 'END LOOP') + M_OPEN = T.Keyword, ('FOR', 'FOREACH') + M_CLOSE = T.Keyword, 'END LOOP' class Comparison(TokenList): @@ -598,15 +528,15 @@ class Comment(TokenList): class Where(TokenList): """A WHERE clause.""" - M_OPEN = (T.Keyword, 'WHERE') - M_CLOSE = (T.Keyword, - ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT', 'HAVING')) + M_OPEN = T.Keyword, 'WHERE' + M_CLOSE = T.Keyword, ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT', + 'HAVING') class Case(TokenList): """A CASE statement with one or more WHEN and possibly an ELSE part.""" - M_OPEN = (T.Keyword, 'CASE') - M_CLOSE = (T.Keyword, 'END') + M_OPEN = T.Keyword, 'CASE' + M_CLOSE = T.Keyword, 'END' def get_cases(self): """Returns a list of 2-tuples (condition, value). @@ -659,15 +589,15 @@ class Function(TokenList): def get_parameters(self): """Return a list of parameters.""" parenthesis = self.tokens[-1] - for t in parenthesis.tokens: - if imt(t, i=IdentifierList): - return t.get_identifiers() - elif imt(t, i=(Function, Identifier), t=T.Literal): - return [t, ] + for token in parenthesis.tokens: + if imt(token, i=IdentifierList): + return token.get_identifiers() + elif imt(token, i=(Function, Identifier), t=T.Literal): + return [token, ] return [] class Begin(TokenList): """A BEGIN/END block.""" - M_OPEN = (T.Keyword, 'BEGIN') - M_CLOSE = (T.Keyword, 'END') + M_OPEN = T.Keyword, 'BEGIN' + M_CLOSE = T.Keyword, 'END' diff --git a/sqlparse/utils.py b/sqlparse/utils.py index 2513c26..4da44c6 100644 --- a/sqlparse/utils.py +++ b/sqlparse/utils.py @@ -7,78 +7,9 @@ import itertools import re -from collections import OrderedDict, deque +from collections import deque from contextlib import contextmanager - -class Cache(OrderedDict): - """Cache with LRU algorithm using an OrderedDict as basis - """ - - def __init__(self, maxsize=100): - OrderedDict.__init__(self) - - self._maxsize = maxsize - - def __getitem__(self, key, *args, **kwargs): - # Get the key and remove it from the cache, or raise KeyError - value = OrderedDict.__getitem__(self, key) - del self[key] - - # Insert the (key, value) pair on the front of the cache - OrderedDict.__setitem__(self, key, value) - - # Return the value from the cache - return value - - def __setitem__(self, key, value, *args, **kwargs): - # Key was inserted before, remove it so we put it at front later - if key in self: - del self[key] - - # Too much items on the cache, remove the least recent used - elif len(self) >= self._maxsize: - self.popitem(False) - - # Insert the (key, value) pair on the front of the cache - OrderedDict.__setitem__(self, key, value, *args, **kwargs) - - -def memoize_generator(func): - """Memoize decorator for generators - - Store `func` results in a cache according to their arguments as 'memoize' - does but instead this works on decorators instead of regular functions. - Obviusly, this is only useful if the generator will always return the same - values for each specific parameters... - """ - cache = Cache() - - def wrapped_func(*args, **kwargs): - params = (args, tuple(sorted(kwargs.items()))) - - # Look if cached - try: - cached = cache[params] - - # Not cached, exec and store it - except KeyError: - cached = [] - - for item in func(*args, **kwargs): - cached.append(item) - yield item - - cache[params] = cached - - # Cached, yield its items - else: - for item in cached: - yield item - - return wrapped_func - - # This regular expression replaces the home-cooked parser that was here before. # It is much faster, but requires an extra post-processing step to get the # desired results (that are compatible with what you would expect from the |
