← Back to team overview

launchpad-reviewers team mailing list archive

[Merge] lp:~jtv/maas/use-get_one into lp:maas

 

Jeroen T. Vermeulen has proposed merging lp:~jtv/maas/use-get_one into lp:maas.

Requested reviews:
  MAAS Maintainers (maas-maintainers)

For more details, see:
https://code.launchpad.net/~jtv/maas/use-get_one/+merge/121159

As mentioned on today's stand-up call.  If we have this helper, we might as well make use of it.  I converted places where it was awkward to retrieve zero-or-one objects.  I did not touch the cases where the zero case really is an exceptional path.

Along the way I found a few older tests that could be phrased more helpfully (as we have learned things since), and one case where it was equally awkward to get the first object of a result set — hence the additional get_first() helper.


Jeroen
-- 
https://code.launchpad.net/~jtv/maas/use-get_one/+merge/121159
Your team MAAS Maintainers is requested to review the proposed merge of lp:~jtv/maas/use-get_one into lp:maas.
=== modified file 'src/maasserver/api.py'
--- src/maasserver/api.py	2012-08-23 10:10:42 +0000
+++ src/maasserver/api.py	2012-08-24 11:04:18 +0000
@@ -124,6 +124,7 @@
     Node,
     NodeGroup,
     )
+from maasserver.utils.orm import get_one
 from piston.doc import generate_doc
 from piston.handler import (
     AnonymousBaseHandler,
@@ -1124,11 +1125,9 @@
     arch = get_mandatory_param(request.GET, 'arch')
     subarch = request.GET.get('subarch', 'generic')
 
-    # See if we have a record of this MAC address, and thus node.
-    try:
-        macaddress = MACAddress.objects.get(mac_address=mac)
-    except MACAddress.DoesNotExist:
-        macaddress = node = None
+    macaddress = get_one(MACAddress.objects.filter(mac_address=mac))
+    if macaddress is None:
+        node = None
     else:
         node = macaddress.node
 

=== modified file 'src/maasserver/forms.py'
--- src/maasserver/forms.py	2012-08-22 02:07:30 +0000
+++ src/maasserver/forms.py	2012-08-24 11:04:18 +0000
@@ -304,7 +304,7 @@
     def clean_mac_addresses(self):
         data = self.cleaned_data['mac_addresses']
         for mac in data:
-            if MACAddress.objects.filter(mac_address=mac.lower()).count() > 0:
+            if MACAddress.objects.filter(mac_address=mac.lower()).exists():
                 raise ValidationError(
                     {'mac_addresses': [
                         'Mac address %s already in use.' % mac]})
@@ -449,8 +449,7 @@
         site.
         """
         email = self.cleaned_data['email']
-        email_count = User.objects.filter(email__iexact=email).count()
-        if email_count != 0:
+        if User.objects.filter(email__iexact=email).exists():
             raise forms.ValidationError(
                 "User with this E-mail address already exists.")
         return email

=== modified file 'src/maasserver/models/filestorage.py'
--- src/maasserver/models/filestorage.py	2012-05-17 08:44:42 +0000
+++ src/maasserver/models/filestorage.py	2012-08-24 11:04:18 +0000
@@ -30,6 +30,7 @@
     )
 from maasserver import DefaultMeta
 from maasserver.models.cleansave import CleanSave
+from maasserver.utils.orm import get_one
 
 
 class FileStorageManager(Manager):
@@ -52,15 +53,7 @@
 
     def get_existing_storage(self, filename):
         """Return an existing `FileStorage` of this name, or None."""
-        existing_storage = self.filter(filename=filename)
-        if len(existing_storage) == 0:
-            return None
-        elif len(existing_storage) == 1:
-            return existing_storage[0]
-        else:
-            raise AssertionError(
-                "There are %d files called '%s'."
-                % (len(existing_storage), filename))
+        return get_one(self.filter(filename=filename))
 
     def save_file(self, filename, file_object):
         """Save the file to the filesystem and persist to the database.

=== modified file 'src/maasserver/models/node.py'
--- src/maasserver/models/node.py	2012-08-16 14:01:00 +0000
+++ src/maasserver/models/node.py	2012-08-24 11:04:18 +0000
@@ -51,6 +51,7 @@
 from maasserver.models.config import Config
 from maasserver.models.timestampedmodel import TimestampedModel
 from maasserver.utils import get_db_state
+from maasserver.utils.orm import get_first
 from piston.models import Token
 from provisioningserver.enum import (
     POWER_TYPE,
@@ -249,11 +250,7 @@
             available_nodes = available_nodes.filter(
                 hostname=constraints['name'])
 
-        available_nodes = list(available_nodes[:1])
-        if len(available_nodes) == 0:
-            return None
-        else:
-            return available_nodes[0]
+        return get_first(available_nodes)
 
     def stop_nodes(self, ids, by_user):
         """Request on given user's behalf that the given nodes be shut down.

=== modified file 'src/maasserver/testing/__init__.py'
--- src/maasserver/testing/__init__.py	2012-08-16 12:26:51 +0000
+++ src/maasserver/testing/__init__.py	2012-08-24 11:04:18 +0000
@@ -24,6 +24,7 @@
 from urlparse import urlparse
 
 from lxml.html import fromstring
+from maasserver.utils.orm import get_one
 
 
 def extract_redirect(http_response):
@@ -55,19 +56,15 @@
     Use this when a test needs to inspect changes to model objects made by
     the API.
 
-    If the object has been deleted, this will raise the `DoesNotExist`
-    exception for its model class.
+    If the object has been deleted, this will return None.
 
     :param model_object: Model object to reload.
     :type model_object: Concrete `Model` subtype.
-    :return: Freshly-loaded instance of `model_object`.
+    :return: Freshly-loaded instance of `model_object`, or None.
     :rtype: Same as `model_object`.
     """
     model_class = model_object.__class__
-    try:
-        return model_class.objects.get(id=model_object.id)
-    except model_class.DoesNotExist:
-        return None
+    return get_one(model_class.objects.filter(id=model_object.id))
 
 
 def reload_objects(model_class, model_objects):

=== modified file 'src/maasserver/tests/test_api.py'
--- src/maasserver/tests/test_api.py	2012-08-23 10:10:42 +0000
+++ src/maasserver/tests/test_api.py	2012-08-24 11:04:18 +0000
@@ -74,6 +74,7 @@
     TestCase,
     )
 from maasserver.utils import map_enum
+from maasserver.utils.orm import get_one
 from maasserver.worker_user import get_worker_user
 from maastesting.celery import CeleryFixture
 from maastesting.djangotestcase import TransactionTestCase
@@ -289,7 +290,7 @@
                     NODE_AFTER_COMMISSIONING_ACTION.DEFAULT,
                 'mac_addresses': ['aa:bb:cc:dd:ee:ff', '22:bb:cc:dd:ee:ff'],
             })
-        [diane] = Node.objects.filter(hostname='diane')
+        diane = get_one(Node.objects.filter(hostname='diane'))
         self.assertItemsEqual(
             ['aa:bb:cc:dd:ee:ff', '22:bb:cc:dd:ee:ff'],
             [mac.mac_address for mac in diane.macaddress_set.all()])

=== modified file 'src/maasserver/tests/test_commands.py'
--- src/maasserver/tests/test_commands.py	2012-04-19 15:48:46 +0000
+++ src/maasserver/tests/test_commands.py	2012-08-24 11:04:18 +0000
@@ -22,6 +22,7 @@
 from django.core.management import call_command
 from maasserver.models import FileStorage
 from maasserver.testing.factory import factory
+from maasserver.utils.orm import get_one
 from maastesting.djangotestcase import DjangoTestCase
 
 
@@ -95,14 +96,13 @@
         call_command(
             'createadmin', username=username, password=password,
             email=email, stderr=stderr, stdout=stdout)
-        users = list(User.objects.filter(username=username))
+        user = get_one(User.objects.filter(username=username))
 
         self.assertEquals('', stderr.getvalue().strip())
         self.assertEquals('', stdout.getvalue().strip())
-        self.assertEqual(1, len(users))  # One user with that name.
-        self.assertTrue(users[0].check_password(password))
-        self.assertTrue(users[0].is_superuser)
-        self.assertEqual(email, users[0].email)
+        self.assertTrue(user.check_password(password))
+        self.assertTrue(user.is_superuser)
+        self.assertEqual(email, user.email)
 
     def test_clearcache_clears_entire_cache(self):
         key = factory.getRandomString()

=== modified file 'src/maasserver/tests/test_node.py'
--- src/maasserver/tests/test_node.py	2012-08-16 14:01:00 +0000
+++ src/maasserver/tests/test_node.py	2012-08-24 11:04:18 +0000
@@ -77,19 +77,20 @@
         self.assertEqual(token, node.token)
 
     def test_add_mac_address(self):
+        mac = factory.getRandomMACAddress()
         node = factory.make_node()
-        node.add_mac_address('AA:BB:CC:DD:EE:FF')
-        macs = MACAddress.objects.filter(
-            node=node, mac_address='AA:BB:CC:DD:EE:FF').count()
+        node.add_mac_address(mac)
+        macs = MACAddress.objects.filter(node=node, mac_address=mac).count()
         self.assertEqual(1, macs)
 
     def test_remove_mac_address(self):
+        mac = factory.getRandomMACAddress()
         node = factory.make_node()
-        node.add_mac_address('AA:BB:CC:DD:EE:FF')
-        node.remove_mac_address('AA:BB:CC:DD:EE:FF')
-        macs = MACAddress.objects.filter(
-            node=node, mac_address='AA:BB:CC:DD:EE:FF').count()
-        self.assertEqual(0, macs)
+        node.add_mac_address(mac)
+        node.remove_mac_address(mac)
+        self.assertItemsEqual(
+            [],
+            MACAddress.objects.filter(node=node, mac_address=mac))
 
     def test_get_primary_mac_returns_mac_address(self):
         node = factory.make_node()

=== modified file 'src/maasserver/utils/__init__.py'
--- src/maasserver/utils/__init__.py	2012-06-26 16:31:54 +0000
+++ src/maasserver/utils/__init__.py	2012-08-24 11:04:18 +0000
@@ -22,6 +22,7 @@
 
 from django.conf import settings
 from django.core.urlresolvers import reverse
+from maasserver.utils.orm import get_one
 
 
 def get_db_state(instance, field_name):
@@ -32,11 +33,11 @@
     :param field_name: The name of the field to return.
     :type field_name: basestring
     """
-    try:
-        return getattr(
-            instance.__class__.objects.get(pk=instance.pk), field_name)
-    except instance.DoesNotExist:
+    obj = get_one(instance.__class__.objects.filter(pk=instance.pk))
+    if obj is None:
         return None
+    else:
+        return getattr(obj, field_name)
 
 
 def ignore_unused(*args):

=== modified file 'src/maasserver/utils/orm.py'
--- src/maasserver/utils/orm.py	2012-08-24 06:49:46 +0000
+++ src/maasserver/utils/orm.py	2012-08-24 11:04:18 +0000
@@ -11,6 +11,7 @@
 
 __metaclass__ = type
 __all__ = [
+    'get_first',
     'get_one',
     ]
 
@@ -52,3 +53,12 @@
         return retrieved_items[0]
     else:
         raise get_exception_class(items)("Got more than one item.")
+
+
+def get_first(items):
+    """Get the first of `items`, or None."""
+    first_item = tuple(islice(items, 0, 1))
+    if len(first_item) == 0:
+        return None
+    else:
+        return first_item[0]

=== modified file 'src/maasserver/utils/tests/test_orm.py'
--- src/maasserver/utils/tests/test_orm.py	2012-08-24 06:48:47 +0000
+++ src/maasserver/utils/tests/test_orm.py	2012-08-24 11:04:18 +0000
@@ -12,8 +12,13 @@
 __metaclass__ = type
 __all__ = []
 
+from itertools import repeat
+
 from django.core.exceptions import MultipleObjectsReturned
-from maasserver.utils.orm import get_one
+from maasserver.utils.orm import (
+    get_first,
+    get_one,
+    )
 from maastesting.factory import factory
 from maastesting.testcase import TestCase
 from mock import Mock
@@ -95,3 +100,27 @@
 
     def test_get_one_raises_generic_error_if_other_sequence_is_too_big(self):
         self.assertRaises(MultipleObjectsReturned, get_one, range(2))
+
+
+class TestGetFirst(TestCase):
+    def test_get_first_returns_None_for_empty_list(self):
+        self.assertIsNone(get_first([]))
+
+    def test_get_first_returns_first_item(self):
+        items = [factory.getRandomString() for counter in range(10)]
+        self.assertEqual(items[0], get_first(items))
+
+    def test_get_first_accepts_any_sequence(self):
+        item = factory.getRandomString()
+        self.assertEqual(item, get_first(repeat(item)))
+
+    def test_get_first_does_not_retrieve_beyond_first_item(self):
+
+        class SecondItemRetrieved(Exception):
+            """Second item as retrieved.  It shouldn't be."""
+
+        def multiple_items():
+            yield "Item 1"
+            raise SecondItemRetrieved()
+
+        self.assertEqual("Item 1", get_first(multiple_items()))

=== modified file 'src/metadataserver/api.py'
--- src/metadataserver/api.py	2012-08-03 15:17:01 +0000
+++ src/metadataserver/api.py	2012-08-24 11:04:18 +0000
@@ -47,6 +47,7 @@
     get_enlist_userdata,
     get_preseed,
     )
+from maasserver.utils.orm import get_one
 from metadataserver.models import (
     NodeCommissionResult,
     NodeKey,
@@ -88,10 +89,9 @@
     if not settings.ALLOW_UNSAFE_METADATA_ACCESS:
         raise PermissionDenied(
             "Unauthenticated metadata access is not allowed on this MAAS.")
-    matching_macs = list(MACAddress.objects.filter(mac_address=mac))
-    if len(matching_macs) == 0:
+    match = get_one(MACAddress.objects.filter(mac_address=mac))
+    if match is None:
         raise MAASAPINotFound()
-    [match] = matching_macs
     return match.node
 
 

=== modified file 'src/metadataserver/models/nodekey.py'
--- src/metadataserver/models/nodekey.py	2012-06-30 04:16:16 +0000
+++ src/metadataserver/models/nodekey.py	2012-08-24 11:04:18 +0000
@@ -23,6 +23,7 @@
     )
 from maasserver.models.cleansave import CleanSave
 from maasserver.models.user import create_auth_token
+from maasserver.utils.orm import get_one
 from metadataserver import DefaultMeta
 from metadataserver.nodeinituser import get_node_init_user
 from piston.models import (
@@ -82,14 +83,10 @@
             uniquely associated with this node.
         :rtype: piston.models.Token
         """
-        existing_nodekey = self.filter(node=node)
-        assert len(existing_nodekey) in (0, 1), (
-            "Found %d keys for node (expected at most one)."
-            % len(existing_nodekey))
-        if len(existing_nodekey) == 0:
+        nodekey = get_one(self.filter(node=node))
+        if nodekey is None:
             return self._create_token(node)
         else:
-            [nodekey] = existing_nodekey
             return nodekey.token
 
     def get_node_for_key(self, key):

=== modified file 'src/metadataserver/tests/test_nodecommissionresult.py'
--- src/metadataserver/tests/test_nodecommissionresult.py	2012-07-09 13:06:08 +0000
+++ src/metadataserver/tests/test_nodecommissionresult.py	2012-08-24 11:04:18 +0000
@@ -15,6 +15,7 @@
 from django.core.exceptions import ValidationError
 from django.http import Http404
 from maasserver.testing.factory import factory
+from maasserver.utils.orm import get_one
 from maastesting.djangotestcase import DjangoTestCase
 from metadataserver.models import NodeCommissionResult
 
@@ -82,9 +83,9 @@
         NodeCommissionResult.objects.store_data(
             node, name=name, data=data)
 
-        results = NodeCommissionResult.objects.filter(node=node)
-        [ncr] = results
-        self.assertAttributes(ncr, dict(name=name, data=data))
+        self.assertAttributes(
+            get_one(NodeCommissionResult.objects.filter(node=node)),
+            dict(name=name, data=data))
 
     def test_store_data_updates_existing(self):
         node = factory.make_node()
@@ -94,9 +95,9 @@
         NodeCommissionResult.objects.store_data(
             node, name=name, data=data)
 
-        results = NodeCommissionResult.objects.filter(node=node)
-        [ncr] = results
-        self.assertAttributes(ncr, dict(name=name, data=data))
+        self.assertAttributes(
+            get_one(NodeCommissionResult.objects.filter(node=node)),
+            dict(name=name, data=data))
 
     def test_get_data(self):
         ncr = factory.make_node_commission_result()