diff options
| -rw-r--r-- | sqlparse/sql.py | 26 | ||||
| -rw-r--r-- | tests/test_parse.py | 24 |
2 files changed, 50 insertions, 0 deletions
diff --git a/sqlparse/sql.py b/sqlparse/sql.py index b5ba902..484b0f0 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -93,6 +93,32 @@ class Token(object): """Return ``True`` if this token is a whitespace token.""" return self.ttype and self.ttype in T.Whitespace + def within(self, group_cls): + """Returns ``True`` if this token is within *group_cls*. + + Use this method for example to check if an identifier is within + a function: ``t.within(sql.Function)``. + """ + parent = self.parent + while parent: + if isinstance(parent, group_cls): + return True + parent = parent.parent + return False + + def is_child_of(self, other): + """Returns ``True`` if this token is a direct child of *other*.""" + return self.parent == other + + def has_ancestor(self, other): + """Returns ``True`` if *other* is in this tokens ancestry.""" + parent = self.parent + while parent: + if parent == other: + return True + parent = parent.parent + return False + class TokenList(Token): """A group of tokens. diff --git a/tests/test_parse.py b/tests/test_parse.py index 3ee2203..d14e329 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -5,6 +5,7 @@ from tests.utils import TestCaseBase import sqlparse +import sqlparse.sql class SQLParseTest(TestCaseBase): @@ -37,3 +38,26 @@ class SQLParseTest(TestCaseBase): sql = u'select\r\n*from foo\n' p = sqlparse.parse(sql)[0] self.assertEqual(unicode(p), sql) + + def test_within(self): + sql = 'foo(col1, col2)' + p = sqlparse.parse(sql)[0] + col1 = p.tokens[0].tokens[1].tokens[1].tokens[0] + self.assert_(col1.within(sqlparse.sql.Function)) + + def test_child_of(self): + sql = '(col1, col2)' + p = sqlparse.parse(sql)[0] + self.assert_(p.tokens[0].tokens[1].is_child_of(p.tokens[0])) + sql = 'select foo' + p = sqlparse.parse(sql)[0] + self.assert_(not p.tokens[2].is_child_of(p.tokens[0])) + self.assert_(p.tokens[2].is_child_of(p)) + + def test_has_ancestor(self): + sql = 'foo or (bar, baz)' + p = sqlparse.parse(sql)[0] + baz = p.tokens[-1].tokens[1].tokens[-1] + self.assert_(baz.has_ancestor(p.tokens[-1].tokens[1])) + self.assert_(baz.has_ancestor(p.tokens[-1])) + self.assert_(baz.has_ancestor(p)) |
