launchpad-reviewers team mailing list archive
-
launchpad-reviewers team
-
Mailing list archive
-
Message #25036
[Merge] ~pappacena/turnip:celery-repo-creation into turnip:master
Thiago F. Pappacena has proposed merging ~pappacena/turnip:celery-repo-creation into turnip:master.
Commit message:
Option on POST /repo API call to asynchronously create a repository using celery
Requested reviews:
Launchpad code reviewers (launchpad-reviewers)
For more details, see:
https://code.launchpad.net/~pappacena/turnip/+git/turnip/+merge/387611
--
Your team Launchpad code reviewers is requested to review the proposed merge of ~pappacena/turnip:celery-repo-creation into turnip:master.
diff --git a/Makefile b/Makefile
index d10a9b2..67f60cd 100644
--- a/Makefile
+++ b/Makefile
@@ -8,6 +8,7 @@ PIP_CACHE = $(CURDIR)/pip-cache
PYTHON := $(ENV)/bin/python
PSERVE := $(ENV)/bin/pserve
FLAKE8 := $(ENV)/bin/flake8
+CELERY := $(ENV)/bin/celery
PIP := $(ENV)/bin/pip
VIRTUALENV := virtualenv
@@ -64,7 +65,12 @@ endif
$(PIP) install $(PIP_ARGS) -c requirements.txt \
-e '.[test,deploy]'
-test: $(ENV)
+bootstrap-test:
+ -sudo rabbitmqctl delete_vhost turnip-test-vhost
+ -sudo rabbitmqctl add_vhost turnip-test-vhost
+ -sudo rabbitmqctl set_permissions -p "turnip-test-vhost" "guest" ".*" ".*" ".*"
+
+test: $(ENV) bootstrap-test
$(PYTHON) -m unittest discover $(ARGS) turnip
clean:
@@ -101,6 +107,24 @@ run-api: $(ENV)
run-pack: $(ENV)
$(PYTHON) turnipserver.py
+run-worker: $(ENV)
+ PYTHONPATH="turnip" $(CELERY) -A tasks worker \
+ --loglevel=info \
+ --concurrency=20 \
+ --pool=gevent
+
+run:
+ make run-api &\
+ make run-pack &\
+ make run-worker&\
+ wait;
+
+stop:
+ -pkill -f 'make run-api'
+ -pkill -f 'make run-pack'
+ -pkill -f 'make run-worker'
+ -pkill -f '$(CELERY) -A tasks worker'
+
$(PIP_CACHE): $(ENV)
mkdir -p $(PIP_CACHE)
$(PIP) install $(PIP_ARGS) -d $(PIP_CACHE) \
diff --git a/config.yaml b/config.yaml
index e078c33..99b0e7b 100644
--- a/config.yaml
+++ b/config.yaml
@@ -20,3 +20,4 @@ cgit_secret_path: null
openid_provider_root: https://testopenid.test/
site_name: git.launchpad.test
main_site_root: https://launchpad.test/
+celery_broker: pyamqp://guest@localhost//
diff --git a/requirements.txt b/requirements.txt
index 365f354..29ca5d2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,9 +4,10 @@ attrs==19.3.0
Automat==20.2.0
bcrypt==3.1.7
beautifulsoup4==4.6.3
+celery==4.4.6
cffi==1.14.0
constantly==15.1.0
-contextlib2==0.4.0
+contextlib2==0.6.0
cornice==3.6.1
cryptography==2.8
docutils==0.14
@@ -15,6 +16,7 @@ envdir==0.7
extras==1.0.0
fixtures==3.0.0
flake8==2.4.0
+gevent==20.6.2
gmpy==1.17
gunicorn==19.3.0
hyperlink==19.0.0
@@ -31,6 +33,7 @@ Paste==2.0.2
PasteDeploy==2.1.0
pbr==5.4.4
pep8==1.5.7
+psutil==5.7.0
pyasn1==0.4.8
pycparser==2.17
pycrypto==2.6.1
diff --git a/setup.py b/setup.py
index 6eff51e..0c1f5f0 100755
--- a/setup.py
+++ b/setup.py
@@ -18,9 +18,11 @@ with open(os.path.join(here, 'NEWS')) as f:
README += "\n\n" + f.read()
requires = [
+ 'celery',
'contextlib2',
'cornice',
'enum34; python_version < "3.4"',
+ 'gevent',
'lazr.sshserver>=0.1.7',
'Paste',
'pygit2>=0.27.4,<0.28.0',
diff --git a/system-dependencies.txt b/system-dependencies.txt
index 845d453..fdcb1d8 100644
--- a/system-dependencies.txt
+++ b/system-dependencies.txt
@@ -6,3 +6,4 @@ libgit2-27
libssl-dev
python-dev
virtualenv
+rabbitmq-server
diff --git a/turnip/api/store.py b/turnip/api/store.py
index 5723048..d6568c3 100644
--- a/turnip/api/store.py
+++ b/turnip/api/store.py
@@ -27,7 +27,10 @@ from pygit2 import (
Repository,
)
+from turnip.config import config
+from turnip.helpers import TimeoutServerProxy
from turnip.pack.helpers import ensure_config
+from turnip.tasks import app
REF_TYPE_NAME = {
@@ -38,6 +41,10 @@ REF_TYPE_NAME = {
}
+# Where to store repository status information inside a repository directory.
+REPOSITORY_CREATING_FILE_NAME = '.turnip-creating'
+
+
def format_ref(ref, git_object):
"""Return a formatted object dict from a ref."""
return {
@@ -209,6 +216,7 @@ def init_repo(repo_path, clone_from=None, clone_refs=False,
if os.path.exists(repo_path):
raise AlreadyExistsError(repo_path)
init_repository(repo_path, is_bare)
+ set_repository_creating(repo_path, True)
if clone_from:
# The clone_from's objects and refs are in fact cloned into a
@@ -240,6 +248,28 @@ def init_repo(repo_path, clone_from=None, clone_refs=False,
write_packed_refs(repo_path, packable_refs)
ensure_config(repo_path) # set repository configuration defaults
+ set_repository_creating(repo_path, False)
+
+
+@app.task
+def init_and_confirm_repo(untranslated_path, repo_path, clone_from=None,
+ clone_refs=False, alternate_repo_paths=None,
+ is_bare=True):
+ xmlrpc_endpoint = config.get("virtinfo_endpoint")
+ xmlrpc_timeout = float(config.get("virtinfo_timeout"))
+ xmlrpc_auth_params = {"user": "+launchpad-services"}
+ xmlrpc_proxy = TimeoutServerProxy(
+ xmlrpc_endpoint, timeout=xmlrpc_timeout, allow_none=True)
+ try:
+ init_repo(
+ repo_path, clone_from, clone_refs, alternate_repo_paths, is_bare)
+ xmlrpc_proxy.confirmRepoCreation(untranslated_path, xmlrpc_auth_params)
+ except Exception:
+ try:
+ delete_repo(repo_path)
+ except IOError:
+ pass
+ xmlrpc_proxy.abortRepoCreation(untranslated_path, xmlrpc_auth_params)
@contextmanager
@@ -287,6 +317,24 @@ def get_default_branch(repo_path):
return repo.references['HEAD'].target
+def set_repository_creating(repo_path, is_creating):
+ file_path = os.path.join(repo_path, REPOSITORY_CREATING_FILE_NAME)
+ if is_creating:
+ open(file_path, 'a').close()
+ else:
+ os.unlink(file_path)
+
+
+def is_repository_available(repo_path):
+ """Checks if the repository is available (that is, if it is not in the
+ middle of a clone or init operation)."""
+ if not os.path.exists(repo_path):
+ return False
+
+ status_file_path = os.path.join(repo_path, REPOSITORY_CREATING_FILE_NAME)
+ return not os.path.exists(status_file_path)
+
+
def set_default_branch(repo_path, target):
repo = Repository(repo_path)
repo.set_head(target)
diff --git a/turnip/api/tests/test_api.py b/turnip/api/tests/test_api.py
index 396369b..5908143 100644
--- a/turnip/api/tests/test_api.py
+++ b/turnip/api/tests/test_api.py
@@ -6,9 +6,11 @@
from __future__ import print_function
import base64
+from datetime import timedelta, datetime
import os
import subprocess
from textwrap import dedent
+import time
import unittest
import uuid
@@ -22,6 +24,8 @@ from testtools.matchers import (
Equals,
MatchesSetwise,
)
+from twisted.internet import reactor as default_reactor
+from twisted.web import server
from webtest import TestApp
from turnip import api
@@ -31,12 +35,14 @@ from turnip.api.tests.test_helpers import (
open_repo,
RepoFactory,
)
+from turnip.config import config
+from turnip.pack.tests.fake_servers import FakeVirtInfoService
+from turnip.tests.compat import mock
+from turnip.tests.tasks import CeleryWorkerFixture
-class ApiTestCase(TestCase):
-
- def setUp(self):
- super(ApiTestCase, self).setUp()
+class ApiRepoStoreMixin:
+ def setupRepoStore(self):
repo_store = self.useFixture(TempDir()).path
self.useFixture(EnvironmentVariable("REPO_STORE", repo_store))
self.app = TestApp(api.main({}))
@@ -46,6 +52,13 @@ class ApiTestCase(TestCase):
self.commit = {'ref': 'refs/heads/master', 'message': 'test commit.'}
self.tag = {'ref': 'refs/tags/tag0', 'message': 'tag message'}
+
+class ApiTestCase(TestCase, ApiRepoStoreMixin):
+
+ def setUp(self):
+ super(ApiTestCase, self).setUp()
+ self.setupRepoStore()
+
def assertReferencesEqual(self, repo, expected, observed):
self.assertEqual(
repo.references[expected].peel().oid,
@@ -98,7 +111,9 @@ class ApiTestCase(TestCase):
resp = self.app.get('/repo/{}'.format(self.repo_path))
self.assertEqual(200, resp.status_code)
- self.assertEqual({'default_branch': 'refs/heads/branch-0'}, resp.json)
+ self.assertEqual({
+ 'default_branch': 'refs/heads/branch-0',
+ 'is_available': True}, resp.json)
def test_repo_get_default_branch_missing(self):
"""default_branch is returned even if that branch has been deleted."""
@@ -109,7 +124,9 @@ class ApiTestCase(TestCase):
resp = self.app.get('/repo/{}'.format(self.repo_path))
self.assertEqual(200, resp.status_code)
- self.assertEqual({'default_branch': 'refs/heads/branch-0'}, resp.json)
+ self.assertEqual({
+ 'default_branch': 'refs/heads/branch-0',
+ 'is_available': True}, resp.json)
def test_repo_patch_default_branch(self):
"""A repository's default branch ("HEAD") can be changed."""
@@ -872,5 +889,159 @@ class ApiTestCase(TestCase):
self.assertEqual(404, resp.status_code)
+class AsyncRepoCreationAPI(TestCase, ApiRepoStoreMixin):
+
+ def setUp(self):
+ super(AsyncRepoCreationAPI, self).setUp()
+ self.setupRepoStore()
+ # XML-RPC server
+ self.virtinfo = FakeVirtInfoService(allowNone=True)
+ self.virtinfo_listener = default_reactor.listenTCP(0, server.Site(
+ self.virtinfo))
+ self.virtinfo_port = self.virtinfo_listener.getHost().port
+ self.virtinfo_url = b'http://localhost:%d/' % self.virtinfo_port
+ self.addCleanup(self.virtinfo_listener.stopListening)
+ config.defaults['virtinfo_endpoint'] = self.virtinfo_url
+
+ def _doReactorIteration(self):
+ """Yield to the reactor so it can process virtinfo requests.
+
+ This is a bit hacky, but allow us to simulate the twisted XML-RPC
+ fake server without needing to make this test suite async.
+ Making this test suite async could make it less realistic, since the
+ API beign tested itself is not running over twisted event loop.
+ """
+ reactor_iterations = (
+ len(default_reactor._reads) + len(default_reactor._writes))
+ for i in range(reactor_iterations * 100):
+ default_reactor.iterate()
+
+ def assertRepositoryCreatedAsynchronously(self, repo_path, timeout_secs=5):
+ """Waits up to `timeout_secs` for a repository to be available."""
+ timeout = timedelta(seconds=timeout_secs)
+ start = datetime.now()
+ while datetime.now() <= (start + timeout):
+ self._doReactorIteration()
+ try:
+ resp = self.app.get('/repo/{}'.format(repo_path),
+ expect_errors=True)
+ if resp.status_code == 200 and resp.json['is_available']:
+ return
+ except:
+ pass
+ time.sleep(0.1)
+ self.fail(
+ "Repository %s was not created after %s secs"
+ % (repo_path, timeout_secs))
+
+ def assertAnyMockCalledAsync(self, mocks, timeout_secs=5):
+ """Asserts that any of the mocks in *args will be called in the
+ next timeout_secs seconds.
+ """
+ timeout = timedelta(seconds=timeout_secs)
+ start = datetime.now()
+ while datetime.now() <= (start + timeout):
+ self._doReactorIteration()
+ if any(i.called for i in mocks):
+ return
+ time.sleep(0.1)
+ self.fail(
+ "None of the given args was called after %s seconds."
+ % timeout_secs)
+
+ def test_repo_async_creation_with_clone(self):
+ """Repo can be initialised with optional clone asynchronously."""
+ self.useFixture(CeleryWorkerFixture())
+ self.virtinfo.xmlrpc_confirmRepoCreation = mock.Mock(return_value=None)
+ self.virtinfo.xmlrpc_abortRepoCreation = mock.Mock(return_value=None)
+
+ factory = RepoFactory(self.repo_store, num_commits=2)
+ factory.build()
+ new_repo_path = uuid.uuid1().hex
+ resp = self.app.post_json('/repo', {
+ 'async': True,
+ 'repo_path': new_repo_path,
+ 'clone_from': self.repo_path,
+ 'clone_refs': True})
+
+ self.assertRepositoryCreatedAsynchronously(new_repo_path)
+
+ repo1_revlist = get_revlist(factory.repo)
+ clone_from = resp.json['repo_url'].split('/')[-1]
+ repo2 = open_repo(os.path.join(self.repo_root, clone_from))
+ repo2_revlist = get_revlist(repo2)
+
+ self.assertEqual(repo1_revlist, repo2_revlist)
+ self.assertEqual(200, resp.status_code)
+ self.assertIn(new_repo_path, resp.json['repo_url'])
+
+ self.assertEqual([mock.call(
+ mock.ANY, new_repo_path, {"user": "+launchpad-services"})],
+ self.virtinfo.xmlrpc_confirmRepoCreation.call_args_list)
+ self.assertEqual(
+ [], self.virtinfo.xmlrpc_abortRepoCreation.call_args_list)
+
+ def test_repo_async_creation_aborts_when_fails_to_create_locally(self):
+ """Repo can be initialised with optional clone asynchronously."""
+ self.useFixture(
+ EnvironmentVariable("REPO_STORE", '/tmp/invalid/path/to/repos/'))
+ self.useFixture(CeleryWorkerFixture())
+ self.virtinfo.xmlrpc_confirmRepoCreation = mock.Mock(return_value=None)
+ self.virtinfo.xmlrpc_abortRepoCreation = mock.Mock(return_value=None)
+
+ factory = RepoFactory(self.repo_store, num_commits=2)
+ factory.build()
+ new_repo_path = uuid.uuid1().hex
+ self.app.post_json('/repo', {
+ 'async': True,
+ 'repo_path': new_repo_path,
+ 'clone_from': self.repo_path,
+ 'clone_refs': True})
+
+ # Wait until the repository creation is either confirmed or aborted
+ # (and we hope it was aborted...)
+ self.assertAnyMockCalledAsync([
+ self.virtinfo.xmlrpc_confirmRepoCreation,
+ self.virtinfo.xmlrpc_abortRepoCreation])
+ self.assertFalse(
+ os.path.exists(os.path.join(self.repo_root, new_repo_path)))
+
+ self.assertEqual([mock.call(
+ mock.ANY, new_repo_path, {"user": "+launchpad-services"})],
+ self.virtinfo.xmlrpc_abortRepoCreation.call_args_list)
+ self.assertEqual(
+ [], self.virtinfo.xmlrpc_confirmRepoCreation.call_args_list)
+
+ def test_repo_async_creation_aborts_when_fails_confirm(self):
+ """Repo can be initialised with optional clone asynchronously."""
+ self.useFixture(CeleryWorkerFixture())
+ self.virtinfo.xmlrpc_confirmRepoCreation = mock.Mock(
+ side_effect=Exception("?"))
+ self.virtinfo.xmlrpc_abortRepoCreation = mock.Mock(return_value=None)
+
+ factory = RepoFactory(self.repo_store, num_commits=2)
+ factory.build()
+ new_repo_path = uuid.uuid1().hex
+ self.app.post_json('/repo', {
+ 'async': True,
+ 'repo_path': new_repo_path,
+ 'clone_from': self.repo_path,
+ 'clone_refs': True})
+
+ # Wait until the repository creation is either confirmed or aborted
+ # (and we hope it was aborted...)
+ self.assertAnyMockCalledAsync([
+ self.virtinfo.xmlrpc_abortRepoCreation])
+ self.assertFalse(
+ os.path.exists(os.path.join(self.repo_root, new_repo_path)))
+
+ self.assertEqual([mock.call(
+ mock.ANY, new_repo_path, {"user": "+launchpad-services"})],
+ self.virtinfo.xmlrpc_abortRepoCreation.call_args_list)
+ self.assertEqual([mock.call(
+ mock.ANY, new_repo_path, {"user": "+launchpad-services"})],
+ self.virtinfo.xmlrpc_confirmRepoCreation.call_args_list)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/turnip/api/tests/test_store.py b/turnip/api/tests/test_store.py
index 9276303..70672d8 100644
--- a/turnip/api/tests/test_store.py
+++ b/turnip/api/tests/test_store.py
@@ -125,6 +125,16 @@ class InitTestCase(TestCase):
self.assertEqual(str(yaml_config['pack.depth']),
repo_config['pack.depth'])
+ def test_is_repository_available(self):
+ repo_path = os.path.join(self.repo_store, 'repo/')
+
+ store.init_repository(repo_path, True)
+ store.set_repository_creating(repo_path, True)
+ self.assertFalse(store.is_repository_available(repo_path))
+
+ store.set_repository_creating(repo_path, False)
+ self.assertTrue(store.is_repository_available(repo_path))
+
def test_open_ephemeral_repo(self):
"""Opening a repo where a repo name contains ':' should return
a new ephemeral repo.
@@ -224,7 +234,10 @@ class InitTestCase(TestCase):
# repo with the same set of refs. And the objects are copied
# too.
to_path = os.path.join(self.repo_store, 'to/')
+ self.assertFalse(store.is_repository_available(to_path))
store.init_repo(to_path, clone_from=self.orig_path, clone_refs=True)
+ self.assertTrue(store.is_repository_available(to_path))
+
to = pygit2.Repository(to_path)
self.assertIsNot(None, to[self.master_oid])
self.assertEqual(
@@ -260,7 +273,10 @@ class InitTestCase(TestCase):
# init_repo with clone_from=orig and clone_refs=False creates a
# repo without any refs, but the objects are copied.
to_path = os.path.join(self.repo_store, 'to/')
+ self.assertFalse(store.is_repository_available(to_path))
store.init_repo(to_path, clone_from=self.orig_path, clone_refs=False)
+ self.assertTrue(store.is_repository_available(to_path))
+
to = pygit2.Repository(to_path)
self.assertIsNot(None, to[self.master_oid])
self.assertEqual([], to.listall_references())
@@ -286,7 +302,10 @@ class InitTestCase(TestCase):
self.assertAllLinkCounts(1, self.orig_objs)
to_path = os.path.join(self.repo_store, 'to/')
+ self.assertFalse(store.is_repository_available(to_path))
store.init_repo(to_path, clone_from=self.orig_path)
+ self.assertTrue(store.is_repository_available(to_path))
+
self.assertAllLinkCounts(2, self.orig_objs)
to = pygit2.Repository(to_path)
to_blob = to.create_blob(b'to')
diff --git a/turnip/api/views.py b/turnip/api/views.py
index 2e8a082..de6bb4e 100644
--- a/turnip/api/views.py
+++ b/turnip/api/views.py
@@ -53,9 +53,11 @@ class RepoAPI(BaseAPI):
def collection_post(self):
"""Initialise a new git repository, or clone from an existing repo."""
- repo_path = extract_json_data(self.request).get('repo_path')
- clone_path = extract_json_data(self.request).get('clone_from')
- clone_refs = extract_json_data(self.request).get('clone_refs', False)
+ json_data = extract_json_data(self.request)
+ repo_path = json_data.get('repo_path')
+ clone_path = json_data.get('clone_from')
+ clone_refs = json_data.get('clone_refs', False)
+ async_run = json_data.get('async', False)
if not repo_path:
self.request.errors.add('body', 'repo_path',
@@ -72,7 +74,13 @@ class RepoAPI(BaseAPI):
repo_clone = None
try:
- store.init_repo(repo, clone_from=repo_clone, clone_refs=clone_refs)
+ kwargs = dict(
+ repo_path=repo, clone_from=repo_clone, clone_refs=clone_refs)
+ if async_run:
+ kwargs["untranslated_path"] = repo_path
+ store.init_and_confirm_repo.apply_async(kwargs=kwargs)
+ else:
+ store.init_repo(**kwargs)
repo_name = os.path.basename(os.path.normpath(repo))
return {'repo_url': '/'.join([self.request.url, repo_name])}
except GitError:
@@ -88,6 +96,7 @@ class RepoAPI(BaseAPI):
raise exc.HTTPNotFound()
return {
'default_branch': store.get_default_branch(repo_path),
+ 'is_available': store.is_repository_available(repo_path)
}
def _patch_default_branch(self, repo_path, value):
diff --git a/turnip/helpers.py b/turnip/helpers.py
index 77e12ed..b23e4a6 100644
--- a/turnip/helpers.py
+++ b/turnip/helpers.py
@@ -10,6 +10,7 @@ from __future__ import (
import os.path
import six
+from six.moves import xmlrpc_client
def compose_path(root, path):
@@ -22,3 +23,24 @@ def compose_path(root, path):
if not full_path.startswith(os.path.abspath(root)):
raise ValueError('Path not contained within root')
return full_path
+
+
+class TimeoutTransport(xmlrpc_client.Transport):
+
+ def __init__(self, timeout, use_datetime=0):
+ self.timeout = timeout
+ xmlrpc_client.Transport.__init__(self, use_datetime)
+
+ def make_connection(self, host):
+ connection = xmlrpc_client.Transport.make_connection(self, host)
+ connection.timeout = self.timeout
+ return connection
+
+
+class TimeoutServerProxy(xmlrpc_client.ServerProxy):
+
+ def __init__(self, uri, timeout=10, transport=None, encoding=None,
+ verbose=0, allow_none=0, use_datetime=0):
+ t = TimeoutTransport(timeout)
+ xmlrpc_client.ServerProxy.__init__(
+ self, uri, t, encoding, verbose, allow_none, use_datetime)
diff --git a/turnip/tasks.py b/turnip/tasks.py
new file mode 100644
index 0000000..51b5d20
--- /dev/null
+++ b/turnip/tasks.py
@@ -0,0 +1,16 @@
+# Copyright 2020 Canonical Ltd. This software is licensed under the
+# GNU Affero General Public License version 3 (see the file LICENSE).
+
+from __future__ import print_function, unicode_literals, absolute_import
+
+__all__ = [
+ 'app'
+]
+
+from celery import Celery
+
+from turnip.config import config
+
+
+app = Celery('tasks', broker=config.get('celery_broker'))
+app.conf.update(imports=('turnip.api.store', ))
diff --git a/turnip/tests/__init__.py b/turnip/tests/__init__.py
index 494f3fd..6be77e3 100644
--- a/turnip/tests/__init__.py
+++ b/turnip/tests/__init__.py
@@ -8,6 +8,8 @@ from __future__ import (
)
from turnip.tests.logging import setupLogger
+from turnip.tests.tasks import setupCelery
setupLogger()
+setupCelery()
diff --git a/turnip/tests/tasks.py b/turnip/tests/tasks.py
new file mode 100644
index 0000000..401d079
--- /dev/null
+++ b/turnip/tests/tasks.py
@@ -0,0 +1,84 @@
+# Copyright 2020 Canonical Ltd. This software is licensed under the
+# GNU Affero General Public License version 3 (see the file LICENSE).
+
+import atexit
+import os
+import subprocess
+import sys
+
+from testtools.testcase import fixtures
+
+from turnip.config import config
+from turnip.tasks import app
+
+BROKER_URL = 'pyamqp://guest@localhost/turnip-test-vhost'
+
+
+def setupCelery():
+ app.conf.update(broker_url=BROKER_URL)
+
+
+class CeleryWorkerFixture(fixtures.Fixture):
+ """Celery worker fixture for tests
+
+ This fixture starts a celery worker with the configuration set when the
+ fixture is setUp. Keep in mind that this will run in a separated
+ new process, so mock patches for example will be lost.
+ """
+ _worker_proc = None
+
+ def __init__(self, loglevel="error", force_restart=True, env=None):
+ """
+ Build a celery worker for test cases.
+
+ :param loglevel: Which log level to use for the worker.
+ :param force_restart: If True and a celery worker is already running,
+ stop it. If False, do not restart if another worker is
+ already running.
+ :param env: The environment variables to be used when creating
+ the worker.
+ """
+ self.force_restart = force_restart
+ self.loglevel = loglevel
+ self.env = env
+
+ def startCeleryWorker(self):
+ """Start a celery worker for integration tests."""
+ if self.force_restart:
+ self.stopCeleryWorker()
+ if CeleryWorkerFixture._worker_proc is not None:
+ return
+ bin_path = os.path.dirname(sys.executable)
+ celery = os.path.join(bin_path, 'celery')
+ turnip_path = os.path.join(os.path.dirname(__file__), '..')
+ cmd = [
+ celery, 'worker', '-A', 'tasks', '--quiet',
+ '--pool=gevent',
+ '--concurrency=2',
+ '--broker=%s' % BROKER_URL,
+ '--loglevel=%s' % self.loglevel]
+
+ # Send to the subprocess, as env variables, the same configurations we
+ # are currently using.
+ proc_env = {'PYTHONPATH': turnip_path}
+ for k in config.defaults:
+ proc_env[k.upper()] = str(config.get(k))
+ proc_env.update(self.env or {})
+
+ CeleryWorkerFixture._worker_proc = subprocess.Popen(cmd, env=proc_env)
+ atexit.register(self.stopCeleryWorker)
+
+ def stopCeleryWorker(self):
+ worker_proc = CeleryWorkerFixture._worker_proc
+ if worker_proc:
+ worker_proc.kill()
+ worker_proc.wait()
+ CeleryWorkerFixture._worker_proc = None
+ # Cleanup the queue.
+ app.control.purge()
+
+ def _setUp(self):
+ self.startCeleryWorker()
+
+ def _cleanup(self):
+ self.stopCeleryWorker()
Follow ups