diff options
-rw-r--r-- | CHANGES.txt | 1 | ||||
-rw-r--r-- | README.md | 1 | ||||
-rw-r--r-- | defusedxml/ElementTree.py | 21 | ||||
-rw-r--r-- | tests.py | 6 |
4 files changed, 23 insertions, 6 deletions
diff --git a/CHANGES.txt b/CHANGES.txt index 3f011de..4671370 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -8,6 +8,7 @@ defusedxml 0.7.0.rc2 - Re-add and deprecate ``defusedxml.cElementTree`` - Use GitHub Actions instead of TravisCI +- Restore ``ElementTree`` attribute of ``xml.etree`` module after patching defusedxml 0.7.0.rc1 -------------------- @@ -722,6 +722,7 @@ See <https://www.python.org/psf/license> for licensing details. - Re-add and deprecate `defusedxml.cElementTree` - Use GitHub Actions instead of TravisCI + - Restore `ElementTree` attribute of `xml.etree` module after patching ## defusedxml 0.7.0.rc1 diff --git a/defusedxml/ElementTree.py b/defusedxml/ElementTree.py index b1504e4..55c123e 100644 --- a/defusedxml/ElementTree.py +++ b/defusedxml/ElementTree.py @@ -45,12 +45,21 @@ def _get_py3_cls(): cmod = sys.modules.pop(cmodname, None) sys.modules[cmodname] = None - pure_pymod = importlib.import_module(pymodname) - if cmod is not None: - sys.modules[cmodname] = cmod - else: - sys.modules.pop(cmodname) - sys.modules[pymodname] = pymod + try: + pure_pymod = importlib.import_module(pymodname) + finally: + # restore module + sys.modules[pymodname] = pymod + if cmod is not None: + sys.modules[cmodname] = cmod + else: + sys.modules.pop(cmodname, None) + # restore attribute on original package + etree_pkg = sys.modules["xml.etree"] + if pymod is not None: + etree_pkg.ElementTree = pymod + elif hasattr(etree_pkg, "ElementTree"): + del etree_pkg.ElementTree _XMLParser = pure_pymod.XMLParser _iterparse = pure_pymod.iterparse @@ -6,6 +6,7 @@ import sys import unittest import warnings +from xml.etree import ElementTree as orig_elementtree from xml.sax.saxutils import XMLGenerator from xml.sax import SAXParseException from pyexpat import ExpatError @@ -208,6 +209,11 @@ class TestDefusedElementTree(BaseTests): assert self.module.XMLParser is parser assert self.module.XMLParse is parser + def test_import_order(self): + from xml.etree import ElementTree as second_elementtree + + self.assertIs(orig_elementtree, second_elementtree) + class TestDefusedcElementTree(TestDefusedElementTree): module = cElementTree |