summaryrefslogtreecommitdiff
path: root/sqlparse
diff options
context:
space:
mode:
authorVik <vmuriart@users.noreply.github.com>2016-06-06 06:29:25 -0700
committerVik <vmuriart@users.noreply.github.com>2016-06-06 06:29:25 -0700
commitb9d81ac4fe49114f57dc33c0d635f99ff56e62f2 (patch)
tree88642eeb84d318511191a822fd781b44e1d63df1 /sqlparse
parentc6a5e7ac2a5ecc993f4e5292ab16e6df6b84f26c (diff)
parent5747015634a39191511de8db576f2cd0aa5eafc9 (diff)
downloadsqlparse-b9d81ac4fe49114f57dc33c0d635f99ff56e62f2.tar.gz
Merge pull request #251 from andialbrecht/filters_sql
Update Filters sql
Diffstat (limited to 'sqlparse')
-rw-r--r--sqlparse/__init__.py7
-rw-r--r--sqlparse/compat.py10
-rw-r--r--sqlparse/engine/__init__.py48
-rw-r--r--sqlparse/engine/filter.py100
-rw-r--r--sqlparse/engine/grouping.py5
-rw-r--r--sqlparse/filters.py355
-rw-r--r--sqlparse/functions.py44
-rw-r--r--sqlparse/pipeline.py31
-rw-r--r--sqlparse/sql.py292
-rw-r--r--sqlparse/utils.py71
10 files changed, 248 insertions, 715 deletions
diff --git a/sqlparse/__init__.py b/sqlparse/__init__.py
index 2943997..cb83a71 100644
--- a/sqlparse/__init__.py
+++ b/sqlparse/__init__.py
@@ -66,11 +66,4 @@ def split(sql, encoding=None):
:returns: A list of strings.
"""
stack = engine.FilterStack()
- stack.split_statements = True
return [u(stmt).strip() for stmt in stack.run(sql, encoding)]
-
-
-def split2(stream):
- from sqlparse.engine.filter import StatementFilter
- splitter = StatementFilter()
- return list(splitter.process(None, stream))
diff --git a/sqlparse/compat.py b/sqlparse/compat.py
index 0226a00..0defd86 100644
--- a/sqlparse/compat.py
+++ b/sqlparse/compat.py
@@ -25,6 +25,10 @@ if PY3:
return str(s)
+ def unicode_compatible(cls):
+ return cls
+
+
text_type = str
string_types = (str,)
from io import StringIO
@@ -39,6 +43,12 @@ elif PY2:
return unicode(s, encoding)
+ def unicode_compatible(cls):
+ cls.__unicode__ = cls.__str__
+ cls.__str__ = lambda x: x.__unicode__().encode('utf-8')
+ return cls
+
+
text_type = unicode
string_types = (basestring,)
from StringIO import StringIO
diff --git a/sqlparse/engine/__init__.py b/sqlparse/engine/__init__.py
index 1c2bf09..7f00c57 100644
--- a/sqlparse/engine/__init__.py
+++ b/sqlparse/engine/__init__.py
@@ -13,12 +13,10 @@ from sqlparse.engine.filter import StatementFilter
class FilterStack(object):
-
def __init__(self):
self.preprocess = []
self.stmtprocess = []
self.postprocess = []
- self.split_statements = False
self._grouping = False
def enable_grouping(self):
@@ -27,42 +25,20 @@ class FilterStack(object):
def run(self, sql, encoding=None):
stream = lexer.tokenize(sql, encoding)
# Process token stream
- if self.preprocess:
- for filter_ in self.preprocess:
- stream = filter_.process(self, stream)
-
- if (self.stmtprocess or self.postprocess or
- self.split_statements or self._grouping):
- splitter = StatementFilter()
- stream = splitter.process(self, stream)
-
- if self._grouping:
-
- def _group(stream):
- for stmt in stream:
- grouping.group(stmt)
- yield stmt
- stream = _group(stream)
+ for filter_ in self.preprocess:
+ stream = filter_.process(stream)
- if self.stmtprocess:
+ stream = StatementFilter().process(stream)
- def _run1(stream):
- ret = []
- for stmt in stream:
- for filter_ in self.stmtprocess:
- filter_.process(self, stmt)
- ret.append(stmt)
- return ret
- stream = _run1(stream)
+ # Output: Stream processed Statements
+ for stmt in stream:
+ if self._grouping:
+ stmt = grouping.group(stmt)
- if self.postprocess:
+ for filter_ in self.stmtprocess:
+ filter_.process(stmt)
- def _run2(stream):
- for stmt in stream:
- stmt.tokens = list(stmt.flatten())
- for filter_ in self.postprocess:
- stmt = filter_.process(self, stmt)
- yield stmt
- stream = _run2(stream)
+ for filter_ in self.postprocess:
+ stmt = filter_.process(stmt)
- return stream
+ yield stmt
diff --git a/sqlparse/engine/filter.py b/sqlparse/engine/filter.py
index 71020e7..ea2033a 100644
--- a/sqlparse/engine/filter.py
+++ b/sqlparse/engine/filter.py
@@ -5,113 +5,119 @@
# This module is part of python-sqlparse and is released under
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
-from sqlparse.sql import Statement, Token
-from sqlparse import tokens as T
+from sqlparse import sql, tokens as T
class StatementFilter(object):
- "Filter that split stream at individual statements"
+ """Filter that split stream at individual statements"""
def __init__(self):
- self._in_declare = False
- self._in_dbldollar = False
- self._is_create = False
- self._begin_depth = 0
+ self._reset()
def _reset(self):
- "Set the filter attributes to its default values"
+ """Set the filter attributes to its default values"""
self._in_declare = False
self._in_dbldollar = False
self._is_create = False
self._begin_depth = 0
+ self.consume_ws = False
+ self.tokens = []
+ self.level = 0
+
def _change_splitlevel(self, ttype, value):
- "Get the new split level (increase, decrease or remain equal)"
+ """Get the new split level (increase, decrease or remain equal)"""
# PostgreSQL
- if ttype == T.Name.Builtin \
- and value.startswith('$') and value.endswith('$'):
+ if ttype == T.Name.Builtin and value[0] == '$' and value[-1] == '$':
+
+ # 2nd dbldollar found. $quote$ completed
+ # decrease level
if self._in_dbldollar:
self._in_dbldollar = False
return -1
else:
self._in_dbldollar = True
return 1
+
+ # if inside $$ everything inside is defining function character.
+ # Nothing inside can create a new statement
elif self._in_dbldollar:
return 0
# ANSI
+ # if normal token return
+ # wouldn't parenthesis increase/decrease a level?
+ # no, inside a paranthesis can't start new statement
if ttype not in T.Keyword:
return 0
+ # Everything after here is ttype = T.Keyword
+ # Also to note, once entered an If statement you are done and basically
+ # returning
unified = value.upper()
+ # three keywords begin with CREATE, but only one of them is DDL
+ # DDL Create though can contain more words such as "or replace"
+ if ttype is T.Keyword.DDL and unified.startswith('CREATE'):
+ self._is_create = True
+ return 0
+
+ # can have nested declare inside of being...
if unified == 'DECLARE' and self._is_create and self._begin_depth == 0:
self._in_declare = True
return 1
if unified == 'BEGIN':
self._begin_depth += 1
- if self._in_declare or self._is_create:
+ if self._is_create:
# FIXME(andi): This makes no sense.
return 1
return 0
- if unified in ('END IF', 'END FOR', 'END WHILE'):
- return -1
-
+ # Should this respect a preceeding BEGIN?
+ # In CASE ... WHEN ... END this results in a split level -1.
+ # Would having multiple CASE WHEN END and a Assigment Operator
+ # cause the statement to cut off prematurely?
if unified == 'END':
- # Should this respect a preceeding BEGIN?
- # In CASE ... WHEN ... END this results in a split level -1.
self._begin_depth = max(0, self._begin_depth - 1)
return -1
- if ttype is T.Keyword.DDL and unified.startswith('CREATE'):
- self._is_create = True
- return 0
-
- if unified in ('IF', 'FOR', 'WHILE') \
- and self._is_create and self._begin_depth > 0:
+ if (unified in ('IF', 'FOR', 'WHILE') and
+ self._is_create and self._begin_depth > 0):
return 1
+ if unified in ('END IF', 'END FOR', 'END WHILE'):
+ return -1
+
# Default
return 0
- def process(self, stack, stream):
- "Process the stream"
- consume_ws = False
- splitlevel = 0
- stmt = None
- stmt_tokens = []
+ def process(self, stream):
+ """Process the stream"""
+ EOS_TTYPE = T.Whitespace, T.Comment.Single
# Run over all stream tokens
for ttype, value in stream:
# Yield token if we finished a statement and there's no whitespaces
- if consume_ws and ttype not in (T.Whitespace, T.Comment.Single):
- stmt.tokens = stmt_tokens
- yield stmt
+ # It will count newline token as a non whitespace. In this context
+ # whitespace ignores newlines.
+ # why don't multi line comments also count?
+ if self.consume_ws and ttype not in EOS_TTYPE:
+ yield sql.Statement(self.tokens)
# Reset filter and prepare to process next statement
self._reset()
- consume_ws = False
- splitlevel = 0
- stmt = None
-
- # Create a new statement if we are not currently in one of them
- if stmt is None:
- stmt = Statement()
- stmt_tokens = []
# Change current split level (increase, decrease or remain equal)
- splitlevel += self._change_splitlevel(ttype, value)
+ self.level += self._change_splitlevel(ttype, value)
# Append the token to the current statement
- stmt_tokens.append(Token(ttype, value))
+ self.tokens.append(sql.Token(ttype, value))
# Check if we get the end of a statement
- if splitlevel <= 0 and ttype is T.Punctuation and value == ';':
- consume_ws = True
+ if self.level <= 0 and ttype is T.Punctuation and value == ';':
+ self.consume_ws = True
# Yield pending statement (if any)
- if stmt is not None:
- stmt.tokens = stmt_tokens
- yield stmt
+ if self.tokens:
+ yield sql.Statement(self.tokens)
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index 0ac1cb3..c680995 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -266,7 +266,7 @@ def align_comments(tlist):
token = tlist.token_next_by(i=sql.Comment, idx=token)
-def group(tlist):
+def group(stmt):
for func in [
group_comments,
group_brackets,
@@ -291,4 +291,5 @@ def group(tlist):
group_foreach,
group_begin,
]:
- func(tlist)
+ func(stmt)
+ return stmt
diff --git a/sqlparse/filters.py b/sqlparse/filters.py
index 1cb2f16..95ac74c 100644
--- a/sqlparse/filters.py
+++ b/sqlparse/filters.py
@@ -7,15 +7,8 @@
import re
-from os.path import abspath, join
-
from sqlparse import sql, tokens as T
-from sqlparse.compat import u, text_type
-from sqlparse.engine import FilterStack
-from sqlparse.pipeline import Pipeline
-from sqlparse.tokens import (Comment, Comparison, Keyword, Name, Punctuation,
- String, Whitespace)
-from sqlparse.utils import memoize_generator
+from sqlparse.compat import text_type
from sqlparse.utils import split_unquoted_newlines
@@ -23,16 +16,13 @@ from sqlparse.utils import split_unquoted_newlines
# token process
class _CaseFilter(object):
-
ttype = None
def __init__(self, case=None):
- if case is None:
- case = 'upper'
- assert case in ['lower', 'upper', 'capitalize']
+ case = case or 'upper'
self.convert = getattr(text_type, case)
- def process(self, stack, stream):
+ def process(self, stream):
for ttype, value in stream:
if ttype in self.ttype:
value = self.convert(value)
@@ -44,165 +34,42 @@ class KeywordCaseFilter(_CaseFilter):
class IdentifierCaseFilter(_CaseFilter):
- ttype = (T.Name, T.String.Symbol)
+ ttype = T.Name, T.String.Symbol
- def process(self, stack, stream):
+ def process(self, stream):
for ttype, value in stream:
- if ttype in self.ttype and not value.strip()[0] == '"':
+ if ttype in self.ttype and value.strip()[0] != '"':
value = self.convert(value)
yield ttype, value
class TruncateStringFilter(object):
-
def __init__(self, width, char):
- self.width = max(width, 1)
- self.char = u(char)
+ self.width = width
+ self.char = char
- def process(self, stack, stream):
+ def process(self, stream):
for ttype, value in stream:
- if ttype is T.Literal.String.Single:
- if value[:2] == '\'\'':
- inner = value[2:-2]
- quote = u'\'\''
- else:
- inner = value[1:-1]
- quote = u'\''
- if len(inner) > self.width:
- value = u''.join((quote, inner[:self.width], self.char,
- quote))
- yield ttype, value
-
-
-class GetComments(object):
- """Get the comments from a stack"""
- def process(self, stack, stream):
- for token_type, value in stream:
- if token_type in Comment:
- yield token_type, value
-
-
-class StripComments(object):
- """Strip the comments from a stack"""
- def process(self, stack, stream):
- for token_type, value in stream:
- if token_type not in Comment:
- yield token_type, value
-
-
-def StripWhitespace(stream):
- "Strip the useless whitespaces from a stream leaving only the minimal ones"
- last_type = None
- has_space = False
- ignore_group = frozenset((Comparison, Punctuation))
-
- for token_type, value in stream:
- # We got a previous token (not empty first ones)
- if last_type:
- if token_type in Whitespace:
- has_space = True
- continue
-
- # Ignore first empty spaces and dot-commas
- elif token_type in (Whitespace, Whitespace.Newline, ignore_group):
- continue
-
- # Yield a whitespace if it can't be ignored
- if has_space:
- if not ignore_group.intersection((last_type, token_type)):
- yield Whitespace, ' '
- has_space = False
-
- # Yield the token and set its type for checking with the next one
- yield token_type, value
- last_type = token_type
-
-
-class IncludeStatement(object):
- """Filter that enable a INCLUDE statement"""
-
- def __init__(self, dirpath=".", maxrecursive=10, raiseexceptions=False):
- if maxrecursive <= 0:
- raise ValueError('Max recursion limit reached')
-
- self.dirpath = abspath(dirpath)
- self.maxRecursive = maxrecursive
- self.raiseexceptions = raiseexceptions
-
- self.detected = False
-
- @memoize_generator
- def process(self, stack, stream):
- # Run over all tokens in the stream
- for token_type, value in stream:
- # INCLUDE statement found, set detected mode
- if token_type in Name and value.upper() == 'INCLUDE':
- self.detected = True
+ if ttype != T.Literal.String.Single:
+ yield ttype, value
continue
- # INCLUDE statement was found, parse it
- elif self.detected:
- # Omit whitespaces
- if token_type in Whitespace:
- continue
-
- # Found file path to include
- if token_type in String.Symbol:
- # Get path of file to include
- path = join(self.dirpath, value[1:-1])
-
- try:
- f = open(path)
- raw_sql = f.read()
- f.close()
-
- # There was a problem loading the include file
- except IOError as err:
- # Raise the exception to the interpreter
- if self.raiseexceptions:
- raise
-
- # Put the exception as a comment on the SQL code
- yield Comment, u'-- IOError: %s\n' % err
-
- else:
- # Create new FilterStack to parse readed file
- # and add all its tokens to the main stack recursively
- try:
- filtr = IncludeStatement(self.dirpath,
- self.maxRecursive - 1,
- self.raiseexceptions)
-
- # Max recursion limit reached
- except ValueError as err:
- # Raise the exception to the interpreter
- if self.raiseexceptions:
- raise
-
- # Put the exception as a comment on the SQL code
- yield Comment, u'-- ValueError: %s\n' % err
-
- stack = FilterStack()
- stack.preprocess.append(filtr)
-
- for tv in stack.run(raw_sql):
- yield tv
-
- # Set normal mode
- self.detected = False
-
- # Don't include any token while in detected mode
- continue
+ if value[:2] == "''":
+ inner = value[2:-2]
+ quote = "''"
+ else:
+ inner = value[1:-1]
+ quote = "'"
- # Normal token
- yield token_type, value
+ if len(inner) > self.width:
+ value = ''.join((quote, inner[:self.width], self.char, quote))
+ yield ttype, value
# ----------------------
# statement process
class StripCommentsFilter(object):
-
def _get_next_comment(self, tlist):
# TODO(andi) Comment types should be unified, see related issue38
token = tlist.token_next_by(i=sql.Comment, t=T.Comment)
@@ -212,8 +79,8 @@ class StripCommentsFilter(object):
token = self._get_next_comment(tlist)
while token:
tidx = tlist.token_index(token)
- prev = tlist.token_prev(tidx, False)
- next_ = tlist.token_next(tidx, False)
+ prev = tlist.token_prev(tidx, skip_ws=False)
+ next_ = tlist.token_next(tidx, skip_ws=False)
# Replace by whitespace if prev and next exist and if they're not
# whitespaces. This doesn't apply if prev or next is a paranthesis.
if (prev is not None and next_ is not None
@@ -225,13 +92,12 @@ class StripCommentsFilter(object):
tlist.tokens.pop(tidx)
token = self._get_next_comment(tlist)
- def process(self, stack, stmt):
- [self.process(stack, sgroup) for sgroup in stmt.get_sublists()]
+ def process(self, stmt):
+ [self.process(sgroup) for sgroup in stmt.get_sublists()]
self._process(stmt)
class StripWhitespaceFilter(object):
-
def _stripws(self, tlist):
func_name = '_stripws_%s' % tlist.__class__.__name__.lower()
func = getattr(self, func_name, self._stripws_default)
@@ -253,14 +119,10 @@ class StripWhitespaceFilter(object):
# Removes newlines before commas, see issue140
last_nl = None
for token in tlist.tokens[:]:
- if token.ttype is T.Punctuation \
- and token.value == ',' \
- and last_nl is not None:
+ if last_nl and token.ttype is T.Punctuation and token.value == ',':
tlist.tokens.remove(last_nl)
- if token.is_whitespace():
- last_nl = token
- else:
- last_nl = None
+
+ last_nl = token if token.is_whitespace() else None
return self._stripws_default(tlist)
def _stripws_parenthesis(self, tlist):
@@ -270,20 +132,14 @@ class StripWhitespaceFilter(object):
tlist.tokens.pop(-2)
self._stripws_default(tlist)
- def process(self, stack, stmt, depth=0):
- [self.process(stack, sgroup, depth + 1)
- for sgroup in stmt.get_sublists()]
+ def process(self, stmt, depth=0):
+ [self.process(sgroup, depth + 1) for sgroup in stmt.get_sublists()]
self._stripws(stmt)
- if (
- depth == 0
- and stmt.tokens
- and stmt.tokens[-1].is_whitespace()
- ):
+ if depth == 0 and stmt.tokens and stmt.tokens[-1].is_whitespace():
stmt.tokens.pop(-1)
class ReindentFilter(object):
-
def __init__(self, width=2, char=' ', line_width=None, wrap_after=0):
self.width = width
self.char = char
@@ -327,8 +183,7 @@ class ReindentFilter(object):
'SET', 'BETWEEN', 'EXCEPT', 'HAVING')
def _next_token(i):
- t = tlist.token_next_match(i, T.Keyword, split_words,
- regex=True)
+ t = tlist.token_next_by(m=(T.Keyword, split_words, True), idx=i)
if t and t.value.upper() == 'BETWEEN':
t = _next_token(tlist.token_index(t) + 1)
if t and t.value.upper() == 'AND':
@@ -339,13 +194,13 @@ class ReindentFilter(object):
token = _next_token(idx)
added = set()
while token:
- prev = tlist.token_prev(tlist.token_index(token), False)
+ prev = tlist.token_prev(token, skip_ws=False)
offset = 1
if prev and prev.is_whitespace() and prev not in added:
tlist.tokens.pop(tlist.token_index(prev))
offset += 1
- uprev = u(prev)
- if (prev and (uprev.endswith('\n') or uprev.endswith('\r'))):
+ uprev = text_type(prev)
+ if prev and (uprev.endswith('\n') or uprev.endswith('\r')):
nl = tlist.token_next(token)
else:
nl = self.nl()
@@ -355,18 +210,17 @@ class ReindentFilter(object):
token = _next_token(tlist.token_index(nl) + offset)
def _split_statements(self, tlist):
- idx = 0
- token = tlist.token_next_by_type(idx, (T.Keyword.DDL, T.Keyword.DML))
+ token = tlist.token_next_by(t=(T.Keyword.DDL, T.Keyword.DML))
while token:
- prev = tlist.token_prev(tlist.token_index(token), False)
+ prev = tlist.token_prev(token, skip_ws=False)
if prev and prev.is_whitespace():
tlist.tokens.pop(tlist.token_index(prev))
# only break if it's not the first token
if prev:
nl = self.nl()
tlist.insert_before(token, nl)
- token = tlist.token_next_by_type(tlist.token_index(token) + 1,
- (T.Keyword.DDL, T.Keyword.DML))
+ token = tlist.token_next_by(t=(T.Keyword.DDL, T.Keyword.DML),
+ idx=token)
def _process(self, tlist):
func_name = '_process_%s' % tlist.__class__.__name__.lower()
@@ -374,7 +228,7 @@ class ReindentFilter(object):
func(tlist)
def _process_where(self, tlist):
- token = tlist.token_next_match(0, T.Keyword, 'WHERE')
+ token = tlist.token_next_by(m=(T.Keyword, 'WHERE'))
try:
tlist.insert_before(token, self.nl())
except ValueError: # issue121, errors in statement
@@ -384,7 +238,7 @@ class ReindentFilter(object):
self.indent -= 1
def _process_having(self, tlist):
- token = tlist.token_next_match(0, T.Keyword, 'HAVING')
+ token = tlist.token_next_by(m=(T.Keyword, 'HAVING'))
try:
tlist.insert_before(token, self.nl())
except ValueError: # issue121, errors in statement
@@ -401,7 +255,7 @@ class ReindentFilter(object):
tlist.tokens.insert(0, self.nl())
indented = True
num_offset = self._get_offset(
- tlist.token_next_match(0, T.Punctuation, '('))
+ tlist.token_next_by(m=(T.Punctuation, '(')))
self.offset += num_offset
self._process_default(tlist, stmts=not indented)
if indented:
@@ -454,7 +308,7 @@ class ReindentFilter(object):
self.offset -= 5
if num_offset is not None:
self.offset -= num_offset
- end = tlist.token_next_match(0, T.Keyword, 'END')
+ end = tlist.token_next_by(m=(T.Keyword, 'END'))
tlist.insert_before(end, self.nl())
self.offset -= outer_offset
@@ -465,13 +319,13 @@ class ReindentFilter(object):
self._split_kwds(tlist)
[self._process(sgroup) for sgroup in tlist.get_sublists()]
- def process(self, stack, stmt):
+ def process(self, stmt):
if isinstance(stmt, sql.Statement):
self._curr_stmt = stmt
self._process(stmt)
if isinstance(stmt, sql.Statement):
if self._last_stmt is not None:
- if u(self._last_stmt).endswith('\n'):
+ if text_type(self._last_stmt).endswith('\n'):
nl = '\n'
else:
nl = '\n\n'
@@ -481,9 +335,8 @@ class ReindentFilter(object):
self._last_stmt = stmt
-# FIXME: Doesn't work ;)
+# FIXME: Doesn't work
class RightMarginFilter(object):
-
keep_together = (
# sql.TypeCast, sql.Identifier, sql.Alias,
)
@@ -492,20 +345,19 @@ class RightMarginFilter(object):
self.width = width
self.line = ''
- def _process(self, stack, group, stream):
+ def _process(self, group, stream):
for token in stream:
if token.is_whitespace() and '\n' in token.value:
if token.value.endswith('\n'):
self.line = ''
else:
self.line = token.value.splitlines()[-1]
- elif (token.is_group()
- and token.__class__ not in self.keep_together):
- token.tokens = self._process(stack, token, token.tokens)
+ elif token.is_group() and type(token) not in self.keep_together:
+ token.tokens = self._process(token, token.tokens)
else:
- val = u(token)
+ val = text_type(token)
if len(self.line) + len(val) > self.width:
- match = re.search('^ +', self.line)
+ match = re.search(r'^ +', self.line)
if match is not None:
indent = match.group()
else:
@@ -515,83 +367,23 @@ class RightMarginFilter(object):
self.line += val
yield token
- def process(self, stack, group):
- return
- group.tokens = self._process(stack, group, group.tokens)
-
-
-class ColumnsSelect(object):
- """Get the columns names of a SELECT query"""
- def process(self, stack, stream):
- mode = 0
- oldValue = ""
- parenthesis = 0
-
- for token_type, value in stream:
- # Ignore comments
- if token_type in Comment:
- continue
-
- # We have not detected a SELECT statement
- if mode == 0:
- if token_type in Keyword and value == 'SELECT':
- mode = 1
-
- # We have detected a SELECT statement
- elif mode == 1:
- if value == 'FROM':
- if oldValue:
- yield oldValue
-
- mode = 3 # Columns have been checked
-
- elif value == 'AS':
- oldValue = ""
- mode = 2
-
- elif (token_type == Punctuation
- and value == ',' and not parenthesis):
- if oldValue:
- yield oldValue
- oldValue = ""
-
- elif token_type not in Whitespace:
- if value == '(':
- parenthesis += 1
- elif value == ')':
- parenthesis -= 1
-
- oldValue += value
-
- # We are processing an AS keyword
- elif mode == 2:
- # We check also for Keywords because a bug in SQLParse
- if token_type == Name or token_type == Keyword:
- yield value
- mode = 1
+ def process(self, group):
+ # return
+ # group.tokens = self._process(group, group.tokens)
+ raise NotImplementedError
# ---------------------------
# postprocess
class SerializerUnicode(object):
-
- def process(self, stack, stmt):
- raw = u(stmt)
+ def process(self, stmt):
+ raw = text_type(stmt)
lines = split_unquoted_newlines(raw)
res = '\n'.join(line.rstrip() for line in lines)
return res
-def Tokens2Unicode(stream):
- result = ""
-
- for _, value in stream:
- result += u(value)
-
- return result
-
-
class OutputFilter(object):
varname_prefix = ''
@@ -602,14 +394,14 @@ class OutputFilter(object):
def _process(self, stream, varname, has_nl):
raise NotImplementedError
- def process(self, stack, stmt):
+ def process(self, stmt):
self.count += 1
if self.count > 1:
varname = '%s%d' % (self.varname, self.count)
else:
varname = self.varname
- has_nl = len(u(stmt).strip().splitlines()) > 1
+ has_nl = len(text_type(stmt).strip().splitlines()) > 1
stmt.tokens = self._process(stmt.tokens, varname, has_nl)
return stmt
@@ -704,34 +496,3 @@ class OutputPHPFilter(OutputFilter):
# Close quote
yield sql.Token(T.Text, '"')
yield sql.Token(T.Punctuation, ';')
-
-
-class Limit(object):
- """Get the LIMIT of a query.
-
- If not defined, return -1 (SQL specification for no LIMIT query)
- """
- def process(self, stack, stream):
- index = 7
- stream = list(stream)
- stream.reverse()
-
- # Run over all tokens in the stream from the end
- for token_type, value in stream:
- index -= 1
-
-# if index and token_type in Keyword:
- if index and token_type in Keyword and value == 'LIMIT':
- return stream[4 - index][1]
-
- return -1
-
-
-def compact(stream):
- """Function that return a compacted version of the stream"""
- pipe = Pipeline()
-
- pipe.append(StripComments())
- pipe.append(StripWhitespace)
-
- return pipe(stream)
diff --git a/sqlparse/functions.py b/sqlparse/functions.py
deleted file mode 100644
index e54457e..0000000
--- a/sqlparse/functions.py
+++ /dev/null
@@ -1,44 +0,0 @@
-'''
-Created on 17/05/2012
-
-@author: piranna
-
-Several utility functions to extract info from the SQL sentences
-'''
-
-from sqlparse.filters import ColumnsSelect, Limit
-from sqlparse.pipeline import Pipeline
-from sqlparse.tokens import Keyword, Whitespace
-
-
-def getlimit(stream):
- """Function that return the LIMIT of a input SQL """
- pipe = Pipeline()
-
- pipe.append(Limit())
-
- result = pipe(stream)
- try:
- return int(result)
- except ValueError:
- return result
-
-
-def getcolumns(stream):
- """Function that return the colums of a SELECT query"""
- pipe = Pipeline()
-
- pipe.append(ColumnsSelect())
-
- return pipe(stream)
-
-
-class IsType(object):
- """Functor that return is the statement is of a specific type"""
- def __init__(self, type):
- self.type = type
-
- def __call__(self, stream):
- for token_type, value in stream:
- if token_type not in Whitespace:
- return token_type in Keyword and value == self.type
diff --git a/sqlparse/pipeline.py b/sqlparse/pipeline.py
deleted file mode 100644
index 34dad19..0000000
--- a/sqlparse/pipeline.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# Copyright (C) 2011 Jesus Leganes "piranna", piranna@gmail.com
-#
-# This module is part of python-sqlparse and is released under
-# the BSD License: http://www.opensource.org/licenses/bsd-license.php.
-
-from types import GeneratorType
-
-
-class Pipeline(list):
- """Pipeline to process filters sequentially"""
-
- def __call__(self, stream):
- """Run the pipeline
-
- Return a static (non generator) version of the result
- """
-
- # Run the stream over all the filters on the pipeline
- for filter in self:
- # Functions and callable objects (objects with '__call__' method)
- if callable(filter):
- stream = filter(stream)
-
- # Normal filters (objects with 'process' method)
- else:
- stream = filter.process(None, stream)
-
- # If last filter return a generator, staticalize it inside a list
- if isinstance(stream, GeneratorType):
- return list(stream)
- return stream
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index 6abc432..57bf1e7 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -6,15 +6,16 @@
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
"""This module contains classes representing syntactical elements of SQL."""
+from __future__ import print_function
import re
-import sys
from sqlparse import tokens as T
-from sqlparse.compat import string_types, u
+from sqlparse.compat import string_types, text_type, unicode_compatible
from sqlparse.utils import imt, remove_quotes
+@unicode_compatible
class Token(object):
"""Base class for all other classes in this module.
@@ -26,40 +27,29 @@ class Token(object):
__slots__ = ('value', 'ttype', 'parent', 'normalized', 'is_keyword')
def __init__(self, ttype, value):
+ value = text_type(value)
self.value = value
- if ttype in T.Keyword:
- self.normalized = value.upper()
- else:
- self.normalized = value
self.ttype = ttype
- self.is_keyword = ttype in T.Keyword
self.parent = None
+ self.is_keyword = ttype in T.Keyword
+ self.normalized = value.upper() if self.is_keyword else value
def __str__(self):
- if sys.version_info[0] == 3:
- return self.value
- else:
- return u(self).encode('utf-8')
+ return self.value
def __repr__(self):
- short = self._get_repr_value()
- if sys.version_info[0] < 3:
- short = short.encode('utf-8')
- return '<%s \'%s\' at 0x%07x>' % (self._get_repr_name(),
- short, id(self))
-
- def __unicode__(self):
- """Returns a unicode representation of this object."""
- return self.value or ''
+ cls = self._get_repr_name()
+ value = self._get_repr_value()
+ return "<{cls} '{value}' at 0x{id:2X}>".format(id=id(self), **locals())
def _get_repr_name(self):
return str(self.ttype).split('.')[-1]
def _get_repr_value(self):
- raw = u(self)
+ raw = text_type(self)
if len(raw) > 7:
- raw = raw[:6] + u'...'
- return re.sub('\s+', ' ', raw)
+ raw = raw[:6] + '...'
+ return re.sub(r'\s+', ' ', raw)
def flatten(self):
"""Resolve subgroups."""
@@ -81,32 +71,23 @@ class Token(object):
if not type_matched or values is None:
return type_matched
- if regex:
- if isinstance(values, string_types):
- values = {values}
+ if isinstance(values, string_types):
+ values = (values,)
- if self.ttype is T.Keyword:
- values = set(re.compile(v, re.IGNORECASE) for v in values)
- else:
- values = set(re.compile(v) for v in values)
+ if regex:
+ # TODO: Add test for regex with is_keyboard = false
+ flag = re.IGNORECASE if self.is_keyword else 0
+ values = (re.compile(v, flag) for v in values)
for pattern in values:
- if pattern.search(self.value):
+ if pattern.search(self.normalized):
return True
return False
- if isinstance(values, string_types):
- if self.is_keyword:
- return values.upper() == self.normalized
- return values == self.value
-
if self.is_keyword:
- for v in values:
- if v.upper() == self.normalized:
- return True
- return False
+ values = (v.upper() for v in values)
- return self.value in values
+ return self.normalized in values
def is_group(self):
"""Returns ``True`` if this object has children."""
@@ -114,7 +95,7 @@ class Token(object):
def is_whitespace(self):
"""Return ``True`` if this token is a whitespace token."""
- return self.ttype and self.ttype in T.Whitespace
+ return self.ttype in T.Whitespace
def within(self, group_cls):
"""Returns ``True`` if this token is within *group_cls*.
@@ -143,6 +124,7 @@ class Token(object):
return False
+@unicode_compatible
class TokenList(Token):
"""A group of tokens.
@@ -150,45 +132,35 @@ class TokenList(Token):
list of child-tokens.
"""
- __slots__ = ('value', 'ttype', 'tokens')
+ __slots__ = 'tokens'
def __init__(self, tokens=None):
- if tokens is None:
- tokens = []
- self.tokens = tokens
- super(TokenList, self).__init__(None, self.__str__())
-
- def __unicode__(self):
- return self._to_string()
+ self.tokens = tokens or []
+ super(TokenList, self).__init__(None, text_type(self))
def __str__(self):
- str_ = self._to_string()
- if sys.version_info[0] <= 2:
- str_ = str_.encode('utf-8')
- return str_
-
- def _to_string(self):
- if sys.version_info[0] == 3:
- return ''.join(x.value for x in self.flatten())
- else:
- return ''.join(u(x) for x in self.flatten())
+ return ''.join(token.value for token in self.flatten())
+
+ def __iter__(self):
+ return iter(self.tokens)
+
+ def __getitem__(self, item):
+ return self.tokens[item]
def _get_repr_name(self):
- return self.__class__.__name__
+ return type(self).__name__
- def _pprint_tree(self, max_depth=None, depth=0):
+ def _pprint_tree(self, max_depth=None, depth=0, f=None):
"""Pretty-print the object tree."""
- indent = ' ' * (depth * 2)
+ ind = ' ' * (depth * 2)
for idx, token in enumerate(self.tokens):
- if token.is_group():
- pre = ' +-'
- else:
- pre = ' | '
- print('%s%s%d %s \'%s\'' % (indent, pre, idx,
- token._get_repr_name(),
- token._get_repr_value()))
- if (token.is_group() and (max_depth is None or depth < max_depth)):
- token._pprint_tree(max_depth, depth + 1)
+ pre = ' +-' if token.is_group() else ' | '
+ cls = token._get_repr_name()
+ value = token._get_repr_value()
+ print("{ind}{pre}{idx} {cls} '{value}'".format(**locals()), file=f)
+
+ if token.is_group() and (max_depth is None or depth < max_depth):
+ token._pprint_tree(max_depth, depth + 1, f)
def get_token_at_offset(self, offset):
"""Returns the token that is on position offset."""
@@ -205,26 +177,19 @@ class TokenList(Token):
This method is recursively called for all child tokens.
"""
for token in self.tokens:
- if isinstance(token, TokenList):
+ if token.is_group():
for item in token.flatten():
yield item
else:
yield token
- # def __iter__(self):
- # return self
- #
- # def next(self):
- # for token in self.tokens:
- # yield token
-
def is_group(self):
return True
def get_sublists(self):
- for x in self.tokens:
- if isinstance(x, TokenList):
- yield x
+ for token in self.tokens:
+ if token.is_group():
+ yield token
@property
def _groupable_tokens(self):
@@ -242,7 +207,7 @@ class TokenList(Token):
funcs = (funcs,)
if reverse:
- iterable = iter(reversed(self.tokens[end:start - 1]))
+ iterable = reversed(self.tokens[end:start - 1])
else:
iterable = self.tokens[start:end]
@@ -251,7 +216,7 @@ class TokenList(Token):
if func(token):
return token
- def token_first(self, ignore_whitespace=True, ignore_comments=False):
+ def token_first(self, skip_ws=True, skip_cm=False):
"""Returns the first child token.
If *ignore_whitespace* is ``True`` (the default), whitespace
@@ -260,35 +225,15 @@ class TokenList(Token):
if *ignore_comments* is ``True`` (default: ``False``), comments are
ignored too.
"""
- funcs = lambda tk: not ((ignore_whitespace and tk.is_whitespace()) or
- (ignore_comments and imt(tk, i=Comment)))
+ # this on is inconsistent, using Comment instead of T.Comment...
+ funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or
+ (skip_cm and imt(tk, i=Comment)))
return self._token_matching(funcs)
def token_next_by(self, i=None, m=None, t=None, idx=0, end=None):
funcs = lambda tk: imt(tk, i, m, t)
return self._token_matching(funcs, idx, end)
- def token_next_by_instance(self, idx, clss, end=None):
- """Returns the next token matching a class.
-
- *idx* is where to start searching in the list of child tokens.
- *clss* is a list of classes the token should be an instance of.
-
- If no matching token can be found ``None`` is returned.
- """
- funcs = lambda tk: imt(tk, i=clss)
- return self._token_matching(funcs, idx, end)
-
- def token_next_by_type(self, idx, ttypes):
- """Returns next matching token by it's token type."""
- funcs = lambda tk: imt(tk, t=ttypes)
- return self._token_matching(funcs, idx)
-
- def token_next_match(self, idx, ttype, value, regex=False):
- """Returns next token where it's ``match`` method returns ``True``."""
- funcs = lambda tk: imt(tk, m=(ttype, value, regex))
- return self._token_matching(funcs, idx)
-
def token_not_matching(self, idx, funcs):
funcs = (funcs,) if not isinstance(funcs, (list, tuple)) else funcs
funcs = [lambda tk: not func(tk) for func in funcs]
@@ -297,7 +242,7 @@ class TokenList(Token):
def token_matching(self, idx, funcs):
return self._token_matching(funcs, idx)
- def token_prev(self, idx, skip_ws=True):
+ def token_prev(self, idx, skip_ws=True, skip_cm=False):
"""Returns the previous token relative to *idx*.
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
@@ -305,10 +250,11 @@ class TokenList(Token):
"""
if isinstance(idx, int):
idx += 1 # alot of code usage current pre-compensates for this
- funcs = lambda tk: not (tk.is_whitespace() and skip_ws)
+ funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or
+ (skip_cm and imt(tk, t=T.Comment)))
return self._token_matching(funcs, idx, reverse=True)
- def token_next(self, idx, skip_ws=True):
+ def token_next(self, idx, skip_ws=True, skip_cm=False):
"""Returns the next token relative to *idx*.
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
@@ -316,7 +262,8 @@ class TokenList(Token):
"""
if isinstance(idx, int):
idx += 1 # alot of code usage current pre-compensates for this
- funcs = lambda tk: not (tk.is_whitespace() and skip_ws)
+ funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or
+ (skip_cm and imt(tk, t=T.Comment)))
return self._token_matching(funcs, idx)
def token_index(self, token, start=0):
@@ -334,9 +281,9 @@ class TokenList(Token):
end_idx = include_end + self.token_index(end)
return self.tokens[start_idx:end_idx]
- def group_tokens(self, grp_cls, tokens, ignore_ws=False, extend=False):
+ def group_tokens(self, grp_cls, tokens, skip_ws=False, extend=False):
"""Replace tokens by an instance of *grp_cls*."""
- if ignore_ws:
+ if skip_ws:
while tokens and tokens[-1].is_whitespace():
tokens = tokens[:-1]
@@ -352,7 +299,7 @@ class TokenList(Token):
left.parent = self
tokens = tokens[1:]
left.tokens.extend(tokens)
- left.value = left.__str__()
+ left.value = str(left)
else:
left = grp_cls(tokens)
@@ -393,8 +340,6 @@ class TokenList(Token):
if len(self.tokens) > 2 and self.token_next_by(t=T.Whitespace):
return self._get_first_name(reverse=True)
- return None
-
def get_name(self):
"""Returns the name of this identifier.
@@ -402,32 +347,22 @@ class TokenList(Token):
be considered as the name under which the object corresponding to
this identifier is known within the current statement.
"""
- alias = self.get_alias()
- if alias is not None:
- return alias
- return self.get_real_name()
+ return self.get_alias() or self.get_real_name()
def get_real_name(self):
"""Returns the real name (object name) of this identifier."""
# a.b
- dot = self.token_next_match(0, T.Punctuation, '.')
- if dot is not None:
- return self._get_first_name(self.token_index(dot))
-
- return self._get_first_name()
+ dot = self.token_next_by(m=(T.Punctuation, '.'))
+ return self._get_first_name(dot)
def get_parent_name(self):
"""Return name of the parent object if any.
A parent object is identified by the first occuring dot.
"""
- dot = self.token_next_match(0, T.Punctuation, '.')
- if dot is None:
- return None
- prev_ = self.token_prev(self.token_index(dot))
- if prev_ is None: # something must be verry wrong here..
- return None
- return remove_quotes(prev_.value)
+ dot = self.token_next_by(m=(T.Punctuation, '.'))
+ prev_ = self.token_prev(dot)
+ return remove_quotes(prev_.value) if prev_ is not None else None
def _get_first_name(self, idx=None, reverse=False, keywords=False):
"""Returns the name of the first token with a name"""
@@ -442,19 +377,16 @@ class TokenList(Token):
if keywords:
types.append(T.Keyword)
- for tok in tokens:
- if tok.ttype in types:
- return remove_quotes(tok.value)
- elif isinstance(tok, Identifier) or isinstance(tok, Function):
- return tok.get_name()
- return None
+ for token in tokens:
+ if token.ttype in types:
+ return remove_quotes(token.value)
+ elif isinstance(token, (Identifier, Function)):
+ return token.get_name()
class Statement(TokenList):
"""Represents a SQL statement."""
- __slots__ = ('value', 'ttype', 'tokens')
-
def get_type(self):
"""Returns the type of a statement.
@@ -465,7 +397,7 @@ class Statement(TokenList):
Whitespaces and comments at the beginning of the statement
are ignored.
"""
- first_token = self.token_first(ignore_comments=True)
+ first_token = self.token_first(skip_cm=True)
if first_token is None:
# An "empty" statement that either has not tokens at all
# or only whitespace tokens.
@@ -478,16 +410,14 @@ class Statement(TokenList):
# The WITH keyword should be followed by either an Identifier or
# an IdentifierList containing the CTE definitions; the actual
# DML keyword (e.g. SELECT, INSERT) will follow next.
- idents = self.token_next(
- self.token_index(first_token), skip_ws=True)
- if isinstance(idents, (Identifier, IdentifierList)):
- dml_keyword = self.token_next(
- self.token_index(idents), skip_ws=True)
+ token = self.token_next(first_token, skip_ws=True)
+ if isinstance(token, (Identifier, IdentifierList)):
+ dml_keyword = self.token_next(token, skip_ws=True)
+
if dml_keyword.ttype == T.Keyword.DML:
return dml_keyword.normalized
- # Hmm, probably invalid syntax, so return unknown.
- return 'UNKNOWN'
+ # Hmm, probably invalid syntax, so return unknown.
return 'UNKNOWN'
@@ -499,33 +429,33 @@ class Identifier(TokenList):
def is_wildcard(self):
"""Return ``True`` if this identifier contains a wildcard."""
- token = self.token_next_by_type(0, T.Wildcard)
+ token = self.token_next_by(t=T.Wildcard)
return token is not None
def get_typecast(self):
"""Returns the typecast or ``None`` of this object as a string."""
- marker = self.token_next_match(0, T.Punctuation, '::')
+ marker = self.token_next_by(m=(T.Punctuation, '::'))
if marker is None:
return None
- next_ = self.token_next(self.token_index(marker), False)
+ next_ = self.token_next(marker, False)
if next_ is None:
return None
- return u(next_)
+ return next_.value
def get_ordering(self):
"""Returns the ordering or ``None`` as uppercase string."""
- ordering = self.token_next_by_type(0, T.Keyword.Order)
+ ordering = self.token_next_by(t=T.Keyword.Order)
if ordering is None:
return None
- return ordering.value.upper()
+ return ordering.normalized
def get_array_indices(self):
"""Returns an iterator of index token lists"""
- for tok in self.tokens:
- if isinstance(tok, SquareBrackets):
+ for token in self.tokens:
+ if isinstance(token, SquareBrackets):
# Use [1:-1] index to discard the square brackets
- yield tok.tokens[1:-1]
+ yield token.tokens[1:-1]
class IdentifierList(TokenList):
@@ -536,15 +466,15 @@ class IdentifierList(TokenList):
Whitespaces and punctuations are not included in this generator.
"""
- for x in self.tokens:
- if not x.is_whitespace() and not x.match(T.Punctuation, ','):
- yield x
+ for token in self.tokens:
+ if not (token.is_whitespace() or token.match(T.Punctuation, ',')):
+ yield token
class Parenthesis(TokenList):
"""Tokens between parenthesis."""
- M_OPEN = (T.Punctuation, '(')
- M_CLOSE = (T.Punctuation, ')')
+ M_OPEN = T.Punctuation, '('
+ M_CLOSE = T.Punctuation, ')'
@property
def _groupable_tokens(self):
@@ -553,8 +483,8 @@ class Parenthesis(TokenList):
class SquareBrackets(TokenList):
"""Tokens between square brackets"""
- M_OPEN = (T.Punctuation, '[')
- M_CLOSE = (T.Punctuation, ']')
+ M_OPEN = T.Punctuation, '['
+ M_CLOSE = T.Punctuation, ']'
@property
def _groupable_tokens(self):
@@ -567,14 +497,14 @@ class Assignment(TokenList):
class If(TokenList):
"""An 'if' clause with possible 'else if' or 'else' parts."""
- M_OPEN = (T.Keyword, 'IF')
- M_CLOSE = (T.Keyword, 'END IF')
+ M_OPEN = T.Keyword, 'IF'
+ M_CLOSE = T.Keyword, 'END IF'
class For(TokenList):
"""A 'FOR' loop."""
- M_OPEN = (T.Keyword, ('FOR', 'FOREACH'))
- M_CLOSE = (T.Keyword, 'END LOOP')
+ M_OPEN = T.Keyword, ('FOR', 'FOREACH')
+ M_CLOSE = T.Keyword, 'END LOOP'
class Comparison(TokenList):
@@ -598,15 +528,15 @@ class Comment(TokenList):
class Where(TokenList):
"""A WHERE clause."""
- M_OPEN = (T.Keyword, 'WHERE')
- M_CLOSE = (T.Keyword,
- ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT', 'HAVING'))
+ M_OPEN = T.Keyword, 'WHERE'
+ M_CLOSE = T.Keyword, ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT',
+ 'HAVING')
class Case(TokenList):
"""A CASE statement with one or more WHEN and possibly an ELSE part."""
- M_OPEN = (T.Keyword, 'CASE')
- M_CLOSE = (T.Keyword, 'END')
+ M_OPEN = T.Keyword, 'CASE'
+ M_CLOSE = T.Keyword, 'END'
def get_cases(self):
"""Returns a list of 2-tuples (condition, value).
@@ -659,15 +589,15 @@ class Function(TokenList):
def get_parameters(self):
"""Return a list of parameters."""
parenthesis = self.tokens[-1]
- for t in parenthesis.tokens:
- if imt(t, i=IdentifierList):
- return t.get_identifiers()
- elif imt(t, i=(Function, Identifier), t=T.Literal):
- return [t, ]
+ for token in parenthesis.tokens:
+ if imt(token, i=IdentifierList):
+ return token.get_identifiers()
+ elif imt(token, i=(Function, Identifier), t=T.Literal):
+ return [token, ]
return []
class Begin(TokenList):
"""A BEGIN/END block."""
- M_OPEN = (T.Keyword, 'BEGIN')
- M_CLOSE = (T.Keyword, 'END')
+ M_OPEN = T.Keyword, 'BEGIN'
+ M_CLOSE = T.Keyword, 'END'
diff --git a/sqlparse/utils.py b/sqlparse/utils.py
index 2513c26..4da44c6 100644
--- a/sqlparse/utils.py
+++ b/sqlparse/utils.py
@@ -7,78 +7,9 @@
import itertools
import re
-from collections import OrderedDict, deque
+from collections import deque
from contextlib import contextmanager
-
-class Cache(OrderedDict):
- """Cache with LRU algorithm using an OrderedDict as basis
- """
-
- def __init__(self, maxsize=100):
- OrderedDict.__init__(self)
-
- self._maxsize = maxsize
-
- def __getitem__(self, key, *args, **kwargs):
- # Get the key and remove it from the cache, or raise KeyError
- value = OrderedDict.__getitem__(self, key)
- del self[key]
-
- # Insert the (key, value) pair on the front of the cache
- OrderedDict.__setitem__(self, key, value)
-
- # Return the value from the cache
- return value
-
- def __setitem__(self, key, value, *args, **kwargs):
- # Key was inserted before, remove it so we put it at front later
- if key in self:
- del self[key]
-
- # Too much items on the cache, remove the least recent used
- elif len(self) >= self._maxsize:
- self.popitem(False)
-
- # Insert the (key, value) pair on the front of the cache
- OrderedDict.__setitem__(self, key, value, *args, **kwargs)
-
-
-def memoize_generator(func):
- """Memoize decorator for generators
-
- Store `func` results in a cache according to their arguments as 'memoize'
- does but instead this works on decorators instead of regular functions.
- Obviusly, this is only useful if the generator will always return the same
- values for each specific parameters...
- """
- cache = Cache()
-
- def wrapped_func(*args, **kwargs):
- params = (args, tuple(sorted(kwargs.items())))
-
- # Look if cached
- try:
- cached = cache[params]
-
- # Not cached, exec and store it
- except KeyError:
- cached = []
-
- for item in func(*args, **kwargs):
- cached.append(item)
- yield item
-
- cache[params] = cached
-
- # Cached, yield its items
- else:
- for item in cached:
- yield item
-
- return wrapped_func
-
-
# This regular expression replaces the home-cooked parser that was here before.
# It is much faster, but requires an extra post-processing step to get the
# desired results (that are compatible with what you would expect from the