← Back to team overview

cloud-init-dev team mailing list archive

[Merge] ~mgerdts/cloud-init:lp1667735 into cloud-init:master

 

Mike Gerdts has proposed merging ~mgerdts/cloud-init:lp1667735 into cloud-init:master.

Commit message:
DataSourceSmartOS: hang when metadata service is down

    If the metadata service in the host is down while a guest that uses
    DataSourceSmartOS is booting, the request from the guest falls into the bit
    bucket.  When the metadata service is eventually started, the guest has no
    awareness of this and does not resend the request.  This results in
    cloud-init hanging forever with a guest reboot as the only recovery option.

    This fix updates the metadata protocol to implement the initialization
    phase, just as is implemented by mdata-get and related utilities.  The
    initialization phase includes draining all pending data from the serial
    port, writing an empty command and getting an expected error message in
    reply.  If the initialization phase times out, it is retried every five
    seconds.  Each timeout results in a warning message: "Timeout while
    initializing metadata client. Is the host metadata service running?"  By
    default, warning messages are logged to the console, thus the reason for a
    hung boot is readily apparent.

    LP: #1667735

Requested reviews:
  cloud-init commiters (cloud-init-dev)
Related bugs:
  Bug #1667735 in cloud-init: "cloud-init doesn't retry metadata lookups and hangs forever if metadata is down"
  https://bugs.launchpad.net/cloud-init/+bug/1667735

For more details, see:
https://code.launchpad.net/~mgerdts/cloud-init/+git/cloud-init/+merge/343118
-- 
Your team cloud-init commiters is requested to review the proposed merge of ~mgerdts/cloud-init:lp1667735 into cloud-init:master.
diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py
index 86bfa5d..5717dae 100644
--- a/cloudinit/sources/DataSourceSmartOS.py
+++ b/cloudinit/sources/DataSourceSmartOS.py
@@ -1,4 +1,5 @@
 # Copyright (C) 2013 Canonical Ltd.
+# Copyright (c) 2018, Joyent, Inc.
 #
 # Author: Ben Howard <ben.howard@xxxxxxxxxxxxx>
 #
@@ -21,6 +22,7 @@
 
 import base64
 import binascii
+import errno
 import json
 import os
 import random
@@ -229,6 +231,9 @@ class DataSourceSmartOS(sources.DataSource):
                       self.md_client)
             return False
 
+        # Open once for many requests, rather than once for each request
+        self.md_client.open_transport()
+
         for ci_noun, attribute in SMARTOS_ATTRIB_MAP.items():
             smartos_noun, strip = attribute
             md[ci_noun] = self.md_client.get(smartos_noun, strip=strip)
@@ -236,6 +241,8 @@ class DataSourceSmartOS(sources.DataSource):
         for ci_noun, smartos_noun in SMARTOS_ATTRIB_JSON.items():
             md[ci_noun] = self.md_client.get_json(smartos_noun)
 
+        self.md_client.close_transport()
+
         # @datadictionary: This key may contain a program that is written
         # to a file in the filesystem of the guest on each boot and then
         # executed. It may be of any format that would be considered
@@ -316,6 +323,10 @@ class JoyentMetadataFetchException(Exception):
     pass
 
 
+class JoyentMetadataTimeoutException(JoyentMetadataFetchException):
+    pass
+
+
 class JoyentMetadataClient(object):
     """
     A client implementing v2 of the Joyent Metadata Protocol Specification.
@@ -360,6 +371,45 @@ class JoyentMetadataClient(object):
         LOG.debug('Value "%s" found.', value)
         return value
 
+    def _readline(self):
+        """
+           Reads a line a byte at a time until \n is encountered.  Returns an
+           ascii string with the trailing newline removed.
+
+           If a timeout (per-byte) is set and it expires, a
+           JoyentMetadataFetchException will be thrown.
+        """
+        response = bytearray()
+        while True:
+            try:
+                byte = self.fp.read(1)
+                if len(byte) == 0:
+                    raise JoyentMetadataTimeoutException(
+                        "Partial response: '%s'" % response.decode('ascii'))
+                if ord(byte) == ord(b'\n'):
+                    return response.decode('ascii')
+                response.extend([ord(byte)])
+            except OSError as exc:
+                if exc.errno == errno.EAGAIN:
+                    raise JoyentMetadataTimeoutException(
+                        "Partial response: '%s'" % response.decode('ascii'))
+                raise
+
+    def _write(self, msg):
+        self.fp.write(msg.encode('ascii'))
+        self.fp.flush()
+
+    def _negotiate(self):
+        LOG.debug('Negotiating protocol V2')
+        self._write('NEGOTIATE V2\n')
+        self.fp.flush()
+        response = self._readline()
+        LOG.debug('read "%s"' % response)
+        if response != 'V2_OK':
+            raise JoyentMetadataFetchException(
+                'Invalid response "%s" to "NEGOTIATE V2"' % response)
+        LOG.debug('Negotiation complete')
+
     def request(self, rtype, param=None):
         request_id = '{0:08x}'.format(random.randint(0, 0xffffffff))
         message_body = ' '.join((request_id, rtype,))
@@ -374,18 +424,11 @@ class JoyentMetadataClient(object):
             self.open_transport()
             need_close = True
 
-        self.fp.write(msg.encode('ascii'))
-        self.fp.flush()
-
-        response = bytearray()
-        response.extend(self.fp.read(1))
-        while response[-1:] != b'\n':
-            response.extend(self.fp.read(1))
-
+        self._write(msg)
+        response = self._readline()
         if need_close:
             self.close_transport()
 
-        response = response.rstrip().decode('ascii')
         LOG.debug('Read "%s" from metadata transport.', response)
 
         if 'SUCCESS' not in response:
@@ -450,6 +493,7 @@ class JoyentMetadataSocketClient(JoyentMetadataClient):
         sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
         sock.connect(self.socketpath)
         self.fp = sock.makefile('rwb')
+        self._negotiate()
 
     def exists(self):
         return os.path.exists(self.socketpath)
@@ -459,8 +503,9 @@ class JoyentMetadataSocketClient(JoyentMetadataClient):
 
 
 class JoyentMetadataSerialClient(JoyentMetadataClient):
-    def __init__(self, device, timeout=10, smartos_type=SMARTOS_ENV_KVM):
-        super(JoyentMetadataSerialClient, self).__init__(smartos_type)
+    def __init__(self, device, timeout=10, smartos_type=SMARTOS_ENV_KVM,
+                 fp=None):
+        super(JoyentMetadataSerialClient, self).__init__(smartos_type, fp)
         self.device = device
         self.timeout = timeout
 
@@ -468,10 +513,50 @@ class JoyentMetadataSerialClient(JoyentMetadataClient):
         return os.path.exists(self.device)
 
     def open_transport(self):
-        ser = serial.Serial(self.device, timeout=self.timeout)
-        if not ser.isOpen():
-            raise SystemError("Unable to open %s" % self.device)
-        self.fp = ser
+        if self.fp is None:
+            ser = serial.Serial(self.device, timeout=self.timeout)
+            if not ser.isOpen():
+                raise SystemError("Unable to open %s" % self.device)
+            self.fp = ser
+        self._flush()
+        self._negotiate()
+
+    def _flush(self):
+        LOG.debug('Flushing input')
+        # Read any pending data
+        timeout = self.fp.timeout
+        self.fp.timeout = 0.1
+        while True:
+            try:
+                self._readline()
+            except JoyentMetadataTimeoutException:
+                break
+        LOG.debug('Input empty')
+
+        # Send a newline and expect "invalid command".  Keep trying until
+        # successful.  Retry rather frequently so that the "Is the host
+        # metadata service running" appears on the console soon after someone
+        # attaches in an effort to debug.
+        if timeout > 5:
+            self.fp.timeout = 5
+        else:
+            self.fp.timeout = timeout
+        while True:
+            LOG.debug('Writing newline, expecting "invalid command"')
+            self._write('\n')
+            try:
+                response = self._readline()
+                if response == 'invalid command':
+                    break
+                if response == 'FAILURE':
+                    LOG.debug('Got "FAILURE".  Retrying.')
+                    continue
+                LOG.warning('Unexpected response "%s" during flush', response)
+            except JoyentMetadataTimeoutException:
+                LOG.warning('Timeout while initializing metadata client. ' +
+                            'Is the host metadata service running?')
+        LOG.debug('Got "invalid command".  Flush complete.')
+        self.fp.timeout = timeout
 
     def __repr__(self):
         return "%s(device=%s, timeout=%s)" % (
diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py
index 88bae5f..6a25af4 100644
--- a/tests/unittests/test_datasource/test_smartos.py
+++ b/tests/unittests/test_datasource/test_smartos.py
@@ -1,4 +1,5 @@
 # Copyright (C) 2013 Canonical Ltd.
+# Copyright (c) 2018, Joyent, Inc.
 #
 # Author: Ben Howard <ben.howard@xxxxxxxxxxxxx>
 #
@@ -324,6 +325,7 @@ class PsuedoJoyentClient(object):
         if data is None:
             data = MOCK_RETURNS.copy()
         self.data = data
+        self._is_open = False
         return
 
     def get(self, key, default=None, strip=False):
@@ -344,6 +346,14 @@ class PsuedoJoyentClient(object):
     def exists(self):
         return True
 
+    def open_transport(self):
+        assert(not self._is_open)
+        self._is_open = True
+
+    def close_transport(self):
+        assert(self._is_open)
+        self._is_open = False
+
 
 class TestSmartOSDataSource(FilesystemMockingTestCase):
     def setUp(self):
@@ -636,6 +646,11 @@ class TestJoyentMetadataClient(FilesystemMockingTestCase):
         return DataSourceSmartOS.JoyentMetadataClient(
             fp=self.serial, smartos_type=DataSourceSmartOS.SMARTOS_ENV_KVM)
 
+    def _get_serial_client(self):
+        self.serial.timeout = 1
+        return DataSourceSmartOS.JoyentMetadataSerialClient(None,
+                                                            fp=self.serial)
+
     def assertEndsWith(self, haystack, prefix):
         self.assertTrue(haystack.endswith(prefix),
                         "{0} does not end with '{1}'".format(
@@ -646,6 +661,9 @@ class TestJoyentMetadataClient(FilesystemMockingTestCase):
                         "{0} does not start with '{1}'".format(
                             repr(haystack), prefix))
 
+    def assertNoMoreSideEffects(self, obj):
+        self.assertRaises(StopIteration, obj)
+
     def test_get_metadata_writes_a_single_line(self):
         client = self._get_client()
         client.get('some_key')
@@ -737,6 +755,48 @@ class TestJoyentMetadataClient(FilesystemMockingTestCase):
         client._checksum = lambda _: self.response_parts['crc']
         self.assertIsNone(client.get('some_key'))
 
+    def test_negotiate(self):
+        client = self._get_client()
+        client.fp.read.side_effect = list('V2_OK\n')
+        client._negotiate()
+        self.assertNoMoreSideEffects(client.fp.read)
+
+    def test_negotiate_short_response(self):
+        client = self._get_client()
+        client.fp.read.side_effect = list('V2_OK') + ['']
+        self.assertRaises(DataSourceSmartOS.JoyentMetadataTimeoutException,
+                          client._negotiate)
+        self.assertNoMoreSideEffects(client.fp.read)
+
+    def test_negotiate_bad_response(self):
+        client = self._get_client()
+        client.fp.read.side_effect = list('garbage\nV2_OK\n')
+        self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException,
+                          client._negotiate)
+        self.assertEqual([c for c in client.fp.read.side_effect],
+                         list('V2_OK\n'))
+
+    def test_serial_open_transport(self):
+        client = self._get_serial_client()
+        client.fp.read.side_effect = \
+            list('garbage') + [''] + list('invalid command\nV2_OK\n')
+        client.open_transport()
+        self.assertNoMoreSideEffects(client.fp.read)
+
+    def test_flush_failure(self):
+        client = self._get_serial_client()
+        client.fp.read.side_effect = \
+            list('garbage') + [''] + list('FAILURE\ninvalid command\nV2_OK\n')
+        client.open_transport()
+        self.assertNoMoreSideEffects(client.fp.read)
+
+    def test_flush_many_timeouts(self):
+        client = self._get_serial_client()
+        client.fp.read.side_effect = \
+            [''] * 100 + list('invalid command\nV2_OK\n')
+        client.open_transport()
+        self.assertNoMoreSideEffects(client.fp.read)
+
 
 class TestNetworkConversion(TestCase):
     def test_convert_simple(self):

Follow ups