summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2020-08-17 06:27:07 -0700
committerBob Halley <halley@dnspython.org>2020-08-17 06:27:07 -0700
commite5545ade37d15ebf87fac626e602ac1ad9852be2 (patch)
treed3a0cd660080d003cbcb28b7b4378f0577fbc64f
parent1359d2465f2749fa1b525ff39f2098fd7aef0c5e (diff)
downloaddnspython-e5545ade37d15ebf87fac626e602ac1ad9852be2.tar.gz
Update _clone protocol for immutable rdatasets.
-rw-r--r--dns/rdataset.py17
-rw-r--r--dns/set.py8
-rw-r--r--tests/test_rdataset.py9
3 files changed, 32 insertions, 2 deletions
diff --git a/dns/rdataset.py b/dns/rdataset.py
index 1f372cd..10cb252 100644
--- a/dns/rdataset.py
+++ b/dns/rdataset.py
@@ -312,6 +312,8 @@ class ImmutableRdataset(Rdataset):
"""An immutable DNS rdataset."""
+ _clone_class = Rdataset
+
def __init__(self, rdataset):
"""Create an immutable rdataset from the specified rdataset."""
@@ -352,6 +354,21 @@ class ImmutableRdataset(Rdataset):
def clear(self):
raise TypeError('immutable')
+ def __copy__(self):
+ return ImmutableRdataset(super().copy())
+
+ def copy(self):
+ return ImmutableRdataset(super().copy())
+
+ def union(self, other):
+ return ImmutableRdataset(super().union(other))
+
+ def intersection(self, other):
+ return ImmutableRdataset(super().intersection(other))
+
+ def difference(self, other):
+ return ImmutableRdataset(super().difference(other))
+
def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
origin=None, relativize=True, relativize_to=None):
diff --git a/dns/set.py b/dns/set.py
index 0982d78..1fd4d0a 100644
--- a/dns/set.py
+++ b/dns/set.py
@@ -84,9 +84,13 @@ class Set:
subclasses.
"""
- cls = self.__class__
+ if hasattr(self, '_clone_class'):
+ cls = self._clone_class
+ else:
+ cls = self.__class__
obj = cls.__new__(cls)
- obj.items = self.items.copy()
+ obj.items = odict()
+ obj.items.update(self.items)
return obj
def __copy__(self):
diff --git a/tests/test_rdataset.py b/tests/test_rdataset.py
index 88b4840..4710e2a 100644
--- a/tests/test_rdataset.py
+++ b/tests/test_rdataset.py
@@ -151,5 +151,14 @@ class ImmutableRdatasetTestCase(unittest.TestCase):
with self.assertRaises(TypeError):
irds.clear()
+ def test_cloning(self):
+ rds1 = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1', '10.0.0.2')
+ rds1 = dns.rdataset.ImmutableRdataset(rds1)
+ rds2 = dns.rdataset.from_text('in', 'a', 300, '10.0.0.2', '10.0.0.3')
+ rds2 = dns.rdataset.ImmutableRdataset(rds2)
+ expected = dns.rdataset.from_text('in', 'a', 300, '10.0.0.2')
+ intersection = rds1.intersection(rds2)
+ self.assertEqual(intersection, expected)
+
if __name__ == '__main__':
unittest.main()