← Back to team overview

cloud-init-dev team mailing list archive

[Merge] ~smoser/cloud-init:cleanup/openssl-manager-no-cleanup into cloud-init:master

 

Scott Moser has proposed merging ~smoser/cloud-init:cleanup/openssl-manager-no-cleanup into cloud-init:master.

Commit message:
Azure: Make the openssl manager object not need cleanup.

The 'clean_up' method on the OpenSSLManager object was annoying
as it had to be called or it would leave temp files around.

The change here makes it not need a persistent temporary storage
but rather make a temp dir and clean up as it needs.


Requested reviews:
  cloud-init commiters (cloud-init-dev)

For more details, see:
https://code.launchpad.net/~smoser/cloud-init/+git/cloud-init/+merge/363757

Azure: Make the openssl manager object not need cleanup.

The 'clean_up' method on the OpenSSLManager object was annoying
as it had to be called or it would leave temp files around.

The change here makes it not need a persistent temporary storage
but rather make a temp dir and clean up as it needs.

-- 
Your team cloud-init commiters is requested to review the proposed merge of ~smoser/cloud-init:cleanup/openssl-manager-no-cleanup into cloud-init:master.
diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py
index 2829dd2..1053f69 100644
--- a/cloudinit/sources/helpers/azure.py
+++ b/cloudinit/sources/helpers/azure.py
@@ -4,6 +4,7 @@ import json
 import logging
 import os
 import re
+import shutil
 import socket
 import struct
 import time
@@ -20,16 +21,6 @@ from cloudinit import util
 LOG = logging.getLogger(__name__)
 
 
-@contextmanager
-def cd(newdir):
-    prevdir = os.getcwd()
-    os.chdir(os.path.expanduser(newdir))
-    try:
-        yield
-    finally:
-        os.chdir(prevdir)
-
-
 def _get_dhcp_endpoint_option_name():
     if util.is_FreeBSD():
         azure_endpoint = "option-245"
@@ -105,43 +96,34 @@ class GoalState(object):
 
 
 class OpenSSLManager(object):
-
-    certificate_names = {
-        'private_key': 'TransportPrivate.pem',
-        'certificate': 'TransportCert.pem',
-    }
-
     def __init__(self):
-        self.tmpdir = temp_utils.mkdtemp()
-        self.certificate = None
-        self.generate_certificate()
-
-    def clean_up(self):
-        util.del_dir(self.tmpdir)
-
-    def generate_certificate(self):
-        LOG.debug('Generating certificate for communication with fabric...')
-        if self.certificate is not None:
-            LOG.debug('Certificate already generated.')
-            return
-        with cd(self.tmpdir):
-            util.subp([
-                'openssl', 'req', '-x509', '-nodes', '-subj',
-                '/CN=LinuxTransport', '-days', '32768', '-newkey', 'rsa:2048',
-                '-keyout', self.certificate_names['private_key'],
-                '-out', self.certificate_names['certificate'],
-            ])
-            certificate = ''
-            for line in open(self.certificate_names['certificate']):
-                if "CERTIFICATE" not in line:
-                    certificate += line.rstrip()
-            self.certificate = certificate
-        LOG.debug('New certificate generated.')
+        self.certificate, self.private_key = self._generate_certificate()
+        self.certificate_data = ''.join(
+            [l for l in self.certificate.splitlines()
+             if not l.startswith("---")])
+
+    @staticmethod
+    def _generate_certificate():
+        tmpd = temp_utils.mkdtemp(suffix="sslmanager.d")
+        key_out = os.path.join(tmpd, 'private.key')
+        cert_out = os.path.join(tmpd, 'out.cert')
+        cmd = ['openssl', 'req', '-x509', '-nodes', '-subj',
+               '/CN=LinuxTransport', '-days', '32768', '-newkey', 'rsa:2048',
+               '-keyout', key_out, '-out', cert_out]
+        try:
+            util.subp(cmd)
+            with open(key_out) as fp:
+                private = fp.read()
+            with open(cert_out) as fp:
+                cert = fp.read()
+        finally:
+            shutil.rmtree(tmpd)
+        return cert, private
 
     @staticmethod
     def _run_x509_action(action, cert):
         cmd = ['openssl', 'x509', '-noout', action]
-        result, _ = util.subp(cmd, data=cert)
+        result, _ = util.subp(cmd, data=cert.encode('utf-8'))
         return result
 
     def _get_ssh_key_from_cert(self, certificate):
@@ -170,19 +152,29 @@ class OpenSSLManager(object):
         tag = ElementTree.fromstring(certificates_xml).find('.//Data')
         certificates_content = tag.text
         lines = [
-            b'MIME-Version: 1.0',
-            b'Content-Disposition: attachment; filename="Certificates.p7m"',
-            b'Content-Type: application/x-pkcs7-mime; name="Certificates.p7m"',
-            b'Content-Transfer-Encoding: base64',
-            b'',
-            certificates_content.encode('utf-8'),
+            'MIME-Version: 1.0',
+            'Content-Disposition: attachment; filename="Certificates.p7m"',
+            'Content-Type: application/x-pkcs7-mime; name="Certificates.p7m"',
+            'Content-Transfer-Encoding: base64',
+            '',
+            certificates_content
         ]
-        with cd(self.tmpdir):
-            out, _ = util.subp(
-                'openssl cms -decrypt -in /dev/stdin -inkey'
-                ' {private_key} -recip {certificate} | openssl pkcs12 -nodes'
-                ' -password pass:'.format(**self.certificate_names),
-                shell=True, data=b'\n'.join(lines))
+        data = '\n'.join(lines).encode('utf-8')
+        tmpd = temp_utils.mkdtemp(suffix="sslmanager.d")
+        pkey_file = os.path.join(tmpd, "private.key")
+        cert_file = os.path.join(tmpd, "certificate")
+        cmd = ('openssl cms -decrypt -in /dev/stdin -inkey '
+               '{pkey_file} -recip {cert_file} | openssl pkcs12 -nodes'
+               ' -password pass:').format(pkey_file=pkey_file,
+                                          cert_file=cert_file)
+        try:
+            with open(pkey_file, "w") as fp:
+                fp.write(self.private_key)
+            with open(cert_file, "w") as fp:
+                fp.write(self.certificate)
+            out, _ = util.subp(cmd, shell=True, data=data)
+        finally:
+            shutil.rmtree(tmpd)
         return out
 
     def parse_certificates(self, certificates_xml):
@@ -233,10 +225,6 @@ class WALinuxAgentShim(object):
         self.openssl_manager = None
         self.lease_file = fallback_lease_file
 
-    def clean_up(self):
-        if self.openssl_manager is not None:
-            self.openssl_manager.clean_up()
-
     @staticmethod
     def _get_hooks_dir():
         _paths = stages.Init()
@@ -355,7 +343,7 @@ class WALinuxAgentShim(object):
     def register_with_azure_and_fetch_data(self, pubkey_info=None):
         if self.openssl_manager is None:
             self.openssl_manager = OpenSSLManager()
-        http_client = AzureEndpointHttpClient(self.openssl_manager.certificate)
+        http_client = AzureEndpointHttpClient(self.openssl_manager.certificate_data)
         LOG.info('Registering with Azure...')
         attempts = 0
         while True:
@@ -423,9 +411,6 @@ def get_metadata_from_fabric(fallback_lease_file=None, dhcp_opts=None,
                              pubkey_info=None):
     shim = WALinuxAgentShim(fallback_lease_file=fallback_lease_file,
                             dhcp_options=dhcp_opts)
-    try:
-        return shim.register_with_azure_and_fetch_data(pubkey_info=pubkey_info)
-    finally:
-        shim.clean_up()
+    return shim.register_with_azure_and_fetch_data(pubkey_info=pubkey_info)
 
 # vi: ts=4 expandtab
diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py
index 0255616..bc13a73 100644
--- a/tests/unittests/test_datasource/test_azure_helper.py
+++ b/tests/unittests/test_datasource/test_azure_helper.py
@@ -249,48 +249,6 @@ class TestAzureEndpointHttpClient(CiTestCase):
             self.read_file_or_url.call_args)
 
 
-class TestOpenSSLManager(CiTestCase):
-
-    def setUp(self):
-        super(TestOpenSSLManager, self).setUp()
-        patches = ExitStack()
-        self.addCleanup(patches.close)
-
-        self.subp = patches.enter_context(
-            mock.patch.object(azure_helper.util, 'subp'))
-        try:
-            self.open = patches.enter_context(
-                mock.patch('__builtin__.open'))
-        except ImportError:
-            self.open = patches.enter_context(
-                mock.patch('builtins.open'))
-
-    @mock.patch.object(azure_helper, 'cd', mock.MagicMock())
-    @mock.patch.object(azure_helper.temp_utils, 'mkdtemp')
-    def test_openssl_manager_creates_a_tmpdir(self, mkdtemp):
-        manager = azure_helper.OpenSSLManager()
-        self.assertEqual(mkdtemp.return_value, manager.tmpdir)
-
-    def test_generate_certificate_uses_tmpdir(self):
-        subp_directory = {}
-
-        def capture_directory(*args, **kwargs):
-            subp_directory['path'] = os.getcwd()
-
-        self.subp.side_effect = capture_directory
-        manager = azure_helper.OpenSSLManager()
-        self.assertEqual(manager.tmpdir, subp_directory['path'])
-        manager.clean_up()
-
-    @mock.patch.object(azure_helper, 'cd', mock.MagicMock())
-    @mock.patch.object(azure_helper.temp_utils, 'mkdtemp', mock.MagicMock())
-    @mock.patch.object(azure_helper.util, 'del_dir')
-    def test_clean_up(self, del_dir):
-        manager = azure_helper.OpenSSLManager()
-        manager.clean_up()
-        self.assertEqual([mock.call(manager.tmpdir)], del_dir.call_args_list)
-
-
 class TestOpenSSLManagerActions(CiTestCase):
 
     def setUp(self):
@@ -356,8 +314,9 @@ class TestWALinuxAgentShim(CiTestCase):
     def test_http_client_uses_certificate(self):
         shim = wa_shim()
         shim.register_with_azure_and_fetch_data()
+        data = self.OpenSSLManager.return_value.certificate_data
         self.assertEqual(
-            [mock.call(self.OpenSSLManager.return_value.certificate)],
+            [mock.call(data)],
             self.AzureEndpointHttpClient.call_args_list)
 
     def test_correct_url_used_for_goalstate(self):
@@ -424,17 +383,6 @@ class TestWALinuxAgentShim(CiTestCase):
         self.assertIn('TestContainerId', posted_document)
         self.assertIn('TestInstanceId', posted_document)
 
-    def test_clean_up_can_be_called_at_any_time(self):
-        shim = wa_shim()
-        shim.clean_up()
-
-    def test_clean_up_will_clean_up_openssl_manager_if_instantiated(self):
-        shim = wa_shim()
-        shim.register_with_azure_and_fetch_data()
-        shim.clean_up()
-        self.assertEqual(
-            1, self.OpenSSLManager.return_value.clean_up.call_count)
-
     def test_failure_to_fetch_goalstate_bubbles_up(self):
         class SentinelException(Exception):
             pass
@@ -454,21 +402,6 @@ class TestGetMetadataFromFabric(CiTestCase):
             shim.return_value.register_with_azure_and_fetch_data.return_value,
             ret)
 
-    @mock.patch.object(azure_helper, 'WALinuxAgentShim')
-    def test_success_calls_clean_up(self, shim):
-        azure_helper.get_metadata_from_fabric()
-        self.assertEqual(1, shim.return_value.clean_up.call_count)
-
-    @mock.patch.object(azure_helper, 'WALinuxAgentShim')
-    def test_failure_in_registration_calls_clean_up(self, shim):
-        class SentinelException(Exception):
-            pass
-        shim.return_value.register_with_azure_and_fetch_data.side_effect = (
-            SentinelException)
-        self.assertRaises(SentinelException,
-                          azure_helper.get_metadata_from_fabric)
-        self.assertEqual(1, shim.return_value.clean_up.call_count)
-
 
 class TestExtractIpAddressFromNetworkd(CiTestCase):
 

Follow ups