← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~cgrabowski/maas:fix_getaddrinfo_for_v6_only into maas:master

 

Christian Grabowski has proposed merging ~cgrabowski/maas:fix_getaddrinfo_for_v6_only into maas:master.

Commit message:
use a higher level resolver when stack is v6 only



Requested reviews:
  MAAS Maintainers (maas-maintainers)
Related bugs:
  Bug #2020142 in MAAS: "commission fails if `maas-url` uses an IPv6"
  https://bugs.launchpad.net/maas/+bug/2020142

For more details, see:
https://code.launchpad.net/~cgrabowski/maas/+git/maas/+merge/443923
-- 
Your team MAAS Committers is subscribed to branch maas:master.
diff --git a/src/provisioningserver/utils/network.py b/src/provisioningserver/utils/network.py
index 4ac6962..23dfa2e 100644
--- a/src/provisioningserver/utils/network.py
+++ b/src/provisioningserver/utils/network.py
@@ -751,6 +751,36 @@ def get_all_interface_addresses() -> Iterable[str]:
         yield from get_all_addresses_for_interface(interface)
 
 
+@synchronous
+def safe_getaddrinfo(hostname, port, addr_family, proto):
+    if addr_family in (
+        AF_INET,
+        0,
+    ):  # IPv6-only hosts currently error when using getaddrinfo
+        return getaddrinfo(hostname, port, family=addr_family, proto=proto)
+
+    # TODO use getaddrinfo for all versions once fixed in the stdlib
+    sock_type = (
+        socket.SOCK_STREAM if proto == IPPROTO_TCP else socket.SOCK_DGRAM
+    )
+
+    @inlineCallbacks
+    def _v6_lookup():
+        resolver = getResolver()
+        answers = yield resolver.lookupIPv6Address(hostname)
+        return [
+            (AF_INET6, sock_type, proto, "", (ans.address, port, 0, 0))
+            for ans in answers[0]
+        ]
+
+    try:
+        addr = IPAddress(hostname)
+    except AddrFormatError:
+        return _v6_lookup()
+    else:
+        return [(AF_INET6, sock_type, proto, "", (hostname, port, 0, 0))]
+
+
 def resolve_host_to_addrinfo(
     hostname, ip_version=4, port=0, proto=IPPROTO_TCP
 ):
@@ -767,7 +797,7 @@ def resolve_host_to_addrinfo(
     addr_families = {4: AF_INET, 6: AF_INET6, 0: 0}
     assert ip_version in addr_families
     try:
-        address_info = getaddrinfo(
+        address_info = safe_getaddrinfo(
             hostname, port, family=addr_families[ip_version], proto=proto
         )
     except gaierror as e:
@@ -1309,7 +1339,7 @@ def resolves_to_loopback_address(hostname):
     :return: True if the hostname appears to be a loopback address.
     """
     try:
-        addrinfo = socket.getaddrinfo(hostname, None, proto=IPPROTO_TCP)
+        addrinfo = safe_getaddrinfo(hostname, None, proto=IPPROTO_TCP)
     except socket.gaierror:
         return hostname.lower() in {"localhost", "localhost."}
     else:
diff --git a/src/provisioningserver/utils/tests/test_network.py b/src/provisioningserver/utils/tests/test_network.py
index 474b57c..f4a86d0 100644
--- a/src/provisioningserver/utils/tests/test_network.py
+++ b/src/provisioningserver/utils/tests/test_network.py
@@ -24,6 +24,7 @@ from testtools.matchers import (
     StartsWith,
 )
 from twisted.internet.defer import inlineCallbacks, succeed
+from twisted.names.dns import Record_AAAA
 from twisted.names.error import (
     AuthoritativeDomainError,
     DNSQueryTimeoutError,
@@ -89,6 +90,7 @@ from provisioningserver.utils.network import (
     resolve_hostname,
     resolves_to_loopback_address,
     reverseResolve,
+    safe_getaddrinfo,
 )
 from provisioningserver.utils.shell import get_env_with_locale
 
@@ -2518,3 +2520,44 @@ class TestGetIfnameForLabel(MAASTestCase):
 
     def test_scenarios(self):
         self.assertEqual(self.expected, get_ifname_for_label(self.input))
+
+
+class TestSafeGetaddrinfo(MAASTestCase):
+    def test_safe_getaddrinfo_uses_getaddrinfo_for_v4(self):
+        getaddrinfo = self.patch(network_module, "getaddrinfo")
+        expected = [
+            (
+                AF_INET,
+                socket.SOCK_STREAM,
+                IPPROTO_TCP,
+                "",
+                ("0.0.0.0", 53, 0, 0),
+            )
+        ]
+        getaddrinfo.return_value = expected
+        result = safe_getaddrinfo("0.0.0.0", 53, AF_INET, IPPROTO_TCP)
+        self.assertCountEqual(expected, result)
+
+    def test_safe_getaddrinfo_uses_getaddrinfo_for_dual_stack(self):
+        getaddrinfo = self.patch(network_module, "getaddrinfo")
+        expected = [
+            (0, socket.SOCK_STREAM, IPPROTO_TCP, "", ("0.0.0.0", 53, 0, 0))
+        ]
+        getaddrinfo.return_value = expected
+        result = safe_getaddrinfo("0.0.0.0", 53, 0, IPPROTO_TCP)
+        self.assertCountEqual(expected, result)
+
+    def test_safe_getaddrinfo_uses_resolver_for_v6(self):
+        expected = [
+            (AF_INET6, socket.SOCK_STREAM, IPPROTO_TCP, "", ("::", 53, 0, 0))
+        ]
+        mock_resolver = Mock()
+        mock_resolver.lookupIPv6Address = lambda _: (
+            [Record_AAAA(address="::")],
+            [],
+            [],
+        )
+        get_resolver = self.patch(network_module, "getResolver")
+        get_resolver.return_value = mock_resolver
+        result = safe_getaddrinfo("::", 53, AF_INET6, IPPROTO_TCP)
+        self.assertCountEqual(expected, result)

Follow ups