summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrian Wellington <bwelling@xbill.org>2023-04-06 06:03:07 -0700
committerGitHub <noreply@github.com>2023-04-06 06:03:07 -0700
commitcee55728316820d037ece45b11d2fed655a2754c (patch)
tree1030a420d5ca24a135b1da6b6237dc2632fa2651
parentc7613943d93986ac23cccff7350c42a32fa5b3bd (diff)
downloaddnspython-cee55728316820d037ece45b11d2fed655a2754c.tar.gz
Enum typing (#923)
* IntEnum improvements. This changes make() to always return an instance of the subclass, creating one on the fly if the value is not known, and updates the typ registration code to deal with this. It also adds typing annotations to make(). * Add missing int check. Some older versions of python weren't rejecting non-int values. * Fix int check. Raise TypeError for non-int, not ValueError, to make tests happy. * Annotate to_text/from_text. * Remove many the_ prefixed variables. These were needed in the past to work around typing issues.
-rw-r--r--dns/dnssec.py6
-rw-r--r--dns/edns.py8
-rw-r--r--dns/enum.py27
-rw-r--r--dns/message.py8
-rw-r--r--dns/rdata.py19
-rw-r--r--dns/rdataset.py6
-rw-r--r--dns/resolver.py12
-rw-r--r--dns/rrset.py6
-rw-r--r--dns/update.py4
-rw-r--r--dns/zone.py30
-rw-r--r--dns/zonefile.py16
11 files changed, 77 insertions, 65 deletions
diff --git a/dns/dnssec.py b/dns/dnssec.py
index 3caa22b..c219965 100644
--- a/dns/dnssec.py
+++ b/dns/dnssec.py
@@ -948,9 +948,9 @@ def _make_dnskey(
else:
raise ValueError("unsupported ECDSA curve")
- the_algorithm = Algorithm.make(algorithm)
+ algorithm = Algorithm.make(algorithm)
- _ensure_algorithm_key_combination(the_algorithm, public_key)
+ _ensure_algorithm_key_combination(algorithm, public_key)
if isinstance(public_key, rsa.RSAPublicKey):
key_bytes = encode_rsa_public_key(public_key)
@@ -974,7 +974,7 @@ def _make_dnskey(
rdtype=dns.rdatatype.DNSKEY,
flags=flags,
protocol=protocol,
- algorithm=the_algorithm,
+ algorithm=algorithm,
key=key_bytes,
)
diff --git a/dns/edns.py b/dns/edns.py
index 64436cd..40899ee 100644
--- a/dns/edns.py
+++ b/dns/edns.py
@@ -380,7 +380,7 @@ class EDEOption(Option): # lgtm[py/missing-equals]
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
- the_code = EDECode.make(parser.get_uint16())
+ code = EDECode.make(parser.get_uint16())
text = parser.get_remaining()
if text:
@@ -390,7 +390,7 @@ class EDEOption(Option): # lgtm[py/missing-equals]
else:
btext = None
- return cls(the_code, btext)
+ return cls(code, btext)
_type_to_class: Dict[OptionType, Any] = {
@@ -424,8 +424,8 @@ def option_from_wire_parser(
Returns an instance of a subclass of ``dns.edns.Option``.
"""
- the_otype = OptionType.make(otype)
- cls = get_option_class(the_otype)
+ otype = OptionType.make(otype)
+ cls = get_option_class(otype)
return cls.from_wire_parser(otype, parser)
diff --git a/dns/enum.py b/dns/enum.py
index b5a4aed..968363a 100644
--- a/dns/enum.py
+++ b/dns/enum.py
@@ -15,19 +15,33 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+from typing import Type, TypeVar, Union
+
import enum
+TIntEnum = TypeVar("TIntEnum", bound="IntEnum")
+
class IntEnum(enum.IntEnum):
@classmethod
+ def _missing_(cls, value):
+ cls._check_value(value)
+ val = int.__new__(cls, value)
+ val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}"
+ val._value_ = value
+ return val
+
+ @classmethod
def _check_value(cls, value):
max = cls._maximum()
+ if not isinstance(value, int):
+ raise TypeError
if value < 0 or value > max:
name = cls._short_name()
- raise ValueError(f"{name} must be between >= 0 and <= {max}")
+ raise ValueError(f"{name} must be an int between >= 0 and <= {max}")
@classmethod
- def from_text(cls, text):
+ def from_text(cls : Type[TIntEnum], text: str) -> TIntEnum:
text = text.upper()
try:
return cls[text]
@@ -47,7 +61,7 @@ class IntEnum(enum.IntEnum):
raise cls._unknown_exception_class()
@classmethod
- def to_text(cls, value):
+ def to_text(cls : Type[TIntEnum], value : int) -> str:
cls._check_value(value)
try:
text = cls(value).name
@@ -59,7 +73,7 @@ class IntEnum(enum.IntEnum):
return text
@classmethod
- def make(cls, value):
+ def make(cls: Type[TIntEnum], value: Union[int, str]) -> TIntEnum:
"""Convert text or a value into an enumerated type, if possible.
*value*, the ``int`` or ``str`` to convert.
@@ -76,10 +90,7 @@ class IntEnum(enum.IntEnum):
if isinstance(value, str):
return cls.from_text(value)
cls._check_value(value)
- try:
- return cls(value)
- except ValueError:
- return value
+ return cls(value)
@classmethod
def _maximum(cls):
diff --git a/dns/message.py b/dns/message.py
index 3a6f427..2ccdc2b 100644
--- a/dns/message.py
+++ b/dns/message.py
@@ -1730,13 +1730,11 @@ def make_query(
if isinstance(qname, str):
qname = dns.name.from_text(qname, idna_codec=idna_codec)
- the_rdtype = dns.rdatatype.RdataType.make(rdtype)
- the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
m = QueryMessage(id=id)
m.flags = dns.flags.Flag(flags)
- m.find_rrset(
- m.question, qname, the_rdclass, the_rdtype, create=True, force_unique=True
- )
+ m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True)
# only pass keywords on to use_edns if they have been set to a
# non-None value. Setting a field will turn EDNS on if it hasn't
# been configured.
diff --git a/dns/rdata.py b/dns/rdata.py
index d166b8a..66c07ee 100644
--- a/dns/rdata.py
+++ b/dns/rdata.py
@@ -880,16 +880,19 @@ def register_type(
it applies to all classes.
"""
- the_rdtype = dns.rdatatype.RdataType.make(rdtype)
- existing_cls = get_rdata_class(rdclass, the_rdtype)
- if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype):
- raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ existing_cls = get_rdata_class(rdclass, rdtype)
+ if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype):
+ raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
try:
- if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text:
- raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
+ if (
+ rdtype in dns.rdatatype.RdataType
+ and dns.rdatatype.RdataType(rdtype).name != rdtype_text
+ ):
+ raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
except ValueError:
pass
- _rdata_classes[(rdclass, the_rdtype)] = getattr(
+ _rdata_classes[(rdclass, rdtype)] = getattr(
implementation, rdtype_text.replace("-", "_")
)
- dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton)
+ dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton)
diff --git a/dns/rdataset.py b/dns/rdataset.py
index c0ede42..b562d1f 100644
--- a/dns/rdataset.py
+++ b/dns/rdataset.py
@@ -471,9 +471,9 @@ def from_text_list(
Returns a ``dns.rdataset.Rdataset`` object.
"""
- the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
- the_rdtype = dns.rdatatype.RdataType.make(rdtype)
- r = Rdataset(the_rdclass, the_rdtype)
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ r = Rdataset(rdclass, rdtype)
r.update_ttl(ttl)
for t in text_rdatas:
rd = dns.rdata.from_text(
diff --git a/dns/resolver.py b/dns/resolver.py
index cd041d9..61d0052 100644
--- a/dns/resolver.py
+++ b/dns/resolver.py
@@ -647,17 +647,17 @@ class _Resolution:
) -> None:
if isinstance(qname, str):
qname = dns.name.from_text(qname, None)
- the_rdtype = dns.rdatatype.RdataType.make(rdtype)
- if dns.rdatatype.is_metatype(the_rdtype):
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ if dns.rdatatype.is_metatype(rdtype):
raise NoMetaqueries
- the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
- if dns.rdataclass.is_metaclass(the_rdclass):
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ if dns.rdataclass.is_metaclass(rdclass):
raise NoMetaqueries
self.resolver = resolver
self.qnames_to_try = resolver._get_qnames_to_try(qname, search)
self.qnames = self.qnames_to_try[:]
- self.rdtype = the_rdtype
- self.rdclass = the_rdclass
+ self.rdtype = rdtype
+ self.rdclass = rdclass
self.tcp = tcp
self.raise_on_no_answer = raise_on_no_answer
self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {}
diff --git a/dns/rrset.py b/dns/rrset.py
index 3f22a90..0519051 100644
--- a/dns/rrset.py
+++ b/dns/rrset.py
@@ -214,9 +214,9 @@ def from_text_list(
if isinstance(name, str):
name = dns.name.from_text(name, None, idna_codec=idna_codec)
- the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
- the_rdtype = dns.rdatatype.RdataType.make(rdtype)
- r = RRset(name, the_rdclass, the_rdtype)
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ r = RRset(name, rdclass, rdtype)
r.update_ttl(ttl)
for t in text_rdatas:
rd = dns.rdata.from_text(
diff --git a/dns/update.py b/dns/update.py
index b10f6ac..2219ec5 100644
--- a/dns/update.py
+++ b/dns/update.py
@@ -335,12 +335,12 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals]
True,
)
else:
- the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
self.find_rrset(
self.prerequisite,
name,
dns.rdataclass.NONE,
- the_rdtype,
+ rdtype,
dns.rdatatype.NONE,
None,
True,
diff --git a/dns/zone.py b/dns/zone.py
index 35724d7..647538c 100644
--- a/dns/zone.py
+++ b/dns/zone.py
@@ -321,11 +321,11 @@ class Zone(dns.transaction.TransactionManager):
Returns a ``dns.rdataset.Rdataset``.
"""
- the_name = self._validate_name(name)
- the_rdtype = dns.rdatatype.RdataType.make(rdtype)
- the_covers = dns.rdatatype.RdataType.make(covers)
- node = self.find_node(the_name, create)
- return node.find_rdataset(self.rdclass, the_rdtype, the_covers, create)
+ name = self._validate_name(name)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ covers = dns.rdatatype.RdataType.make(covers)
+ node = self.find_node(name, create)
+ return node.find_rdataset(self.rdclass, rdtype, covers, create)
def get_rdataset(
self,
@@ -404,14 +404,14 @@ class Zone(dns.transaction.TransactionManager):
types were aggregated into a single RRSIG rdataset.
"""
- the_name = self._validate_name(name)
- the_rdtype = dns.rdatatype.RdataType.make(rdtype)
- the_covers = dns.rdatatype.RdataType.make(covers)
- node = self.get_node(the_name)
+ name = self._validate_name(name)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ covers = dns.rdatatype.RdataType.make(covers)
+ node = self.get_node(name)
if node is not None:
- node.delete_rdataset(self.rdclass, the_rdtype, the_covers)
+ node.delete_rdataset(self.rdclass, rdtype, covers)
if len(node) == 0:
- self.delete_node(the_name)
+ self.delete_node(name)
def replace_rdataset(
self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset
@@ -484,10 +484,10 @@ class Zone(dns.transaction.TransactionManager):
"""
vname = self._validate_name(name)
- the_rdtype = dns.rdatatype.RdataType.make(rdtype)
- the_covers = dns.rdatatype.RdataType.make(covers)
- rdataset = self.nodes[vname].find_rdataset(self.rdclass, the_rdtype, the_covers)
- rrset = dns.rrset.RRset(vname, self.rdclass, the_rdtype, the_covers)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ covers = dns.rdatatype.RdataType.make(covers)
+ rdataset = self.nodes[vname].find_rdataset(self.rdclass, rdtype, covers)
+ rrset = dns.rrset.RRset(vname, self.rdclass, rdtype, covers)
rrset.update(rdataset)
return rrset
diff --git a/dns/zonefile.py b/dns/zonefile.py
index fad78c3..48bedad 100644
--- a/dns/zonefile.py
+++ b/dns/zonefile.py
@@ -710,26 +710,26 @@ def read_rrsets(
if isinstance(default_ttl, str):
default_ttl = dns.ttl.from_text(default_ttl)
if rdclass is not None:
- the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
else:
- the_rdclass = None
- the_default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass)
+ rdclass = None
+ default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass)
if rdtype is not None:
- the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
else:
- the_rdtype = None
+ rdtype = None
manager = RRSetsReaderManager(origin, relativize, default_rdclass)
with manager.writer(True) as txn:
tok = dns.tokenizer.Tokenizer(text, "<input>", idna_codec=idna_codec)
reader = Reader(
tok,
- the_default_rdclass,
+ default_rdclass,
txn,
allow_directives=False,
force_name=name,
force_ttl=ttl,
- force_rdclass=the_rdclass,
- force_rdtype=the_rdtype,
+ force_rdclass=rdclass,
+ force_rdtype=rdtype,
default_ttl=default_ttl,
)
reader.read()