← Back to team overview

launchpad-reviewers team mailing list archive

[Merge] ~pappacena/launchpad:public-ecr-aws into launchpad:master

 

Thiago F. Pappacena has proposed merging ~pappacena/launchpad:public-ecr-aws into launchpad:master.

Commit message:
Adding bearer token implementation for AWS registry

Requested reviews:
  Launchpad code reviewers (launchpad-reviewers)

For more details, see:
https://code.launchpad.net/~pappacena/launchpad/+git/launchpad/+merge/394598
-- 
Your team Launchpad code reviewers is requested to review the proposed merge of ~pappacena/launchpad:public-ecr-aws into launchpad:master.
diff --git a/lib/lp/oci/interfaces/ociregistrycredentials.py b/lib/lp/oci/interfaces/ociregistrycredentials.py
index c6d2707..f55f274 100644
--- a/lib/lp/oci/interfaces/ociregistrycredentials.py
+++ b/lib/lp/oci/interfaces/ociregistrycredentials.py
@@ -57,6 +57,9 @@ class IOCIRegistryCredentialsView(Interface):
     def getCredentials():
         """Get the saved credentials."""
 
+    def getCredentialsValue(key):
+        """Gets the credential value for a specific key."""
+
     username = TextLine(
         title=_("Username"),
         description=_("The username for the credentials, if available."),
diff --git a/lib/lp/oci/model/ociregistryclient.py b/lib/lp/oci/model/ociregistryclient.py
index 56225da..0500601 100644
--- a/lib/lp/oci/model/ociregistryclient.py
+++ b/lib/lp/oci/model/ociregistryclient.py
@@ -20,9 +20,11 @@ try:
 except ImportError:
     JSONDecodeError = ValueError
 import logging
+import re
 import tarfile
 
 import boto3
+from botocore.config import Config
 from requests.exceptions import (
     ConnectionError,
     HTTPError,
@@ -41,6 +43,8 @@ from tenacity import (
     )
 from zope.interface import implementer
 
+from lp.services.config import config as lp_config
+from lp.services.features import getFeatureFlag
 from lp.oci.interfaces.ociregistryclient import (
     BlobUploadFailed,
     IOCIRegistryClient,
@@ -57,6 +61,20 @@ log = logging.getLogger(__name__)
 proxy_urlfetch = partial(urlfetch, use_proxy=True)
 
 
+OCI_AWS_BEARER_TOKEN_DOMAINS_FLAG = 'oci.push.aws.bearer_token_domains'
+OCI_AWS_BOT_EXTRA_MODEL_PATH = 'oci.push.aws.boto.extra_paths'
+OCI_AWS_BOT_EXTRA_MODEL_NAME = 'oci.push.aws.boto.extra_model_name'
+
+
+def is_aws_bearer_token_domain(domain):
+    """Returns True if the given registry domain should use bearer token
+    instead of basic auth."""
+    domains = getFeatureFlag(OCI_AWS_BEARER_TOKEN_DOMAINS_FLAG)
+    if not domains:
+        return False
+    return any(domain.endswith(i) for i in domains.split())
+
+
 @implementer(IOCIRegistryClient)
 class OCIRegistryClient:
 
@@ -482,8 +500,10 @@ class RegistryHTTPClient:
     def getInstance(cls, push_rule):
         """Returns an instance of RegistryHTTPClient adapted to the
         given push rule and registry's authentication flow."""
-        registry_domain = urlparse(push_rule.registry_url).netloc
-        if registry_domain.endswith(".amazonaws.com"):
+        domain = urlparse(push_rule.registry_url).netloc
+        if is_aws_bearer_token_domain(domain):
+            return AWSRegistryBearerTokenClient(push_rule)
+        if domain.endswith(".amazonaws.com"):
             return AWSRegistryHTTPClient(push_rule)
         try:
             proxy_urlfetch("{}/v2/".format(push_rule.registry_url))
@@ -577,29 +597,76 @@ class BearerTokenRegistryClient(RegistryHTTPClient):
             raise
 
 
-class AWSRegistryHTTPClient(RegistryHTTPClient):
+class AWSAuthenticatorMixin:
+    """Basic class to override the way we get credentials, exchanging
+    registered aws_access_key_id and aws_secret_access_key with the
+    temporary token got from AWS API.
+    """
+
+    def _getClientParameters(self):
+        if lp_config.launchpad.http_proxy:
+            boto_config = Config(proxies={
+                'http': lp_config.launchpad.http_proxy,
+                'https': lp_config.launchpad.http_proxy})
+        else:
+            boto_config = Config()
+        auth = self.push_rule.registry_credentials.getCredentials()
+        username, password = auth['username'], auth.get('password')
+        region = self._getRegion()
+        log.info("Trying to authenticate with AWS in region %s" % region)
+        return dict(
+            aws_access_key_id=username,
+            aws_secret_access_key=password, region_name=region,
+            config=boto_config)
+
+    def _getBotoClient(self):
+        params = self._getClientParameters()
+        if not self.should_use_aws_extra_model:
+            return boto3.client('ecr', **params)
+        model_path = getFeatureFlag(OCI_AWS_BOT_EXTRA_MODEL_PATH)
+        model_name = getFeatureFlag(OCI_AWS_BOT_EXTRA_MODEL_NAME)
+        if not model_path or not model_name:
+            log.warning(
+                "%s or %s feature rules are not set. Using default model." %
+                (OCI_AWS_BOT_EXTRA_MODEL_PATH, OCI_AWS_BOT_EXTRA_MODEL_NAME))
+            return boto3.client('ecr', **params)
+        session = boto3.Session()
+        session._loader.search_paths.extend([model_path])
+        return session.client(model_name, **params)
+
+    @property
+    def should_use_aws_extra_model(self):
+        """Returns True if the given registry domain requires extra boto API
+        model.
+        """
+        domain = urlparse(self.push_rule.registry_url).netloc
+        return is_aws_bearer_token_domain(domain)
 
     def _getRegion(self):
         """Returns the region from the push URL domain."""
-        domain = urlparse(self.push_rule.registry_url).netloc
-        # The domain format should be something like
+        push_rule = self.push_rule
+        region = push_rule.registry_credentials.getCredentialsValue("region")
+        if region is not None:
+            return region
+        # Try to guess from the domain. The format should be something like
         # 'xxx.dkr.ecr.sa-east-1.amazonaws.com'. 'sa-east-1' is the region.
-        return domain.split(".")[-3]
+        domain = urlparse(self.push_rule.registry_url).netloc
+        if re.match(r'.+\.dkr\.ecr\..+\.amazonaws\.com', domain):
+            return domain.split(".")[-3]
+        raise OCIRegistryAuthenticationError("Unknown AWS region.")
 
     @cachedproperty
     def credentials(self):
         """Exchange aws_access_key_id and aws_secret_access_key with the
         authentication token that should be used when talking to ECR."""
         try:
-            auth = self.push_rule.registry_credentials.getCredentials()
-            username, password = auth['username'], auth.get('password')
-            region = self._getRegion()
-            log.info("Trying to authenticate with AWS in region %s" % region)
-            client = boto3.client('ecr', aws_access_key_id=username,
-                                  aws_secret_access_key=password,
-                                  region_name=region)
+            client = self._getBotoClient()
             token = client.get_authorization_token()
-            auth_data = token["authorizationData"][0]
+            auth_data = token["authorizationData"]
+            # ecr-public returns a dict directly, but ecr returns list with
+            # one element inside. Go figure...
+            if isinstance(auth_data, list):
+                auth_data = auth_data[0]
             authorization_token = auth_data['authorizationToken']
             username, password = base64.b64decode(
                 authorization_token).decode().split(':')
@@ -610,3 +677,18 @@ class AWSRegistryHTTPClient(RegistryHTTPClient):
             raise OCIRegistryAuthenticationError(
                 "It was not possible to get AWS credentials for %s: %s" %
                 (self.push_rule.registry_url, e))
+
+
+class AWSRegistryHTTPClient(AWSAuthenticatorMixin, RegistryHTTPClient):
+    """AWS registry client with authentication flow based on basic auth
+    (private ECR, for example).
+    """
+    pass
+
+
+class AWSRegistryBearerTokenClient(
+        AWSAuthenticatorMixin, BearerTokenRegistryClient):
+    """AWS registry client with authentication flow based on bearer token
+    flow (public ECR, for example).
+    """
+    pass
diff --git a/lib/lp/oci/model/ociregistrycredentials.py b/lib/lp/oci/model/ociregistrycredentials.py
index 396124e..377dded 100644
--- a/lib/lp/oci/model/ociregistrycredentials.py
+++ b/lib/lp/oci/model/ociregistrycredentials.py
@@ -119,6 +119,9 @@ class OCIRegistryCredentials(Storm):
             data["username"] = username
         self._credentials = data
 
+    def getCredentialsValue(self, key):
+        return self.getCredentials().get(key)
+
     @property
     def username(self):
         return self._credentials.get('username')
diff --git a/lib/lp/oci/tests/helpers.py b/lib/lp/oci/tests/helpers.py
index 52561a3..b0ee421 100644
--- a/lib/lp/oci/tests/helpers.py
+++ b/lib/lp/oci/tests/helpers.py
@@ -23,7 +23,7 @@ from lp.services.features.testing import FeatureFixture
 
 class OCIConfigHelperMixin:
 
-    def setConfig(self):
+    def setConfig(self, feature_flags=None):
         self.private_key = PrivateKey.generate()
         self.pushConfig(
             "oci",
@@ -34,7 +34,9 @@ class OCIConfigHelperMixin:
             registry_secrets_private_key=base64.b64encode(
                 bytes(self.private_key)).decode("UTF-8"))
         # Default feature flags for our tests
-        self.useFixture(FeatureFixture({OCI_RECIPE_ALLOW_CREATE: 'on'}))
+        feature_flags = feature_flags or {}
+        feature_flags.update({OCI_RECIPE_ALLOW_CREATE: 'on'})
+        self.useFixture(FeatureFixture(feature_flags))
 
 
 class MatchesOCIRegistryCredentials(MatchesAll):
diff --git a/lib/lp/oci/tests/test_ociregistryclient.py b/lib/lp/oci/tests/test_ociregistryclient.py
index 65e416b..7d48ec1 100644
--- a/lib/lp/oci/tests/test_ociregistryclient.py
+++ b/lib/lp/oci/tests/test_ociregistryclient.py
@@ -39,6 +39,7 @@ from zope.component import getUtility
 from zope.security.proxy import removeSecurityProxy
 
 from lp.buildmaster.interfaces.processor import IProcessorSet
+from lp.oci.interfaces.ocirecipe import OCI_RECIPE_ALLOW_CREATE
 from lp.oci.interfaces.ocirecipejob import IOCIRecipeRequestBuildsJobSource
 from lp.oci.interfaces.ociregistryclient import (
     BlobUploadFailed,
@@ -47,8 +48,11 @@ from lp.oci.interfaces.ociregistryclient import (
     )
 from lp.oci.model.ocirecipe import OCIRecipeBuildRequest
 from lp.oci.model.ociregistryclient import (
+    AWSAuthenticatorMixin,
+    AWSRegistryBearerTokenClient,
     AWSRegistryHTTPClient,
     BearerTokenRegistryClient,
+    OCI_AWS_BEARER_TOKEN_DOMAINS_FLAG,
     OCIRegistryAuthenticationError,
     OCIRegistryClient,
     proxy_urlfetch,
@@ -56,7 +60,11 @@ from lp.oci.model.ociregistryclient import (
     )
 from lp.oci.tests.helpers import OCIConfigHelperMixin
 from lp.services.compat import mock
-from lp.testing import TestCaseWithFactory
+from lp.services.features.testing import FeatureFixture
+from lp.testing import (
+    admin_logged_in,
+    TestCaseWithFactory,
+    )
 from lp.testing.fixture import ZopeUtilityFixture
 from lp.testing.layers import (
     DatabaseFunctionalLayer,
@@ -861,7 +869,7 @@ class TestRegistryHTTPClient(OCIConfigHelperMixin, SpyProxyCallsMixin,
         self.assertEqual("%s/v2/" % push_rule.registry_url, call.request.url)
 
     @responses.activate
-    def test_get_aws_client_instance(self):
+    def test_get_aws_basic_auth_client_instance(self):
         credentials = self.factory.makeOCIRegistryCredentials(
             url="https://123456789.dkr.ecr.sa-east-1.amazonaws.com";,
             credentials={
@@ -873,10 +881,33 @@ class TestRegistryHTTPClient(OCIConfigHelperMixin, SpyProxyCallsMixin,
 
         instance = RegistryHTTPClient.getInstance(push_rule)
         self.assertEqual(AWSRegistryHTTPClient, type(instance))
+        self.assertFalse(instance.should_use_aws_extra_model)
+        self.assertIsInstance(instance, RegistryHTTPClient)
+
+    @responses.activate
+    def test_get_aws_bearer_token_auth_client_instance(self):
+        self.useFixture(FeatureFixture({
+            OCI_RECIPE_ALLOW_CREATE: 'on',
+            OCI_AWS_BEARER_TOKEN_DOMAINS_FLAG: (
+                'foo.domain.com fake.domain.com'),
+        }))
+        credentials = self.factory.makeOCIRegistryCredentials(
+            url="https://fake.domain.com";,
+            credentials={
+                'username': 'aws_access_key_id',
+                'password': "aws_secret_access_key"})
+        push_rule = removeSecurityProxy(self.factory.makeOCIPushRule(
+            registry_credentials=credentials,
+            image_name="ecr-test"))
+
+        instance = RegistryHTTPClient.getInstance(push_rule)
+        self.assertEqual(AWSRegistryBearerTokenClient, type(instance))
+        self.assertTrue(instance.should_use_aws_extra_model)
         self.assertIsInstance(instance, RegistryHTTPClient)
 
     @responses.activate
     def test_aws_credentials(self):
+        self.pushConfig('launchpad', http_proxy='http://proxy.local.com:123')
         boto_patch = self.useFixture(
             MockPatch('lp.oci.model.ociregistryclient.boto3'))
         boto = boto_patch.mock
@@ -908,8 +939,12 @@ class TestRegistryHTTPClient(OCIConfigHelperMixin, SpyProxyCallsMixin,
             self.assertEqual(mock.call(
                 'ecr', aws_access_key_id="my_aws_access_key_id",
                 aws_secret_access_key="my_aws_secret_access_key",
-                region_name="sa-east-1"),
+                region_name="sa-east-1", config=mock.ANY),
                 boto.client.call_args)
+            config = boto.client.call_args[-1]['config']
+            self.assertEqual({
+                u'http': u'http://proxy.local.com:123',
+                u'https': u'http://proxy.local.com:123'}, config.proxies)
 
     @responses.activate
     def test_aws_malformed_url_region(self):
@@ -1098,3 +1133,65 @@ class TestBearerTokenRegistryClient(OCIConfigHelperMixin,
 
         self.assertRaises(OCIRegistryAuthenticationError,
                           client.authenticate, previous_request)
+
+
+class TestAWSAuthenticator(OCIConfigHelperMixin, TestCaseWithFactory):
+    layer = DatabaseFunctionalLayer
+
+    def setUp(self):
+        super(TestAWSAuthenticator, self).setUp()
+        self.setConfig()
+
+    def test_get_region_from_credential(self):
+        cred = self.factory.makeOCIRegistryCredentials(
+            url="https://any.com";, credentials={"region": "sa-east-1"})
+        push_rule = self.factory.makeOCIPushRule(registry_credentials=cred)
+
+        with admin_logged_in():
+            auth = AWSAuthenticatorMixin()
+            auth.push_rule = push_rule
+            self.assertEqual("sa-east-1", auth._getRegion())
+
+    def test_get_region_from_url(self):
+        cred = self.factory.makeOCIRegistryCredentials(
+            url="https://123456789.dkr.ecr.sa-west-1.amazonaws.com";)
+        push_rule = self.factory.makeOCIPushRule(registry_credentials=cred)
+
+        with admin_logged_in():
+            auth = AWSAuthenticatorMixin()
+            auth.push_rule = push_rule
+            self.assertEqual("sa-west-1", auth._getRegion())
+
+    def test_get_region_invalid_url(self):
+        cred = self.factory.makeOCIRegistryCredentials(
+            url="https://something.invalid";)
+        push_rule = self.factory.makeOCIPushRule(registry_credentials=cred)
+
+        with admin_logged_in():
+            auth = AWSAuthenticatorMixin()
+            auth.push_rule = push_rule
+            self.assertRaises(OCIRegistryAuthenticationError, auth._getRegion)
+
+    def test_should_use_extra_model(self):
+        self.setConfig({
+            OCI_AWS_BEARER_TOKEN_DOMAINS_FLAG: 'bearertoken.aws.com'})
+        cred = self.factory.makeOCIRegistryCredentials(
+            url="https://myregistry.bearertoken.aws.com";)
+        push_rule = self.factory.makeOCIPushRule(registry_credentials=cred)
+
+        with admin_logged_in():
+            auth = AWSAuthenticatorMixin()
+            auth.push_rule = push_rule
+            self.assertTrue(auth.should_use_aws_extra_model)
+
+    def test_should_not_use_extra_model(self):
+        self.setConfig({
+            OCI_AWS_BEARER_TOKEN_DOMAINS_FLAG: 'bearertoken.aws.com'})
+        cred = self.factory.makeOCIRegistryCredentials(
+            url="https://123456789.dkr.ecr.sa-west-1.amazonaws.com";)
+        push_rule = self.factory.makeOCIPushRule(registry_credentials=cred)
+
+        with admin_logged_in():
+            auth = AWSAuthenticatorMixin()
+            auth.push_rule = push_rule
+            self.assertFalse(auth.should_use_aws_extra_model)

Follow ups