sts-sponsors team mailing list archive
-
sts-sponsors team
-
Mailing list archive
-
Message #08879
[Merge] ~cgrabowski/maas:fix_getaddrinfo_for_v6_only into maas:master
Christian Grabowski has proposed merging ~cgrabowski/maas:fix_getaddrinfo_for_v6_only into maas:master.
Commit message:
use a higher level resolver when stack is v6 only
Requested reviews:
MAAS Maintainers (maas-maintainers)
Related bugs:
Bug #2020142 in MAAS: "commission fails if `maas-url` uses an IPv6"
https://bugs.launchpad.net/maas/+bug/2020142
For more details, see:
https://code.launchpad.net/~cgrabowski/maas/+git/maas/+merge/443923
--
Your team MAAS Committers is subscribed to branch maas:master.
diff --git a/src/provisioningserver/utils/network.py b/src/provisioningserver/utils/network.py
index 4ac6962..23dfa2e 100644
--- a/src/provisioningserver/utils/network.py
+++ b/src/provisioningserver/utils/network.py
@@ -751,6 +751,36 @@ def get_all_interface_addresses() -> Iterable[str]:
yield from get_all_addresses_for_interface(interface)
+@synchronous
+def safe_getaddrinfo(hostname, port, addr_family, proto):
+ if addr_family in (
+ AF_INET,
+ 0,
+ ): # IPv6-only hosts currently error when using getaddrinfo
+ return getaddrinfo(hostname, port, family=addr_family, proto=proto)
+
+ # TODO use getaddrinfo for all versions once fixed in the stdlib
+ sock_type = (
+ socket.SOCK_STREAM if proto == IPPROTO_TCP else socket.SOCK_DGRAM
+ )
+
+ @inlineCallbacks
+ def _v6_lookup():
+ resolver = getResolver()
+ answers = yield resolver.lookupIPv6Address(hostname)
+ return [
+ (AF_INET6, sock_type, proto, "", (ans.address, port, 0, 0))
+ for ans in answers[0]
+ ]
+
+ try:
+ addr = IPAddress(hostname)
+ except AddrFormatError:
+ return _v6_lookup()
+ else:
+ return [(AF_INET6, sock_type, proto, "", (hostname, port, 0, 0))]
+
+
def resolve_host_to_addrinfo(
hostname, ip_version=4, port=0, proto=IPPROTO_TCP
):
@@ -767,7 +797,7 @@ def resolve_host_to_addrinfo(
addr_families = {4: AF_INET, 6: AF_INET6, 0: 0}
assert ip_version in addr_families
try:
- address_info = getaddrinfo(
+ address_info = safe_getaddrinfo(
hostname, port, family=addr_families[ip_version], proto=proto
)
except gaierror as e:
@@ -1309,7 +1339,7 @@ def resolves_to_loopback_address(hostname):
:return: True if the hostname appears to be a loopback address.
"""
try:
- addrinfo = socket.getaddrinfo(hostname, None, proto=IPPROTO_TCP)
+ addrinfo = safe_getaddrinfo(hostname, None, proto=IPPROTO_TCP)
except socket.gaierror:
return hostname.lower() in {"localhost", "localhost."}
else:
diff --git a/src/provisioningserver/utils/tests/test_network.py b/src/provisioningserver/utils/tests/test_network.py
index 474b57c..f4a86d0 100644
--- a/src/provisioningserver/utils/tests/test_network.py
+++ b/src/provisioningserver/utils/tests/test_network.py
@@ -24,6 +24,7 @@ from testtools.matchers import (
StartsWith,
)
from twisted.internet.defer import inlineCallbacks, succeed
+from twisted.names.dns import Record_AAAA
from twisted.names.error import (
AuthoritativeDomainError,
DNSQueryTimeoutError,
@@ -89,6 +90,7 @@ from provisioningserver.utils.network import (
resolve_hostname,
resolves_to_loopback_address,
reverseResolve,
+ safe_getaddrinfo,
)
from provisioningserver.utils.shell import get_env_with_locale
@@ -2518,3 +2520,44 @@ class TestGetIfnameForLabel(MAASTestCase):
def test_scenarios(self):
self.assertEqual(self.expected, get_ifname_for_label(self.input))
+
+
+class TestSafeGetaddrinfo(MAASTestCase):
+ def test_safe_getaddrinfo_uses_getaddrinfo_for_v4(self):
+ getaddrinfo = self.patch(network_module, "getaddrinfo")
+ expected = [
+ (
+ AF_INET,
+ socket.SOCK_STREAM,
+ IPPROTO_TCP,
+ "",
+ ("0.0.0.0", 53, 0, 0),
+ )
+ ]
+ getaddrinfo.return_value = expected
+ result = safe_getaddrinfo("0.0.0.0", 53, AF_INET, IPPROTO_TCP)
+ self.assertCountEqual(expected, result)
+
+ def test_safe_getaddrinfo_uses_getaddrinfo_for_dual_stack(self):
+ getaddrinfo = self.patch(network_module, "getaddrinfo")
+ expected = [
+ (0, socket.SOCK_STREAM, IPPROTO_TCP, "", ("0.0.0.0", 53, 0, 0))
+ ]
+ getaddrinfo.return_value = expected
+ result = safe_getaddrinfo("0.0.0.0", 53, 0, IPPROTO_TCP)
+ self.assertCountEqual(expected, result)
+
+ def test_safe_getaddrinfo_uses_resolver_for_v6(self):
+ expected = [
+ (AF_INET6, socket.SOCK_STREAM, IPPROTO_TCP, "", ("::", 53, 0, 0))
+ ]
+ mock_resolver = Mock()
+ mock_resolver.lookupIPv6Address = lambda _: (
+ [Record_AAAA(address="::")],
+ [],
+ [],
+ )
+ get_resolver = self.patch(network_module, "getResolver")
+ get_resolver.return_value = mock_resolver
+ result = safe_getaddrinfo("::", 53, AF_INET6, IPPROTO_TCP)
+ self.assertCountEqual(expected, result)
Follow ups