← Back to team overview

launchpad-reviewers team mailing list archive

[Merge] ~pappacena/turnip:celery-repo-creation into turnip:master

 

Thiago F. Pappacena has proposed merging ~pappacena/turnip:celery-repo-creation into turnip:master.

Commit message:
Option on POST /repo API call to asynchronously create a repository using celery

Requested reviews:
  Launchpad code reviewers (launchpad-reviewers)

For more details, see:
https://code.launchpad.net/~pappacena/turnip/+git/turnip/+merge/387611
-- 
Your team Launchpad code reviewers is requested to review the proposed merge of ~pappacena/turnip:celery-repo-creation into turnip:master.
diff --git a/Makefile b/Makefile
index d10a9b2..67f60cd 100644
--- a/Makefile
+++ b/Makefile
@@ -8,6 +8,7 @@ PIP_CACHE = $(CURDIR)/pip-cache
 PYTHON := $(ENV)/bin/python
 PSERVE := $(ENV)/bin/pserve
 FLAKE8 := $(ENV)/bin/flake8
+CELERY := $(ENV)/bin/celery
 PIP := $(ENV)/bin/pip
 VIRTUALENV := virtualenv
 
@@ -64,7 +65,12 @@ endif
 	$(PIP) install $(PIP_ARGS) -c requirements.txt \
 		-e '.[test,deploy]'
 
-test: $(ENV)
+bootstrap-test:
+	-sudo rabbitmqctl delete_vhost turnip-test-vhost
+	-sudo rabbitmqctl add_vhost turnip-test-vhost
+	-sudo rabbitmqctl set_permissions -p "turnip-test-vhost" "guest" ".*" ".*" ".*"
+
+test: $(ENV) bootstrap-test
 	$(PYTHON) -m unittest discover $(ARGS) turnip
 
 clean:
@@ -101,6 +107,24 @@ run-api: $(ENV)
 run-pack: $(ENV)
 	$(PYTHON) turnipserver.py
 
+run-worker: $(ENV)
+	PYTHONPATH="turnip" $(CELERY) -A tasks worker \
+		--loglevel=info \
+		--concurrency=20 \
+		--pool=gevent
+
+run:
+	make run-api &\
+	make run-pack &\
+	make run-worker&\
+	wait;
+
+stop:
+	-pkill -f 'make run-api'
+	-pkill -f 'make run-pack'
+	-pkill -f 'make run-worker'
+	-pkill -f '$(CELERY) -A tasks worker'
+
 $(PIP_CACHE): $(ENV)
 	mkdir -p $(PIP_CACHE)
 	$(PIP) install $(PIP_ARGS) -d $(PIP_CACHE) \
diff --git a/config.yaml b/config.yaml
index e078c33..99b0e7b 100644
--- a/config.yaml
+++ b/config.yaml
@@ -20,3 +20,4 @@ cgit_secret_path: null
 openid_provider_root: https://testopenid.test/
 site_name: git.launchpad.test
 main_site_root: https://launchpad.test/
+celery_broker: pyamqp://guest@localhost//
diff --git a/requirements.txt b/requirements.txt
index 365f354..29ca5d2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,9 +4,10 @@ attrs==19.3.0
 Automat==20.2.0
 bcrypt==3.1.7
 beautifulsoup4==4.6.3
+celery==4.4.6
 cffi==1.14.0
 constantly==15.1.0
-contextlib2==0.4.0
+contextlib2==0.6.0
 cornice==3.6.1
 cryptography==2.8
 docutils==0.14
@@ -15,6 +16,7 @@ envdir==0.7
 extras==1.0.0
 fixtures==3.0.0
 flake8==2.4.0
+gevent==20.6.2
 gmpy==1.17
 gunicorn==19.3.0
 hyperlink==19.0.0
@@ -31,6 +33,7 @@ Paste==2.0.2
 PasteDeploy==2.1.0
 pbr==5.4.4
 pep8==1.5.7
+psutil==5.7.0
 pyasn1==0.4.8
 pycparser==2.17
 pycrypto==2.6.1
diff --git a/setup.py b/setup.py
index 6eff51e..0c1f5f0 100755
--- a/setup.py
+++ b/setup.py
@@ -18,9 +18,11 @@ with open(os.path.join(here, 'NEWS')) as f:
     README += "\n\n" + f.read()
 
 requires = [
+    'celery',
     'contextlib2',
     'cornice',
     'enum34; python_version < "3.4"',
+    'gevent',
     'lazr.sshserver>=0.1.7',
     'Paste',
     'pygit2>=0.27.4,<0.28.0',
diff --git a/system-dependencies.txt b/system-dependencies.txt
index 845d453..fdcb1d8 100644
--- a/system-dependencies.txt
+++ b/system-dependencies.txt
@@ -6,3 +6,4 @@ libgit2-27
 libssl-dev
 python-dev
 virtualenv
+rabbitmq-server
diff --git a/turnip/api/store.py b/turnip/api/store.py
index 5723048..d6568c3 100644
--- a/turnip/api/store.py
+++ b/turnip/api/store.py
@@ -27,7 +27,10 @@ from pygit2 import (
     Repository,
     )
 
+from turnip.config import config
+from turnip.helpers import TimeoutServerProxy
 from turnip.pack.helpers import ensure_config
+from turnip.tasks import app
 
 
 REF_TYPE_NAME = {
@@ -38,6 +41,10 @@ REF_TYPE_NAME = {
     }
 
 
+# Where to store repository status information inside a repository directory.
+REPOSITORY_CREATING_FILE_NAME = '.turnip-creating'
+
+
 def format_ref(ref, git_object):
     """Return a formatted object dict from a ref."""
     return {
@@ -209,6 +216,7 @@ def init_repo(repo_path, clone_from=None, clone_refs=False,
     if os.path.exists(repo_path):
         raise AlreadyExistsError(repo_path)
     init_repository(repo_path, is_bare)
+    set_repository_creating(repo_path, True)
 
     if clone_from:
         # The clone_from's objects and refs are in fact cloned into a
@@ -240,6 +248,28 @@ def init_repo(repo_path, clone_from=None, clone_refs=False,
         write_packed_refs(repo_path, packable_refs)
 
     ensure_config(repo_path)  # set repository configuration defaults
+    set_repository_creating(repo_path, False)
+
+
+@app.task
+def init_and_confirm_repo(untranslated_path, repo_path, clone_from=None,
+                          clone_refs=False, alternate_repo_paths=None,
+                          is_bare=True):
+    xmlrpc_endpoint = config.get("virtinfo_endpoint")
+    xmlrpc_timeout = float(config.get("virtinfo_timeout"))
+    xmlrpc_auth_params = {"user": "+launchpad-services"}
+    xmlrpc_proxy = TimeoutServerProxy(
+        xmlrpc_endpoint, timeout=xmlrpc_timeout, allow_none=True)
+    try:
+        init_repo(
+            repo_path, clone_from, clone_refs, alternate_repo_paths, is_bare)
+        xmlrpc_proxy.confirmRepoCreation(untranslated_path, xmlrpc_auth_params)
+    except Exception:
+        try:
+            delete_repo(repo_path)
+        except IOError:
+            pass
+        xmlrpc_proxy.abortRepoCreation(untranslated_path, xmlrpc_auth_params)
 
 
 @contextmanager
@@ -287,6 +317,24 @@ def get_default_branch(repo_path):
     return repo.references['HEAD'].target
 
 
+def set_repository_creating(repo_path, is_creating):
+    file_path = os.path.join(repo_path, REPOSITORY_CREATING_FILE_NAME)
+    if is_creating:
+        open(file_path, 'a').close()
+    else:
+        os.unlink(file_path)
+
+
+def is_repository_available(repo_path):
+    """Checks if the repository is available (that is, if it is not in the
+    middle of a clone or init operation)."""
+    if not os.path.exists(repo_path):
+        return False
+
+    status_file_path = os.path.join(repo_path, REPOSITORY_CREATING_FILE_NAME)
+    return not os.path.exists(status_file_path)
+
+
 def set_default_branch(repo_path, target):
     repo = Repository(repo_path)
     repo.set_head(target)
diff --git a/turnip/api/tests/test_api.py b/turnip/api/tests/test_api.py
index 396369b..5908143 100644
--- a/turnip/api/tests/test_api.py
+++ b/turnip/api/tests/test_api.py
@@ -6,9 +6,11 @@
 from __future__ import print_function
 
 import base64
+from datetime import timedelta, datetime
 import os
 import subprocess
 from textwrap import dedent
+import time
 import unittest
 import uuid
 
@@ -22,6 +24,8 @@ from testtools.matchers import (
     Equals,
     MatchesSetwise,
     )
+from twisted.internet import reactor as default_reactor
+from twisted.web import server
 from webtest import TestApp
 
 from turnip import api
@@ -31,12 +35,14 @@ from turnip.api.tests.test_helpers import (
     open_repo,
     RepoFactory,
     )
+from turnip.config import config
+from turnip.pack.tests.fake_servers import FakeVirtInfoService
+from turnip.tests.compat import mock
+from turnip.tests.tasks import CeleryWorkerFixture
 
 
-class ApiTestCase(TestCase):
-
-    def setUp(self):
-        super(ApiTestCase, self).setUp()
+class ApiRepoStoreMixin:
+    def setupRepoStore(self):
         repo_store = self.useFixture(TempDir()).path
         self.useFixture(EnvironmentVariable("REPO_STORE", repo_store))
         self.app = TestApp(api.main({}))
@@ -46,6 +52,13 @@ class ApiTestCase(TestCase):
         self.commit = {'ref': 'refs/heads/master', 'message': 'test commit.'}
         self.tag = {'ref': 'refs/tags/tag0', 'message': 'tag message'}
 
+
+class ApiTestCase(TestCase, ApiRepoStoreMixin):
+
+    def setUp(self):
+        super(ApiTestCase, self).setUp()
+        self.setupRepoStore()
+
     def assertReferencesEqual(self, repo, expected, observed):
         self.assertEqual(
             repo.references[expected].peel().oid,
@@ -98,7 +111,9 @@ class ApiTestCase(TestCase):
 
         resp = self.app.get('/repo/{}'.format(self.repo_path))
         self.assertEqual(200, resp.status_code)
-        self.assertEqual({'default_branch': 'refs/heads/branch-0'}, resp.json)
+        self.assertEqual({
+            'default_branch': 'refs/heads/branch-0',
+            'is_available': True}, resp.json)
 
     def test_repo_get_default_branch_missing(self):
         """default_branch is returned even if that branch has been deleted."""
@@ -109,7 +124,9 @@ class ApiTestCase(TestCase):
 
         resp = self.app.get('/repo/{}'.format(self.repo_path))
         self.assertEqual(200, resp.status_code)
-        self.assertEqual({'default_branch': 'refs/heads/branch-0'}, resp.json)
+        self.assertEqual({
+            'default_branch': 'refs/heads/branch-0',
+            'is_available': True}, resp.json)
 
     def test_repo_patch_default_branch(self):
         """A repository's default branch ("HEAD") can be changed."""
@@ -872,5 +889,159 @@ class ApiTestCase(TestCase):
         self.assertEqual(404, resp.status_code)
 
 
+class AsyncRepoCreationAPI(TestCase, ApiRepoStoreMixin):
+
+    def setUp(self):
+        super(AsyncRepoCreationAPI, self).setUp()
+        self.setupRepoStore()
+        # XML-RPC server
+        self.virtinfo = FakeVirtInfoService(allowNone=True)
+        self.virtinfo_listener = default_reactor.listenTCP(0, server.Site(
+            self.virtinfo))
+        self.virtinfo_port = self.virtinfo_listener.getHost().port
+        self.virtinfo_url = b'http://localhost:%d/' % self.virtinfo_port
+        self.addCleanup(self.virtinfo_listener.stopListening)
+        config.defaults['virtinfo_endpoint'] = self.virtinfo_url
+
+    def _doReactorIteration(self):
+        """Yield to the reactor so it can process virtinfo requests.
+
+        This is a bit hacky, but allow us to simulate the twisted XML-RPC
+        fake server without needing to make this test suite async.
+        Making this test suite async could make it less realistic, since the
+        API beign tested itself is not running over twisted event loop.
+        """
+        reactor_iterations = (
+            len(default_reactor._reads) + len(default_reactor._writes))
+        for i in range(reactor_iterations * 100):
+            default_reactor.iterate()
+
+    def assertRepositoryCreatedAsynchronously(self, repo_path, timeout_secs=5):
+        """Waits up to `timeout_secs` for a repository to be available."""
+        timeout = timedelta(seconds=timeout_secs)
+        start = datetime.now()
+        while datetime.now() <= (start + timeout):
+            self._doReactorIteration()
+            try:
+                resp = self.app.get('/repo/{}'.format(repo_path),
+                                    expect_errors=True)
+                if resp.status_code == 200 and resp.json['is_available']:
+                    return
+            except:
+                pass
+            time.sleep(0.1)
+        self.fail(
+            "Repository %s was not created after %s secs"
+            % (repo_path, timeout_secs))
+
+    def assertAnyMockCalledAsync(self, mocks, timeout_secs=5):
+        """Asserts that any of the mocks in *args will be called in the
+        next timeout_secs seconds.
+        """
+        timeout = timedelta(seconds=timeout_secs)
+        start = datetime.now()
+        while datetime.now() <= (start + timeout):
+            self._doReactorIteration()
+            if any(i.called for i in mocks):
+                return
+            time.sleep(0.1)
+        self.fail(
+            "None of the given args was called after %s seconds."
+            % timeout_secs)
+
+    def test_repo_async_creation_with_clone(self):
+        """Repo can be initialised with optional clone asynchronously."""
+        self.useFixture(CeleryWorkerFixture())
+        self.virtinfo.xmlrpc_confirmRepoCreation = mock.Mock(return_value=None)
+        self.virtinfo.xmlrpc_abortRepoCreation = mock.Mock(return_value=None)
+
+        factory = RepoFactory(self.repo_store, num_commits=2)
+        factory.build()
+        new_repo_path = uuid.uuid1().hex
+        resp = self.app.post_json('/repo', {
+            'async': True,
+            'repo_path': new_repo_path,
+            'clone_from': self.repo_path,
+            'clone_refs': True})
+
+        self.assertRepositoryCreatedAsynchronously(new_repo_path)
+
+        repo1_revlist = get_revlist(factory.repo)
+        clone_from = resp.json['repo_url'].split('/')[-1]
+        repo2 = open_repo(os.path.join(self.repo_root, clone_from))
+        repo2_revlist = get_revlist(repo2)
+
+        self.assertEqual(repo1_revlist, repo2_revlist)
+        self.assertEqual(200, resp.status_code)
+        self.assertIn(new_repo_path, resp.json['repo_url'])
+
+        self.assertEqual([mock.call(
+            mock.ANY, new_repo_path, {"user": "+launchpad-services"})],
+            self.virtinfo.xmlrpc_confirmRepoCreation.call_args_list)
+        self.assertEqual(
+            [], self.virtinfo.xmlrpc_abortRepoCreation.call_args_list)
+
+    def test_repo_async_creation_aborts_when_fails_to_create_locally(self):
+        """Repo can be initialised with optional clone asynchronously."""
+        self.useFixture(
+            EnvironmentVariable("REPO_STORE", '/tmp/invalid/path/to/repos/'))
+        self.useFixture(CeleryWorkerFixture())
+        self.virtinfo.xmlrpc_confirmRepoCreation = mock.Mock(return_value=None)
+        self.virtinfo.xmlrpc_abortRepoCreation = mock.Mock(return_value=None)
+
+        factory = RepoFactory(self.repo_store, num_commits=2)
+        factory.build()
+        new_repo_path = uuid.uuid1().hex
+        self.app.post_json('/repo', {
+            'async': True,
+            'repo_path': new_repo_path,
+            'clone_from': self.repo_path,
+            'clone_refs': True})
+
+        # Wait until the repository creation is either confirmed or aborted
+        # (and we hope it was aborted...)
+        self.assertAnyMockCalledAsync([
+            self.virtinfo.xmlrpc_confirmRepoCreation,
+            self.virtinfo.xmlrpc_abortRepoCreation])
+        self.assertFalse(
+            os.path.exists(os.path.join(self.repo_root, new_repo_path)))
+
+        self.assertEqual([mock.call(
+            mock.ANY, new_repo_path, {"user": "+launchpad-services"})],
+            self.virtinfo.xmlrpc_abortRepoCreation.call_args_list)
+        self.assertEqual(
+            [], self.virtinfo.xmlrpc_confirmRepoCreation.call_args_list)
+
+    def test_repo_async_creation_aborts_when_fails_confirm(self):
+        """Repo can be initialised with optional clone asynchronously."""
+        self.useFixture(CeleryWorkerFixture())
+        self.virtinfo.xmlrpc_confirmRepoCreation = mock.Mock(
+            side_effect=Exception("?"))
+        self.virtinfo.xmlrpc_abortRepoCreation = mock.Mock(return_value=None)
+
+        factory = RepoFactory(self.repo_store, num_commits=2)
+        factory.build()
+        new_repo_path = uuid.uuid1().hex
+        self.app.post_json('/repo', {
+            'async': True,
+            'repo_path': new_repo_path,
+            'clone_from': self.repo_path,
+            'clone_refs': True})
+
+        # Wait until the repository creation is either confirmed or aborted
+        # (and we hope it was aborted...)
+        self.assertAnyMockCalledAsync([
+            self.virtinfo.xmlrpc_abortRepoCreation])
+        self.assertFalse(
+            os.path.exists(os.path.join(self.repo_root, new_repo_path)))
+
+        self.assertEqual([mock.call(
+            mock.ANY, new_repo_path, {"user": "+launchpad-services"})],
+            self.virtinfo.xmlrpc_abortRepoCreation.call_args_list)
+        self.assertEqual([mock.call(
+            mock.ANY, new_repo_path, {"user": "+launchpad-services"})],
+            self.virtinfo.xmlrpc_confirmRepoCreation.call_args_list)
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/turnip/api/tests/test_store.py b/turnip/api/tests/test_store.py
index 9276303..70672d8 100644
--- a/turnip/api/tests/test_store.py
+++ b/turnip/api/tests/test_store.py
@@ -125,6 +125,16 @@ class InitTestCase(TestCase):
         self.assertEqual(str(yaml_config['pack.depth']),
                          repo_config['pack.depth'])
 
+    def test_is_repository_available(self):
+        repo_path = os.path.join(self.repo_store, 'repo/')
+
+        store.init_repository(repo_path, True)
+        store.set_repository_creating(repo_path, True)
+        self.assertFalse(store.is_repository_available(repo_path))
+
+        store.set_repository_creating(repo_path, False)
+        self.assertTrue(store.is_repository_available(repo_path))
+
     def test_open_ephemeral_repo(self):
         """Opening a repo where a repo name contains ':' should return
         a new ephemeral repo.
@@ -224,7 +234,10 @@ class InitTestCase(TestCase):
         # repo with the same set of refs. And the objects are copied
         # too.
         to_path = os.path.join(self.repo_store, 'to/')
+        self.assertFalse(store.is_repository_available(to_path))
         store.init_repo(to_path, clone_from=self.orig_path, clone_refs=True)
+        self.assertTrue(store.is_repository_available(to_path))
+
         to = pygit2.Repository(to_path)
         self.assertIsNot(None, to[self.master_oid])
         self.assertEqual(
@@ -260,7 +273,10 @@ class InitTestCase(TestCase):
         # init_repo with clone_from=orig and clone_refs=False creates a
         # repo without any refs, but the objects are copied.
         to_path = os.path.join(self.repo_store, 'to/')
+        self.assertFalse(store.is_repository_available(to_path))
         store.init_repo(to_path, clone_from=self.orig_path, clone_refs=False)
+        self.assertTrue(store.is_repository_available(to_path))
+
         to = pygit2.Repository(to_path)
         self.assertIsNot(None, to[self.master_oid])
         self.assertEqual([], to.listall_references())
@@ -286,7 +302,10 @@ class InitTestCase(TestCase):
 
         self.assertAllLinkCounts(1, self.orig_objs)
         to_path = os.path.join(self.repo_store, 'to/')
+        self.assertFalse(store.is_repository_available(to_path))
         store.init_repo(to_path, clone_from=self.orig_path)
+        self.assertTrue(store.is_repository_available(to_path))
+
         self.assertAllLinkCounts(2, self.orig_objs)
         to = pygit2.Repository(to_path)
         to_blob = to.create_blob(b'to')
diff --git a/turnip/api/views.py b/turnip/api/views.py
index 2e8a082..de6bb4e 100644
--- a/turnip/api/views.py
+++ b/turnip/api/views.py
@@ -53,9 +53,11 @@ class RepoAPI(BaseAPI):
 
     def collection_post(self):
         """Initialise a new git repository, or clone from an existing repo."""
-        repo_path = extract_json_data(self.request).get('repo_path')
-        clone_path = extract_json_data(self.request).get('clone_from')
-        clone_refs = extract_json_data(self.request).get('clone_refs', False)
+        json_data = extract_json_data(self.request)
+        repo_path = json_data.get('repo_path')
+        clone_path = json_data.get('clone_from')
+        clone_refs = json_data.get('clone_refs', False)
+        async_run = json_data.get('async', False)
 
         if not repo_path:
             self.request.errors.add('body', 'repo_path',
@@ -72,7 +74,13 @@ class RepoAPI(BaseAPI):
             repo_clone = None
 
         try:
-            store.init_repo(repo, clone_from=repo_clone, clone_refs=clone_refs)
+            kwargs = dict(
+                repo_path=repo, clone_from=repo_clone, clone_refs=clone_refs)
+            if async_run:
+                kwargs["untranslated_path"] = repo_path
+                store.init_and_confirm_repo.apply_async(kwargs=kwargs)
+            else:
+                store.init_repo(**kwargs)
             repo_name = os.path.basename(os.path.normpath(repo))
             return {'repo_url': '/'.join([self.request.url, repo_name])}
         except GitError:
@@ -88,6 +96,7 @@ class RepoAPI(BaseAPI):
             raise exc.HTTPNotFound()
         return {
             'default_branch': store.get_default_branch(repo_path),
+            'is_available': store.is_repository_available(repo_path)
             }
 
     def _patch_default_branch(self, repo_path, value):
diff --git a/turnip/helpers.py b/turnip/helpers.py
index 77e12ed..b23e4a6 100644
--- a/turnip/helpers.py
+++ b/turnip/helpers.py
@@ -10,6 +10,7 @@ from __future__ import (
 import os.path
 
 import six
+from six.moves import xmlrpc_client
 
 
 def compose_path(root, path):
@@ -22,3 +23,24 @@ def compose_path(root, path):
     if not full_path.startswith(os.path.abspath(root)):
         raise ValueError('Path not contained within root')
     return full_path
+
+
+class TimeoutTransport(xmlrpc_client.Transport):
+
+    def __init__(self, timeout, use_datetime=0):
+        self.timeout = timeout
+        xmlrpc_client.Transport.__init__(self, use_datetime)
+
+    def make_connection(self, host):
+        connection = xmlrpc_client.Transport.make_connection(self, host)
+        connection.timeout = self.timeout
+        return connection
+
+
+class TimeoutServerProxy(xmlrpc_client.ServerProxy):
+
+    def __init__(self, uri, timeout=10, transport=None, encoding=None,
+                 verbose=0, allow_none=0, use_datetime=0):
+        t = TimeoutTransport(timeout)
+        xmlrpc_client.ServerProxy.__init__(
+            self, uri, t, encoding, verbose, allow_none, use_datetime)
diff --git a/turnip/tasks.py b/turnip/tasks.py
new file mode 100644
index 0000000..51b5d20
--- /dev/null
+++ b/turnip/tasks.py
@@ -0,0 +1,16 @@
+# Copyright 2020 Canonical Ltd.  This software is licensed under the
+# GNU Affero General Public License version 3 (see the file LICENSE).
+
+from __future__ import print_function, unicode_literals, absolute_import
+
+__all__ = [
+    'app'
+]
+
+from celery import Celery
+
+from turnip.config import config
+
+
+app = Celery('tasks', broker=config.get('celery_broker'))
+app.conf.update(imports=('turnip.api.store', ))
diff --git a/turnip/tests/__init__.py b/turnip/tests/__init__.py
index 494f3fd..6be77e3 100644
--- a/turnip/tests/__init__.py
+++ b/turnip/tests/__init__.py
@@ -8,6 +8,8 @@ from __future__ import (
     )
 
 from turnip.tests.logging import setupLogger
+from turnip.tests.tasks import setupCelery
 
 
 setupLogger()
+setupCelery()
diff --git a/turnip/tests/tasks.py b/turnip/tests/tasks.py
new file mode 100644
index 0000000..401d079
--- /dev/null
+++ b/turnip/tests/tasks.py
@@ -0,0 +1,84 @@
+# Copyright 2020 Canonical Ltd.  This software is licensed under the
+# GNU Affero General Public License version 3 (see the file LICENSE).
+
+import atexit
+import os
+import subprocess
+import sys
+
+from testtools.testcase import fixtures
+
+from turnip.config import config
+from turnip.tasks import app
+
+BROKER_URL = 'pyamqp://guest@localhost/turnip-test-vhost'
+
+
+def setupCelery():
+    app.conf.update(broker_url=BROKER_URL)
+
+
+class CeleryWorkerFixture(fixtures.Fixture):
+    """Celery worker fixture for tests
+
+    This fixture starts a celery worker with the configuration set when the
+    fixture is setUp. Keep in mind that this will run in a separated
+    new process, so mock patches for example will be lost.
+    """
+    _worker_proc = None
+
+    def __init__(self, loglevel="error", force_restart=True, env=None):
+        """
+        Build a celery worker for test cases.
+
+        :param loglevel: Which log level to use for the worker.
+        :param force_restart: If True and a celery worker is already running,
+            stop it. If False, do not restart if another worker is
+            already running.
+        :param env: The environment variables to be used when creating
+            the worker.
+        """
+        self.force_restart = force_restart
+        self.loglevel = loglevel
+        self.env = env
+
+    def startCeleryWorker(self):
+        """Start a celery worker for integration tests."""
+        if self.force_restart:
+            self.stopCeleryWorker()
+        if CeleryWorkerFixture._worker_proc is not None:
+            return
+        bin_path = os.path.dirname(sys.executable)
+        celery = os.path.join(bin_path, 'celery')
+        turnip_path = os.path.join(os.path.dirname(__file__), '..')
+        cmd = [
+            celery, 'worker', '-A', 'tasks', '--quiet',
+            '--pool=gevent',
+            '--concurrency=2',
+            '--broker=%s' % BROKER_URL,
+            '--loglevel=%s' % self.loglevel]
+
+        # Send to the subprocess, as env variables, the same configurations we
+        # are currently using.
+        proc_env = {'PYTHONPATH': turnip_path}
+        for k in config.defaults:
+            proc_env[k.upper()] = str(config.get(k))
+        proc_env.update(self.env or {})
+
+        CeleryWorkerFixture._worker_proc = subprocess.Popen(cmd, env=proc_env)
+        atexit.register(self.stopCeleryWorker)
+
+    def stopCeleryWorker(self):
+        worker_proc = CeleryWorkerFixture._worker_proc
+        if worker_proc:
+            worker_proc.kill()
+            worker_proc.wait()
+        CeleryWorkerFixture._worker_proc = None
+        # Cleanup the queue.
+        app.control.purge()
+
+    def _setUp(self):
+        self.startCeleryWorker()
+
+    def _cleanup(self):
+        self.stopCeleryWorker()

Follow ups