summaryrefslogtreecommitdiff
path: root/sqlparse
diff options
context:
space:
mode:
authorVik <vmuriart@gmail.com>2016-06-12 13:09:56 -0700
committerGitHub <noreply@github.com>2016-06-12 13:09:56 -0700
commit8ad44059d4d9ab5a8a7489a963dcb8de45ca3a0a (patch)
tree3fac58eff5e7f0150874e1205dfcc4dfe8a28455 /sqlparse
parent50de51a5d6abb2a2f8649091912090983dab843d (diff)
parent42fb1d05b601444599f10d10c5d2dd0b431ccc15 (diff)
downloadsqlparse-8ad44059d4d9ab5a8a7489a963dcb8de45ca3a0a.tar.gz
Merge pull request #255 from vmuriart/console-script-examples-sqlOperation
Add Console-script, sql.Operation, fix examples
Diffstat (limited to 'sqlparse')
-rw-r--r--sqlparse/__main__.py144
-rw-r--r--sqlparse/engine/grouping.py33
-rw-r--r--sqlparse/filters/others.py4
-rw-r--r--sqlparse/sql.py46
-rw-r--r--sqlparse/utils.py4
5 files changed, 189 insertions, 42 deletions
diff --git a/sqlparse/__main__.py b/sqlparse/__main__.py
new file mode 100644
index 0000000..28abb6c
--- /dev/null
+++ b/sqlparse/__main__.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (C) 2016 Andi Albrecht, albrecht.andi@gmail.com
+#
+# This module is part of python-sqlparse and is released under
+# the BSD License: http://www.opensource.org/licenses/bsd-license.php
+
+import argparse
+import sys
+
+import sqlparse
+from sqlparse.compat import PY2
+from sqlparse.exceptions import SQLParseError
+
+_CASE_CHOICES = ['upper', 'lower', 'capitalize']
+
+# TODO: Add CLI Tests
+# TODO: Simplify formatter by using argparse `type` arguments
+parser = argparse.ArgumentParser(
+ prog='sqlparse',
+ description='Format FILE according to OPTIONS. Use "-" as FILE '
+ 'to read from stdin.',
+ usage='%(prog)s [OPTIONS] FILE, ...',
+ version=sqlparse.__version__,)
+
+parser.add_argument('filename')
+
+parser.add_argument(
+ '-o', '--outfile',
+ dest='outfile',
+ metavar='FILE',
+ help='write output to FILE (defaults to stdout)')
+
+group = parser.add_argument_group('Formatting Options')
+
+group.add_argument(
+ '-k', '--keywords',
+ metavar='CHOICE',
+ dest='keyword_case',
+ choices=_CASE_CHOICES,
+ help='change case of keywords, CHOICE is one of {0}'.format(
+ ', '.join('"{0}"'.format(x) for x in _CASE_CHOICES)))
+
+group.add_argument(
+ '-i', '--identifiers',
+ metavar='CHOICE',
+ dest='identifier_case',
+ choices=_CASE_CHOICES,
+ help='change case of identifiers, CHOICE is one of {0}'.format(
+ ', '.join('"{0}"'.format(x) for x in _CASE_CHOICES)))
+
+group.add_argument(
+ '-l', '--language',
+ metavar='LANG',
+ dest='output_format',
+ choices=['python', 'php'],
+ help='output a snippet in programming language LANG, '
+ 'choices are "python", "php"')
+
+group.add_argument(
+ '--strip-comments',
+ dest='strip_comments',
+ action='store_true',
+ default=False,
+ help='remove comments')
+
+group.add_argument(
+ '-r', '--reindent',
+ dest='reindent',
+ action='store_true',
+ default=False,
+ help='reindent statements')
+
+group.add_argument(
+ '--indent_width',
+ dest='indent_width',
+ default=2,
+ type=int,
+ help='indentation width (defaults to 2 spaces)')
+
+group.add_argument(
+ '-a', '--reindent_aligned',
+ action='store_true',
+ default=False,
+ help='reindent statements to aligned format')
+
+group.add_argument(
+ '-s', '--use_space_around_operators',
+ action='store_true',
+ default=False,
+ help='place spaces around mathematical operators')
+
+group.add_argument(
+ '--wrap_after',
+ dest='wrap_after',
+ default=0,
+ type=int,
+ help='Column after which lists should be wrapped')
+
+
+def _error(msg):
+ """Print msg and optionally exit with return code exit_."""
+ sys.stderr.write('[ERROR] %s\n' % msg)
+
+
+def main(args=None):
+ args = parser.parse_args(args)
+
+ if args.filename == '-': # read from stdin
+ data = sys.stdin.read()
+ else:
+ try:
+ data = ''.join(open(args.filename).readlines())
+ except IOError as e:
+ _error('Failed to read %s: %s' % (args.filename, e))
+ return 1
+
+ if args.outfile:
+ try:
+ stream = open(args.outfile, 'w')
+ except IOError as e:
+ _error('Failed to open %s: %s' % (args.outfile, e))
+ return 1
+ else:
+ stream = sys.stdout
+
+ formatter_opts = vars(args)
+ try:
+ formatter_opts = sqlparse.formatter.validate_options(formatter_opts)
+ except SQLParseError as e:
+ _error('Invalid options: %s' % e)
+ return 1
+
+ s = sqlparse.format(data, **formatter_opts)
+ if PY2:
+ s = s.encode('utf-8', 'replace')
+ stream.write(s)
+ stream.flush()
+ return 0
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index 91bb3d9..6e414b8 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -23,11 +23,15 @@ def _group_left_right(tlist, m, cls,
valid_right=lambda t: t is not None,
semicolon=False):
"""Groups together tokens that are joined by a middle token. ie. x < y"""
- [_group_left_right(sgroup, m, cls, valid_left, valid_right, semicolon)
- for sgroup in tlist.get_sublists() if not isinstance(sgroup, cls)]
- token = tlist.token_next_by(m=m)
- while token:
+ for token in list(tlist):
+ if token.is_group() and not isinstance(token, cls):
+ _group_left_right(token, m, cls, valid_left, valid_right,
+ semicolon)
+
+ if not token.match(*m):
+ continue
+
left, right = tlist.token_prev(token), tlist.token_next(token)
if valid_left(left) and valid_right(right):
@@ -36,8 +40,7 @@ def _group_left_right(tlist, m, cls,
sright = tlist.token_next_by(m=M_SEMICOLON, idx=right)
right = sright or right
tokens = tlist.tokens_between(left, right)
- token = tlist.group_tokens(cls, tokens, extend=True)
- token = tlist.token_next_by(m=m, idx=token)
+ tlist.group_tokens(cls, tokens, extend=True)
def _group_matching(tlist, cls):
@@ -85,11 +88,12 @@ def group_assignment(tlist):
def group_comparison(tlist):
- I_COMPERABLE = (sql.Parenthesis, sql.Function, sql.Identifier)
+ I_COMPERABLE = (sql.Parenthesis, sql.Function, sql.Identifier,
+ sql.Operation)
T_COMPERABLE = T_NUMERICAL + T_STRING + T_NAME
- func = lambda tk: imt(tk, t=T_COMPERABLE, i=I_COMPERABLE) or (
- imt(tk, t=T.Keyword) and tk.value.upper() == 'NULL')
+ func = lambda tk: (imt(tk, t=T_COMPERABLE, i=I_COMPERABLE) or
+ (tk and tk.is_keyword and tk.normalized == 'NULL'))
_group_left_right(tlist, (T.Operator.Comparison, None), sql.Comparison,
valid_left=func, valid_right=func)
@@ -134,9 +138,9 @@ def group_arrays(tlist):
@recurse(sql.Identifier)
def group_operator(tlist):
I_CYCLE = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
- sql.Identifier,) # sql.Operation)
+ sql.Identifier, sql.Operation)
# wilcards wouldn't have operations next to them
- T_CYCLE = T_NUMERICAL + T_STRING + T_NAME # + T.Wildcard
+ T_CYCLE = T_NUMERICAL + T_STRING + T_NAME
func = lambda tk: imt(tk, i=I_CYCLE, t=T_CYCLE)
token = tlist.token_next_by(t=(T.Operator, T.Wildcard))
@@ -146,8 +150,7 @@ def group_operator(tlist):
if func(left) and func(right):
token.ttype = T.Operator
tokens = tlist.tokens_between(left, right)
- # token = tlist.group_tokens(sql.Operation, tokens)
- token = tlist.group_tokens(sql.Identifier, tokens)
+ token = tlist.group_tokens(sql.Operation, tokens)
token = tlist.token_next_by(t=(T.Operator, T.Wildcard), idx=token)
@@ -155,7 +158,7 @@ def group_operator(tlist):
@recurse(sql.IdentifierList)
def group_identifier_list(tlist):
I_IDENT_LIST = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
- sql.IdentifierList) # sql.Operation
+ sql.IdentifierList, sql.Operation)
T_IDENT_LIST = (T_NUMERICAL + T_STRING + T_NAME +
(T.Keyword, T.Comment, T.Wildcard))
@@ -212,7 +215,7 @@ def group_where(tlist):
@recurse()
def group_aliased(tlist):
I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier,
- ) # sql.Operation)
+ sql.Operation)
token = tlist.token_next_by(i=I_ALIAS, t=T.Number)
while token:
diff --git a/sqlparse/filters/others.py b/sqlparse/filters/others.py
index 6951c74..71b1f8e 100644
--- a/sqlparse/filters/others.py
+++ b/sqlparse/filters/others.py
@@ -6,7 +6,6 @@
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
from sqlparse import sql, tokens as T
-from sqlparse.compat import text_type
from sqlparse.utils import split_unquoted_newlines
@@ -114,6 +113,5 @@ class SpacesAroundOperatorsFilter(object):
class SerializerUnicode(object):
@staticmethod
def process(stmt):
- raw = text_type(stmt)
- lines = split_unquoted_newlines(raw)
+ lines = split_unquoted_newlines(stmt)
return '\n'.join(line.rstrip() for line in lines)
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index 43a89e7..cee6af5 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -37,6 +37,10 @@ class Token(object):
def __str__(self):
return self.value
+ # Pending tokenlist __len__ bug fix
+ # def __len__(self):
+ # return len(self.value)
+
def __repr__(self):
cls = self._get_repr_name()
value = self._get_repr_value()
@@ -141,6 +145,10 @@ class TokenList(Token):
def __str__(self):
return ''.join(token.value for token in self.flatten())
+ # weird bug
+ # def __len__(self):
+ # return len(self.tokens)
+
def __iter__(self):
return iter(self.tokens)
@@ -152,12 +160,12 @@ class TokenList(Token):
def _pprint_tree(self, max_depth=None, depth=0, f=None):
"""Pretty-print the object tree."""
- ind = ' ' * (depth * 2)
+ indent = ' | ' * depth
for idx, token in enumerate(self.tokens):
- pre = ' +-' if token.is_group() else ' | '
cls = token._get_repr_name()
value = token._get_repr_value()
- print("{ind}{pre}{idx} {cls} '{value}'".format(**locals()), file=f)
+ print("{indent}{idx:2d} {cls} '{value}'"
+ .format(**locals()), file=f)
if token.is_group() and (max_depth is None or depth < max_depth):
token._pprint_tree(max_depth, depth + 1, f)
@@ -216,20 +224,6 @@ class TokenList(Token):
if func(token):
return token
- def token_first(self, skip_ws=True, skip_cm=False):
- """Returns the first child token.
-
- If *ignore_whitespace* is ``True`` (the default), whitespace
- tokens are ignored.
-
- if *ignore_comments* is ``True`` (default: ``False``), comments are
- ignored too.
- """
- # this on is inconsistent, using Comment instead of T.Comment...
- funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or
- (skip_cm and imt(tk, i=Comment)))
- return self._token_matching(funcs)
-
def token_next_by(self, i=None, m=None, t=None, idx=0, end=None):
funcs = lambda tk: imt(tk, i, m, t)
return self._token_matching(funcs, idx, end)
@@ -242,24 +236,26 @@ class TokenList(Token):
def token_matching(self, idx, funcs):
return self._token_matching(funcs, idx)
- def token_prev(self, idx, skip_ws=True, skip_cm=False):
+ def token_prev(self, idx=0, skip_ws=True, skip_cm=False):
"""Returns the previous token relative to *idx*.
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
``None`` is returned if there's no previous token.
"""
funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or
- (skip_cm and imt(tk, t=T.Comment)))
+ (skip_cm and imt(tk, t=T.Comment, i=Comment)))
return self._token_matching(funcs, idx, reverse=True)
- def token_next(self, idx, skip_ws=True, skip_cm=False):
+ def token_next(self, idx=0, skip_ws=True, skip_cm=False):
"""Returns the next token relative to *idx*.
+ If called with idx = 0. Returns the first child token.
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
+ If *skip_cm* is ``True`` (default: ``False``), comments are ignored.
``None`` is returned if there's no next token.
"""
funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or
- (skip_cm and imt(tk, t=T.Comment)))
+ (skip_cm and imt(tk, t=T.Comment, i=Comment)))
return self._token_matching(funcs, idx)
def token_index(self, token, start=0):
@@ -387,7 +383,7 @@ class Statement(TokenList):
Whitespaces and comments at the beginning of the statement
are ignored.
"""
- first_token = self.token_first(skip_cm=True)
+ first_token = self.token_next(skip_cm=True)
if first_token is None:
# An "empty" statement that either has not tokens at all
# or only whitespace tokens.
@@ -425,7 +421,7 @@ class Identifier(TokenList):
def get_typecast(self):
"""Returns the typecast or ``None`` of this object as a string."""
marker = self.token_next_by(m=(T.Punctuation, '::'))
- next_ = self.token_next(marker, False)
+ next_ = self.token_next(marker, skip_ws=False)
return next_.value if next_ else None
def get_ordering(self):
@@ -588,3 +584,7 @@ class Begin(TokenList):
"""A BEGIN/END block."""
M_OPEN = T.Keyword, 'BEGIN'
M_CLOSE = T.Keyword, 'END'
+
+
+class Operation(TokenList):
+ """Grouping of operations"""
diff --git a/sqlparse/utils.py b/sqlparse/utils.py
index 8253e0b..4a8646d 100644
--- a/sqlparse/utils.py
+++ b/sqlparse/utils.py
@@ -9,6 +9,7 @@ import itertools
import re
from collections import deque
from contextlib import contextmanager
+from sqlparse.compat import text_type
# This regular expression replaces the home-cooked parser that was here before.
# It is much faster, but requires an extra post-processing step to get the
@@ -33,11 +34,12 @@ SPLIT_REGEX = re.compile(r"""
LINE_MATCH = re.compile(r'(\r\n|\r|\n)')
-def split_unquoted_newlines(text):
+def split_unquoted_newlines(stmt):
"""Split a string on all unquoted newlines.
Unlike str.splitlines(), this will ignore CR/LF/CR+LF if the requisite
character is inside of a string."""
+ text = text_type(stmt)
lines = SPLIT_REGEX.split(text)
outputlines = ['']
for line in lines: