diff options
author | Brian Wellington <bwelling@xbill.org> | 2023-04-06 06:03:07 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-06 06:03:07 -0700 |
commit | cee55728316820d037ece45b11d2fed655a2754c (patch) | |
tree | 1030a420d5ca24a135b1da6b6237dc2632fa2651 | |
parent | c7613943d93986ac23cccff7350c42a32fa5b3bd (diff) | |
download | dnspython-cee55728316820d037ece45b11d2fed655a2754c.tar.gz |
Enum typing (#923)
* IntEnum improvements.
This changes make() to always return an instance of the subclass,
creating one on the fly if the value is not known, and updates the typ
registration code to deal with this. It also adds typing annotations to
make().
* Add missing int check.
Some older versions of python weren't rejecting non-int values.
* Fix int check.
Raise TypeError for non-int, not ValueError, to make tests happy.
* Annotate to_text/from_text.
* Remove many the_ prefixed variables.
These were needed in the past to work around typing issues.
-rw-r--r-- | dns/dnssec.py | 6 | ||||
-rw-r--r-- | dns/edns.py | 8 | ||||
-rw-r--r-- | dns/enum.py | 27 | ||||
-rw-r--r-- | dns/message.py | 8 | ||||
-rw-r--r-- | dns/rdata.py | 19 | ||||
-rw-r--r-- | dns/rdataset.py | 6 | ||||
-rw-r--r-- | dns/resolver.py | 12 | ||||
-rw-r--r-- | dns/rrset.py | 6 | ||||
-rw-r--r-- | dns/update.py | 4 | ||||
-rw-r--r-- | dns/zone.py | 30 | ||||
-rw-r--r-- | dns/zonefile.py | 16 |
11 files changed, 77 insertions, 65 deletions
diff --git a/dns/dnssec.py b/dns/dnssec.py index 3caa22b..c219965 100644 --- a/dns/dnssec.py +++ b/dns/dnssec.py @@ -948,9 +948,9 @@ def _make_dnskey( else: raise ValueError("unsupported ECDSA curve") - the_algorithm = Algorithm.make(algorithm) + algorithm = Algorithm.make(algorithm) - _ensure_algorithm_key_combination(the_algorithm, public_key) + _ensure_algorithm_key_combination(algorithm, public_key) if isinstance(public_key, rsa.RSAPublicKey): key_bytes = encode_rsa_public_key(public_key) @@ -974,7 +974,7 @@ def _make_dnskey( rdtype=dns.rdatatype.DNSKEY, flags=flags, protocol=protocol, - algorithm=the_algorithm, + algorithm=algorithm, key=key_bytes, ) diff --git a/dns/edns.py b/dns/edns.py index 64436cd..40899ee 100644 --- a/dns/edns.py +++ b/dns/edns.py @@ -380,7 +380,7 @@ class EDEOption(Option): # lgtm[py/missing-equals] def from_wire_parser( cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" ) -> Option: - the_code = EDECode.make(parser.get_uint16()) + code = EDECode.make(parser.get_uint16()) text = parser.get_remaining() if text: @@ -390,7 +390,7 @@ class EDEOption(Option): # lgtm[py/missing-equals] else: btext = None - return cls(the_code, btext) + return cls(code, btext) _type_to_class: Dict[OptionType, Any] = { @@ -424,8 +424,8 @@ def option_from_wire_parser( Returns an instance of a subclass of ``dns.edns.Option``. """ - the_otype = OptionType.make(otype) - cls = get_option_class(the_otype) + otype = OptionType.make(otype) + cls = get_option_class(otype) return cls.from_wire_parser(otype, parser) diff --git a/dns/enum.py b/dns/enum.py index b5a4aed..968363a 100644 --- a/dns/enum.py +++ b/dns/enum.py @@ -15,19 +15,33 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +from typing import Type, TypeVar, Union + import enum +TIntEnum = TypeVar("TIntEnum", bound="IntEnum") + class IntEnum(enum.IntEnum): @classmethod + def _missing_(cls, value): + cls._check_value(value) + val = int.__new__(cls, value) + val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}" + val._value_ = value + return val + + @classmethod def _check_value(cls, value): max = cls._maximum() + if not isinstance(value, int): + raise TypeError if value < 0 or value > max: name = cls._short_name() - raise ValueError(f"{name} must be between >= 0 and <= {max}") + raise ValueError(f"{name} must be an int between >= 0 and <= {max}") @classmethod - def from_text(cls, text): + def from_text(cls : Type[TIntEnum], text: str) -> TIntEnum: text = text.upper() try: return cls[text] @@ -47,7 +61,7 @@ class IntEnum(enum.IntEnum): raise cls._unknown_exception_class() @classmethod - def to_text(cls, value): + def to_text(cls : Type[TIntEnum], value : int) -> str: cls._check_value(value) try: text = cls(value).name @@ -59,7 +73,7 @@ class IntEnum(enum.IntEnum): return text @classmethod - def make(cls, value): + def make(cls: Type[TIntEnum], value: Union[int, str]) -> TIntEnum: """Convert text or a value into an enumerated type, if possible. *value*, the ``int`` or ``str`` to convert. @@ -76,10 +90,7 @@ class IntEnum(enum.IntEnum): if isinstance(value, str): return cls.from_text(value) cls._check_value(value) - try: - return cls(value) - except ValueError: - return value + return cls(value) @classmethod def _maximum(cls): diff --git a/dns/message.py b/dns/message.py index 3a6f427..2ccdc2b 100644 --- a/dns/message.py +++ b/dns/message.py @@ -1730,13 +1730,11 @@ def make_query( if isinstance(qname, str): qname = dns.name.from_text(qname, idna_codec=idna_codec) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) m = QueryMessage(id=id) m.flags = dns.flags.Flag(flags) - m.find_rrset( - m.question, qname, the_rdclass, the_rdtype, create=True, force_unique=True - ) + m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True) # only pass keywords on to use_edns if they have been set to a # non-None value. Setting a field will turn EDNS on if it hasn't # been configured. diff --git a/dns/rdata.py b/dns/rdata.py index d166b8a..66c07ee 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -880,16 +880,19 @@ def register_type( it applies to all classes. """ - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - existing_cls = get_rdata_class(rdclass, the_rdtype) - if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype): - raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype) + rdtype = dns.rdatatype.RdataType.make(rdtype) + existing_cls = get_rdata_class(rdclass, rdtype) + if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype): + raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) try: - if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text: - raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype) + if ( + rdtype in dns.rdatatype.RdataType + and dns.rdatatype.RdataType(rdtype).name != rdtype_text + ): + raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) except ValueError: pass - _rdata_classes[(rdclass, the_rdtype)] = getattr( + _rdata_classes[(rdclass, rdtype)] = getattr( implementation, rdtype_text.replace("-", "_") ) - dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton) + dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton) diff --git a/dns/rdataset.py b/dns/rdataset.py index c0ede42..b562d1f 100644 --- a/dns/rdataset.py +++ b/dns/rdataset.py @@ -471,9 +471,9 @@ def from_text_list( Returns a ``dns.rdataset.Rdataset`` object. """ - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - r = Rdataset(the_rdclass, the_rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + r = Rdataset(rdclass, rdtype) r.update_ttl(ttl) for t in text_rdatas: rd = dns.rdata.from_text( diff --git a/dns/resolver.py b/dns/resolver.py index cd041d9..61d0052 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -647,17 +647,17 @@ class _Resolution: ) -> None: if isinstance(qname, str): qname = dns.name.from_text(qname, None) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - if dns.rdatatype.is_metatype(the_rdtype): + rdtype = dns.rdatatype.RdataType.make(rdtype) + if dns.rdatatype.is_metatype(rdtype): raise NoMetaqueries - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) - if dns.rdataclass.is_metaclass(the_rdclass): + rdclass = dns.rdataclass.RdataClass.make(rdclass) + if dns.rdataclass.is_metaclass(rdclass): raise NoMetaqueries self.resolver = resolver self.qnames_to_try = resolver._get_qnames_to_try(qname, search) self.qnames = self.qnames_to_try[:] - self.rdtype = the_rdtype - self.rdclass = the_rdclass + self.rdtype = rdtype + self.rdclass = rdclass self.tcp = tcp self.raise_on_no_answer = raise_on_no_answer self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {} diff --git a/dns/rrset.py b/dns/rrset.py index 3f22a90..0519051 100644 --- a/dns/rrset.py +++ b/dns/rrset.py @@ -214,9 +214,9 @@ def from_text_list( if isinstance(name, str): name = dns.name.from_text(name, None, idna_codec=idna_codec) - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - r = RRset(name, the_rdclass, the_rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + r = RRset(name, rdclass, rdtype) r.update_ttl(ttl) for t in text_rdatas: rd = dns.rdata.from_text( diff --git a/dns/update.py b/dns/update.py index b10f6ac..2219ec5 100644 --- a/dns/update.py +++ b/dns/update.py @@ -335,12 +335,12 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] True, ) else: - the_rdtype = dns.rdatatype.RdataType.make(rdtype) + rdtype = dns.rdatatype.RdataType.make(rdtype) self.find_rrset( self.prerequisite, name, dns.rdataclass.NONE, - the_rdtype, + rdtype, dns.rdatatype.NONE, None, True, diff --git a/dns/zone.py b/dns/zone.py index 35724d7..647538c 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -321,11 +321,11 @@ class Zone(dns.transaction.TransactionManager): Returns a ``dns.rdataset.Rdataset``. """ - the_name = self._validate_name(name) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - the_covers = dns.rdatatype.RdataType.make(covers) - node = self.find_node(the_name, create) - return node.find_rdataset(self.rdclass, the_rdtype, the_covers, create) + name = self._validate_name(name) + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + node = self.find_node(name, create) + return node.find_rdataset(self.rdclass, rdtype, covers, create) def get_rdataset( self, @@ -404,14 +404,14 @@ class Zone(dns.transaction.TransactionManager): types were aggregated into a single RRSIG rdataset. """ - the_name = self._validate_name(name) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - the_covers = dns.rdatatype.RdataType.make(covers) - node = self.get_node(the_name) + name = self._validate_name(name) + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + node = self.get_node(name) if node is not None: - node.delete_rdataset(self.rdclass, the_rdtype, the_covers) + node.delete_rdataset(self.rdclass, rdtype, covers) if len(node) == 0: - self.delete_node(the_name) + self.delete_node(name) def replace_rdataset( self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset @@ -484,10 +484,10 @@ class Zone(dns.transaction.TransactionManager): """ vname = self._validate_name(name) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - the_covers = dns.rdatatype.RdataType.make(covers) - rdataset = self.nodes[vname].find_rdataset(self.rdclass, the_rdtype, the_covers) - rrset = dns.rrset.RRset(vname, self.rdclass, the_rdtype, the_covers) + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + rdataset = self.nodes[vname].find_rdataset(self.rdclass, rdtype, covers) + rrset = dns.rrset.RRset(vname, self.rdclass, rdtype, covers) rrset.update(rdataset) return rrset diff --git a/dns/zonefile.py b/dns/zonefile.py index fad78c3..48bedad 100644 --- a/dns/zonefile.py +++ b/dns/zonefile.py @@ -710,26 +710,26 @@ def read_rrsets( if isinstance(default_ttl, str): default_ttl = dns.ttl.from_text(default_ttl) if rdclass is not None: - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdclass = dns.rdataclass.RdataClass.make(rdclass) else: - the_rdclass = None - the_default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass) + rdclass = None + default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass) if rdtype is not None: - the_rdtype = dns.rdatatype.RdataType.make(rdtype) + rdtype = dns.rdatatype.RdataType.make(rdtype) else: - the_rdtype = None + rdtype = None manager = RRSetsReaderManager(origin, relativize, default_rdclass) with manager.writer(True) as txn: tok = dns.tokenizer.Tokenizer(text, "<input>", idna_codec=idna_codec) reader = Reader( tok, - the_default_rdclass, + default_rdclass, txn, allow_directives=False, force_name=name, force_ttl=ttl, - force_rdclass=the_rdclass, - force_rdtype=the_rdtype, + force_rdclass=rdclass, + force_rdtype=rdtype, default_ttl=default_ttl, ) reader.read() |