summaryrefslogtreecommitdiff
path: root/sqlparse/engine
diff options
context:
space:
mode:
authorAndi Albrecht <albrecht.andi@gmail.com>2009-04-03 21:26:42 +0200
committerAndi Albrecht <albrecht.andi@gmail.com>2009-04-03 21:26:42 +0200
commit361122eb22d5681c58dac731009e4814b3dd5fa5 (patch)
treeb096496bc9c6b8febe092d0aefd56de1a4f8f4a0 /sqlparse/engine
downloadsqlparse-361122eb22d5681c58dac731009e4814b3dd5fa5.tar.gz
Initial import.
Diffstat (limited to 'sqlparse/engine')
-rw-r--r--sqlparse/engine/__init__.py81
-rw-r--r--sqlparse/engine/_grouping.py499
-rw-r--r--sqlparse/engine/filter.py98
-rw-r--r--sqlparse/engine/grouping.py537
4 files changed, 1215 insertions, 0 deletions
diff --git a/sqlparse/engine/__init__.py b/sqlparse/engine/__init__.py
new file mode 100644
index 0000000..5cac528
--- /dev/null
+++ b/sqlparse/engine/__init__.py
@@ -0,0 +1,81 @@
+# Copyright (C) 2008 Andi Albrecht, albrecht.andi@gmail.com
+#
+# This module is part of python-sqlparse and is released under
+# the BSD License: http://www.opensource.org/licenses/bsd-license.php.
+
+"""filter"""
+
+import logging
+import re
+
+from sqlparse import lexer, SQLParseError
+from sqlparse.engine import grouping
+from sqlparse.engine.filter import StatementFilter
+
+# XXX remove this when cleanup is complete
+Filter = object
+
+
+class FilterStack(object):
+
+ def __init__(self):
+ self.preprocess = []
+ self.stmtprocess = []
+ self.postprocess = []
+ self.split_statements = False
+ self._grouping = False
+
+ def _flatten(self, stream):
+ for token in stream:
+ if token.is_group():
+ for t in self._flatten(token.tokens):
+ yield t
+ else:
+ yield token
+
+ def enable_grouping(self):
+ self._grouping = True
+
+ def full_analyze(self):
+ self.enable_grouping()
+
+ def run(self, sql):
+ stream = lexer.tokenize(sql)
+ # Process token stream
+ if self.preprocess:
+ for filter_ in self.preprocess:
+ stream = filter_.process(self, stream)
+
+ if (self.stmtprocess or self.postprocess or self.split_statements
+ or self._grouping):
+ splitter = StatementFilter()
+ stream = splitter.process(self, stream)
+
+ if self._grouping:
+ def _group(stream):
+ for stmt in stream:
+ grouping.group(stmt)
+ yield stmt
+ stream = _group(stream)
+
+ if self.stmtprocess:
+ def _run(stream):
+ ret = []
+ for stmt in stream:
+ for filter_ in self.stmtprocess:
+ filter_.process(self, stmt)
+ ret.append(stmt)
+ return ret
+ stream = _run(stream)
+
+ if self.postprocess:
+ def _run(stream):
+ for stmt in stream:
+ stmt.tokens = list(self._flatten(stmt.tokens))
+ for filter_ in self.postprocess:
+ stmt = filter_.process(self, stmt)
+ yield stmt
+ stream = _run(stream)
+
+ return stream
+
diff --git a/sqlparse/engine/_grouping.py b/sqlparse/engine/_grouping.py
new file mode 100644
index 0000000..512c590
--- /dev/null
+++ b/sqlparse/engine/_grouping.py
@@ -0,0 +1,499 @@
+# -*- coding: utf-8 -*-
+
+import re
+
+from sqlparse.engine.filter import TokenFilter
+from sqlparse import tokens as T
+
+class _Base(object):
+
+ __slots__ = ('to_unicode', 'to_str', '_get_repr_name')
+
+ def __unicode__(self):
+ return 'Unkown _Base object'
+
+ def __str__(self):
+ return unicode(self).encode('latin-1')
+
+ def __repr__(self):
+ raw = unicode(self)
+ if len(raw) > 7:
+ short = raw[:6]+u'...'
+ else:
+ short = raw
+ short = re.sub('\s+', ' ', short)
+ return '<%s \'%s\' at 0x%07x>' % (self._get_repr_name(),
+ short, id(self))
+
+ def _get_repr_name(self):
+ return self.__class__.__name__
+
+ def to_unicode(self):
+ return unicode(self)
+
+ def to_str(self):
+ return str(self)
+
+
+class Token(_Base):
+
+ __slots__ = ('value', 'ttype')
+
+ def __init__(self, ttype, value):
+ self.value = value
+ self.ttype = ttype
+
+ def __unicode__(self):
+ return self.value
+
+ def _get_repr_name(self):
+ return str(self.ttype).split('.')[-1]
+
+ def match(self, ttype, values):
+ if self.ttype is not ttype:
+ return False
+ if isinstance(values, basestring):
+ values = [values]
+ if self.ttype is T.Keyword:
+ return self.value.upper() in [v.upper() for v in values]
+ else:
+ return self.value in values
+
+ def is_group(self):
+ return False
+
+ def is_whitespace(self):
+ return self.ttype and self.ttype is T.Whitespace
+
+
+class _Group(Token):
+
+ __slots__ = ('value', 'ttype', 'tokens')
+
+ def __init__(self, tokens=None):
+ super(_Group, self).__init__(None, None)
+ if tokens is None:
+ tokens = []
+ self._tokens = tokens
+
+ def _set_tokens(self, tokens):
+ self._tokens = tokens
+ def _get_tokens(self):
+ if type(self._tokens) is not types.TupleType:
+ self._tokens = tuple(self._tokens)
+ return self._tokens
+ tokens = property(fget=_get_tokens, fset=_set_tokens)
+
+ def _get_repr_name(self):
+ return self.__class__.__name__
+
+ def _pprint_tree(self, depth=0):
+ """Pretty-print the object tree."""
+ indent = ' '*(depth*2)
+ for token in self.tokens:
+ print '%s%r' % (indent, token)
+ if token.is_group():
+ token._pprint_tree(depth+1)
+
+ def __unicode__(self):
+ return u''.join(unicode(t) for t in self.tokens)
+
+ @property
+ def subgroups(self):
+ #return [x for x in self.tokens if isinstance(x, _Group)]
+ for item in self.tokens:
+ if item.is_group():
+ yield item
+
+ def is_group(self):
+ return True
+
+
+class Statement(_Group):
+ __slots__ = ('value', 'ttype', '_tokens')
+
+
+class Parenthesis(_Group):
+ __slots__ = ('value', 'ttype', '_tokens')
+
+
+class Where(_Group):
+ __slots__ = ('value', 'ttype', '_tokens')
+
+
+class CommentMulti(_Group):
+ __slots__ = ('value', 'ttype', '_tokens')
+
+
+class Identifier(_Group):
+ __slots__ = ('value', 'ttype', '_tokens')
+
+
+class TypeCast(_Group):
+ __slots__ = ('value', 'ttype', '_tokens')
+
+ @property
+ def casted_object(self):
+ return self.tokens[0]
+
+ @property
+ def casted_type(self):
+ return self.tokens[-1]
+
+
+class Alias(_Group):
+ __slots__ = ('value', 'ttype', '_tokens')
+
+ @property
+ def aliased_object(self):
+ return self.tokens[0]
+
+ @property
+ def alias(self):
+ return self.tokens[-1]
+
+
+
+
+# - Filter
+
+class StatementFilter(TokenFilter):
+
+ def __init__(self):
+ self._in_declare = False
+ self._in_dbldollar = False
+ self._is_create = False
+
+ def _reset(self):
+ self._in_declare = False
+ self._in_dbldollar = False
+ self._is_create = False
+
+ def _change_splitlevel(self, ttype, value):
+ # PostgreSQL
+ if (ttype == T.Name.Builtin
+ and value.startswith('$') and value.endswith('$')):
+ if self._in_dbldollar:
+ self._in_dbldollar = False
+ return -1
+ else:
+ self._in_dbldollar = True
+ return 1
+ elif self._in_dbldollar:
+ return 0
+
+ # ANSI
+ if ttype is not T.Keyword:
+ return 0
+
+ unified = value.upper()
+
+ if unified == 'DECLARE':
+ self._in_declare = True
+ return 1
+
+ if unified == 'BEGIN':
+ if self._in_declare:
+ return 0
+ return 0
+
+ if unified == 'END':
+ return -1
+
+ if ttype is T.Keyword.DDL and unified.startswith('CREATE'):
+ self._is_create = True
+
+ if unified in ('IF', 'FOR') and self._is_create:
+ return 1
+
+ # Default
+ return 0
+
+ def process(self, stack, stream):
+ splitlevel = 0
+ stmt = None
+ consume_ws = False
+ stmt_tokens = []
+ for ttype, value in stream:
+ # Before appending the token
+ if (consume_ws and ttype is not T.Whitespace
+ and ttype is not T.Comment.Single):
+ consume_ws = False
+ stmt.tokens = stmt_tokens
+ yield stmt
+ self._reset()
+ stmt = None
+ splitlevel = 0
+ if stmt is None:
+ stmt = Statement()
+ stmt_tokens = []
+ splitlevel += self._change_splitlevel(ttype, value)
+ # Append the token
+ stmt_tokens.append(Token(ttype, value))
+ # After appending the token
+ if (not splitlevel and ttype is T.Punctuation
+ and value == ';'):
+ consume_ws = True
+ if stmt is not None:
+ stmt.tokens = stmt_tokens
+ yield stmt
+
+
+class GroupFilter(object):
+
+ def process(self, stream):
+ pass
+
+
+class GroupParenthesis(GroupFilter):
+ """Group parenthesis groups."""
+
+ def _finish_group(self, group):
+ start = group[0]
+ end = group[-1]
+ tokens = list(self._process(group[1:-1]))
+ return [start]+tokens+[end]
+
+ def _process(self, stream):
+ group = None
+ depth = 0
+ for token in stream:
+ if token.is_group():
+ token.tokens = self._process(token.tokens)
+ if token.match(T.Punctuation, '('):
+ if depth == 0:
+ group = []
+ depth += 1
+ if group is not None:
+ group.append(token)
+ if token.match(T.Punctuation, ')'):
+ depth -= 1
+ if depth == 0:
+ yield Parenthesis(self._finish_group(group))
+ group = None
+ continue
+ if group is None:
+ yield token
+
+ def process(self, group):
+ if not isinstance(group, Parenthesis):
+ group.tokens = self._process(group.tokens)
+
+
+class GroupWhere(GroupFilter):
+
+ def _process(self, stream):
+ group = None
+ depth = 0
+ for token in stream:
+ if token.is_group():
+ token.tokens = self._process(token.tokens)
+ if token.match(T.Keyword, 'WHERE'):
+ if depth == 0:
+ group = []
+ depth += 1
+ # Process conditions here? E.g. "A =|!=|in|is|... B"...
+ elif (token.ttype is T.Keyword
+ and token.value.upper() in ('ORDER', 'GROUP',
+ 'LIMIT', 'UNION')):
+ depth -= 1
+ if depth == 0:
+ yield Where(group)
+ group = None
+ if depth < 0:
+ depth = 0
+ if group is not None:
+ group.append(token)
+ else:
+ yield token
+ if group is not None:
+ yield Where(group)
+
+ def process(self, group):
+ if not isinstance(group, Where):
+ group.tokens = self._process(group.tokens)
+
+
+class GroupMultiComments(GroupFilter):
+ """Groups Comment.Multiline and adds trailing whitespace up to first lb."""
+
+ def _process(self, stream):
+ new_tokens = []
+ grp = None
+ consume_ws = False
+ for token in stream:
+ if token.is_group():
+ token.tokens = self._process(token.tokens)
+ if token.ttype is T.Comment.Multiline:
+ if grp is None:
+ grp = []
+ consume_ws = True
+ grp.append(token)
+ elif consume_ws and token.ttype is not T.Whitespace:
+ yield CommentMulti(grp)
+ grp = None
+ consume_ws = False
+ yield token
+ elif consume_ws:
+ lines = token.value.splitlines(True)
+ grp.append(Token(T.Whitespace, lines[0]))
+ if lines[0].endswith('\n'):
+ yield CommentMulti(grp)
+ grp = None
+ consume_ws = False
+ if lines[1:]:
+ yield Token(T.Whitespace, ''.join(lines[1:]))
+ else:
+ yield token
+
+ def process(self, group):
+ if not isinstance(group, CommentMulti):
+ group.tokens = self._process(group.tokens)
+
+
+## class GroupIdentifier(GroupFilter):
+
+## def _process(self, stream):
+## buff = []
+## expect_dot = False
+## for token in stream:
+## if token.is_group():
+## token.tokens = self._process(token.tokens)
+## if (token.ttype is T.String.Symbol or token.ttype is T.Name
+## and not expect_dot):
+## buff.append(token)
+## expect_dot = True
+## elif expect_dot and token.match(T.Punctuation, '.'):
+## buff.append(token)
+## expect_dot = False
+## else:
+## if expect_dot == False:
+## # something's wrong, it ends with a dot...
+## while buff:
+## yield buff.pop(0)
+## expect_dot = False
+## elif buff:
+## idt = Identifier()
+## idt.tokens = buff
+## yield idt
+## buff = []
+## yield token
+## if buff and expect_dot:
+## idt = Identifier()
+## idt.tokens = buff
+## yield idt
+## buff = []
+## while buff:
+## yield buff.pop(0)
+
+## def process(self, group):
+## if not isinstance(group, Identifier):
+## group.tokens = self._process(group.tokens)
+
+
+class AddTypeCastFilter(GroupFilter):
+
+ def _process(self, stream):
+ buff = []
+ expect_colon = False
+ has_colons = False
+ for token in stream:
+ if token.is_group():
+ token.tokens = self._process(token.tokens)
+ if ((isinstance(token, Parenthesis)
+ or isinstance(token, Identifier))
+ and not expect_colon):
+ buff.append(token)
+ expect_colon = True
+ elif expect_colon and token.match(T.Punctuation, ':'):
+ buff.append(token)
+ has_colons = True
+ elif (expect_colon
+ and (token.ttype in T.Name
+ or isinstance(token, Identifier))
+ ):
+ if not has_colons:
+ while buff:
+ yield buff.pop(0)
+ yield token
+ else:
+ buff.append(token)
+ grp = TypeCast()
+ grp.tokens = buff
+ buff = []
+ yield grp
+ expect_colons = has_colons = False
+ else:
+ while buff:
+ yield buff.pop(0)
+ yield token
+ while buff:
+ yield buff.pop(0)
+
+ def process(self, group):
+ if not isinstance(group, TypeCast):
+ group.tokens = self._process(group.tokens)
+
+
+class AddAliasFilter(GroupFilter):
+
+ def _process(self, stream):
+ buff = []
+ search_alias = False
+ lazy = False
+ for token in stream:
+ if token.is_group():
+ token.tokens = self._process(token.tokens)
+ if search_alias and (isinstance(token, Identifier)
+ or token.ttype in (T.Name,
+ T.String.Symbol)
+ or (lazy and not token.is_whitespace())):
+ buff.append(token)
+ search_alias = lazy = False
+ grp = Alias()
+ grp.tokens = buff
+ buff = []
+ yield grp
+ elif (isinstance(token, (Identifier, TypeCast))
+ or token.ttype in (T.Name, T.String.Symbol)):
+ buff.append(token)
+ search_alias = True
+ elif search_alias and (token.is_whitespace()
+ or token.match(T.Keyword, 'as')):
+ buff.append(token)
+ if token.match(T.Keyword, 'as'):
+ lazy = True
+ else:
+ while buff:
+ yield buff.pop(0)
+ yield token
+ search_alias = False
+ while buff:
+ yield buff.pop(0)
+
+ def process(self, group):
+ if not isinstance(group, Alias):
+ group.tokens = self._process(group.tokens)
+
+
+GROUP_FILTER = (GroupParenthesis(),
+ GroupMultiComments(),
+ GroupWhere(),
+ GroupIdentifier(),
+ AddTypeCastFilter(),
+ AddAliasFilter(),
+ )
+
+import types
+def group_tokens(group):
+ def _materialize(g):
+ if type(g.tokens) is not types.TupleType:
+ g.tokens = tuple(g.tokens)
+ for sg in g.subgroups:
+ _materialize(sg)
+ for groupfilter in GROUP_FILTER:
+ groupfilter.process(group)
+# _materialize(group)
+# group.tokens = tuple(group.tokens)
+# for subgroup in group.subgroups:
+# group_tokens(subgroup)
diff --git a/sqlparse/engine/filter.py b/sqlparse/engine/filter.py
new file mode 100644
index 0000000..146690c
--- /dev/null
+++ b/sqlparse/engine/filter.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+
+from sqlparse import tokens as T
+from sqlparse.engine.grouping import Statement, Token
+
+
+class TokenFilter(object):
+
+ def __init__(self, **options):
+ self.options = options
+
+ def process(self, stack, stream):
+ """Process token stream."""
+ raise NotImplementedError
+
+
+class StatementFilter(TokenFilter):
+
+ def __init__(self):
+ self._in_declare = False
+ self._in_dbldollar = False
+ self._is_create = False
+
+ def _reset(self):
+ self._in_declare = False
+ self._in_dbldollar = False
+ self._is_create = False
+
+ def _change_splitlevel(self, ttype, value):
+ # PostgreSQL
+ if (ttype == T.Name.Builtin
+ and value.startswith('$') and value.endswith('$')):
+ if self._in_dbldollar:
+ self._in_dbldollar = False
+ return -1
+ else:
+ self._in_dbldollar = True
+ return 1
+ elif self._in_dbldollar:
+ return 0
+
+ # ANSI
+ if ttype is not T.Keyword:
+ return 0
+
+ unified = value.upper()
+
+ if unified == 'DECLARE':
+ self._in_declare = True
+ return 1
+
+ if unified == 'BEGIN':
+ if self._in_declare:
+ return 0
+ return 0
+
+ if unified == 'END':
+ # Should this respect a preceeding BEGIN?
+ # In CASE ... WHEN ... END this results in a split level -1.
+ return -1
+
+ if ttype is T.Keyword.DDL and unified.startswith('CREATE'):
+ self._is_create = True
+
+ if unified in ('IF', 'FOR') and self._is_create:
+ return 1
+
+ # Default
+ return 0
+
+ def process(self, stack, stream):
+ splitlevel = 0
+ stmt = None
+ consume_ws = False
+ stmt_tokens = []
+ for ttype, value in stream:
+ # Before appending the token
+ if (consume_ws and ttype is not T.Whitespace
+ and ttype is not T.Comment.Single):
+ consume_ws = False
+ stmt.tokens = stmt_tokens
+ yield stmt
+ self._reset()
+ stmt = None
+ splitlevel = 0
+ if stmt is None:
+ stmt = Statement()
+ stmt_tokens = []
+ splitlevel += self._change_splitlevel(ttype, value)
+ # Append the token
+ stmt_tokens.append(Token(ttype, value))
+ # After appending the token
+ if (splitlevel <= 0 and ttype is T.Punctuation
+ and value == ';'):
+ consume_ws = True
+ if stmt is not None:
+ stmt.tokens = stmt_tokens
+ yield stmt
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
new file mode 100644
index 0000000..433f539
--- /dev/null
+++ b/sqlparse/engine/grouping.py
@@ -0,0 +1,537 @@
+# -*- coding: utf-8 -*-
+
+import itertools
+import re
+import types
+
+from sqlparse import tokens as T
+
+
+class Token(object):
+
+ __slots__ = ('value', 'ttype')
+
+ def __init__(self, ttype, value):
+ self.value = value
+ self.ttype = ttype
+
+ def __str__(self):
+ return unicode(self).encode('latin-1')
+
+ def __repr__(self):
+ short = self._get_repr_value()
+ return '<%s \'%s\' at 0x%07x>' % (self._get_repr_name(),
+ short, id(self))
+
+ def __unicode__(self):
+ return self.value
+
+ def to_unicode(self):
+ return unicode(self)
+
+ def _get_repr_name(self):
+ return str(self.ttype).split('.')[-1]
+
+ def _get_repr_value(self):
+ raw = unicode(self)
+ if len(raw) > 7:
+ short = raw[:6]+u'...'
+ else:
+ short = raw
+ return re.sub('\s+', ' ', short)
+
+ def match(self, ttype, values, regex=False):
+ if self.ttype is not ttype:
+ return False
+ if values is None:
+ return self.ttype is ttype
+ if isinstance(values, basestring):
+ values = [values]
+ if regex:
+ if self.ttype is T.Keyword:
+ values = [re.compile(v, re.IGNORECASE) for v in values]
+ else:
+ values = [re.compile(v) for v in values]
+ for pattern in values:
+ if pattern.search(self.value):
+ return True
+ return False
+ else:
+ if self.ttype is T.Keyword:
+ return self.value.upper() in [v.upper() for v in values]
+ else:
+ return self.value in values
+
+ def is_group(self):
+ return False
+
+ def is_whitespace(self):
+ return self.ttype and self.ttype in T.Whitespace
+
+
+class TokenList(Token):
+
+ __slots__ = ('value', 'ttype', 'tokens')
+
+ def __init__(self, tokens=None):
+ if tokens is None:
+ tokens = []
+ self.tokens = tokens
+ Token.__init__(self, None, None)
+
+ def __unicode__(self):
+ return ''.join(unicode(x) for x in self.flatten())
+
+ def __str__(self):
+ return unicode(self).encode('latin-1')
+
+ def _get_repr_name(self):
+ return self.__class__.__name__
+
+ def _pprint_tree(self, max_depth=None, depth=0):
+ """Pretty-print the object tree."""
+ indent = ' '*(depth*2)
+ for token in self.tokens:
+ if token.is_group():
+ pre = ' | '
+ else:
+ pre = ' | '
+ print '%s%s%s \'%s\'' % (indent, pre, token._get_repr_name(),
+ token._get_repr_value())
+ if (token.is_group() and max_depth is not None
+ and depth < max_depth):
+ token._pprint_tree(max_depth, depth+1)
+
+ def flatten(self):
+ for token in self.tokens:
+ if isinstance(token, TokenList):
+ for item in token.flatten():
+ yield item
+ else:
+ yield token
+
+ def is_group(self):
+ return True
+
+ def get_sublists(self):
+ return [x for x in self.tokens if isinstance(x, TokenList)]
+
+ def token_first(self, ignore_whitespace=True):
+ for token in self.tokens:
+ if ignore_whitespace and token.is_whitespace():
+ continue
+ return token
+ return None
+
+ def token_next_by_instance(self, idx, clss):
+ if type(clss) not in (types.ListType, types.TupleType):
+ clss = (clss,)
+ if type(clss) is not types.TupleType:
+ clss = tuple(clss)
+ for token in self.tokens[idx:]:
+ if isinstance(token, clss):
+ return token
+ return None
+
+ def token_next_by_type(self, idx, ttypes):
+ if not isinstance(ttypes, (types.TupleType, types.ListType)):
+ ttypes = [ttypes]
+ for token in self.tokens[idx:]:
+ if token.ttype in ttypes:
+ return token
+ return None
+
+ def token_next_match(self, idx, ttype, value, regex=False):
+ if type(idx) != types.IntType:
+ idx = self.token_index(idx)
+ for token in self.tokens[idx:]:
+ if token.match(ttype, value, regex):
+ return token
+ return None
+
+ def token_not_matching(self, idx, funcs):
+ for token in self.tokens[idx:]:
+ passed = False
+ for func in funcs:
+ if func(token):
+ passed = True
+ break
+ if not passed:
+ return token
+ return None
+
+ def token_prev(self, idx, skip_ws=True):
+ while idx != 0:
+ idx -= 1
+ if self.tokens[idx].is_whitespace() and skip_ws:
+ continue
+ return self.tokens[idx]
+
+ def token_next(self, idx, skip_ws=True):
+ while idx < len(self.tokens)-1:
+ idx += 1
+ if self.tokens[idx].is_whitespace() and skip_ws:
+ continue
+ return self.tokens[idx]
+
+ def token_index(self, token):
+ """Return list index of token."""
+ return self.tokens.index(token)
+
+ def tokens_between(self, start, end, exclude_end=False):
+ """Return all tokens between (and including) start and end."""
+ if exclude_end:
+ offset = 0
+ else:
+ offset = 1
+ return self.tokens[self.token_index(start):self.token_index(end)+offset]
+
+ def group_tokens(self, grp_cls, tokens):
+ """Replace tokens by instance of grp_cls."""
+ idx = self.token_index(tokens[0])
+ for t in tokens:
+ self.tokens.remove(t)
+ grp = grp_cls(tokens)
+ self.tokens.insert(idx, grp)
+ return grp
+
+ def insert_before(self, where, token):
+ self.tokens.insert(self.token_index(where), token)
+
+
+class Statement(TokenList):
+
+ __slots__ = ('value', 'ttype', 'tokens')
+
+ def get_type(self):
+ first_token = self.token_first()
+ if first_token.ttype in (T.Keyword.DML, T.Keyword.DDL):
+ return first_token.value.upper()
+ else:
+ return 'UNKNOWN'
+
+
+class Identifier(TokenList):
+
+ __slots__ = ('value', 'ttype', 'tokens')
+
+ def has_alias(self):
+ return self.get_alias() is not None
+
+ def get_alias(self):
+ kw = self.token_next_match(0, T.Keyword, 'AS')
+ if kw is not None:
+ alias = self.token_next(self.token_index(kw))
+ if alias is None:
+ return None
+ else:
+ next_ = self.token_next(0)
+ if next_ is None or not isinstance(next_, Identifier):
+ return None
+ alias = next_
+ if isinstance(alias, Identifier):
+ return alias.get_name()
+ else:
+ return alias.to_unicode()
+
+ def get_name(self):
+ alias = self.get_alias()
+ if alias is not None:
+ return alias
+ return self.get_real_name()
+
+ def get_real_name(self):
+ return self.token_next_by_type(0, T.Name).value
+
+ def get_typecast(self):
+ marker = self.token_next_match(0, T.Punctuation, '::')
+ if marker is None:
+ return None
+ next_ = self.token_next(self.token_index(marker), False)
+ if next_ is None:
+ return None
+ return next_.to_unicode()
+
+
+class IdentifierList(TokenList):
+
+ __slots__ = ('value', 'ttype', 'tokens')
+
+ def get_identifiers(self):
+ return [x for x in self.tokens if isinstance(x, Identifier)]
+
+
+class Parenthesis(TokenList):
+ __slots__ = ('value', 'ttype', 'tokens')
+
+
+class Assignment(TokenList):
+ __slots__ = ('value', 'ttype', 'tokens')
+
+class If(TokenList):
+ __slots__ = ('value', 'ttype', 'tokens')
+
+class For(TokenList):
+ __slots__ = ('value', 'ttype', 'tokens')
+
+class Comparsion(TokenList):
+ __slots__ = ('value', 'ttype', 'tokens')
+
+class Comment(TokenList):
+ __slots__ = ('value', 'ttype', 'tokens')
+
+class Where(TokenList):
+ __slots__ = ('value', 'ttype', 'tokens')
+
+
+class Case(TokenList):
+
+ __slots__ = ('value', 'ttype', 'tokens')
+
+ def get_cases(self):
+ """Returns a list of 2-tuples (condition, value).
+
+ If an ELSE exists condition is None.
+ """
+ ret = []
+ in_condition = in_value = False
+ for token in self.tokens:
+ if token.match(T.Keyword, 'WHEN'):
+ ret.append(([], []))
+ in_condition = True
+ in_value = False
+ elif token.match(T.Keyword, 'ELSE'):
+ ret.append((None, []))
+ in_condition = False
+ in_value = True
+ elif token.match(T.Keyword, 'THEN'):
+ in_condition = False
+ in_value = True
+ elif token.match(T.Keyword, 'END'):
+ in_condition = False
+ in_value = False
+ if in_condition:
+ ret[-1][0].append(token)
+ elif in_value:
+ ret[-1][1].append(token)
+ return ret
+
+def _group_left_right(tlist, ttype, value, cls,
+ check_right=lambda t: True,
+ include_semicolon=False):
+# [_group_left_right(sgroup, ttype, value, cls, check_right,
+# include_semicolon) for sgroup in tlist.get_sublists()
+# if not isinstance(sgroup, cls)]
+ idx = 0
+ token = tlist.token_next_match(idx, ttype, value)
+ while token:
+ right = tlist.token_next(tlist.token_index(token))
+ left = tlist.token_prev(tlist.token_index(token))
+ if (right is None or not check_right(right)
+ or left is None):
+ token = tlist.token_next_match(tlist.token_index(token)+1,
+ ttype, value)
+ else:
+ if include_semicolon:
+ right = tlist.token_next_match(tlist.token_index(right),
+ T.Punctuation, ';')
+ tokens = tlist.tokens_between(left, right)[1:]
+ if not isinstance(left, cls):
+ new = cls([left])
+ new_idx = tlist.token_index(left)
+ tlist.tokens.remove(left)
+ tlist.tokens.insert(new_idx, new)
+ left = new
+ left.tokens.extend(tokens)
+ for t in tokens:
+ tlist.tokens.remove(t)
+ token = tlist.token_next_match(tlist.token_index(left)+1,
+ ttype, value)
+
+def _group_matching(tlist, start_ttype, start_value, end_ttype, end_value,
+ cls, include_semicolon=False, recurse=False):
+ def _find_matching(i, tl, stt, sva, ett, eva):
+ depth = 1
+ for t in tl.tokens[i:]:
+ if t.match(stt, sva):
+ depth += 1
+ elif t.match(ett, eva):
+ depth -= 1
+ if depth == 1:
+ return t
+ return None
+ [_group_matching(sgroup, start_ttype, start_value, end_ttype, end_value,
+ cls, include_semicolon) for sgroup in tlist.get_sublists()
+ if recurse]
+ if isinstance(tlist, cls):
+ idx = 1
+ else:
+ idx = 0
+ token = tlist.token_next_match(idx, start_ttype, start_value)
+ while token:
+ tidx = tlist.token_index(token)
+ end = _find_matching(tidx, tlist, start_ttype, start_value,
+ end_ttype, end_value)
+ if end is None:
+ idx = tidx+1
+ else:
+ if include_semicolon:
+ next_ = tlist.token_next(tlist.token_index(end))
+ if next_ and next_.match(T.Punctuation, ';'):
+ end = next_
+ group = tlist.group_tokens(cls, tlist.tokens_between(token, end))
+ _group_matching(group, start_ttype, start_value,
+ end_ttype, end_value, cls, include_semicolon)
+ idx = tlist.token_index(group)+1
+ token = tlist.token_next_match(idx, start_ttype, start_value)
+
+def group_if(tlist):
+ _group_matching(tlist, T.Keyword, 'IF', T.Keyword, 'END IF', If, True)
+
+def group_for(tlist):
+ _group_matching(tlist, T.Keyword, 'FOR', T.Keyword, 'END LOOP', For, True)
+
+def group_as(tlist):
+ _group_left_right(tlist, T.Keyword, 'AS', Identifier)
+
+def group_assignment(tlist):
+ _group_left_right(tlist, T.Assignment, ':=', Assignment,
+ include_semicolon=True)
+
+def group_comparsion(tlist):
+ _group_left_right(tlist, T.Operator, None, Comparsion)
+
+
+def group_case(tlist):
+ _group_matching(tlist, T.Keyword, 'CASE', T.Keyword, 'END', Case, True)
+
+
+def group_identifier(tlist):
+ def _consume_cycle(tl, i):
+ x = itertools.cycle((lambda y: y.match(T.Punctuation, '.'),
+ lambda y: y.ttype in (T.String.Symbol, T.Name)))
+ for t in tl.tokens[i:]:
+ if x.next()(t):
+ yield t
+ else:
+ raise StopIteration
+
+ # bottom up approach: group subgroups first
+ [group_identifier(sgroup) for sgroup in tlist.get_sublists()
+ if not isinstance(sgroup, Identifier)]
+
+ # real processing
+ idx = 0
+ token = tlist.token_next_by_type(idx, (T.String.Symbol, T.Name))
+ while token:
+ identifier_tokens = [token]+list(
+ _consume_cycle(tlist,
+ tlist.token_index(token)+1))
+ group = tlist.group_tokens(Identifier, identifier_tokens)
+ idx = tlist.token_index(group)+1
+ token = tlist.token_next_by_type(idx, (T.String.Symbol, T.Name))
+
+
+def group_identifier_list(tlist):
+ [group_identifier_list(sgroup) for sgroup in tlist.get_sublists()
+ if not isinstance(sgroup, IdentifierList)]
+ idx = 0
+ token = tlist.token_next_by_instance(idx, Identifier)
+ while token:
+ tidx = tlist.token_index(token)
+ end = tlist.token_not_matching(tidx+1,
+ [lambda t: isinstance(t, Identifier),
+ lambda t: t.is_whitespace(),
+ lambda t: t.match(T.Punctuation,
+ ',')
+ ])
+ if end is None:
+ idx = tidx + 1
+ else:
+ grp_tokens = tlist.tokens_between(token, end, exclude_end=True)
+ while grp_tokens and (grp_tokens[-1].is_whitespace()
+ or grp_tokens[-1].match(T.Punctuation, ',')):
+ grp_tokens.pop()
+ if len(grp_tokens) <= 1:
+ idx = tidx + 1
+ else:
+ group = tlist.group_tokens(IdentifierList, grp_tokens)
+ idx = tlist.token_index(group)
+ token = tlist.token_next_by_instance(idx, Identifier)
+
+
+def group_parenthesis(tlist):
+ _group_matching(tlist, T.Punctuation, '(', T.Punctuation, ')', Parenthesis)
+
+def group_comments(tlist):
+ [group_comments(sgroup) for sgroup in tlist.get_sublists()
+ if not isinstance(sgroup, Comment)]
+ idx = 0
+ token = tlist.token_next_by_type(idx, T.Comment)
+ while token:
+ tidx = tlist.token_index(token)
+ end = tlist.token_not_matching(tidx+1,
+ [lambda t: t.ttype in T.Comment,
+ lambda t: t.is_whitespace()])
+ if end is None:
+ idx = tidx + 1
+ else:
+ eidx = tlist.token_index(end)
+ grp_tokens = tlist.tokens_between(token,
+ tlist.token_prev(eidx, False))
+ group = tlist.group_tokens(Comment, grp_tokens)
+ idx = tlist.token_index(group)
+ token = tlist.token_next_by_type(idx, T.Comment)
+
+def group_where(tlist):
+ [group_where(sgroup) for sgroup in tlist.get_sublists()
+ if not isinstance(sgroup, Where)]
+ idx = 0
+ token = tlist.token_next_match(idx, T.Keyword, 'WHERE')
+ stopwords = ('ORDER', 'GROUP', 'LIMIT', 'UNION')
+ while token:
+ tidx = tlist.token_index(token)
+ end = tlist.token_next_match(tidx+1, T.Keyword, stopwords)
+ if end is None:
+ end = tlist.tokens[-1]
+ else:
+ end = tlist.tokens[tlist.token_index(end)-1]
+ group = tlist.group_tokens(Where, tlist.tokens_between(token, end))
+ idx = tlist.token_index(group)
+ token = tlist.token_next_match(idx, T.Keyword, 'WHERE')
+
+def group_aliased(tlist):
+ [group_aliased(sgroup) for sgroup in tlist.get_sublists()
+ if not isinstance(sgroup, Identifier)]
+ idx = 0
+ token = tlist.token_next_by_instance(idx, Identifier)
+ while token:
+ next_ = tlist.token_next(tlist.token_index(token))
+ if next_ is not None and isinstance(next_, Identifier):
+ grp = tlist.tokens_between(token, next_)[1:]
+ token.tokens.extend(grp)
+ for t in grp:
+ tlist.tokens.remove(t)
+ idx = tlist.token_index(token)+1
+ token = tlist.token_next_by_instance(idx, Identifier)
+
+
+def group_typecasts(tlist):
+ _group_left_right(tlist, T.Punctuation, '::', Identifier)
+
+
+def group(tlist):
+ for func in [group_parenthesis,
+ group_comments,
+ group_where,
+ group_case,
+ group_identifier,
+ group_typecasts,
+ group_as,
+ group_aliased,
+ group_assignment,
+ group_comparsion,
+ group_identifier_list,
+ group_if,
+ group_for,]:
+ func(tlist)