summaryrefslogtreecommitdiff
path: root/pylint/test/unittest_pyreverse_writer.py
blob: 4fa964c2807bad476ece61041b7a41d701f9f9b2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) 2003-2016 LOGILAB S.A. (Paris, FRANCE).
# http://www.logilab.fr/ -- mailto:contact@logilab.fr
# Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
# For details: https://github.com/PyCQA/pylint/blob/master/COPYING

"""
unittest for visitors.diadefs and extensions.diadefslib modules
"""


import os
import sys
import codecs
from os.path import join, dirname, abspath
from difflib import unified_diff
import unittest

from astroid import MANAGER

from pylint.pyreverse.inspector import Linker, project_from_files
from pylint.pyreverse.diadefslib import DefaultDiadefGenerator, DiadefsHandler
from pylint.pyreverse.writer import DotWriter
from pylint.pyreverse.utils import get_visibility


_DEFAULTS = {
    'all_ancestors': None, 'show_associated': None,
    'module_names': None,
    'output_format': 'dot', 'diadefs_file': None, 'quiet': 0,
    'show_ancestors': None, 'classes': (), 'all_associated': None,
    'mode': 'PUB_ONLY', 'show_builtin': False, 'only_classnames': False
    }

class Config(object):
    """config object for tests"""
    def __init__(self):
        for attr, value in _DEFAULTS.items():
            setattr(self, attr, value)


def _file_lines(path):
    # we don't care about the actual encoding, but python3 forces us to pick one
    with codecs.open(path, encoding='latin1') as stream:
        lines = [line.strip() for line in stream.readlines()
                 if (line.find('squeleton generated by ') == -1 and
                     not line.startswith('__revision__ = "$Id:'))]
    return [line for line in lines if line]

def get_project(module, name="No Name"):
    """return a astroid project representation"""
    def _astroid_wrapper(func, modname):
        return func(modname)
    return project_from_files([module], _astroid_wrapper,
                              project_name=name)

CONFIG = Config()

class DotWriterTC(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        project = get_project(os.path.join(os.path.dirname(__file__), 'data'))
        linker = Linker(project)
        handler = DiadefsHandler(CONFIG)
        dd = DefaultDiadefGenerator(linker, handler).visit(project)
        for diagram in dd:
            diagram.extract_relationships()
        writer = DotWriter(CONFIG)
        writer.write(dd)

    @classmethod
    def tearDownClass(cls):
        for fname in ('packages_No_Name.dot', 'classes_No_Name.dot',):
            try:
                os.remove(fname)
            except:
                continue

    def _test_same_file(self, generated_file):
        expected_file = os.path.join(os.path.dirname(__file__), 'data', generated_file)
        generated = _file_lines(generated_file)
        expected = _file_lines(expected_file)
        generated = '\n'.join(generated)
        expected = '\n'.join(expected)
        files = "\n *** expected : %s, generated : %s \n"  % (
            expected_file, generated_file)
        self.assertEqual(expected, generated, '%s%s' % (
            files, '\n'.join(line for line in unified_diff(
            expected.splitlines(), generated.splitlines() ))) )
        os.remove(generated_file)

    def test_package_diagram(self):
        self._test_same_file('packages_No_Name.dot')

    def test_class_diagram(self):
        self._test_same_file('classes_No_Name.dot')



class GetVisibilityTC(unittest.TestCase):

    def test_special(self):
        for name in ["__reduce_ex__",  "__setattr__"]:
            self.assertEqual(get_visibility(name), 'special')

    def test_private(self):
        for name in ["__g_", "____dsf", "__23_9"]:
            got = get_visibility(name)
            self.assertEqual(got, 'private',
                             'got %s instead of private for value %s' % (got, name))

    def test_public(self):
        self.assertEqual(get_visibility('simple'), 'public')

    def test_protected(self):
        for name in ["_","__", "___", "____", "_____", "___e__",
                     "_nextsimple", "_filter_it_"]:
            got = get_visibility(name)
            self.assertEqual(got, 'protected',
                             'got %s instead of protected for value %s' % (got, name))


if __name__ == '__main__':
    unittest.main()