diff options
author | Bob Halley <halley@dnspython.org> | 2023-03-14 13:26:51 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-14 13:26:51 -0700 |
commit | 1aeec724042b0af114ebf3f9966e7ea4be1dad33 (patch) | |
tree | 434e9140b3c67b44dd7a9bf909cd75a82902719b | |
parent | 8933f303adfc99f72f4985f10507e1b7c64bcc0c (diff) | |
download | dnspython-1aeec724042b0af114ebf3f9966e7ea4be1dad33.tar.gz |
Add names iteration to transactions via iterate_names(). (#907)
Also make rdataset iteration more obvious by adding an
explicit iterate_rdatasets() API.
-rw-r--r-- | dns/transaction.py | 27 | ||||
-rw-r--r-- | dns/zone.py | 14 | ||||
-rw-r--r-- | dns/zonefile.py | 5 | ||||
-rw-r--r-- | tests/test_transaction.py | 25 |
4 files changed, 58 insertions, 13 deletions
diff --git a/dns/transaction.py b/dns/transaction.py index c4a9e1f..91ed732 100644 --- a/dns/transaction.py +++ b/dns/transaction.py @@ -1,6 +1,6 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Iterator, List, Optional, Tuple, Union import collections @@ -357,6 +357,27 @@ class Transaction: """ self._check_delete_name.append(check) + def iterate_rdatasets( + self, + ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]: + """Iterate all the rdatasets in the transaction, returning + (`dns.name.Name`, `dns.rdataset.Rdataset`) tuples. + + Note that as is usual with python iterators, adding or removing items + while iterating will invalidate the iterator and may raise `RuntimeError` + or fail to iterate over all entries.""" + self._check_ended() + return self._iterate_rdatasets() + + def iterate_names(self) -> Iterator[dns.name.Name]: + """Iterate all the names in the transaction. + + Note that as is usual with python iterators, adding or removing names + while iterating will invalidate the iterator and may raise `RuntimeError` + or fail to iterate over all entries.""" + self._check_ended() + return self._iterate_names() + # # Helper methods # @@ -610,6 +631,10 @@ class Transaction: """Return an iterator that yields (name, rdataset) tuples.""" raise NotImplementedError # pragma: no cover + def _iterate_names(self): + """Return an iterator that yields a name.""" + raise NotImplementedError # pragma: no cover + def _get_node(self, name): """Return the node at *name*, if any. diff --git a/dns/zone.py b/dns/zone.py index cc8268d..35724d7 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -565,7 +565,7 @@ class Zone(dns.transaction.TransactionManager): rdtype = dns.rdatatype.RdataType.make(rdtype) covers = dns.rdatatype.RdataType.make(covers) - for (name, node) in self.items(): + for name, node in self.items(): for rds in node: if rdtype == dns.rdatatype.ANY or ( rds.rdtype == rdtype and rds.covers == covers @@ -597,7 +597,7 @@ class Zone(dns.transaction.TransactionManager): rdtype = dns.rdatatype.RdataType.make(rdtype) covers = dns.rdatatype.RdataType.make(covers) - for (name, node) in self.items(): + for name, node in self.items(): for rds in node: if rdtype == dns.rdatatype.ANY or ( rds.rdtype == rdtype and rds.covers == covers @@ -795,7 +795,7 @@ class Zone(dns.transaction.TransactionManager): assert self.origin is not None origin_name = self.origin hasher = hashinfo() - for (name, node) in sorted(self.items()): + for name, node in sorted(self.items()): rrnamebuf = name.to_digestable(self.origin) for rdataset in sorted(node, key=lambda rds: (rds.rdtype, rds.covers)): if name == origin_name and dns.rdatatype.ZONEMD in ( @@ -997,6 +997,9 @@ class Version: return None return node.get_rdataset(self.zone.rdclass, rdtype, covers) + def keys(self): + return self.nodes.keys() + def items(self): return self.nodes.items() @@ -1143,10 +1146,13 @@ class Transaction(dns.transaction.Transaction): self.version.origin = origin def _iterate_rdatasets(self): - for (name, node) in self.version.items(): + for name, node in self.version.items(): for rdataset in node: yield (name, rdataset) + def _iterate_names(self): + return self.version.keys() + def _get_node(self, name): return self.version.get_node(name) diff --git a/dns/zonefile.py b/dns/zonefile.py index 1a53f5b..fad78c3 100644 --- a/dns/zonefile.py +++ b/dns/zonefile.py @@ -581,7 +581,7 @@ class RRsetsReaderTransaction(dns.transaction.Transaction): pass def _name_exists(self, name): - for (n, _, _) in self.rdatasets: + for n, _, _ in self.rdatasets: if n == name: return True return False @@ -606,6 +606,9 @@ class RRsetsReaderTransaction(dns.transaction.Transaction): def _iterate_rdatasets(self): raise NotImplementedError # pragma: no cover + def _iterate_names(self): + raise NotImplementedError # pragma: no cover + class RRSetsReaderManager(dns.transaction.TransactionManager): def __init__( diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 8e2744a..80559bd 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -499,12 +499,23 @@ def test_zone_ooz_name(zone): def test_zone_iteration(zone): expected = {} - for (name, rdataset) in zone.iterate_rdatasets(): + for name, rdataset in zone.iterate_rdatasets(): expected[(name, rdataset.rdtype, rdataset.covers)] = rdataset with zone.writer() as txn: - actual = {} - for (name, rdataset) in txn: - actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset + actual1 = {} + for name, rdataset in txn: + actual1[(name, rdataset.rdtype, rdataset.covers)] = rdataset + actual2 = {} + for name, rdataset in txn.iterate_rdatasets(): + actual2[(name, rdataset.rdtype, rdataset.covers)] = rdataset + assert actual1 == expected + assert actual2 == expected + + +def test_zone_name_iteration(zone): + expected = list(zone.keys()) + with zone.writer() as txn: + actual = list(txn.iterate_names()) assert actual == expected @@ -515,7 +526,7 @@ def test_iteration_in_replacement_txn(zone): with zone.writer(True) as txn: txn.replace(dns.name.empty, rds) actual = {} - for (name, rdataset) in txn: + for name, rdataset in txn: actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset assert actual == expected @@ -528,7 +539,7 @@ def test_replacement_commit(zone): txn.replace(dns.name.empty, rds) with zone.reader() as txn: actual = {} - for (name, rdataset) in txn: + for name, rdataset in txn: actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset assert actual == expected @@ -592,7 +603,7 @@ def test_vzone_multiple_versions(vzone): def _dump(zone): for v in zone._versions: print("VERSION", v.id) - for (name, n) in v.nodes.items(): + for name, n in v.nodes.items(): for rdataset in n: print(rdataset.to_text(name)) |