diff options
Diffstat (limited to 'sqlparse')
-rw-r--r-- | sqlparse/__init__.py | 8 | ||||
-rw-r--r-- | sqlparse/compat.py | 39 | ||||
-rw-r--r-- | sqlparse/engine/grouping.py | 17 | ||||
-rw-r--r-- | sqlparse/filters.py | 42 | ||||
-rw-r--r-- | sqlparse/lexer.py | 51 | ||||
-rw-r--r-- | sqlparse/sql.py | 31 | ||||
-rw-r--r-- | sqlparse/tokens.py | 9 | ||||
-rw-r--r-- | sqlparse/utils.py | 75 |
8 files changed, 121 insertions, 151 deletions
diff --git a/sqlparse/__init__.py b/sqlparse/__init__.py index 83bb684..d943956 100644 --- a/sqlparse/__init__.py +++ b/sqlparse/__init__.py @@ -8,15 +8,12 @@ __version__ = '0.1.17-dev' - # Setup namespace +from sqlparse import compat from sqlparse import engine from sqlparse import filters from sqlparse import formatter -# Deprecated in 0.1.5. Will be removed in 0.2.0 -from sqlparse.exceptions import SQLParseError - def parse(sql, encoding=None): """Parse sql and return a list of statements. @@ -67,7 +64,8 @@ def split(sql, encoding=None): """ stack = engine.FilterStack() stack.split_statements = True - return [unicode(stmt).strip() for stmt in stack.run(sql, encoding)] + return [compat.text_type(stmt).strip() + for stmt in stack.run(sql, encoding)] from sqlparse.engine.filter import StatementFilter diff --git a/sqlparse/compat.py b/sqlparse/compat.py new file mode 100644 index 0000000..1849c13 --- /dev/null +++ b/sqlparse/compat.py @@ -0,0 +1,39 @@ +"""Python 2/3 compatibility. + +This module only exists to avoid a dependency on six +for very trivial stuff. We only need to take care regarding +string types and buffers. +""" + +import sys + +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 + +if PY3: + text_type = str + string_types = (str,) + from io import StringIO + + def u(s): + return s + +elif PY2: + text_type = unicode + string_types = (basestring,) + from StringIO import StringIO # flake8: noqa + + def u(s): + return unicode(s, 'unicode_escape') + + +# Directly copied from six: +def with_metaclass(meta, *bases): + """Create a base class with a metaclass.""" + # This requires a bit of explanation: the basic idea is to make a dummy + # metaclass for one level of class instantiation that replaces itself with + # the actual metaclass. + class metaclass(meta): + def __new__(cls, name, this_bases, d): + return meta(name, bases, d) + return type.__new__(metaclass, 'temporary_class', (), {}) diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index a317044..12ae385 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -5,11 +5,6 @@ import itertools from sqlparse import sql from sqlparse import tokens as T -try: - next -except NameError: # Python < 2.6 - next = lambda i: i.next() - def _group_left_right(tlist, ttype, value, cls, check_right=lambda t: True, @@ -116,7 +111,7 @@ def group_as(tlist): def _right_valid(token): # Currently limited to DML/DDL. Maybe additional more non SQL reserved # keywords should appear here (see issue8). - return not token.ttype in (T.DML, T.DDL) + return token.ttype not in (T.DML, T.DDL) def _left_valid(token): if token.ttype is T.Keyword and token.value in ('NULL',): @@ -216,9 +211,10 @@ def group_identifier(tlist): if identifier_tokens and identifier_tokens[-1].ttype is T.Whitespace: identifier_tokens = identifier_tokens[:-1] if not (len(identifier_tokens) == 1 - and (isinstance(identifier_tokens[0], (sql.Function, sql.Parenthesis)) - or identifier_tokens[0].ttype in (T.Literal.Number.Integer, - T.Literal.Number.Float))): + and (isinstance(identifier_tokens[0], (sql.Function, + sql.Parenthesis)) + or identifier_tokens[0].ttype in ( + T.Literal.Number.Integer, T.Literal.Number.Float))): group = tlist.group_tokens(sql.Identifier, identifier_tokens) idx = tlist.token_index(group) + 1 else: @@ -451,6 +447,5 @@ def group(tlist): group_if, group_for, group_foreach, - group_begin, - ]: + group_begin]: func(tlist) diff --git a/sqlparse/filters.py b/sqlparse/filters.py index 676344f..b187907 100644 --- a/sqlparse/filters.py +++ b/sqlparse/filters.py @@ -4,9 +4,8 @@ import re from os.path import abspath, join -from sqlparse import sql, tokens as T +from sqlparse import compat, sql, tokens as T from sqlparse.engine import FilterStack -from sqlparse.lexer import tokenize from sqlparse.pipeline import Pipeline from sqlparse.tokens import (Comment, Comparison, Keyword, Name, Punctuation, String, Whitespace) @@ -25,7 +24,7 @@ class _CaseFilter: if case is None: case = 'upper' assert case in ['lower', 'upper', 'capitalize'] - self.convert = getattr(unicode, case) + self.convert = getattr(compat.text_type, case) def process(self, stack, stream): for ttype, value in stream: @@ -52,20 +51,20 @@ class TruncateStringFilter: def __init__(self, width, char): self.width = max(width, 1) - self.char = unicode(char) + self.char = compat.text_type(char) def process(self, stack, stream): for ttype, value in stream: if ttype is T.Literal.String.Single: if value[:2] == '\'\'': inner = value[2:-2] - quote = u'\'\'' + quote = compat.text_type('\'\'') else: inner = value[1:-1] - quote = u'\'' + quote = compat.text_type('\'') if len(inner) > self.width: - value = u''.join((quote, inner[:self.width], self.char, - quote)) + value = compat.text_type('').join( + (quote, inner[:self.width], self.char, quote)) yield ttype, value @@ -143,7 +142,6 @@ class IncludeStatement: # Found file path to include if token_type in String.Symbol: -# if token_type in tokens.String.Symbol: # Get path of file to include path = join(self.dirpath, value[1:-1]) @@ -154,13 +152,14 @@ class IncludeStatement: f.close() # There was a problem loading the include file - except IOError, err: + except IOError as err: # Raise the exception to the interpreter if self.raiseexceptions: raise # Put the exception as a comment on the SQL code - yield Comment, u'-- IOError: %s\n' % err + yield Comment, compat.text_type( + '-- IOError: %s\n' % err) else: # Create new FilterStack to parse readed file @@ -171,13 +170,14 @@ class IncludeStatement: self.raiseexceptions) # Max recursion limit reached - except ValueError, err: + except ValueError as err: # Raise the exception to the interpreter if self.raiseexceptions: raise # Put the exception as a comment on the SQL code - yield Comment, u'-- ValueError: %s\n' % err + yield Comment, compat.text_type( + '-- ValueError: %s\n' % err) stack = FilterStack() stack.preprocess.append(filtr) @@ -300,7 +300,7 @@ class ReindentFilter: raise StopIteration def _get_offset(self, token): - raw = ''.join(map(unicode, self._flatten_up_to_token(token))) + raw = ''.join(map(compat.text_type, self._flatten_up_to_token(token))) line = raw.splitlines()[-1] # Now take current offset into account and return relative offset. full_offset = len(line) - len(self.char * (self.width * self.indent)) @@ -340,7 +340,7 @@ class ReindentFilter: if prev and prev.is_whitespace() and prev not in added: tlist.tokens.pop(tlist.token_index(prev)) offset += 1 - uprev = unicode(prev) + uprev = compat.text_type(prev) if (prev and (uprev.endswith('\n') or uprev.endswith('\r'))): nl = tlist.token_next(token) else: @@ -462,7 +462,7 @@ class ReindentFilter: self._process(stmt) if isinstance(stmt, sql.Statement): if self._last_stmt is not None: - if unicode(self._last_stmt).endswith('\n'): + if compat.text_type(self._last_stmt).endswith('\n'): nl = '\n' else: nl = '\n\n' @@ -491,10 +491,10 @@ class RightMarginFilter: else: self.line = token.value.splitlines()[-1] elif (token.is_group() - and not token.__class__ in self.keep_together): + and token.__class__ not in self.keep_together): token.tokens = self._process(stack, token, token.tokens) else: - val = unicode(token) + val = compat.text_type(token) if len(self.line) + len(val) > self.width: match = re.search('^ +', self.line) if match is not None: @@ -568,7 +568,7 @@ class ColumnsSelect: class SerializerUnicode: def process(self, stack, stmt): - raw = unicode(stmt) + raw = compat.text_type(stmt) lines = split_unquoted_newlines(raw) res = '\n'.join(line.rstrip() for line in lines) return res @@ -578,7 +578,7 @@ def Tokens2Unicode(stream): result = "" for _, value in stream: - result += unicode(value) + result += compat.text_type(value) return result @@ -600,7 +600,7 @@ class OutputFilter: else: varname = self.varname - has_nl = len(unicode(stmt).strip().splitlines()) > 1 + has_nl = len(compat.text_type(stmt).strip().splitlines()) > 1 stmt.tokens = self._process(stmt.tokens, varname, has_nl) return stmt diff --git a/sqlparse/lexer.py b/sqlparse/lexer.py index fd29f5c..7ce6d36 100644 --- a/sqlparse/lexer.py +++ b/sqlparse/lexer.py @@ -15,31 +15,23 @@ import re import sys +from sqlparse import compat from sqlparse import tokens +from sqlparse.compat import StringIO from sqlparse.keywords import KEYWORDS, KEYWORDS_COMMON -from cStringIO import StringIO class include(str): pass -class combined(tuple): - """Indicates a state combined from multiple states.""" - - def __new__(cls, *args): - return tuple.__new__(cls, args) - - def __init__(self, *args): - # tuple.__init__ doesn't do anything - pass - - def is_keyword(value): test = value.upper() return KEYWORDS_COMMON.get(test, KEYWORDS.get(test, tokens.Name)), value +# TODO(andi): Can this be removed? If so, add_filter and Lexer.filters +# should be removed too. def apply_filters(stream, filters, lexer=None): """ Use this method to apply an iterable of filters to @@ -81,14 +73,14 @@ class LexerMeta(type): try: rex = re.compile(tdef[0], rflags).match - except Exception, err: + except Exception as err: raise ValueError(("uncompilable regex %r in state" " %r of %r: %s" % (tdef[0], state, cls, err))) assert type(tdef[1]) is tokens._TokenType or callable(tdef[1]), \ - ('token type must be simple type or callable, not %r' - % (tdef[1],)) + ('token type must be simple type or callable, not %r' + % (tdef[1],)) if len(tdef) == 2: new_state = None @@ -106,24 +98,12 @@ class LexerMeta(type): new_state = -int(tdef2[5:]) else: assert False, 'unknown new state %r' % tdef2 - elif isinstance(tdef2, combined): - # combine a new state from existing ones - new_state = '_tmp_%d' % cls._tmpname - cls._tmpname += 1 - itokens = [] - for istate in tdef2: - assert istate != state, \ - 'circular state ref %r' % istate - itokens.extend(cls._process_state(unprocessed, - processed, istate)) - processed[new_state] = itokens - new_state = (new_state,) elif isinstance(tdef2, tuple): # push more than one state for state in tdef2: assert (state in unprocessed or state in ('#pop', '#push')), \ - 'unknown new state ' + state + 'unknown new state ' + state new_state = tdef2 else: assert False, 'unknown new state def %r' % tdef2 @@ -134,7 +114,6 @@ class LexerMeta(type): cls._all_tokens = {} cls._tmpname = 0 processed = cls._all_tokens[cls.__name__] = {} - #tokendefs = tokendefs or cls.tokens[name] for state in cls.tokens.keys(): cls._process_state(cls.tokens, processed, state) return processed @@ -152,9 +131,7 @@ class LexerMeta(type): return type.__call__(cls, *args, **kwds) -class Lexer(object): - - __metaclass__ = LexerMeta +class Lexer(compat.with_metaclass(LexerMeta)): encoding = 'utf-8' stripall = False @@ -235,8 +212,8 @@ class Lexer(object): if self.encoding == 'guess': try: text = text.decode('utf-8') - if text.startswith(u'\ufeff'): - text = text[len(u'\ufeff'):] + if text.startswith(compat.text_type('\ufeff')): + text = text[len(compat.text_type('\ufeff')):] except UnicodeDecodeError: text = text.decode('latin1') else: @@ -258,13 +235,13 @@ class Lexer(object): Also preprocess the text, i.e. expand tabs and strip it if wanted and applies registered filters. """ - if isinstance(text, basestring): + if isinstance(text, compat.string_types): if self.stripall: text = text.strip() elif self.stripnl: text = text.strip('\n') - if sys.version_info[0] < 3 and isinstance(text, unicode): + if compat.PY2 and isinstance(text, compat.text_type): text = StringIO(text.encode('utf-8')) self.encoding = 'utf-8' else: @@ -342,7 +319,7 @@ class Lexer(object): pos += 1 statestack = ['root'] statetokens = tokendefs['root'] - yield pos, tokens.Text, u'\n' + yield pos, tokens.Text, compat.text_type('\n') continue yield pos, tokens.Error, text[pos] pos += 1 diff --git a/sqlparse/sql.py b/sqlparse/sql.py index 5ecfbdc..8601537 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -5,6 +5,7 @@ import re import sys +from sqlparse import compat from sqlparse import tokens as T @@ -32,7 +33,7 @@ class Token(object): if sys.version_info[0] == 3: return self.value else: - return unicode(self).encode('utf-8') + return compat.text_type(self).encode('utf-8') def __repr__(self): short = self._get_repr_value() @@ -51,15 +52,15 @@ class Token(object): .. deprecated:: 0.1.5 Use ``unicode(token)`` (for Python 3: ``str(token)``) instead. """ - return unicode(self) + return compat.text_type(self) def _get_repr_name(self): return str(self.ttype).split('.')[-1] def _get_repr_value(self): - raw = unicode(self) + raw = compat.text_type(self) if len(raw) > 7: - raw = raw[:6] + u'...' + raw = raw[:6] + compat.text_type('...') return re.sub('\s+', ' ', raw) def flatten(self): @@ -83,7 +84,7 @@ class Token(object): return type_matched if regex: - if isinstance(values, basestring): + if isinstance(values, compat.string_types): values = set([values]) if self.ttype is T.Keyword: @@ -96,7 +97,7 @@ class Token(object): return True return False - if isinstance(values, basestring): + if isinstance(values, compat.string_types): if self.is_keyword: return values.upper() == self.normalized return values == self.value @@ -172,7 +173,7 @@ class TokenList(Token): if sys.version_info[0] == 3: return ''.join(x.value for x in self.flatten()) else: - return ''.join(unicode(x) for x in self.flatten()) + return ''.join(compat.text_type(x) for x in self.flatten()) def _get_repr_name(self): return self.__class__.__name__ @@ -185,9 +186,9 @@ class TokenList(Token): pre = ' +-' else: pre = ' | ' - print '%s%s%d %s \'%s\'' % (indent, pre, idx, + print('%s%s%d %s \'%s\'' % (indent, pre, idx, token._get_repr_name(), - token._get_repr_value()) + token._get_repr_value())) if (token.is_group() and (max_depth is None or depth < max_depth)): token._pprint_tree(max_depth, depth + 1) @@ -220,18 +221,10 @@ class TokenList(Token): else: yield token -# def __iter__(self): -# return self -# -# def next(self): -# for token in self.tokens: -# yield token - def is_group(self): return True def get_sublists(self): -# return [x for x in self.tokens if isinstance(x, TokenList)] for x in self.tokens: if isinstance(x, TokenList): yield x @@ -285,7 +278,7 @@ class TokenList(Token): if not isinstance(idx, int): idx = self.token_index(idx) - for n in xrange(idx, len(self.tokens)): + for n in range(idx, len(self.tokens)): token = self.tokens[n] if token.match(ttype, value, regex): return token @@ -510,7 +503,7 @@ class Identifier(TokenList): next_ = self.token_next(self.token_index(marker), False) if next_ is None: return None - return unicode(next_) + return compat.text_type(next_) def get_ordering(self): """Returns the ordering or ``None`` as uppercase string.""" diff --git a/sqlparse/tokens.py b/sqlparse/tokens.py index 01a9b89..53c31ce 100644 --- a/sqlparse/tokens.py +++ b/sqlparse/tokens.py @@ -13,15 +13,6 @@ class _TokenType(tuple): parent = None - def split(self): - buf = [] - node = self - while node is not None: - buf.append(node) - node = node.parent - buf.reverse() - return buf - def __contains__(self, val): return val is not None and (self is val or val[:len(self)] == self) diff --git a/sqlparse/utils.py b/sqlparse/utils.py index 3a49ac2..7595e9d 100644 --- a/sqlparse/utils.py +++ b/sqlparse/utils.py @@ -5,65 +5,42 @@ Created on 17/05/2012 ''' import re +from collections import OrderedDict -try: - from collections import OrderedDict -except ImportError: - OrderedDict = None +class Cache(OrderedDict): + """Cache with LRU algorithm using an OrderedDict as basis.""" + def __init__(self, maxsize=100): + OrderedDict.__init__(self) -if OrderedDict: - class Cache(OrderedDict): - """Cache with LRU algorithm using an OrderedDict as basis - """ - def __init__(self, maxsize=100): - OrderedDict.__init__(self) + self._maxsize = maxsize - self._maxsize = maxsize + def __getitem__(self, key, *args, **kwargs): + # Get the key and remove it from the cache, or raise KeyError + value = OrderedDict.__getitem__(self, key) + del self[key] - def __getitem__(self, key, *args, **kwargs): - # Get the key and remove it from the cache, or raise KeyError - value = OrderedDict.__getitem__(self, key) - del self[key] - - # Insert the (key, value) pair on the front of the cache - OrderedDict.__setitem__(self, key, value) - - # Return the value from the cache - return value - - def __setitem__(self, key, value, *args, **kwargs): - # Key was inserted before, remove it so we put it at front later - if key in self: - del self[key] + # Insert the (key, value) pair on the front of the cache + OrderedDict.__setitem__(self, key, value) - # Too much items on the cache, remove the least recent used - elif len(self) >= self._maxsize: - self.popitem(False) + # Return the value from the cache + return value - # Insert the (key, value) pair on the front of the cache - OrderedDict.__setitem__(self, key, value, *args, **kwargs) - -else: - class Cache(dict): - """Cache that reset when gets full - """ - def __init__(self, maxsize=100): - dict.__init__(self) - - self._maxsize = maxsize + def __setitem__(self, key, value, *args, **kwargs): + # Key was inserted before, remove it so we put it at front later + if key in self: + del self[key] - def __setitem__(self, key, value, *args, **kwargs): - # Reset the cache if we have too much cached entries and start over - if len(self) >= self._maxsize: - self.clear() + # Too much items on the cache, remove the least recent used + elif len(self) >= self._maxsize: + self.popitem(False) - # Insert the (key, value) pair on the front of the cache - dict.__setitem__(self, key, value, *args, **kwargs) + # Insert the (key, value) pair on the front of the cache + OrderedDict.__setitem__(self, key, value, *args, **kwargs) def memoize_generator(func): - """Memoize decorator for generators + """Memoize decorator for generators. Store `func` results in a cache according to their arguments as 'memoize' does but instead this works on decorators instead of regular functions. @@ -73,7 +50,6 @@ def memoize_generator(func): cache = Cache() def wrapped_func(*args, **kwargs): -# params = (args, kwargs) params = (args, tuple(sorted(kwargs.items()))) # Look if cached @@ -120,6 +96,7 @@ SPLIT_REGEX = re.compile(r""" LINE_MATCH = re.compile(r'(\r\n|\r|\n)') + def split_unquoted_newlines(text): """Split a string on all unquoted newlines. @@ -134,4 +111,4 @@ def split_unquoted_newlines(text): outputlines.append('') else: outputlines[-1] += line - return outputlines
\ No newline at end of file + return outputlines |