summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sqlparse/sql.py26
-rw-r--r--tests/test_parse.py24
2 files changed, 50 insertions, 0 deletions
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index b5ba902..484b0f0 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -93,6 +93,32 @@ class Token(object):
"""Return ``True`` if this token is a whitespace token."""
return self.ttype and self.ttype in T.Whitespace
+ def within(self, group_cls):
+ """Returns ``True`` if this token is within *group_cls*.
+
+ Use this method for example to check if an identifier is within
+ a function: ``t.within(sql.Function)``.
+ """
+ parent = self.parent
+ while parent:
+ if isinstance(parent, group_cls):
+ return True
+ parent = parent.parent
+ return False
+
+ def is_child_of(self, other):
+ """Returns ``True`` if this token is a direct child of *other*."""
+ return self.parent == other
+
+ def has_ancestor(self, other):
+ """Returns ``True`` if *other* is in this tokens ancestry."""
+ parent = self.parent
+ while parent:
+ if parent == other:
+ return True
+ parent = parent.parent
+ return False
+
class TokenList(Token):
"""A group of tokens.
diff --git a/tests/test_parse.py b/tests/test_parse.py
index 3ee2203..d14e329 100644
--- a/tests/test_parse.py
+++ b/tests/test_parse.py
@@ -5,6 +5,7 @@
from tests.utils import TestCaseBase
import sqlparse
+import sqlparse.sql
class SQLParseTest(TestCaseBase):
@@ -37,3 +38,26 @@ class SQLParseTest(TestCaseBase):
sql = u'select\r\n*from foo\n'
p = sqlparse.parse(sql)[0]
self.assertEqual(unicode(p), sql)
+
+ def test_within(self):
+ sql = 'foo(col1, col2)'
+ p = sqlparse.parse(sql)[0]
+ col1 = p.tokens[0].tokens[1].tokens[1].tokens[0]
+ self.assert_(col1.within(sqlparse.sql.Function))
+
+ def test_child_of(self):
+ sql = '(col1, col2)'
+ p = sqlparse.parse(sql)[0]
+ self.assert_(p.tokens[0].tokens[1].is_child_of(p.tokens[0]))
+ sql = 'select foo'
+ p = sqlparse.parse(sql)[0]
+ self.assert_(not p.tokens[2].is_child_of(p.tokens[0]))
+ self.assert_(p.tokens[2].is_child_of(p))
+
+ def test_has_ancestor(self):
+ sql = 'foo or (bar, baz)'
+ p = sqlparse.parse(sql)[0]
+ baz = p.tokens[-1].tokens[1].tokens[-1]
+ self.assert_(baz.has_ancestor(p.tokens[-1].tokens[1]))
+ self.assert_(baz.has_ancestor(p.tokens[-1]))
+ self.assert_(baz.has_ancestor(p))