summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2023-03-12 07:55:36 -0700
committerBob Halley <halley@dnspython.org>2023-03-12 08:19:48 -0700
commitdc2355b2d91b753fcc45388dac150f8efce537a8 (patch)
tree6f04ee24f7cf36a4345a0c10f7c554887a1b97cf
parente24a50c7105299b5d015c253fc0fbf7243877cdf (diff)
downloaddnspython-txn-names.tar.gz
Add names iteration to transactions via iterate_names().txn-names
Also make rdataset iteration more obvious by adding an explicit iterate_rdatasets() API.
-rw-r--r--dns/transaction.py27
-rw-r--r--dns/zone.py14
-rw-r--r--dns/zonefile.py5
-rw-r--r--tests/test_transaction.py25
4 files changed, 58 insertions, 13 deletions
diff --git a/dns/transaction.py b/dns/transaction.py
index c4a9e1f..91ed732 100644
--- a/dns/transaction.py
+++ b/dns/transaction.py
@@ -1,6 +1,6 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
-from typing import Any, Callable, List, Optional, Tuple, Union
+from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
import collections
@@ -357,6 +357,27 @@ class Transaction:
"""
self._check_delete_name.append(check)
+ def iterate_rdatasets(
+ self,
+ ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]:
+ """Iterate all the rdatasets in the transaction, returning
+ (`dns.name.Name`, `dns.rdataset.Rdataset`) tuples.
+
+ Note that as is usual with python iterators, adding or removing items
+ while iterating will invalidate the iterator and may raise `RuntimeError`
+ or fail to iterate over all entries."""
+ self._check_ended()
+ return self._iterate_rdatasets()
+
+ def iterate_names(self) -> Iterator[dns.name.Name]:
+ """Iterate all the names in the transaction.
+
+ Note that as is usual with python iterators, adding or removing names
+ while iterating will invalidate the iterator and may raise `RuntimeError`
+ or fail to iterate over all entries."""
+ self._check_ended()
+ return self._iterate_names()
+
#
# Helper methods
#
@@ -610,6 +631,10 @@ class Transaction:
"""Return an iterator that yields (name, rdataset) tuples."""
raise NotImplementedError # pragma: no cover
+ def _iterate_names(self):
+ """Return an iterator that yields a name."""
+ raise NotImplementedError # pragma: no cover
+
def _get_node(self, name):
"""Return the node at *name*, if any.
diff --git a/dns/zone.py b/dns/zone.py
index cc8268d..35724d7 100644
--- a/dns/zone.py
+++ b/dns/zone.py
@@ -565,7 +565,7 @@ class Zone(dns.transaction.TransactionManager):
rdtype = dns.rdatatype.RdataType.make(rdtype)
covers = dns.rdatatype.RdataType.make(covers)
- for (name, node) in self.items():
+ for name, node in self.items():
for rds in node:
if rdtype == dns.rdatatype.ANY or (
rds.rdtype == rdtype and rds.covers == covers
@@ -597,7 +597,7 @@ class Zone(dns.transaction.TransactionManager):
rdtype = dns.rdatatype.RdataType.make(rdtype)
covers = dns.rdatatype.RdataType.make(covers)
- for (name, node) in self.items():
+ for name, node in self.items():
for rds in node:
if rdtype == dns.rdatatype.ANY or (
rds.rdtype == rdtype and rds.covers == covers
@@ -795,7 +795,7 @@ class Zone(dns.transaction.TransactionManager):
assert self.origin is not None
origin_name = self.origin
hasher = hashinfo()
- for (name, node) in sorted(self.items()):
+ for name, node in sorted(self.items()):
rrnamebuf = name.to_digestable(self.origin)
for rdataset in sorted(node, key=lambda rds: (rds.rdtype, rds.covers)):
if name == origin_name and dns.rdatatype.ZONEMD in (
@@ -997,6 +997,9 @@ class Version:
return None
return node.get_rdataset(self.zone.rdclass, rdtype, covers)
+ def keys(self):
+ return self.nodes.keys()
+
def items(self):
return self.nodes.items()
@@ -1143,10 +1146,13 @@ class Transaction(dns.transaction.Transaction):
self.version.origin = origin
def _iterate_rdatasets(self):
- for (name, node) in self.version.items():
+ for name, node in self.version.items():
for rdataset in node:
yield (name, rdataset)
+ def _iterate_names(self):
+ return self.version.keys()
+
def _get_node(self, name):
return self.version.get_node(name)
diff --git a/dns/zonefile.py b/dns/zonefile.py
index 1a53f5b..fad78c3 100644
--- a/dns/zonefile.py
+++ b/dns/zonefile.py
@@ -581,7 +581,7 @@ class RRsetsReaderTransaction(dns.transaction.Transaction):
pass
def _name_exists(self, name):
- for (n, _, _) in self.rdatasets:
+ for n, _, _ in self.rdatasets:
if n == name:
return True
return False
@@ -606,6 +606,9 @@ class RRsetsReaderTransaction(dns.transaction.Transaction):
def _iterate_rdatasets(self):
raise NotImplementedError # pragma: no cover
+ def _iterate_names(self):
+ raise NotImplementedError # pragma: no cover
+
class RRSetsReaderManager(dns.transaction.TransactionManager):
def __init__(
diff --git a/tests/test_transaction.py b/tests/test_transaction.py
index 8e2744a..80559bd 100644
--- a/tests/test_transaction.py
+++ b/tests/test_transaction.py
@@ -499,12 +499,23 @@ def test_zone_ooz_name(zone):
def test_zone_iteration(zone):
expected = {}
- for (name, rdataset) in zone.iterate_rdatasets():
+ for name, rdataset in zone.iterate_rdatasets():
expected[(name, rdataset.rdtype, rdataset.covers)] = rdataset
with zone.writer() as txn:
- actual = {}
- for (name, rdataset) in txn:
- actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+ actual1 = {}
+ for name, rdataset in txn:
+ actual1[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+ actual2 = {}
+ for name, rdataset in txn.iterate_rdatasets():
+ actual2[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+ assert actual1 == expected
+ assert actual2 == expected
+
+
+def test_zone_name_iteration(zone):
+ expected = list(zone.keys())
+ with zone.writer() as txn:
+ actual = list(txn.iterate_names())
assert actual == expected
@@ -515,7 +526,7 @@ def test_iteration_in_replacement_txn(zone):
with zone.writer(True) as txn:
txn.replace(dns.name.empty, rds)
actual = {}
- for (name, rdataset) in txn:
+ for name, rdataset in txn:
actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
assert actual == expected
@@ -528,7 +539,7 @@ def test_replacement_commit(zone):
txn.replace(dns.name.empty, rds)
with zone.reader() as txn:
actual = {}
- for (name, rdataset) in txn:
+ for name, rdataset in txn:
actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
assert actual == expected
@@ -592,7 +603,7 @@ def test_vzone_multiple_versions(vzone):
def _dump(zone):
for v in zone._versions:
print("VERSION", v.id)
- for (name, n) in v.nodes.items():
+ for name, n in v.nodes.items():
for rdataset in n:
print(rdataset.to_text(name))