summaryrefslogtreecommitdiff
path: root/sqlparse/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlparse/sql.py')
-rw-r--r--sqlparse/sql.py292
1 files changed, 111 insertions, 181 deletions
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index 6abc432..57bf1e7 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -6,15 +6,16 @@
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
"""This module contains classes representing syntactical elements of SQL."""
+from __future__ import print_function
import re
-import sys
from sqlparse import tokens as T
-from sqlparse.compat import string_types, u
+from sqlparse.compat import string_types, text_type, unicode_compatible
from sqlparse.utils import imt, remove_quotes
+@unicode_compatible
class Token(object):
"""Base class for all other classes in this module.
@@ -26,40 +27,29 @@ class Token(object):
__slots__ = ('value', 'ttype', 'parent', 'normalized', 'is_keyword')
def __init__(self, ttype, value):
+ value = text_type(value)
self.value = value
- if ttype in T.Keyword:
- self.normalized = value.upper()
- else:
- self.normalized = value
self.ttype = ttype
- self.is_keyword = ttype in T.Keyword
self.parent = None
+ self.is_keyword = ttype in T.Keyword
+ self.normalized = value.upper() if self.is_keyword else value
def __str__(self):
- if sys.version_info[0] == 3:
- return self.value
- else:
- return u(self).encode('utf-8')
+ return self.value
def __repr__(self):
- short = self._get_repr_value()
- if sys.version_info[0] < 3:
- short = short.encode('utf-8')
- return '<%s \'%s\' at 0x%07x>' % (self._get_repr_name(),
- short, id(self))
-
- def __unicode__(self):
- """Returns a unicode representation of this object."""
- return self.value or ''
+ cls = self._get_repr_name()
+ value = self._get_repr_value()
+ return "<{cls} '{value}' at 0x{id:2X}>".format(id=id(self), **locals())
def _get_repr_name(self):
return str(self.ttype).split('.')[-1]
def _get_repr_value(self):
- raw = u(self)
+ raw = text_type(self)
if len(raw) > 7:
- raw = raw[:6] + u'...'
- return re.sub('\s+', ' ', raw)
+ raw = raw[:6] + '...'
+ return re.sub(r'\s+', ' ', raw)
def flatten(self):
"""Resolve subgroups."""
@@ -81,32 +71,23 @@ class Token(object):
if not type_matched or values is None:
return type_matched
- if regex:
- if isinstance(values, string_types):
- values = {values}
+ if isinstance(values, string_types):
+ values = (values,)
- if self.ttype is T.Keyword:
- values = set(re.compile(v, re.IGNORECASE) for v in values)
- else:
- values = set(re.compile(v) for v in values)
+ if regex:
+ # TODO: Add test for regex with is_keyboard = false
+ flag = re.IGNORECASE if self.is_keyword else 0
+ values = (re.compile(v, flag) for v in values)
for pattern in values:
- if pattern.search(self.value):
+ if pattern.search(self.normalized):
return True
return False
- if isinstance(values, string_types):
- if self.is_keyword:
- return values.upper() == self.normalized
- return values == self.value
-
if self.is_keyword:
- for v in values:
- if v.upper() == self.normalized:
- return True
- return False
+ values = (v.upper() for v in values)
- return self.value in values
+ return self.normalized in values
def is_group(self):
"""Returns ``True`` if this object has children."""
@@ -114,7 +95,7 @@ class Token(object):
def is_whitespace(self):
"""Return ``True`` if this token is a whitespace token."""
- return self.ttype and self.ttype in T.Whitespace
+ return self.ttype in T.Whitespace
def within(self, group_cls):
"""Returns ``True`` if this token is within *group_cls*.
@@ -143,6 +124,7 @@ class Token(object):
return False
+@unicode_compatible
class TokenList(Token):
"""A group of tokens.
@@ -150,45 +132,35 @@ class TokenList(Token):
list of child-tokens.
"""
- __slots__ = ('value', 'ttype', 'tokens')
+ __slots__ = 'tokens'
def __init__(self, tokens=None):
- if tokens is None:
- tokens = []
- self.tokens = tokens
- super(TokenList, self).__init__(None, self.__str__())
-
- def __unicode__(self):
- return self._to_string()
+ self.tokens = tokens or []
+ super(TokenList, self).__init__(None, text_type(self))
def __str__(self):
- str_ = self._to_string()
- if sys.version_info[0] <= 2:
- str_ = str_.encode('utf-8')
- return str_
-
- def _to_string(self):
- if sys.version_info[0] == 3:
- return ''.join(x.value for x in self.flatten())
- else:
- return ''.join(u(x) for x in self.flatten())
+ return ''.join(token.value for token in self.flatten())
+
+ def __iter__(self):
+ return iter(self.tokens)
+
+ def __getitem__(self, item):
+ return self.tokens[item]
def _get_repr_name(self):
- return self.__class__.__name__
+ return type(self).__name__
- def _pprint_tree(self, max_depth=None, depth=0):
+ def _pprint_tree(self, max_depth=None, depth=0, f=None):
"""Pretty-print the object tree."""
- indent = ' ' * (depth * 2)
+ ind = ' ' * (depth * 2)
for idx, token in enumerate(self.tokens):
- if token.is_group():
- pre = ' +-'
- else:
- pre = ' | '
- print('%s%s%d %s \'%s\'' % (indent, pre, idx,
- token._get_repr_name(),
- token._get_repr_value()))
- if (token.is_group() and (max_depth is None or depth < max_depth)):
- token._pprint_tree(max_depth, depth + 1)
+ pre = ' +-' if token.is_group() else ' | '
+ cls = token._get_repr_name()
+ value = token._get_repr_value()
+ print("{ind}{pre}{idx} {cls} '{value}'".format(**locals()), file=f)
+
+ if token.is_group() and (max_depth is None or depth < max_depth):
+ token._pprint_tree(max_depth, depth + 1, f)
def get_token_at_offset(self, offset):
"""Returns the token that is on position offset."""
@@ -205,26 +177,19 @@ class TokenList(Token):
This method is recursively called for all child tokens.
"""
for token in self.tokens:
- if isinstance(token, TokenList):
+ if token.is_group():
for item in token.flatten():
yield item
else:
yield token
- # def __iter__(self):
- # return self
- #
- # def next(self):
- # for token in self.tokens:
- # yield token
-
def is_group(self):
return True
def get_sublists(self):
- for x in self.tokens:
- if isinstance(x, TokenList):
- yield x
+ for token in self.tokens:
+ if token.is_group():
+ yield token
@property
def _groupable_tokens(self):
@@ -242,7 +207,7 @@ class TokenList(Token):
funcs = (funcs,)
if reverse:
- iterable = iter(reversed(self.tokens[end:start - 1]))
+ iterable = reversed(self.tokens[end:start - 1])
else:
iterable = self.tokens[start:end]
@@ -251,7 +216,7 @@ class TokenList(Token):
if func(token):
return token
- def token_first(self, ignore_whitespace=True, ignore_comments=False):
+ def token_first(self, skip_ws=True, skip_cm=False):
"""Returns the first child token.
If *ignore_whitespace* is ``True`` (the default), whitespace
@@ -260,35 +225,15 @@ class TokenList(Token):
if *ignore_comments* is ``True`` (default: ``False``), comments are
ignored too.
"""
- funcs = lambda tk: not ((ignore_whitespace and tk.is_whitespace()) or
- (ignore_comments and imt(tk, i=Comment)))
+ # this on is inconsistent, using Comment instead of T.Comment...
+ funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or
+ (skip_cm and imt(tk, i=Comment)))
return self._token_matching(funcs)
def token_next_by(self, i=None, m=None, t=None, idx=0, end=None):
funcs = lambda tk: imt(tk, i, m, t)
return self._token_matching(funcs, idx, end)
- def token_next_by_instance(self, idx, clss, end=None):
- """Returns the next token matching a class.
-
- *idx* is where to start searching in the list of child tokens.
- *clss* is a list of classes the token should be an instance of.
-
- If no matching token can be found ``None`` is returned.
- """
- funcs = lambda tk: imt(tk, i=clss)
- return self._token_matching(funcs, idx, end)
-
- def token_next_by_type(self, idx, ttypes):
- """Returns next matching token by it's token type."""
- funcs = lambda tk: imt(tk, t=ttypes)
- return self._token_matching(funcs, idx)
-
- def token_next_match(self, idx, ttype, value, regex=False):
- """Returns next token where it's ``match`` method returns ``True``."""
- funcs = lambda tk: imt(tk, m=(ttype, value, regex))
- return self._token_matching(funcs, idx)
-
def token_not_matching(self, idx, funcs):
funcs = (funcs,) if not isinstance(funcs, (list, tuple)) else funcs
funcs = [lambda tk: not func(tk) for func in funcs]
@@ -297,7 +242,7 @@ class TokenList(Token):
def token_matching(self, idx, funcs):
return self._token_matching(funcs, idx)
- def token_prev(self, idx, skip_ws=True):
+ def token_prev(self, idx, skip_ws=True, skip_cm=False):
"""Returns the previous token relative to *idx*.
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
@@ -305,10 +250,11 @@ class TokenList(Token):
"""
if isinstance(idx, int):
idx += 1 # alot of code usage current pre-compensates for this
- funcs = lambda tk: not (tk.is_whitespace() and skip_ws)
+ funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or
+ (skip_cm and imt(tk, t=T.Comment)))
return self._token_matching(funcs, idx, reverse=True)
- def token_next(self, idx, skip_ws=True):
+ def token_next(self, idx, skip_ws=True, skip_cm=False):
"""Returns the next token relative to *idx*.
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
@@ -316,7 +262,8 @@ class TokenList(Token):
"""
if isinstance(idx, int):
idx += 1 # alot of code usage current pre-compensates for this
- funcs = lambda tk: not (tk.is_whitespace() and skip_ws)
+ funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or
+ (skip_cm and imt(tk, t=T.Comment)))
return self._token_matching(funcs, idx)
def token_index(self, token, start=0):
@@ -334,9 +281,9 @@ class TokenList(Token):
end_idx = include_end + self.token_index(end)
return self.tokens[start_idx:end_idx]
- def group_tokens(self, grp_cls, tokens, ignore_ws=False, extend=False):
+ def group_tokens(self, grp_cls, tokens, skip_ws=False, extend=False):
"""Replace tokens by an instance of *grp_cls*."""
- if ignore_ws:
+ if skip_ws:
while tokens and tokens[-1].is_whitespace():
tokens = tokens[:-1]
@@ -352,7 +299,7 @@ class TokenList(Token):
left.parent = self
tokens = tokens[1:]
left.tokens.extend(tokens)
- left.value = left.__str__()
+ left.value = str(left)
else:
left = grp_cls(tokens)
@@ -393,8 +340,6 @@ class TokenList(Token):
if len(self.tokens) > 2 and self.token_next_by(t=T.Whitespace):
return self._get_first_name(reverse=True)
- return None
-
def get_name(self):
"""Returns the name of this identifier.
@@ -402,32 +347,22 @@ class TokenList(Token):
be considered as the name under which the object corresponding to
this identifier is known within the current statement.
"""
- alias = self.get_alias()
- if alias is not None:
- return alias
- return self.get_real_name()
+ return self.get_alias() or self.get_real_name()
def get_real_name(self):
"""Returns the real name (object name) of this identifier."""
# a.b
- dot = self.token_next_match(0, T.Punctuation, '.')
- if dot is not None:
- return self._get_first_name(self.token_index(dot))
-
- return self._get_first_name()
+ dot = self.token_next_by(m=(T.Punctuation, '.'))
+ return self._get_first_name(dot)
def get_parent_name(self):
"""Return name of the parent object if any.
A parent object is identified by the first occuring dot.
"""
- dot = self.token_next_match(0, T.Punctuation, '.')
- if dot is None:
- return None
- prev_ = self.token_prev(self.token_index(dot))
- if prev_ is None: # something must be verry wrong here..
- return None
- return remove_quotes(prev_.value)
+ dot = self.token_next_by(m=(T.Punctuation, '.'))
+ prev_ = self.token_prev(dot)
+ return remove_quotes(prev_.value) if prev_ is not None else None
def _get_first_name(self, idx=None, reverse=False, keywords=False):
"""Returns the name of the first token with a name"""
@@ -442,19 +377,16 @@ class TokenList(Token):
if keywords:
types.append(T.Keyword)
- for tok in tokens:
- if tok.ttype in types:
- return remove_quotes(tok.value)
- elif isinstance(tok, Identifier) or isinstance(tok, Function):
- return tok.get_name()
- return None
+ for token in tokens:
+ if token.ttype in types:
+ return remove_quotes(token.value)
+ elif isinstance(token, (Identifier, Function)):
+ return token.get_name()
class Statement(TokenList):
"""Represents a SQL statement."""
- __slots__ = ('value', 'ttype', 'tokens')
-
def get_type(self):
"""Returns the type of a statement.
@@ -465,7 +397,7 @@ class Statement(TokenList):
Whitespaces and comments at the beginning of the statement
are ignored.
"""
- first_token = self.token_first(ignore_comments=True)
+ first_token = self.token_first(skip_cm=True)
if first_token is None:
# An "empty" statement that either has not tokens at all
# or only whitespace tokens.
@@ -478,16 +410,14 @@ class Statement(TokenList):
# The WITH keyword should be followed by either an Identifier or
# an IdentifierList containing the CTE definitions; the actual
# DML keyword (e.g. SELECT, INSERT) will follow next.
- idents = self.token_next(
- self.token_index(first_token), skip_ws=True)
- if isinstance(idents, (Identifier, IdentifierList)):
- dml_keyword = self.token_next(
- self.token_index(idents), skip_ws=True)
+ token = self.token_next(first_token, skip_ws=True)
+ if isinstance(token, (Identifier, IdentifierList)):
+ dml_keyword = self.token_next(token, skip_ws=True)
+
if dml_keyword.ttype == T.Keyword.DML:
return dml_keyword.normalized
- # Hmm, probably invalid syntax, so return unknown.
- return 'UNKNOWN'
+ # Hmm, probably invalid syntax, so return unknown.
return 'UNKNOWN'
@@ -499,33 +429,33 @@ class Identifier(TokenList):
def is_wildcard(self):
"""Return ``True`` if this identifier contains a wildcard."""
- token = self.token_next_by_type(0, T.Wildcard)
+ token = self.token_next_by(t=T.Wildcard)
return token is not None
def get_typecast(self):
"""Returns the typecast or ``None`` of this object as a string."""
- marker = self.token_next_match(0, T.Punctuation, '::')
+ marker = self.token_next_by(m=(T.Punctuation, '::'))
if marker is None:
return None
- next_ = self.token_next(self.token_index(marker), False)
+ next_ = self.token_next(marker, False)
if next_ is None:
return None
- return u(next_)
+ return next_.value
def get_ordering(self):
"""Returns the ordering or ``None`` as uppercase string."""
- ordering = self.token_next_by_type(0, T.Keyword.Order)
+ ordering = self.token_next_by(t=T.Keyword.Order)
if ordering is None:
return None
- return ordering.value.upper()
+ return ordering.normalized
def get_array_indices(self):
"""Returns an iterator of index token lists"""
- for tok in self.tokens:
- if isinstance(tok, SquareBrackets):
+ for token in self.tokens:
+ if isinstance(token, SquareBrackets):
# Use [1:-1] index to discard the square brackets
- yield tok.tokens[1:-1]
+ yield token.tokens[1:-1]
class IdentifierList(TokenList):
@@ -536,15 +466,15 @@ class IdentifierList(TokenList):
Whitespaces and punctuations are not included in this generator.
"""
- for x in self.tokens:
- if not x.is_whitespace() and not x.match(T.Punctuation, ','):
- yield x
+ for token in self.tokens:
+ if not (token.is_whitespace() or token.match(T.Punctuation, ',')):
+ yield token
class Parenthesis(TokenList):
"""Tokens between parenthesis."""
- M_OPEN = (T.Punctuation, '(')
- M_CLOSE = (T.Punctuation, ')')
+ M_OPEN = T.Punctuation, '('
+ M_CLOSE = T.Punctuation, ')'
@property
def _groupable_tokens(self):
@@ -553,8 +483,8 @@ class Parenthesis(TokenList):
class SquareBrackets(TokenList):
"""Tokens between square brackets"""
- M_OPEN = (T.Punctuation, '[')
- M_CLOSE = (T.Punctuation, ']')
+ M_OPEN = T.Punctuation, '['
+ M_CLOSE = T.Punctuation, ']'
@property
def _groupable_tokens(self):
@@ -567,14 +497,14 @@ class Assignment(TokenList):
class If(TokenList):
"""An 'if' clause with possible 'else if' or 'else' parts."""
- M_OPEN = (T.Keyword, 'IF')
- M_CLOSE = (T.Keyword, 'END IF')
+ M_OPEN = T.Keyword, 'IF'
+ M_CLOSE = T.Keyword, 'END IF'
class For(TokenList):
"""A 'FOR' loop."""
- M_OPEN = (T.Keyword, ('FOR', 'FOREACH'))
- M_CLOSE = (T.Keyword, 'END LOOP')
+ M_OPEN = T.Keyword, ('FOR', 'FOREACH')
+ M_CLOSE = T.Keyword, 'END LOOP'
class Comparison(TokenList):
@@ -598,15 +528,15 @@ class Comment(TokenList):
class Where(TokenList):
"""A WHERE clause."""
- M_OPEN = (T.Keyword, 'WHERE')
- M_CLOSE = (T.Keyword,
- ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT', 'HAVING'))
+ M_OPEN = T.Keyword, 'WHERE'
+ M_CLOSE = T.Keyword, ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT',
+ 'HAVING')
class Case(TokenList):
"""A CASE statement with one or more WHEN and possibly an ELSE part."""
- M_OPEN = (T.Keyword, 'CASE')
- M_CLOSE = (T.Keyword, 'END')
+ M_OPEN = T.Keyword, 'CASE'
+ M_CLOSE = T.Keyword, 'END'
def get_cases(self):
"""Returns a list of 2-tuples (condition, value).
@@ -659,15 +589,15 @@ class Function(TokenList):
def get_parameters(self):
"""Return a list of parameters."""
parenthesis = self.tokens[-1]
- for t in parenthesis.tokens:
- if imt(t, i=IdentifierList):
- return t.get_identifiers()
- elif imt(t, i=(Function, Identifier), t=T.Literal):
- return [t, ]
+ for token in parenthesis.tokens:
+ if imt(token, i=IdentifierList):
+ return token.get_identifiers()
+ elif imt(token, i=(Function, Identifier), t=T.Literal):
+ return [token, ]
return []
class Begin(TokenList):
"""A BEGIN/END block."""
- M_OPEN = (T.Keyword, 'BEGIN')
- M_CLOSE = (T.Keyword, 'END')
+ M_OPEN = T.Keyword, 'BEGIN'
+ M_CLOSE = T.Keyword, 'END'