diff options
Diffstat (limited to 'sqlparse/sql.py')
| -rw-r--r-- | sqlparse/sql.py | 292 |
1 files changed, 111 insertions, 181 deletions
diff --git a/sqlparse/sql.py b/sqlparse/sql.py index 6abc432..57bf1e7 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -6,15 +6,16 @@ # the BSD License: http://www.opensource.org/licenses/bsd-license.php """This module contains classes representing syntactical elements of SQL.""" +from __future__ import print_function import re -import sys from sqlparse import tokens as T -from sqlparse.compat import string_types, u +from sqlparse.compat import string_types, text_type, unicode_compatible from sqlparse.utils import imt, remove_quotes +@unicode_compatible class Token(object): """Base class for all other classes in this module. @@ -26,40 +27,29 @@ class Token(object): __slots__ = ('value', 'ttype', 'parent', 'normalized', 'is_keyword') def __init__(self, ttype, value): + value = text_type(value) self.value = value - if ttype in T.Keyword: - self.normalized = value.upper() - else: - self.normalized = value self.ttype = ttype - self.is_keyword = ttype in T.Keyword self.parent = None + self.is_keyword = ttype in T.Keyword + self.normalized = value.upper() if self.is_keyword else value def __str__(self): - if sys.version_info[0] == 3: - return self.value - else: - return u(self).encode('utf-8') + return self.value def __repr__(self): - short = self._get_repr_value() - if sys.version_info[0] < 3: - short = short.encode('utf-8') - return '<%s \'%s\' at 0x%07x>' % (self._get_repr_name(), - short, id(self)) - - def __unicode__(self): - """Returns a unicode representation of this object.""" - return self.value or '' + cls = self._get_repr_name() + value = self._get_repr_value() + return "<{cls} '{value}' at 0x{id:2X}>".format(id=id(self), **locals()) def _get_repr_name(self): return str(self.ttype).split('.')[-1] def _get_repr_value(self): - raw = u(self) + raw = text_type(self) if len(raw) > 7: - raw = raw[:6] + u'...' - return re.sub('\s+', ' ', raw) + raw = raw[:6] + '...' + return re.sub(r'\s+', ' ', raw) def flatten(self): """Resolve subgroups.""" @@ -81,32 +71,23 @@ class Token(object): if not type_matched or values is None: return type_matched - if regex: - if isinstance(values, string_types): - values = {values} + if isinstance(values, string_types): + values = (values,) - if self.ttype is T.Keyword: - values = set(re.compile(v, re.IGNORECASE) for v in values) - else: - values = set(re.compile(v) for v in values) + if regex: + # TODO: Add test for regex with is_keyboard = false + flag = re.IGNORECASE if self.is_keyword else 0 + values = (re.compile(v, flag) for v in values) for pattern in values: - if pattern.search(self.value): + if pattern.search(self.normalized): return True return False - if isinstance(values, string_types): - if self.is_keyword: - return values.upper() == self.normalized - return values == self.value - if self.is_keyword: - for v in values: - if v.upper() == self.normalized: - return True - return False + values = (v.upper() for v in values) - return self.value in values + return self.normalized in values def is_group(self): """Returns ``True`` if this object has children.""" @@ -114,7 +95,7 @@ class Token(object): def is_whitespace(self): """Return ``True`` if this token is a whitespace token.""" - return self.ttype and self.ttype in T.Whitespace + return self.ttype in T.Whitespace def within(self, group_cls): """Returns ``True`` if this token is within *group_cls*. @@ -143,6 +124,7 @@ class Token(object): return False +@unicode_compatible class TokenList(Token): """A group of tokens. @@ -150,45 +132,35 @@ class TokenList(Token): list of child-tokens. """ - __slots__ = ('value', 'ttype', 'tokens') + __slots__ = 'tokens' def __init__(self, tokens=None): - if tokens is None: - tokens = [] - self.tokens = tokens - super(TokenList, self).__init__(None, self.__str__()) - - def __unicode__(self): - return self._to_string() + self.tokens = tokens or [] + super(TokenList, self).__init__(None, text_type(self)) def __str__(self): - str_ = self._to_string() - if sys.version_info[0] <= 2: - str_ = str_.encode('utf-8') - return str_ - - def _to_string(self): - if sys.version_info[0] == 3: - return ''.join(x.value for x in self.flatten()) - else: - return ''.join(u(x) for x in self.flatten()) + return ''.join(token.value for token in self.flatten()) + + def __iter__(self): + return iter(self.tokens) + + def __getitem__(self, item): + return self.tokens[item] def _get_repr_name(self): - return self.__class__.__name__ + return type(self).__name__ - def _pprint_tree(self, max_depth=None, depth=0): + def _pprint_tree(self, max_depth=None, depth=0, f=None): """Pretty-print the object tree.""" - indent = ' ' * (depth * 2) + ind = ' ' * (depth * 2) for idx, token in enumerate(self.tokens): - if token.is_group(): - pre = ' +-' - else: - pre = ' | ' - print('%s%s%d %s \'%s\'' % (indent, pre, idx, - token._get_repr_name(), - token._get_repr_value())) - if (token.is_group() and (max_depth is None or depth < max_depth)): - token._pprint_tree(max_depth, depth + 1) + pre = ' +-' if token.is_group() else ' | ' + cls = token._get_repr_name() + value = token._get_repr_value() + print("{ind}{pre}{idx} {cls} '{value}'".format(**locals()), file=f) + + if token.is_group() and (max_depth is None or depth < max_depth): + token._pprint_tree(max_depth, depth + 1, f) def get_token_at_offset(self, offset): """Returns the token that is on position offset.""" @@ -205,26 +177,19 @@ class TokenList(Token): This method is recursively called for all child tokens. """ for token in self.tokens: - if isinstance(token, TokenList): + if token.is_group(): for item in token.flatten(): yield item else: yield token - # def __iter__(self): - # return self - # - # def next(self): - # for token in self.tokens: - # yield token - def is_group(self): return True def get_sublists(self): - for x in self.tokens: - if isinstance(x, TokenList): - yield x + for token in self.tokens: + if token.is_group(): + yield token @property def _groupable_tokens(self): @@ -242,7 +207,7 @@ class TokenList(Token): funcs = (funcs,) if reverse: - iterable = iter(reversed(self.tokens[end:start - 1])) + iterable = reversed(self.tokens[end:start - 1]) else: iterable = self.tokens[start:end] @@ -251,7 +216,7 @@ class TokenList(Token): if func(token): return token - def token_first(self, ignore_whitespace=True, ignore_comments=False): + def token_first(self, skip_ws=True, skip_cm=False): """Returns the first child token. If *ignore_whitespace* is ``True`` (the default), whitespace @@ -260,35 +225,15 @@ class TokenList(Token): if *ignore_comments* is ``True`` (default: ``False``), comments are ignored too. """ - funcs = lambda tk: not ((ignore_whitespace and tk.is_whitespace()) or - (ignore_comments and imt(tk, i=Comment))) + # this on is inconsistent, using Comment instead of T.Comment... + funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or + (skip_cm and imt(tk, i=Comment))) return self._token_matching(funcs) def token_next_by(self, i=None, m=None, t=None, idx=0, end=None): funcs = lambda tk: imt(tk, i, m, t) return self._token_matching(funcs, idx, end) - def token_next_by_instance(self, idx, clss, end=None): - """Returns the next token matching a class. - - *idx* is where to start searching in the list of child tokens. - *clss* is a list of classes the token should be an instance of. - - If no matching token can be found ``None`` is returned. - """ - funcs = lambda tk: imt(tk, i=clss) - return self._token_matching(funcs, idx, end) - - def token_next_by_type(self, idx, ttypes): - """Returns next matching token by it's token type.""" - funcs = lambda tk: imt(tk, t=ttypes) - return self._token_matching(funcs, idx) - - def token_next_match(self, idx, ttype, value, regex=False): - """Returns next token where it's ``match`` method returns ``True``.""" - funcs = lambda tk: imt(tk, m=(ttype, value, regex)) - return self._token_matching(funcs, idx) - def token_not_matching(self, idx, funcs): funcs = (funcs,) if not isinstance(funcs, (list, tuple)) else funcs funcs = [lambda tk: not func(tk) for func in funcs] @@ -297,7 +242,7 @@ class TokenList(Token): def token_matching(self, idx, funcs): return self._token_matching(funcs, idx) - def token_prev(self, idx, skip_ws=True): + def token_prev(self, idx, skip_ws=True, skip_cm=False): """Returns the previous token relative to *idx*. If *skip_ws* is ``True`` (the default) whitespace tokens are ignored. @@ -305,10 +250,11 @@ class TokenList(Token): """ if isinstance(idx, int): idx += 1 # alot of code usage current pre-compensates for this - funcs = lambda tk: not (tk.is_whitespace() and skip_ws) + funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or + (skip_cm and imt(tk, t=T.Comment))) return self._token_matching(funcs, idx, reverse=True) - def token_next(self, idx, skip_ws=True): + def token_next(self, idx, skip_ws=True, skip_cm=False): """Returns the next token relative to *idx*. If *skip_ws* is ``True`` (the default) whitespace tokens are ignored. @@ -316,7 +262,8 @@ class TokenList(Token): """ if isinstance(idx, int): idx += 1 # alot of code usage current pre-compensates for this - funcs = lambda tk: not (tk.is_whitespace() and skip_ws) + funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or + (skip_cm and imt(tk, t=T.Comment))) return self._token_matching(funcs, idx) def token_index(self, token, start=0): @@ -334,9 +281,9 @@ class TokenList(Token): end_idx = include_end + self.token_index(end) return self.tokens[start_idx:end_idx] - def group_tokens(self, grp_cls, tokens, ignore_ws=False, extend=False): + def group_tokens(self, grp_cls, tokens, skip_ws=False, extend=False): """Replace tokens by an instance of *grp_cls*.""" - if ignore_ws: + if skip_ws: while tokens and tokens[-1].is_whitespace(): tokens = tokens[:-1] @@ -352,7 +299,7 @@ class TokenList(Token): left.parent = self tokens = tokens[1:] left.tokens.extend(tokens) - left.value = left.__str__() + left.value = str(left) else: left = grp_cls(tokens) @@ -393,8 +340,6 @@ class TokenList(Token): if len(self.tokens) > 2 and self.token_next_by(t=T.Whitespace): return self._get_first_name(reverse=True) - return None - def get_name(self): """Returns the name of this identifier. @@ -402,32 +347,22 @@ class TokenList(Token): be considered as the name under which the object corresponding to this identifier is known within the current statement. """ - alias = self.get_alias() - if alias is not None: - return alias - return self.get_real_name() + return self.get_alias() or self.get_real_name() def get_real_name(self): """Returns the real name (object name) of this identifier.""" # a.b - dot = self.token_next_match(0, T.Punctuation, '.') - if dot is not None: - return self._get_first_name(self.token_index(dot)) - - return self._get_first_name() + dot = self.token_next_by(m=(T.Punctuation, '.')) + return self._get_first_name(dot) def get_parent_name(self): """Return name of the parent object if any. A parent object is identified by the first occuring dot. """ - dot = self.token_next_match(0, T.Punctuation, '.') - if dot is None: - return None - prev_ = self.token_prev(self.token_index(dot)) - if prev_ is None: # something must be verry wrong here.. - return None - return remove_quotes(prev_.value) + dot = self.token_next_by(m=(T.Punctuation, '.')) + prev_ = self.token_prev(dot) + return remove_quotes(prev_.value) if prev_ is not None else None def _get_first_name(self, idx=None, reverse=False, keywords=False): """Returns the name of the first token with a name""" @@ -442,19 +377,16 @@ class TokenList(Token): if keywords: types.append(T.Keyword) - for tok in tokens: - if tok.ttype in types: - return remove_quotes(tok.value) - elif isinstance(tok, Identifier) or isinstance(tok, Function): - return tok.get_name() - return None + for token in tokens: + if token.ttype in types: + return remove_quotes(token.value) + elif isinstance(token, (Identifier, Function)): + return token.get_name() class Statement(TokenList): """Represents a SQL statement.""" - __slots__ = ('value', 'ttype', 'tokens') - def get_type(self): """Returns the type of a statement. @@ -465,7 +397,7 @@ class Statement(TokenList): Whitespaces and comments at the beginning of the statement are ignored. """ - first_token = self.token_first(ignore_comments=True) + first_token = self.token_first(skip_cm=True) if first_token is None: # An "empty" statement that either has not tokens at all # or only whitespace tokens. @@ -478,16 +410,14 @@ class Statement(TokenList): # The WITH keyword should be followed by either an Identifier or # an IdentifierList containing the CTE definitions; the actual # DML keyword (e.g. SELECT, INSERT) will follow next. - idents = self.token_next( - self.token_index(first_token), skip_ws=True) - if isinstance(idents, (Identifier, IdentifierList)): - dml_keyword = self.token_next( - self.token_index(idents), skip_ws=True) + token = self.token_next(first_token, skip_ws=True) + if isinstance(token, (Identifier, IdentifierList)): + dml_keyword = self.token_next(token, skip_ws=True) + if dml_keyword.ttype == T.Keyword.DML: return dml_keyword.normalized - # Hmm, probably invalid syntax, so return unknown. - return 'UNKNOWN' + # Hmm, probably invalid syntax, so return unknown. return 'UNKNOWN' @@ -499,33 +429,33 @@ class Identifier(TokenList): def is_wildcard(self): """Return ``True`` if this identifier contains a wildcard.""" - token = self.token_next_by_type(0, T.Wildcard) + token = self.token_next_by(t=T.Wildcard) return token is not None def get_typecast(self): """Returns the typecast or ``None`` of this object as a string.""" - marker = self.token_next_match(0, T.Punctuation, '::') + marker = self.token_next_by(m=(T.Punctuation, '::')) if marker is None: return None - next_ = self.token_next(self.token_index(marker), False) + next_ = self.token_next(marker, False) if next_ is None: return None - return u(next_) + return next_.value def get_ordering(self): """Returns the ordering or ``None`` as uppercase string.""" - ordering = self.token_next_by_type(0, T.Keyword.Order) + ordering = self.token_next_by(t=T.Keyword.Order) if ordering is None: return None - return ordering.value.upper() + return ordering.normalized def get_array_indices(self): """Returns an iterator of index token lists""" - for tok in self.tokens: - if isinstance(tok, SquareBrackets): + for token in self.tokens: + if isinstance(token, SquareBrackets): # Use [1:-1] index to discard the square brackets - yield tok.tokens[1:-1] + yield token.tokens[1:-1] class IdentifierList(TokenList): @@ -536,15 +466,15 @@ class IdentifierList(TokenList): Whitespaces and punctuations are not included in this generator. """ - for x in self.tokens: - if not x.is_whitespace() and not x.match(T.Punctuation, ','): - yield x + for token in self.tokens: + if not (token.is_whitespace() or token.match(T.Punctuation, ',')): + yield token class Parenthesis(TokenList): """Tokens between parenthesis.""" - M_OPEN = (T.Punctuation, '(') - M_CLOSE = (T.Punctuation, ')') + M_OPEN = T.Punctuation, '(' + M_CLOSE = T.Punctuation, ')' @property def _groupable_tokens(self): @@ -553,8 +483,8 @@ class Parenthesis(TokenList): class SquareBrackets(TokenList): """Tokens between square brackets""" - M_OPEN = (T.Punctuation, '[') - M_CLOSE = (T.Punctuation, ']') + M_OPEN = T.Punctuation, '[' + M_CLOSE = T.Punctuation, ']' @property def _groupable_tokens(self): @@ -567,14 +497,14 @@ class Assignment(TokenList): class If(TokenList): """An 'if' clause with possible 'else if' or 'else' parts.""" - M_OPEN = (T.Keyword, 'IF') - M_CLOSE = (T.Keyword, 'END IF') + M_OPEN = T.Keyword, 'IF' + M_CLOSE = T.Keyword, 'END IF' class For(TokenList): """A 'FOR' loop.""" - M_OPEN = (T.Keyword, ('FOR', 'FOREACH')) - M_CLOSE = (T.Keyword, 'END LOOP') + M_OPEN = T.Keyword, ('FOR', 'FOREACH') + M_CLOSE = T.Keyword, 'END LOOP' class Comparison(TokenList): @@ -598,15 +528,15 @@ class Comment(TokenList): class Where(TokenList): """A WHERE clause.""" - M_OPEN = (T.Keyword, 'WHERE') - M_CLOSE = (T.Keyword, - ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT', 'HAVING')) + M_OPEN = T.Keyword, 'WHERE' + M_CLOSE = T.Keyword, ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT', + 'HAVING') class Case(TokenList): """A CASE statement with one or more WHEN and possibly an ELSE part.""" - M_OPEN = (T.Keyword, 'CASE') - M_CLOSE = (T.Keyword, 'END') + M_OPEN = T.Keyword, 'CASE' + M_CLOSE = T.Keyword, 'END' def get_cases(self): """Returns a list of 2-tuples (condition, value). @@ -659,15 +589,15 @@ class Function(TokenList): def get_parameters(self): """Return a list of parameters.""" parenthesis = self.tokens[-1] - for t in parenthesis.tokens: - if imt(t, i=IdentifierList): - return t.get_identifiers() - elif imt(t, i=(Function, Identifier), t=T.Literal): - return [t, ] + for token in parenthesis.tokens: + if imt(token, i=IdentifierList): + return token.get_identifiers() + elif imt(token, i=(Function, Identifier), t=T.Literal): + return [token, ] return [] class Begin(TokenList): """A BEGIN/END block.""" - M_OPEN = (T.Keyword, 'BEGIN') - M_CLOSE = (T.Keyword, 'END') + M_OPEN = T.Keyword, 'BEGIN' + M_CLOSE = T.Keyword, 'END' |
