← Back to team overview

cloud-init-dev team mailing list archive

[Merge] ~illfelder/cloud-init:master into cloud-init:master

 

Scott Moser has proposed merging ~illfelder/cloud-init:master into cloud-init:master.

Commit message:
Improve GCE logic setting Ubuntu user SSH keys.

The behavior improvements include:
- Only set ssh keys for the ubuntu user.
- Use instance or project level keys based on GCE convention.
- Respect expiration time when keys are set.
- Support ssh-keys in project level metadata (the GCE default).

As part of this change, we also update the request header when talking
to the metadata server based on the documentation:
https://cloud.google.com/compute/docs/storing-retrieving-metadata#querying

Requested reviews:
  Dan Watkins (daniel-thewatkins)
  cloud-init commiters (cloud-init-dev)

For more details, see:
https://code.launchpad.net/~illfelder/cloud-init/+git/cloud-init/+merge/334777
-- 
Your team cloud-init commiters is requested to review the proposed merge of ~illfelder/cloud-init:master into cloud-init:master.
diff --git a/cloudinit/sources/DataSourceGCE.py b/cloudinit/sources/DataSourceGCE.py
index ccae420..0bec9bc 100644
--- a/cloudinit/sources/DataSourceGCE.py
+++ b/cloudinit/sources/DataSourceGCE.py
@@ -2,6 +2,9 @@
 #
 # This file is part of cloud-init. See LICENSE file for license information.
 
+import datetime
+import json
+
 from base64 import b64decode
 
 from cloudinit import log as logging
@@ -17,16 +20,18 @@ REQUIRED_FIELDS = ('instance-id', 'availability-zone', 'local-hostname')
 
 
 class GoogleMetadataFetcher(object):
-    headers = {'X-Google-Metadata-Request': 'True'}
+    headers = {'Metadata-Flavor': 'Google'}
 
     def __init__(self, metadata_address):
         self.metadata_address = metadata_address
 
-    def get_value(self, path, is_text):
+    def get_value(self, path, is_text, is_recursive=False):
         value = None
         try:
-            resp = url_helper.readurl(url=self.metadata_address + path,
-                                      headers=self.headers)
+            url = self.metadata_address + path
+            if is_recursive:
+              url += '/?recursive=True'
+            resp = url_helper.readurl(url=url, headers=self.headers)
         except url_helper.UrlError as exc:
             msg = "url %s raised exception %s"
             LOG.debug(msg, path, exc)
@@ -35,7 +40,7 @@ class GoogleMetadataFetcher(object):
                 if is_text:
                     value = util.decode_binary(resp.contents)
                 else:
-                    value = resp.contents
+                    value = resp.contents.decode('utf-8')
             else:
                 LOG.debug("url %s returned code %s", path, resp.code)
         return value
@@ -89,15 +94,54 @@ class DataSourceGCE(sources.DataSource):
         return self.availability_zone.rsplit('-', 1)[0]
 
 
-def _trim_key(public_key):
-    # GCE takes sshKeys attribute in the format of '<user>:<public_key>'
-    # so we have to trim each key to remove the username part
+def _has_expired(public_key):
+    # Check whether an SSH key is expired using GCE specific key format.
+    try:
+        # Check for the Google-specific schema identifier.
+        schema, json_str = public_key.split(None, 3)[2:]
+    except (ValueError, AttributeError):
+        return False
+
+    # Do not expire keys if they do not have the expected schema identifier.
+    if schema != 'google-ssh':
+        return False
+
+    try:
+        json_obj = json.loads(json_str)
+    except ValueError:
+        return False
+
+    # Do not expire keys if there is no expriation timestamp.
+    if 'expireOn' not in json_obj:
+        return False
+
+    expire_str = json_obj['expireOn']
+    format_str = '%Y-%m-%dT%H:%M:%S+0000'
     try:
-        index = public_key.index(':')
-        if index > 0:
-            return public_key[(index + 1):]
-    except Exception:
-        return public_key
+        expire_time = datetime.datetime.strptime(expire_str, format_str)
+    except ValueError:
+        return False
+
+    # Expire the key if and only if we have exceeded the expiration timestamp.
+    return datetime.datetime.utcnow() > expire_time
+
+
+def _parse_public_keys(public_keys_data):
+    # Parse the SSH key data for the Ubuntu user account.
+    public_keys = []
+    if not public_keys_data:
+        return public_keys
+    lines = [line for line in public_keys_data.splitlines() if line]
+    for line in lines:
+        if not all(ord(c) < 128 for c in line):
+            continue
+        split_line = line.split(':', 1)
+        if len(split_line) != 2:
+            continue
+        user, key = split_line
+        if user == 'ubuntu' and not _has_expired(key):
+            public_keys.append(key)
+    return public_keys
 
 
 def read_md(address=None, platform_check=True):
@@ -119,25 +163,22 @@ def read_md(address=None, platform_check=True):
         ret['reason'] = 'address "%s" is not resolvable' % address
         return ret
 
-    # url_map: (our-key, path, required, is_text)
+    # url_map: (our-key, path, required, is_text, is_recursive)
     url_map = [
-        ('instance-id', ('instance/id',), True, True),
-        ('availability-zone', ('instance/zone',), True, True),
-        ('local-hostname', ('instance/hostname',), True, True),
-        ('public-keys', ('project/attributes/sshKeys',
-                         'instance/attributes/ssh-keys'), False, True),
-        ('user-data', ('instance/attributes/user-data',), False, False),
-        ('user-data-encoding', ('instance/attributes/user-data-encoding',),
-         False, True),
+        ('instance-id', ('instance/id',), True, True, False),
+        ('availability-zone', ('instance/zone',), True, True, False),
+        ('local-hostname', ('instance/hostname',), True, True, False),
+        ('instance-data', ('instance/attributes',), False, False, True),
+        ('project-data', ('project/attributes',), False, False, True),
     ]
 
     metadata_fetcher = GoogleMetadataFetcher(address)
     md = {}
     # iterate over url_map keys to get metadata items
-    for (mkey, paths, required, is_text) in url_map:
+    for (mkey, paths, required, is_text, is_recursive) in url_map:
         value = None
         for path in paths:
-            new_value = metadata_fetcher.get_value(path, is_text)
+            new_value = metadata_fetcher.get_value(path, is_text, is_recursive)
             if new_value is not None:
                 value = new_value
         if required and value is None:
@@ -146,17 +187,25 @@ def read_md(address=None, platform_check=True):
             return ret
         md[mkey] = value
 
-    if md['public-keys']:
-        lines = md['public-keys'].splitlines()
-        md['public-keys'] = [_trim_key(k) for k in lines]
+    print('Instance: %s' % md['instance-data'])
+    print('Instance Type: %s' % type(md['instance-data']))
+    instance_data = json.loads(md['instance-data'] or '{}')
+    project_data = json.loads(md['project-data'] or '{}')
+    valid_keys = [instance_data.get('sshKeys'), instance_data.get('ssh-keys')]
+    block_project = instance_data.get('block-project-ssh-keys', '').lower()
+    if block_project != 'true' and not instance_data.get('sshKeys'):
+        valid_keys.append(project_data.get('ssh-keys'))
+        valid_keys.append(project_data.get('sshKeys'))
+    public_keys_data = '\n'.join([key for key in valid_keys if key])
+    md['public-keys'] = _parse_public_keys(public_keys_data)
 
     if md['availability-zone']:
         md['availability-zone'] = md['availability-zone'].split('/')[-1]
 
-    encoding = md.get('user-data-encoding')
+    encoding = instance_data.get('user-data-encoding')
     if encoding:
         if encoding == 'base64':
-            md['user-data'] = b64decode(md['user-data'])
+            md['user-data'] = b64decode(instance_data.get('user-data'))
         else:
             LOG.warning('unknown user-data-encoding: %s, ignoring', encoding)
 
@@ -198,7 +247,6 @@ def get_datasource_list(depends):
 
 if __name__ == "__main__":
     import argparse
-    import json
     import sys
 
     from base64 import b64encode
diff --git a/tests/unittests/test_datasource/test_gce.py b/tests/unittests/test_datasource/test_gce.py
index d399ae7..87acb22 100644
--- a/tests/unittests/test_datasource/test_gce.py
+++ b/tests/unittests/test_datasource/test_gce.py
@@ -5,6 +5,7 @@
 # This file is part of cloud-init. See LICENSE file for license information.
 
 import httpretty
+import json
 import mock
 import re
 
@@ -21,10 +22,7 @@ from cloudinit.tests import helpers as test_helpers
 GCE_META = {
     'instance/id': '123',
     'instance/zone': 'foo/bar',
-    'project/attributes/sshKeys': 'user:ssh-rsa AA2..+aRD0fyVw== root@server',
     'instance/hostname': 'server.project-foo.local',
-    # UnicodeDecodeError below if set to ds.userdata instead of userdata_raw
-    'instance/attributes/user-data': b'/bin/echo \xff\n',
 }
 
 GCE_META_PARTIAL = {
@@ -37,11 +35,13 @@ GCE_META_ENCODING = {
     'instance/id': '12345',
     'instance/hostname': 'server.project-baz.local',
     'instance/zone': 'baz/bang',
-    'instance/attributes/user-data': b64encode(b'/bin/echo baz\n'),
-    'instance/attributes/user-data-encoding': 'base64',
+    'instance/attributes': {
+        'user-data': b64encode(b'/bin/echo baz\n').decode('utf-8'),
+        'user-data-encoding': 'base64',
+    }
 }
 
-HEADERS = {'X-Google-Metadata-Request': 'True'}
+HEADERS = {'Metadata-Flavor': 'Google'}
 MD_URL_RE = re.compile(
     r'http://metadata.google.internal/computeMetadata/v1/.*')
 
@@ -54,10 +54,15 @@ def _set_mock_metadata(gce_meta=None):
         url_path = urlparse(uri).path
         if url_path.startswith('/computeMetadata/v1/'):
             path = url_path.split('/computeMetadata/v1/')[1:][0]
+            recursive = path.endswith('/')
+            path = path.rstrip('/')
         else:
             path = None
         if path in gce_meta:
-            return (200, headers, gce_meta.get(path))
+            response = gce_meta.get(path)
+            if recursive:
+                response = json.dumps(response)
+            return (200, headers, response)
         else:
             return (404, headers, '')
 
@@ -89,6 +94,10 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase):
         self.assertDictContainsSubset(HEADERS, req_header)
 
     def test_metadata(self):
+        # UnicodeDecodeError below if set to ds.userdata instead of userdata_raw
+        meta = GCE_META.copy()
+        meta['instance/attributes/user-data'] = b'/bin/echo \xff\n'
+
         _set_mock_metadata()
         self.ds.get_data()
 
@@ -117,8 +126,8 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase):
         _set_mock_metadata(GCE_META_ENCODING)
         self.ds.get_data()
 
-        decoded = b64decode(
-            GCE_META_ENCODING.get('instance/attributes/user-data'))
+        instance_data = GCE_META_ENCODING.get('instance/attributes')
+        decoded = b64decode(instance_data.get('user-data'))
         self.assertEqual(decoded, self.ds.get_userdata_raw())
 
     def test_missing_required_keys_return_false(self):
@@ -130,33 +139,86 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase):
             self.assertEqual(False, self.ds.get_data())
             httpretty.reset()
 
-    def test_project_level_ssh_keys_are_used(self):
+    def test_no_ssh_keys_metadata(self):
         _set_mock_metadata()
         self.ds.get_data()
+        self.assertEqual([], self.ds.get_public_ssh_keys())
+
+    def test_ubuntu_ssh_keys(self):
+        valid_key = 'ssh-rsa VALID {0}'
+        invalid_key = 'ssh-rsa INVALID {0}'
+        project_attributes = {
+            'sshKeys': '\n'.join([
+                'ubuntu:{0}'.format(valid_key.format(0)),
+                'user:{0}'.format(valid_key.format(0)),
+            ]),
+            'ssh-keys': '\n'.join([
+                'ubuntu:{0}'.format(valid_key.format(1)),
+                'user:{0}'.format(valid_key.format(1)),
+            ]),
+        }
+        instance_attributes = {
+            'ssh-keys': '\n'.join([
+                'ubuntu:{0}'.format(valid_key.format(2)),
+                'user:{0}'.format(valid_key.format(2)),
+            ]),
+            'block-project-ssh-keys': 'False',
+        }
+
+        meta = GCE_META.copy()
+        meta['project/attributes'] = project_attributes
+        meta['instance/attributes'] = instance_attributes
+
+        _set_mock_metadata(meta)
+        self.ds.get_data()
 
-        # we expect a list of public ssh keys with user names stripped
-        self.assertEqual(['ssh-rsa AA2..+aRD0fyVw== root@server'],
-                         self.ds.get_public_ssh_keys())
+        expected = [valid_key.format(key) for key in range(3)]
+        self.assertEquals(set(expected), set(self.ds.get_public_ssh_keys()))
+
+    def test_instance_ssh_keys_override(self):
+        valid_key = 'ssh-rsa VALID {0}'
+        invalid_key = 'ssh-rsa INVALID {0}'
+        project_attributes = {
+            'sshKeys': 'ubuntu:{0}'.format(invalid_key.format(0)),
+            'ssh-keys': 'ubuntu:{0}'.format(invalid_key.format(1)),
+        }
+        instance_attributes = {
+            'sshKeys': 'ubuntu:{0}'.format(valid_key.format(0)),
+            'ssh-keys': 'ubuntu:{0}'.format(valid_key.format(1)),
+            'block-project-ssh-keys': 'False',
+        }
 
-    def test_instance_level_ssh_keys_are_used(self):
-        key_content = 'ssh-rsa JustAUser root@server'
         meta = GCE_META.copy()
-        meta['instance/attributes/ssh-keys'] = 'user:{0}'.format(key_content)
+        meta['project/attributes'] = project_attributes
+        meta['instance/attributes'] = instance_attributes
 
         _set_mock_metadata(meta)
         self.ds.get_data()
 
-        self.assertIn(key_content, self.ds.get_public_ssh_keys())
+        expected = [valid_key.format(key) for key in range(2)]
+        self.assertEquals(set(expected), set(self.ds.get_public_ssh_keys()))
+
+    def test_block_project_ssh_keys_override(self):
+        valid_key = 'ssh-rsa VALID {0}'
+        invalid_key = 'ssh-rsa INVALID {0}'
+        project_attributes = {
+            'sshKeys': 'ubuntu:{0}'.format(invalid_key.format(0)),
+            'ssh-keys': 'ubuntu:{0}'.format(invalid_key.format(1)),
+        }
+        instance_attributes = {
+            'ssh-keys': 'ubuntu:{0}'.format(valid_key.format(0)),
+            'block-project-ssh-keys': 'True',
+        }
 
-    def test_instance_level_keys_replace_project_level_keys(self):
-        key_content = 'ssh-rsa JustAUser root@server'
         meta = GCE_META.copy()
-        meta['instance/attributes/ssh-keys'] = 'user:{0}'.format(key_content)
+        meta['project/attributes'] = project_attributes
+        meta['instance/attributes'] = instance_attributes
 
         _set_mock_metadata(meta)
         self.ds.get_data()
 
-        self.assertEqual([key_content], self.ds.get_public_ssh_keys())
+        expected = [valid_key.format(0)]
+        self.assertEquals(set(expected), set(self.ds.get_public_ssh_keys()))
 
     def test_only_last_part_of_zone_used_for_availability_zone(self):
         _set_mock_metadata()

References