← Back to team overview

cloud-init-dev team mailing list archive

[Merge] ~smoser/cloud-init:fix/1677205-eol-on-sshd_config into cloud-init:master

 

Scott Moser has proposed merging ~smoser/cloud-init:fix/1677205-eol-on-sshd_config into cloud-init:master.

Commit message:
set_passwords: Add newline to end of sshd config, only restart if updated.

This admittedly does a fairly extensive re-factor to simply add a newline
to the end of sshd_config.

It makes the ssh_config updating portion of set_passwords more testable
and adds tests for that.

The new function is in 'update_ssh_config_lines' which allows you
to update a config with multiple changes even though only a single one
is currently used.

We also only restart the ssh daemon now if a change was made to the
config file.  Before it was always restarted if the user specified
a value for ssh_pwauth other than 'unchanged'.

LP: #1677205


Requested reviews:
  Server Team CI bot (server-team-bot): continuous-integration
  cloud-init commiters (cloud-init-dev)
Related bugs:
  Bug #1677205 in cloud-init: "cloud-init eats final EOL of sshd_config"
  https://bugs.launchpad.net/cloud-init/+bug/1677205

For more details, see:
https://code.launchpad.net/~smoser/cloud-init/+git/cloud-init/+merge/343246

see commit message
-- 
Your team cloud-init commiters is requested to review the proposed merge of ~smoser/cloud-init:fix/1677205-eol-on-sshd_config into cloud-init:master.
diff --git a/cloudinit/config/cc_set_passwords.py b/cloudinit/config/cc_set_passwords.py
index bb24d57..3e5ba4c 100755
--- a/cloudinit/config/cc_set_passwords.py
+++ b/cloudinit/config/cc_set_passwords.py
@@ -68,16 +68,49 @@ import re
 import sys
 
 from cloudinit.distros import ug_util
-from cloudinit import ssh_util
+from cloudinit import log as logging
+from cloudinit.ssh_util import update_ssh_config
 from cloudinit import util
 
 from string import ascii_letters, digits
 
+LOG = logging.getLogger(__name__)
+
 # We are removing certain 'painful' letters/numbers
 PW_SET = (''.join([x for x in ascii_letters + digits
                    if x not in 'loLOI01']))
 
 
+def handle_ssh_pwauth(pw_auth, service_cmd=None, service_name="ssh"):
+    cfg_name = "PasswordAuthentication"
+    if service_cmd is None:
+        service_cmd = ["service"]
+
+    if util.is_true(pw_auth):
+        cfg_val = 'yes'
+    elif util.is_false(pw_auth):
+        cfg_val = 'no'
+    else:
+        bmsg = "Leaving ssh config '%s' unchanged." % cfg_name
+        if pw_auth is None or pw_auth.lower() == 'unchanged':
+            LOG.debug("%s ssh_pwauth=%s", bmsg, pw_auth)
+        else:
+            LOG.warning("%s Unrecognized value: ssh_pwauth=%s", bmsg, pw_auth)
+        return
+
+    updated = update_ssh_config({cfg_name: cfg_val})
+    if not updated:
+        LOG.debug("No need to restart ssh service, %s not updated.", cfg_name)
+        return
+
+    if 'systemctl' in service_cmd:
+        cmd = list(service_cmd) + ["restart", service_name]
+    else:
+        cmd = list(service_cmd) + [service_name, "restart"]
+    util.subp(cmd)
+    LOG.debug("Restarted the ssh daemon.")
+
+
 def handle(_name, cfg, cloud, log, args):
     if len(args) != 0:
         # if run from command line, and give args, wipe the chpasswd['list']
@@ -170,65 +203,9 @@ def handle(_name, cfg, cloud, log, args):
             if expired_users:
                 log.debug("Expired passwords for: %s users", expired_users)
 
-    change_pwauth = False
-    pw_auth = None
-    if 'ssh_pwauth' in cfg:
-        if util.is_true(cfg['ssh_pwauth']):
-            change_pwauth = True
-            pw_auth = 'yes'
-        elif util.is_false(cfg['ssh_pwauth']):
-            change_pwauth = True
-            pw_auth = 'no'
-        elif str(cfg['ssh_pwauth']).lower() == 'unchanged':
-            log.debug('Leaving auth line unchanged')
-            change_pwauth = False
-        elif not str(cfg['ssh_pwauth']).strip():
-            log.debug('Leaving auth line unchanged')
-            change_pwauth = False
-        elif not cfg['ssh_pwauth']:
-            log.debug('Leaving auth line unchanged')
-            change_pwauth = False
-        else:
-            msg = 'Unrecognized value %s for ssh_pwauth' % cfg['ssh_pwauth']
-            util.logexc(log, msg)
-
-    if change_pwauth:
-        replaced_auth = False
-
-        # See: man sshd_config
-        old_lines = ssh_util.parse_ssh_config(ssh_util.DEF_SSHD_CFG)
-        new_lines = []
-        i = 0
-        for (i, line) in enumerate(old_lines):
-            # Keywords are case-insensitive and arguments are case-sensitive
-            if line.key == 'passwordauthentication':
-                log.debug("Replacing auth line %s with %s", i + 1, pw_auth)
-                replaced_auth = True
-                line.value = pw_auth
-            new_lines.append(line)
-
-        if not replaced_auth:
-            log.debug("Adding new auth line %s", i + 1)
-            replaced_auth = True
-            new_lines.append(ssh_util.SshdConfigLine('',
-                                                     'PasswordAuthentication',
-                                                     pw_auth))
-
-        lines = [str(l) for l in new_lines]
-        util.write_file(ssh_util.DEF_SSHD_CFG, "\n".join(lines),
-                        copy_mode=True)
-
-        try:
-            cmd = cloud.distro.init_cmd  # Default service
-            cmd.append(cloud.distro.get_option('ssh_svcname', 'ssh'))
-            cmd.append('restart')
-            if 'systemctl' in cmd:  # Switch action ordering
-                cmd[1], cmd[2] = cmd[2], cmd[1]
-            cmd = filter(None, cmd)  # Remove empty arguments
-            util.subp(cmd)
-            log.debug("Restarted the ssh daemon")
-        except Exception:
-            util.logexc(log, "Restarting of the ssh daemon failed")
+    handle_ssh_pwauth(
+        cfg.get('ssh_pwauth'), service_cmd=cloud.distro.init_cmd,
+        service_name=cloud.distro.get_option('ssh_svcname', 'ssh'))
 
     if len(errors):
         log.debug("%s errors occured, re-raising the last one", len(errors))
diff --git a/cloudinit/config/tests/test_set_passwords.py b/cloudinit/config/tests/test_set_passwords.py
new file mode 100644
index 0000000..b051ec8
--- /dev/null
+++ b/cloudinit/config/tests/test_set_passwords.py
@@ -0,0 +1,71 @@
+# This file is part of cloud-init. See LICENSE file for license information.
+
+import mock
+
+from cloudinit.config import cc_set_passwords as setpass
+from cloudinit.tests.helpers import CiTestCase
+from cloudinit import util
+
+MODPATH = "cloudinit.config.cc_set_passwords."
+
+
+class TestHandleSshPwauth(CiTestCase):
+    """Test cc_set_passwords handling of ssh_pwauth in handle_ssh_pwauth."""
+
+    with_logs = True
+
+    @mock.patch(MODPATH + "util.subp")
+    def test_unknown_value_logs_warning(self, m_subp):
+        setpass.handle_ssh_pwauth("floo")
+        self.assertIn("Unrecognized value: ssh_pwauth=floo",
+                      self.logs.getvalue())
+        m_subp.assert_not_called()
+
+    @mock.patch(MODPATH + "update_ssh_config", return_value=True)
+    @mock.patch(MODPATH + "util.subp")
+    def test_systemctl_as_service_cmd(self, m_subp, m_update_ssh_config):
+        """If systemctl in service cmd: systemctl restart name."""
+        setpass.handle_ssh_pwauth(
+            True, service_cmd=["systemctl"], service_name="myssh")
+        self.assertEqual(mock.call(["systemctl", "restart", "myssh"]),
+                         m_subp.call_args)
+
+    @mock.patch(MODPATH + "update_ssh_config", return_value=True)
+    @mock.patch(MODPATH + "util.subp")
+    def test_service_as_service_cmd(self, m_subp, m_update_ssh_config):
+        """If systemctl in service cmd: systemctl restart name."""
+        setpass.handle_ssh_pwauth(
+            True, service_cmd=["service"], service_name="myssh")
+        self.assertEqual(mock.call(["service", "myssh", "restart"]),
+                         m_subp.call_args)
+
+    @mock.patch(MODPATH + "update_ssh_config", return_value=False)
+    @mock.patch(MODPATH + "util.subp")
+    def test_not_restarted_if_not_updated(self, m_subp, m_update_ssh_config):
+        """If config is not updated, then no system restart should be done."""
+        setpass.handle_ssh_pwauth(True)
+        m_subp.assert_not_called()
+        self.assertIn("No need to restart ssh", self.logs.getvalue())
+
+    @mock.patch(MODPATH + "update_ssh_config", return_value=True)
+    @mock.patch(MODPATH + "util.subp")
+    def test_unchanged_does_nothing(self, m_subp, m_update_ssh_config):
+        """If 'unchanged', then no updates to config and no restart."""
+        setpass.handle_ssh_pwauth(
+            "unchanged", service_cmd=["systemctl"], service_name="myssh")
+        m_update_ssh_config.assert_not_called()
+        m_subp.assert_not_called()
+
+    @mock.patch(MODPATH + "util.subp")
+    def test_valid_change_values(self, m_subp):
+        """If value is a valid changen value, then update should be called."""
+        upname = MODPATH + "update_ssh_config"
+        optname = "PasswordAuthentication"
+        for value in util.FALSE_STRINGS + util.TRUE_STRINGS:
+            optval = "yes" if value in util.TRUE_STRINGS else "no"
+            with mock.patch(upname, return_value=False) as m_update:
+                setpass.handle_ssh_pwauth(value)
+                m_update.assert_called_with({optname: optval})
+        m_subp.assert_not_called()
+
+# vi: ts=4 expandtab
diff --git a/cloudinit/ssh_util.py b/cloudinit/ssh_util.py
index 882517f..3cb235b 100644
--- a/cloudinit/ssh_util.py
+++ b/cloudinit/ssh_util.py
@@ -279,24 +279,28 @@ class SshdConfigLine(object):
 
 
 def parse_ssh_config(fname):
+    if not os.path.isfile(fname):
+        return []
+    return parse_ssh_config_lines(util.load_file(fname).splitlines())
+
+
+def parse_ssh_config_lines(lines):
     # See: man sshd_config
     # The file contains keyword-argument pairs, one per line.
     # Lines starting with '#' and empty lines are interpreted as comments.
     # Note: key-words are case-insensitive and arguments are case-sensitive
-    lines = []
-    if not os.path.isfile(fname):
-        return lines
-    for line in util.load_file(fname).splitlines():
+    ret = []
+    for line in lines:
         line = line.strip()
         if not line or line.startswith("#"):
-            lines.append(SshdConfigLine(line))
+            ret.append(SshdConfigLine(line))
             continue
         try:
             key, val = line.split(None, 1)
         except ValueError:
             key, val = line.split('=', 1)
-        lines.append(SshdConfigLine(line, key, val))
-    return lines
+        ret.append(SshdConfigLine(line, key, val))
+    return ret
 
 
 def parse_ssh_config_map(fname):
@@ -310,4 +314,55 @@ def parse_ssh_config_map(fname):
         ret[line.key] = line.value
     return ret
 
+
+def update_ssh_config(updates, fname=DEF_SSHD_CFG):
+    """Read fname, and update if changes are necessary.
+
+    @return: boolean indicating if an update was done."""
+    lines = parse_ssh_config(fname)
+    changed = update_ssh_config_lines(lines=lines, updates=updates)
+    if changed:
+        util.write_file(
+            fname, "\n".join([str(l) for l in lines]) + "\n", copy_mode=True)
+    return bool(changed)
+
+
+def update_ssh_config_lines(lines, updates):
+    """Update the ssh config lines per updates.
+
+    @param lines: array of SshdConfigLine.  This array is updated in place.
+    @param update: dictionary of desired values {Option: value}
+    @return: A list of keys in updates that were changed."""
+    found = set()
+    changed = []
+
+    # Keywords are case-insensitive and arguments are case-sensitive
+    casemap = dict([(k.lower(), k) for k in updates.keys()])
+
+    for (i, line) in enumerate(lines, start=1):
+        if not line.key:
+            continue
+        if line.key in casemap:
+            key = casemap[line.key]
+            value = updates[key]
+            found.add(key)
+            if line.value == value:
+                LOG.debug("line %d: option %s already set to %s",
+                          i, key, value)
+            else:
+                changed.append(key)
+                LOG.debug("line %d: option %s updated %s -> %s", i,
+                          key, line.value, value)
+                line.value = value
+
+    if len(found) != len(updates):
+        for key, value in updates.items():
+            if key in found:
+                continue
+            changed.append(key)
+            lines.append(SshdConfigLine('', key, value))
+            LOG.debug("line %d: option %s added with %s",
+                      len(lines), key, value)
+    return changed
+
 # vi: ts=4 expandtab
diff --git a/tests/unittests/test_sshutil.py b/tests/unittests/test_sshutil.py
index 4c62c8b..73ae897 100644
--- a/tests/unittests/test_sshutil.py
+++ b/tests/unittests/test_sshutil.py
@@ -4,6 +4,7 @@ from mock import patch
 
 from cloudinit import ssh_util
 from cloudinit.tests import helpers as test_helpers
+from cloudinit import util
 
 
 VALID_CONTENT = {
@@ -56,7 +57,7 @@ TEST_OPTIONS = (
     'user \"root\".\';echo;sleep 10"')
 
 
-class TestAuthKeyLineParser(test_helpers.TestCase):
+class TestAuthKeyLineParser(test_helpers.CiTestCase):
 
     def test_simple_parse(self):
         # test key line with common 3 fields (keytype, base64, comment)
@@ -126,7 +127,7 @@ class TestAuthKeyLineParser(test_helpers.TestCase):
         self.assertFalse(key.valid())
 
 
-class TestUpdateAuthorizedKeys(test_helpers.TestCase):
+class TestUpdateAuthorizedKeys(test_helpers.CiTestCase):
 
     def test_new_keys_replace(self):
         """new entries with the same base64 should replace old."""
@@ -168,7 +169,7 @@ class TestUpdateAuthorizedKeys(test_helpers.TestCase):
         self.assertEqual(expected, found)
 
 
-class TestParseSSHConfig(test_helpers.TestCase):
+class TestParseSSHConfig(test_helpers.CiTestCase):
 
     def setUp(self):
         self.load_file_patch = patch('cloudinit.ssh_util.util.load_file')
@@ -235,4 +236,94 @@ class TestParseSSHConfig(test_helpers.TestCase):
         self.assertEqual('foo', ret[0].key)
         self.assertEqual('bar', ret[0].value)
 
+
+class TestUpdateSshConfigLines(test_helpers.CiTestCase):
+    """Test the update_ssh_config_lines method."""
+    exlines = [
+        "#PasswordAuthentication yes",
+        "UsePAM yes",
+        "# Comment line",
+        "AcceptEnv LANG LC_*",
+        "X11Forwarding no",
+    ]
+    pwauth = "PasswordAuthentication"
+
+    def check_line(self, line, opt, val):
+        self.assertEqual(line.key, opt.lower())
+        self.assertEqual(line.value, val)
+        self.assertIn(opt, str(line))
+        self.assertIn(val, str(line))
+
+    def test_new_option_added(self):
+        """A single update of non-existing option."""
+        lines = ssh_util.parse_ssh_config_lines(list(self.exlines))
+        result = ssh_util.update_ssh_config_lines(lines, {'MyKey': 'MyVal'})
+        self.assertEqual(['MyKey'], result)
+        self.check_line(lines[-1], "MyKey", "MyVal")
+
+    def test_commented_out_not_updated_but_appended(self):
+        """Implementation does not un-comment and update lines."""
+        lines = ssh_util.parse_ssh_config_lines(list(self.exlines))
+        result = ssh_util.update_ssh_config_lines(lines, {self.pwauth: "no"})
+        self.assertEqual([self.pwauth], result)
+        self.check_line(lines[-1], self.pwauth, "no")
+
+    def test_single_option_updated(self):
+        """A single update should have change made and line updated."""
+        opt, val = ("UsePAM", "no")
+        lines = ssh_util.parse_ssh_config_lines(list(self.exlines))
+        result = ssh_util.update_ssh_config_lines(lines, {opt: val})
+        self.assertEqual([opt], result)
+        self.check_line(lines[1], opt, val)
+
+    def test_multiple_updates_with_add(self):
+        """Verify multiple updates some added some changed, some not."""
+        updates = {"UsePAM": "no", "X11Forwarding": "no", "NewOpt": "newval",
+                   "AcceptEnv": "LANG ADD LC_*"}
+        lines = ssh_util.parse_ssh_config_lines(list(self.exlines))
+        result = ssh_util.update_ssh_config_lines(lines, updates)
+        self.assertEqual(set(["UsePAM", "NewOpt", "AcceptEnv"]), set(result))
+        self.check_line(lines[3], "AcceptEnv", updates["AcceptEnv"])
+
+    def test_return_empty_if_no_changes(self):
+        """If there are no changes, then return should be empty list."""
+        updates = {"UsePAM": "yes"}
+        lines = ssh_util.parse_ssh_config_lines(list(self.exlines))
+        result = ssh_util.update_ssh_config_lines(lines, updates)
+        self.assertEqual([], result)
+        self.assertEqual(self.exlines, [str(l) for l in lines])
+
+    def test_keycase_not_modified(self):
+        """Original case of key should not be changed on update.
+        This behavior is to keep original config as much intact as can be."""
+        updates = {"usepam": "no"}
+        lines = ssh_util.parse_ssh_config_lines(list(self.exlines))
+        result = ssh_util.update_ssh_config_lines(lines, updates)
+        self.assertEqual(["usepam"], result)
+        self.assertEqual("UsePAM no", str(lines[1]))
+
+
+class TestUpdateSshConfig(test_helpers.CiTestCase):
+    cfgdata = '\n'.join(["#Option val", "MyKey ORIG_VAL", ""])
+
+    def test_modified(self):
+        mycfg = self.tmp_path("ssh_config_1")
+        util.write_file(mycfg, self.cfgdata)
+        ret = ssh_util.update_ssh_config({"MyKey": "NEW_VAL"}, mycfg)
+        self.assertTrue(ret)
+        found = util.load_file(mycfg)
+        self.assertEqual(self.cfgdata.replace("ORIG_VAL", "NEW_VAL"), found)
+        # assert there is a newline at end of file (LP: #1677205)
+        self.assertEqual('\n', found[-1])
+
+    def test_not_modified(self):
+        mycfg = self.tmp_path("ssh_config_2")
+        util.write_file(mycfg, self.cfgdata)
+        with patch("cloudinit.ssh_util.util.write_file") as m_write_file:
+            ret = ssh_util.update_ssh_config({"MyKey": "ORIG_VAL"}, mycfg)
+        self.assertFalse(ret)
+        self.assertEqual(self.cfgdata, util.load_file(mycfg))
+        m_write_file.assert_not_called()
+
+
 # vi: ts=4 expandtab