← Back to team overview

launchpad-reviewers team mailing list archive

[Merge] lp:~rvb/maas/share-mem into lp:maas

 

Raphaël Badin has proposed merging lp:~rvb/maas/share-mem into lp:maas.

Requested reviews:
  Launchpad code reviewers (launchpad-reviewers)

For more details, see:
https://code.launchpad.net/~rvb/maas/share-mem/+merge/119371

This branch uses the process-safe data containers provided by python's multiprocessing module to share objects across celery's workers.

This was discussed with Jeroen.

= Notes =

- Since using multiprocessing's Array structure requires using fixed-length data structures, I've chosen to play it safe and use Django's maximum size for CharField fields.  I could have been more clever because (for instance), the 'secret' part of the credential string is empty but I choose to keep it simple.

- This change simplifies the tests a great deal since we don't have to use patch() so much but the counterpart is that I had to create a provisioning TestCase class.  I could have created a fixture but, as much as I agree that we should avoid having too many TestCase classes, having one per subproject seems like the right thing to do here. Since we also call tasks from maasserver's tests, I've called the init_shared_globals() method from maasserver's TestCase as well.
-- 
https://code.launchpad.net/~rvb/maas/share-mem/+merge/119371
Your team Launchpad code reviewers is requested to review the proposed merge of lp:~rvb/maas/share-mem into lp:maas.
=== modified file 'src/maasserver/testing/testcase.py'
--- src/maasserver/testing/testcase.py	2012-06-07 11:44:14 +0000
+++ src/maasserver/testing/testcase.py	2012-08-13 15:40:24 +0000
@@ -22,6 +22,7 @@
 from maasserver.testing.factory import factory
 from maastesting.celery import CeleryFixture
 import maastesting.djangotestcase
+from provisioningserver.auth import init_shared_globals
 
 
 class TestCase(maastesting.djangotestcase.DjangoTestCase):
@@ -31,6 +32,7 @@
         self.addCleanup(cache.clear)
         self.addCleanup(reset_fake_provisioning_api_proxy)
         self.celery = self.useFixture(CeleryFixture())
+        self.addCleanup(init_shared_globals)
 
 
 class TestModelTestCase(TestCase,

=== modified file 'src/provisioningserver/auth.py'
--- src/provisioningserver/auth.py	2012-08-10 13:11:19 +0000
+++ src/provisioningserver/auth.py	2012-08-13 15:40:24 +0000
@@ -18,23 +18,39 @@
     'record_nodegroup_name',
     ]
 
+
+from multiprocessing import Array
+
 # API credentials as last sent by the server.  The worker uses these
 # credentials to access the MAAS API.
-# Shared between threads.
+# Shared between threads/processes.
 recorded_api_credentials = None
 
 
+# The name of the nodegroup that this worker manages.
+# Shared between threads/processes.
+recorded_nodegroup_name = None
+
+
+def init_shared_globals():
+    """Initialize the process-safe globals from this module."""
+    global recorded_api_credentials
+    # credentials=<consumer_key>:<key>:<secret>
+    recorded_api_credentials = Array('c', 3 * 255 + 2)
+
+    global recorded_nodegroup_name
+    recorded_nodegroup_name = Array('c', 255)
+
+
+init_shared_globals()
+
+
 def locate_maas_api():
     """Return the base URL for the MAAS API."""
 # TODO: Configure this somehow.  What you see here is a placeholder.
     return "http://localhost/MAAS/";
 
 
-# The name of the nodegroup that this worker manages.
-# Shared between threads.
-recorded_nodegroup_name = None
-
-
 def record_api_credentials(api_credentials):
     """Update the recorded API credentials.
 
@@ -43,7 +59,7 @@
         separated by colons.
     """
     global recorded_api_credentials
-    recorded_api_credentials = api_credentials
+    recorded_api_credentials.value = api_credentials
 
 
 def get_recorded_api_credentials():
@@ -54,16 +70,16 @@
         :class:`MAASOauth`.  Otherwise, None.
     """
     credentials_string = recorded_api_credentials
-    if credentials_string is None:
+    if credentials_string.value == '':
         return None
     else:
-        return tuple(credentials_string.split(':'))
+        return tuple(credentials_string.value.split(':'))
 
 
 def record_nodegroup_name(nodegroup_name):
     """Record the name of the nodegroup we manage, as sent by the server."""
     global recorded_nodegroup_name
-    recorded_nodegroup_name = nodegroup_name
+    recorded_nodegroup_name.value = nodegroup_name
 
 
 def get_recorded_nodegroup_name():
@@ -71,4 +87,7 @@
 
     If the server has not sent the name yet, returns None.
     """
-    return recorded_nodegroup_name
+    if recorded_nodegroup_name.value == '':
+        return None
+    else:
+        return recorded_nodegroup_name.value

=== modified file 'src/provisioningserver/dhcp/tests/test_leases.py'
--- src/provisioningserver/dhcp/tests/test_leases.py	2012-08-10 13:11:19 +0000
+++ src/provisioningserver/dhcp/tests/test_leases.py	2012-08-13 15:40:24 +0000
@@ -21,7 +21,6 @@
 from apiclient.maas_client import MAASClient
 from maastesting.factory import factory
 from maastesting.fakemethod import FakeMethod
-from maastesting.testcase import TestCase
 from maastesting.utils import (
     age_file,
     get_write_time,
@@ -41,6 +40,7 @@
     upload_leases,
     )
 from provisioningserver.omshell import Omshell
+from provisioningserver.testing.testcase import TestCase
 from testtools.testcase import ExpectedException
 
 
@@ -133,22 +133,14 @@
     def set_nodegroup_name(self):
         """Set the recorded nodegroup name for the duration of this test."""
         name = factory.make_name('nodegroup')
-        self.patch(auth, 'recorded_nodegroup_name', name)
+        auth.record_nodegroup_name(name)
         return name
 
     def set_api_credentials(self):
         """Set recorded API credentials for the duration of this test."""
         creds_string = ':'.join(
             factory.getRandomString() for counter in range(3))
-        self.patch(auth, 'recorded_api_credentials', creds_string)
-
-    def clear_api_credentials(self):
-        """Clear recorded API credentials for the duration of this test."""
-        self.patch(auth, 'recorded_api_credentials', None)
-
-    def clear_nodegroup_name(self):
-        """Set the recorded nodegroup name for the duration of this test."""
-        self.patch(auth, 'recorded_nodegroup_name', None)
+        auth.record_api_credentials(creds_string)
 
     def set_lease_state(self, time=None, leases=None):
         """Set the recorded state of DHCP leases.
@@ -393,7 +385,6 @@
         self.patch(Omshell, 'create', FakeMethod())
         self.set_lease_state()
         self.clear_omapi_key()
-        self.clear_nodegroup_name()
         new_leases = {
             factory.getRandomIPAddress(): factory.getRandomMACAddress(),
         }
@@ -416,7 +407,6 @@
             MAASClient.post.calls)
 
     def test_send_leases_does_nothing_without_credentials(self):
-        self.clear_api_credentials()
         self.patch(MAASClient, 'post', FakeMethod())
         leases = {
             factory.getRandomIPAddress(): factory.getRandomMACAddress(),

=== added file 'src/provisioningserver/testing/testcase.py'
--- src/provisioningserver/testing/testcase.py	1970-01-01 00:00:00 +0000
+++ src/provisioningserver/testing/testcase.py	2012-08-13 15:40:24 +0000
@@ -0,0 +1,25 @@
+# Copyright 2012 Canonical Ltd.  This software is licensed under the
+# GNU Affero General Public License version 3 (see the file LICENSE).
+
+"""Provisioningserver-specific test-case classes."""
+
+from __future__ import (
+    absolute_import,
+    print_function,
+    unicode_literals,
+    )
+
+__metaclass__ = type
+__all__ = [
+    'TestCase',
+    ]
+
+from provisioningserver.auth import init_shared_globals
+from maastesting import testcase
+
+
+class TestCase(testcase.TestCase):
+
+    def setUp(self):
+        super(TestCase, self).setUp()
+        self.addCleanup(init_shared_globals)

=== modified file 'src/provisioningserver/tests/test_auth.py'
--- src/provisioningserver/tests/test_auth.py	2012-08-10 12:32:22 +0000
+++ src/provisioningserver/tests/test_auth.py	2012-08-13 15:40:24 +0000
@@ -12,9 +12,11 @@
 __metaclass__ = type
 __all__ = []
 
+from multiprocessing.sharedctypes import SynchronizedString
+
 from maastesting.factory import factory
-from maastesting.testcase import TestCase
 from provisioningserver import auth
+from provisioningserver.testing.testcase import TestCase
 
 
 def make_credentials():
@@ -33,24 +35,37 @@
 
 class TestAuth(TestCase):
 
+    def test_init_globals_initializes_recorded_api_credentials(self):
+        self.patch(auth, 'recorded_api_credentials', None)
+        auth.init_shared_globals()
+        self.assertIsInstance(
+            auth.recorded_api_credentials, SynchronizedString)
+        self.assertEqual('', auth.recorded_api_credentials.value)
+
+    def test_init_globals_initializes_recorded_nodegroup_name(self):
+        self.patch(auth, 'recorded_nodegroup_name', None)
+        auth.init_shared_globals()
+        self.assertIsInstance(
+            auth.recorded_nodegroup_name, SynchronizedString)
+        self.assertEqual('', auth.recorded_nodegroup_name.value)
+
     def test_record_api_credentials_records_credentials_string(self):
-        self.patch(auth, 'recorded_api_credentials', None)
         creds_string = represent_credentials(make_credentials())
         auth.record_api_credentials(creds_string)
-        self.assertEqual(creds_string, auth.recorded_api_credentials)
+        self.assertEqual(creds_string, auth.recorded_api_credentials.value)
 
     def test_get_recorded_api_credentials_returns_credentials_as_tuple(self):
-        self.patch(auth, 'recorded_api_credentials', None)
         creds = make_credentials()
         auth.record_api_credentials(represent_credentials(creds))
         self.assertEqual(creds, auth.get_recorded_api_credentials())
 
     def test_get_recorded_api_credentials_returns_None_without_creds(self):
-        self.patch(auth, 'recorded_api_credentials', None)
         self.assertIsNone(auth.get_recorded_api_credentials())
 
+    def test_get_recorded_nodegroup_name_returns_None_initially(self):
+        self.assertIsNone(auth.get_recorded_nodegroup_name())
+
     def test_get_recorded_nodegroup_name_vs_record_nodegroup_name(self):
-        self.patch(auth, 'recorded_nodegroup_name', None)
         nodegroup_name = factory.make_name('nodegroup')
         auth.record_nodegroup_name(nodegroup_name)
         self.assertEqual(nodegroup_name, auth.get_recorded_nodegroup_name())

=== modified file 'src/provisioningserver/tests/test_tasks.py'
--- src/provisioningserver/tests/test_tasks.py	2012-08-10 13:20:48 +0000
+++ src/provisioningserver/tests/test_tasks.py	2012-08-13 15:40:24 +0000
@@ -23,7 +23,6 @@
     MultiFakeMethod,
     )
 from maastesting.matchers import ContainsAll
-from maastesting.testcase import TestCase
 from netaddr import IPNetwork
 from provisioningserver import (
     auth,
@@ -56,6 +55,7 @@
     write_full_dns_config,
     )
 from provisioningserver.testing import network_infos
+from provisioningserver.testing.testcase import TestCase
 from testresources import FixtureResource
 from testtools.matchers import (
     Equals,
@@ -112,13 +112,11 @@
             factory.make_name('token'),
             factory.make_name('secret'),
             )
-        self.patch(auth, 'recorded_api_credentials', None)
         refresh_secrets(api_credentials=':'.join(credentials))
         self.assertEqual(credentials, auth.get_recorded_api_credentials())
 
     def test_updates_nodegroup_name(self):
         nodegroup_name = factory.make_name('nodegroup')
-        self.patch(auth, 'recorded_nodegroup_name', None)
         refresh_secrets(nodegroup_name=nodegroup_name)
         self.assertEqual(nodegroup_name, auth.get_recorded_nodegroup_name())
 


Follow ups