← Back to team overview

cloud-init-dev team mailing list archive

[Merge] lp:~harlowja/cloud-init/resolv-conf-niceness into lp:cloud-init

 

Joshua Harlow has proposed merging lp:~harlowja/cloud-init/resolv-conf-niceness into lp:cloud-init.

Requested reviews:
  cloud init development team (cloud-init-dev)

For more details, see:
https://code.launchpad.net/~harlowja/cloud-init/resolv-conf-niceness/+merge/125578
-- 
https://code.launchpad.net/~harlowja/cloud-init/resolv-conf-niceness/+merge/125578
Your team cloud init development team is requested to review the proposed merge of lp:~harlowja/cloud-init/resolv-conf-niceness into lp:cloud-init.
=== added file 'cloudinit/distros/helpers.py'
--- cloudinit/distros/helpers.py	1970-01-01 00:00:00 +0000
+++ cloudinit/distros/helpers.py	2012-09-20 21:05:25 +0000
@@ -0,0 +1,179 @@
+# vi: ts=4 expandtab
+#
+#    Copyright (C) 2012 Canonical Ltd.
+#    Copyright (C) 2012 Yahoo! Inc.
+#
+#    Author: Scott Moser <scott.moser@xxxxxxxxxxxxx>
+#    Author: Joshua Harlow <harlowja@xxxxxxxxxxxxx>
+#
+#    This program is free software: you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License version 3, as
+#    published by the Free Software Foundation.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from StringIO import StringIO
+
+from cloudinit import util
+
+
+# See: man resolv.conf
+class ResolvConf(object):
+    def __init__(self, text):
+        self._text = text
+        self._contents = None
+
+    def parse(self):
+        if self._contents is None:
+            self._contents = self._parse(self._text)
+
+    @property
+    def nameservers(self):
+        self.parse()
+        return self._retr_option('nameserver')
+
+    @property
+    def local_domain(self):
+        self.parse()
+        dm = self._retr_option('domain')
+        if dm:
+            return dm[0]
+        return None
+
+    @property
+    def search_domains(self):
+        self.parse()
+        current_sds = self._retr_option('search')
+        flat_sds = []
+        for sdlist in current_sds:
+            for sd in sdlist.split(None):
+                if sd:
+                    flat_sds.append(sd)
+        return flat_sds
+
+    def __str__(self):
+        self.parse()
+        contents = StringIO()
+        for (line_type, components) in self._contents:
+            if line_type == 'blank':
+                contents.write("\n")
+            elif line_type == 'all_comment':
+                contents.write("%s\n" % (components[0]))
+            elif line_type == 'option':
+                (cfg_opt, cfg_value, comment_tail) = components
+                line = "%s %s" % (cfg_opt, cfg_value)
+                if len(comment_tail):
+                    line += comment_tail
+                contents.write("%s\n" % (line))
+        return contents.getvalue()
+
+    def _retr_option(self, opt_name):
+        found = []
+        for (line_type, components) in self._contents:
+            if line_type == 'option':
+                (cfg_opt, cfg_value, comment_tail) = components
+                if cfg_opt == opt_name:
+                    found.append(cfg_value)
+        return found
+
+    def add_nameserver(self, ns):
+        self.parse()
+        current_ns = self._retr_option('nameserver')
+        new_ns = list(current_ns)
+        new_ns.append(str(ns))
+        new_ns = util.uniq_list(new_ns)
+        if len(new_ns) == len(current_ns):
+            return current_ns
+        if len(current_ns) >= 3:
+            # Hard restriction on only 3 name servers
+            raise ValueError(("Adding %r would go beyond the "
+                              "'3' maximum name servers") % (ns))
+        self._remove_option('nameserver')
+        for n in new_ns:
+            self._contents.append(('option', ['nameserver', n, '']))
+        return new_ns
+
+    def _remove_option(self, opt_name):
+
+        def remove_opt(item):
+            line_type, components = item
+            if line_type != 'option':
+                return True
+            (cfg_opt, cfg_value, comment_tail) = components
+            if cfg_opt != opt_name:
+                return True
+            return False
+
+        new_contents = filter(remove_opt, self._contents)
+        self._contents = new_contents
+
+    def add_search_domain(self, search_domain):
+        flat_sds = self.search_domains
+        new_sds = list(flat_sds)
+        new_sds.append(str(search_domain))
+        new_sds = util.uniq_list(new_sds)
+        if len(flat_sds) == len(new_sds):
+            return new_sds
+        if len(flat_sds) >= 6:
+            # Hard restriction on only 6 search domains
+            raise ValueError(("Adding %r would go beyond the "
+                              "'6' maximum search domains") % (search_domain))
+        s_list  = " ".join(new_sds)
+        if len(s_list) > 256:
+            # Some hard limit on 256 chars total
+            raise ValueError(("Adding %r would go beyond the "
+                              "256 maximum search list character limit")
+                              % (search_domain))
+        self._remove_option('search')
+        self._contents.append(('option', ['search', s_list, '']))
+        return flat_sds
+
+    @local_domain.setter
+    def local_domain(self, domain):
+        self.parse()
+        self._remove_option('domain')
+        self._contents.append(('option', ['domain', str(domain), '']))
+        return domain
+
+    def _parse(self, contents):
+        entries = []
+        for (i, line) in enumerate(contents.splitlines()):
+            sline = line.strip()
+            if not sline:
+                entries.append(('blank', [line]))
+                continue
+            comment_s_loc = sline.find(";")
+            comment_h_loc = sline.find("#")
+            comment_loc = -1
+            if comment_s_loc != -1 and comment_h_loc != -1:
+                comment_loc = min(comment_h_loc, comment_s_loc)
+            elif comment_s_loc != -1:
+                comment_loc = comment_s_loc
+            elif comment_h_loc != -1:
+                comment_loc = comment_h_loc
+            head = line
+            tail = None
+            if comment_loc != -1:
+                head = line[:comment_loc]
+                tail = line[comment_loc:]
+            if not len(head.strip()):
+                entries.append(('all_comment', [line]))
+                continue
+            if not tail:
+                tail = ''
+            try:
+                (cfg_opt, cfg_values) = head.split(None, 1)
+            except (IndexError, ValueError):
+                raise IOError("Incorrectly formatted resolv.conf line %s" % (i + 1))
+            if cfg_opt not in ('nameserver', 'domain', 'search', 'sortlist', 'options'):
+                raise IOError("Unexpected resolv.conf option %s" % (cfg_opt))
+            entries.append(("option", [cfg_opt, cfg_values, tail]))
+        return entries
+
+

=== modified file 'cloudinit/distros/rhel.py'
--- cloudinit/distros/rhel.py	2012-09-19 20:06:58 +0000
+++ cloudinit/distros/rhel.py	2012-09-20 21:05:25 +0000
@@ -23,6 +23,8 @@
 import os
 
 from cloudinit import distros
+from cloudinit.distros import helpers as d_helpers
+
 from cloudinit import helpers
 from cloudinit import log as logging
 from cloudinit import util
@@ -68,17 +70,29 @@
     def install_packages(self, pkglist):
         self.package_command('install', pkglist)
 
-    def _write_resolve(self, dns_servers, search_servers):
-        contents = []
+    def _adjust_resolve(self, dns_servers, search_servers):
+        r_conf = d_helpers.ResolvConf(util.load_file("/etc/resolv.conf"))
+        try:
+            r_conf.parse()
+        except IOError:
+            util.logexc(LOG, 
+                        "Failed at parsing %s reverting to an empty instance",
+                        "/etc/resolv.conf")
+            r_conf = d_helpers.ResolvConf('')
+            r_conf.parse()
         if dns_servers:
             for s in dns_servers:
-                contents.append("nameserver %s" % (s))
+                try:
+                    r_conf.add_nameserver(s)
+                except ValueError:
+                    util.logexc(LOG, "Failed at adding nameserver %s", s)
         if search_servers:
-            contents.append("search %s" % (" ".join(search_servers)))
-        if contents:
-            resolve_rw_fn = self._paths.join(False, "/etc/resolv.conf")
-            contents.insert(0, '# Created by cloud-init')
-            util.write_file(resolve_rw_fn, "\n".join(contents), 0644)
+            for s in search_servers:
+                try:
+                    r_conf.add_search_domain(s)
+                except ValueError:
+                    util.logexc(LOG, "Failed at adding search domain %s", s)
+        util.write_file("/etc/resolv.conf", str(r_conf), 0644)
 
     def _write_network(self, settings):
         # TODO(harlowja) fix this... since this is the ubuntu format
@@ -126,7 +140,7 @@
             net_rw_fn = self._paths.join(False, net_fn)
             util.write_file(net_rw_fn, w_contents, 0644)
         if nameservers or searchservers:
-            self._write_resolve(nameservers, searchservers)
+            self._adjust_resolve(nameservers, searchservers)
 
     def set_hostname(self, hostname):
         out_fn = self._paths.join(False, '/etc/sysconfig/network')

=== modified file 'cloudinit/util.py'
--- cloudinit/util.py	2012-08-28 03:51:00 +0000
+++ cloudinit/util.py	2012-09-20 21:05:25 +0000
@@ -952,6 +952,16 @@
     return entries
 
 
+def uniq_list(in_list):
+    out_list = []
+    for i in in_list:
+        if i in out_list:
+            continue
+        else:
+            out_list.append(i)
+    return out_list
+
+
 def load_file(fname, read_cb=None, quiet=False):
     LOG.debug("Reading from %s (quiet=%s)", fname, quiet)
     ofh = StringIO()

=== added file 'tests/unittests/test_distros/test_resolv.py'
--- tests/unittests/test_distros/test_resolv.py	1970-01-01 00:00:00 +0000
+++ tests/unittests/test_distros/test_resolv.py	2012-09-20 21:05:25 +0000
@@ -0,0 +1,61 @@
+from mocker import MockerTestCase
+
+from cloudinit.distros import helpers
+
+
+BASE_RESOLVE = '''
+; generated by /sbin/dhclient-script
+search blah.yahoo.com yahoo.com
+nameserver 10.15.44.14
+nameserver 10.15.30.92
+'''
+BASE_RESOLVE = BASE_RESOLVE.strip()
+
+
+class TestResolvHelper(MockerTestCase):
+    def test_parse_same(self):
+        rp = helpers.ResolvConf(BASE_RESOLVE)
+        rp_r = str(rp).strip()
+        self.assertEquals(BASE_RESOLVE, rp_r)
+
+    def test_local_domain(self):
+        rp = helpers.ResolvConf(BASE_RESOLVE)
+        self.assertEquals(None, rp.local_domain)
+
+        rp.local_domain = "bob"
+        self.assertEquals('bob', rp.local_domain)
+        self.assertIn('domain bob', str(rp))
+
+    def test_nameservers(self):
+        rp = helpers.ResolvConf(BASE_RESOLVE)
+        self.assertIn('10.15.44.14', rp.nameservers)
+        self.assertIn('10.15.30.92', rp.nameservers)
+        rp.add_nameserver('10.2')
+        self.assertIn('10.2', rp.nameservers)
+        self.assertIn('nameserver 10.2', str(rp))
+        self.assertNotIn('10.3', rp.nameservers)
+        self.assertEquals(len(rp.nameservers), 3)
+        rp.add_nameserver('10.2')
+        with self.assertRaises(ValueError):
+            rp.add_nameserver('10.3')
+        self.assertNotIn('10.3', rp.nameservers)
+
+    def test_search_domains(self):
+        rp = helpers.ResolvConf(BASE_RESOLVE)
+        self.assertIn('yahoo.com', rp.search_domains)
+        self.assertIn('blah.yahoo.com', rp.search_domains)
+        rp.add_search_domain('bbb.y.com')
+        self.assertIn('bbb.y.com', rp.search_domains)
+        self.assertRegexpMatches(str(rp), r'search(.*)bbb.y.com(.*)')
+        self.assertIn('bbb.y.com', rp.search_domains)
+        rp.add_search_domain('bbb.y.com')
+        self.assertEquals(len(rp.search_domains), 3)
+        rp.add_search_domain('bbb2.y.com')
+        self.assertEquals(len(rp.search_domains), 4)
+        rp.add_search_domain('bbb3.y.com')
+        self.assertEquals(len(rp.search_domains), 5)
+        rp.add_search_domain('bbb4.y.com')
+        self.assertEquals(len(rp.search_domains), 6)
+        with self.assertRaises(ValueError):
+            rp.add_search_domain('bbb5.y.com')
+        self.assertEquals(len(rp.search_domains), 6)