summaryrefslogtreecommitdiff
path: root/sqlparse
diff options
context:
space:
mode:
authorAndi Albrecht <albrecht.andi@googlemail.com>2012-04-21 21:28:03 -0700
committerAndi Albrecht <albrecht.andi@googlemail.com>2012-04-21 21:28:03 -0700
commit9b643b52bfd59b583094d08615c7bd698f98e576 (patch)
tree5d13bc4428bf678c75e0cbbdf1e35ec5655788ee /sqlparse
parent0afebf47e24d8a1ee1981faef39c0a15a798f7fd (diff)
parenta16c08703c8eb213a8b570bb16636fbe7a2b4a28 (diff)
downloadsqlparse-9b643b52bfd59b583094d08615c7bd698f98e576.tar.gz
Merge pull request #63 from bittrance/master
Support for reading from file-like object
Diffstat (limited to 'sqlparse')
-rw-r--r--sqlparse/__init__.py11
-rw-r--r--sqlparse/engine/grouping.py3
-rw-r--r--sqlparse/lexer.py63
-rw-r--r--sqlparse/sql.py27
4 files changed, 75 insertions, 29 deletions
diff --git a/sqlparse/__init__.py b/sqlparse/__init__.py
index f924c04..58a560c 100644
--- a/sqlparse/__init__.py
+++ b/sqlparse/__init__.py
@@ -31,6 +31,16 @@ def parse(sql):
return tuple(stack.run(sql))
+def parsestream(stream):
+ """Pares sql statements from file-like object.
+
+ Returns a generator of Statement instances.
+ """
+ stack = engine.FilterStack()
+ stack.full_analyze()
+ return stack.run(stream)
+
+
def format(sql, **options):
"""Format *sql* according to *options*.
@@ -54,7 +64,6 @@ def split(sql):
stack.split_statements = True
return [unicode(stmt) for stmt in stack.run(sql)]
-
from sqlparse.engine.filter import StatementFilter
def split2(stream):
splitter = StatementFilter()
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index 55ec7e2..1487c24 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -55,7 +55,8 @@ def _group_matching(tlist, start_ttype, start_value, end_ttype, end_value,
cls, include_semicolon=False, recurse=False):
def _find_matching(i, tl, stt, sva, ett, eva):
depth = 1
- for t in tl.tokens[i:]:
+ for n in xrange(i, len(tl.tokens)):
+ t = tl.tokens[n]
if t.match(stt, sva):
depth += 1
elif t.match(ett, eva):
diff --git a/sqlparse/lexer.py b/sqlparse/lexer.py
index 321669d..dc794ab 100644
--- a/sqlparse/lexer.py
+++ b/sqlparse/lexer.py
@@ -16,7 +16,7 @@ import re
from sqlparse import tokens
from sqlparse.keywords import KEYWORDS, KEYWORDS_COMMON
-
+from cStringIO import StringIO
class include(str):
pass
@@ -159,6 +159,7 @@ class Lexer(object):
stripnl = False
tabsize = 0
flags = re.IGNORECASE
+ bufsize = 4096
tokens = {
'root': [
@@ -214,6 +215,21 @@ class Lexer(object):
filter_ = filter_(**options)
self.filters.append(filter_)
+ def _decode(self, text):
+ if self.encoding == 'guess':
+ try:
+ text = text.decode('utf-8')
+ if text.startswith(u'\ufeff'):
+ text = text[len(u'\ufeff'):]
+ except UnicodeDecodeError:
+ text = text.decode('latin1')
+ else:
+ text = text.decode(self.encoding)
+
+ if self.tabsize > 0:
+ text = text.expandtabs(self.tabsize)
+ return text
+
def get_tokens(self, text, unfiltered=False):
"""
Return an iterable of (tokentype, value) pairs generated from
@@ -223,24 +239,17 @@ class Lexer(object):
Also preprocess the text, i.e. expand tabs and strip it if
wanted and applies registered filters.
"""
- if not isinstance(text, unicode):
- if self.encoding == 'guess':
- try:
- text = text.decode('utf-8')
- if text.startswith(u'\ufeff'):
- text = text[len(u'\ufeff'):]
- except UnicodeDecodeError:
- text = text.decode('latin1')
+ if isinstance(text, basestring):
+ if self.stripall:
+ text = text.strip()
+ elif self.stripnl:
+ text = text.strip('\n')
+
+ if isinstance(text, unicode):
+ text = StringIO(text.encode('utf-8'))
+ self.encoding = 'utf-8'
else:
- text = text.decode(self.encoding)
- if self.stripall:
- text = text.strip()
- elif self.stripnl:
- text = text.strip('\n')
- if self.tabsize > 0:
- text = text.expandtabs(self.tabsize)
-# if not text.endswith('\n'):
-# text += '\n'
+ text = StringIO(text)
def streamer():
for i, t, v in self.get_tokens_unprocessed(text):
@@ -250,7 +259,7 @@ class Lexer(object):
stream = apply_filters(stream, self.filters, self)
return stream
- def get_tokens_unprocessed(self, text, stack=('root',)):
+ def get_tokens_unprocessed(self, stream, stack=('root',)):
"""
Split ``text`` into (tokentype, text) pairs.
@@ -261,10 +270,19 @@ class Lexer(object):
statestack = list(stack)
statetokens = tokendefs[statestack[-1]]
known_names = {}
+
+ text = stream.read(self.bufsize)
+ hasmore = len(text) == self.bufsize
+ text = self._decode(text)
+
while 1:
for rexmatch, action, new_state in statetokens:
m = rexmatch(text, pos)
if m:
+ if hasmore and m.end() == len(text):
+ # Since this is end, token may be truncated
+ continue
+
# print rex.pattern
value = m.group()
if value in known_names:
@@ -299,6 +317,13 @@ class Lexer(object):
statetokens = tokendefs[statestack[-1]]
break
else:
+ if hasmore:
+ buf = stream.read(self.bufsize)
+ hasmore = len(buf) == self.bufsize
+ text = text[pos:] + self._decode(buf)
+ pos = 0
+ continue
+
try:
if text[pos] == '\n':
# at EOL, reset state to "root"
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index 9c7aeee..31fa34d 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -15,11 +15,13 @@ class Token(object):
the type of the token.
"""
- __slots__ = ('value', 'ttype', 'parent')
+ __slots__ = ('value', 'ttype', 'parent', 'normalized', 'is_keyword')
def __init__(self, ttype, value):
self.value = value
+ self.normalized = value.upper() if ttype in T.Keyword else value
self.ttype = ttype
+ self.is_keyword = ttype in T.Keyword
self.parent = None
def __str__(self):
@@ -71,9 +73,9 @@ class Token(object):
type_matched = self.ttype is ttype
if not type_matched or values is None:
return type_matched
- if isinstance(values, basestring):
- values = set([values])
if regex:
+ if isinstance(values, basestring):
+ values = set([values])
if self.ttype is T.Keyword:
values = set([re.compile(v, re.IGNORECASE) for v in values])
else:
@@ -83,10 +85,18 @@ class Token(object):
return True
return False
else:
- if self.ttype in T.Keyword:
- values = set([v.upper() for v in values])
- return self.value.upper() in values
+ if isinstance(values, basestring):
+ if self.is_keyword:
+ return values.upper() == self.normalized
+ else:
+ return values == self.value
+ if self.is_keyword:
+ for v in values:
+ if v.upper() == self.normalized:
+ return True
+ return False
else:
+ print len(values)
return self.value in values
def is_group(self):
@@ -227,7 +237,8 @@ class TokenList(Token):
if not isinstance(idx, int):
idx = self.token_index(idx)
- for token in self.tokens[idx:]:
+ for n in xrange(idx, len(self.tokens)):
+ token = self.tokens[n]
if token.match(ttype, value, regex):
return token
@@ -395,7 +406,7 @@ class Statement(TokenList):
return 'UNKNOWN'
elif first_token.ttype in (T.Keyword.DML, T.Keyword.DDL):
- return first_token.value.upper()
+ return first_token.normalized
return 'UNKNOWN'