# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license import unittest import dns.exception import dns.flags import dns.message import dns.renderer import dns.tsig import dns.tsigkeyring basic_answer = """flags QR edns 0 payload 4096 ;QUESTION foo.example. IN A ;ANSWER foo.example. 30 IN A 10.0.0.1 foo.example. 30 IN A 10.0.0.2 """ class RendererTestCase(unittest.TestCase): def test_basic(self): r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512) qname = dns.name.from_text("foo.example") r.add_question(qname, dns.rdatatype.A) rds = dns.rdataset.from_text("in", "a", 30, "10.0.0.1", "10.0.0.2") r.add_rdataset(dns.renderer.ANSWER, qname, rds) r.add_edns(0, 0, 4096) r.write_header() wire = r.get_wire() message = dns.message.from_wire(wire) expected = dns.message.from_text(basic_answer) # Our rendered message purposely has a random query id so we # exercise that code, so copy it into the expected message. expected.id = message.id self.assertEqual(message, expected) def test_tsig(self): r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512) qname = dns.name.from_text("foo.example") r.add_question(qname, dns.rdatatype.A) keyring = dns.tsigkeyring.from_text({"key": "12345678"}) keyname = next(iter(keyring)) r.write_header() r.add_tsig( keyname, keyring[keyname], 300, r.id, 0, b"", b"", dns.tsig.HMAC_SHA256 ) wire = r.get_wire() message = dns.message.from_wire(wire, keyring=keyring) expected = dns.message.make_query(qname, dns.rdatatype.A) expected.id = message.id self.assertEqual(message, expected) def test_multi_tsig(self): qname = dns.name.from_text("foo.example") keyring = dns.tsigkeyring.from_text({"key": "12345678"}) keyname = next(iter(keyring)) r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512) r.add_question(qname, dns.rdatatype.A) r.write_header() ctx = r.add_multi_tsig( None, keyname, keyring[keyname], 300, r.id, 0, b"", b"", dns.tsig.HMAC_SHA256, ) wire = r.get_wire() message = dns.message.from_wire(wire, keyring=keyring, multi=True) expected = dns.message.make_query(qname, dns.rdatatype.A) expected.id = message.id self.assertEqual(message, expected) r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512) r.add_question(qname, dns.rdatatype.A) r.write_header() ctx = r.add_multi_tsig( ctx, keyname, keyring[keyname], 300, r.id, 0, b"", b"", dns.tsig.HMAC_SHA256 ) wire = r.get_wire() message = dns.message.from_wire( wire, keyring=keyring, tsig_ctx=message.tsig_ctx, multi=True ) expected = dns.message.make_query(qname, dns.rdatatype.A) expected.id = message.id self.assertEqual(message, expected) def test_going_backwards_fails(self): r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512) qname = dns.name.from_text("foo.example") r.add_question(qname, dns.rdatatype.A) r.add_edns(0, 0, 4096) rds = dns.rdataset.from_text("in", "a", 30, "10.0.0.1", "10.0.0.2") def bad(): r.add_rdataset(dns.renderer.ANSWER, qname, rds) self.assertRaises(dns.exception.FormError, bad) def test_reservation(self): r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512) r.reserve(100) assert r.max_size == 412 r.release_reserved() assert r.max_size == 512 with self.assertRaises(ValueError): r.reserve(-1) with self.assertRaises(ValueError): r.reserve(513)