← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~ack/maas:websocket-inlinecallbacks into maas:master

 

Alberto Donato has proposed merging ~ack/maas:websocket-inlinecallbacks into maas:master.

Commit message:
rework websockets code to use inlineCallbacks instead of explicit Deferreds


Requested reviews:
  MAAS Maintainers (maas-maintainers)

For more details, see:
https://code.launchpad.net/~ack/maas/+git/maas/+merge/438520
-- 
Your team MAAS Maintainers is requested to review the proposed merge of ~ack/maas:websocket-inlinecallbacks into maas:master.
diff --git a/src/maasserver/websockets/protocol.py b/src/maasserver/websockets/protocol.py
index 57c5f33..b44ff2f 100644
--- a/src/maasserver/websockets/protocol.py
+++ b/src/maasserver/websockets/protocol.py
@@ -16,7 +16,7 @@ from django.contrib.auth import BACKEND_SESSION_KEY, load_backend, SESSION_KEY
 from django.core.exceptions import ValidationError
 from django.http import HttpRequest
 from twisted.internet import defer
-from twisted.internet.defer import fail, inlineCallbacks
+from twisted.internet.defer import fail, inlineCallbacks, returnValue
 from twisted.internet.protocol import Factory, Protocol
 from twisted.python.modules import getModule
 from twisted.web.server import NOT_DONE_YET
@@ -27,7 +27,7 @@ from maasserver.utils.threads import deferToDatabase
 from maasserver.websockets import handlers
 from maasserver.websockets.websockets import STATUSES
 from provisioningserver.logger import LegacyLogger
-from provisioningserver.utils.twisted import deferred, synchronous
+from provisioningserver.utils.twisted import synchronous
 from provisioningserver.utils.url import splithost
 
 log = LegacyLogger()
@@ -80,6 +80,7 @@ class WebSocketProtocol(Protocol):
         self.cache = {}
         self.sequence_number = 0
 
+    @inlineCallbacks
     def connectionMade(self):
         """Connection has been made to client."""
         # Using the provided cookies on the connection request, authenticate
@@ -88,51 +89,40 @@ class WebSocketProtocol(Protocol):
         # from an authenticated user.
 
         cookies = self.transport.cookies.decode("ascii")
-        d = self.authenticate(
+        user = yield self.authenticate(
             get_cookie(cookies, "sessionid"), get_cookie(cookies, "csrftoken")
         )
+        if not user:
+            return
+
+        self.user = user
+        # XXX newell 2018-10-17 bug=1798479:
+        # Check that 'SERVER_NAME' and 'SERVER_PORT' are set.
+        # 'SERVER_NAME' and 'SERVER_PORT' are required so
+        # `build_absolure_uri` can create an actual absolute URI so
+        # that the curtin configuration is valid.  See the bug and
+        # maasserver.node_actions for more details.
+        #
+        # `splithost` will split the host and port from either an
+        # ipv4 or an ipv6 address.
+        host, port = splithost(str(self.transport.host))
+
+        # Create the request for the handlers for this connection.
+        self.request = HttpRequest()
+        self.request.user = self.user
+        self.request.META.update(
+            {
+                "HTTP_USER_AGENT": self.transport.user_agent,
+                "REMOTE_ADDR": self.transport.ip_address,
+                "SERVER_NAME": host or "localhost",
+                "SERVER_PORT": port or 5248,
+            }
+        )
 
-        # Only add the client to the list of known clients if/when the
-        # authentication succeeds.
-        def authenticated(user):
-            if user is None:
-                # This user could not be authenticated. No further interaction
-                # should take place. The connection is already being dropped.
-                pass
-            else:
-                # This user is a keeper. Record it and process any message
-                # that have already been received.
-                self.user = user
-
-                # XXX newell 2018-10-17 bug=1798479:
-                # Check that 'SERVER_NAME' and 'SERVER_PORT' are set.
-                # 'SERVER_NAME' and 'SERVER_PORT' are required so
-                # `build_absolure_uri` can create an actual absolute URI so
-                # that the curtin configuration is valid.  See the bug and
-                # maasserver.node_actions for more details.
-                #
-                # `splithost` will split the host and port from either an
-                # ipv4 or an ipv6 address.
-                host, port = splithost(str(self.transport.host))
-
-                # Create the request for the handlers for this connection.
-                self.request = HttpRequest()
-                self.request.user = self.user
-                self.request.META.update(
-                    {
-                        "HTTP_USER_AGENT": self.transport.user_agent,
-                        "REMOTE_ADDR": self.transport.ip_address,
-                        "SERVER_NAME": host or "localhost",
-                        "SERVER_PORT": port or 5248,
-                    }
-                )
-
-                # Be sure to process messages after the metadata is populated,
-                # in order to avoid bug #1802390.
-                self.processMessages()
-                self.factory.clients.append(self)
-
-        d.addCallback(authenticated)
+        # Be sure to process messages after the metadata is populated,
+        # in order to avoid bug #1802390.
+        self.processMessages()
+        self.factory.clients.append(self)
 
     def connectionLost(self, reason):
         """Connection to the client has been lost."""
@@ -145,9 +135,7 @@ class WebSocketProtocol(Protocol):
         """Close connection with status and reason."""
         msgFormat = "Closing connection: {status!r} ({reason!r})"
         log.debug(msgFormat, status=status, reason=reason)
-        self.transport._receiver._transport.loseConnection(
-            status, reason.encode("utf-8")
-        )
+        self.transport.loseConnection(status, reason.encode("utf-8"))
 
     def getMessageField(self, message, field):
         """Get `field` value from `message`.
@@ -168,7 +156,7 @@ class WebSocketProtocol(Protocol):
     def getUserFromSessionId(self, session_id):
         """Return the user from `session_id`."""
         session_engine = self.factory.getSessionEngine()
-        session_wrapper = session_engine.SessionStore(session_id)
+        session_wrapper = session_engine.SessionStore(session_key=session_id)
         user_id = session_wrapper.get(SESSION_KEY)
         backend = session_wrapper.get(BACKEND_SESSION_KEY)
         if backend is None:
@@ -179,7 +167,7 @@ class WebSocketProtocol(Protocol):
 
         return None
 
-    @deferred
+    @inlineCallbacks
     def authenticate(self, session_id, csrftoken):
         """Authenticate the connection.
 
@@ -200,27 +188,22 @@ class WebSocketProtocol(Protocol):
             self.loseConnection(STATUSES.PROTOCOL_ERROR, "Invalid CSRF token.")
             return None
 
-        # Authenticate user.
-        def got_user(user):
-            if user is None:
-                self.loseConnection(
-                    STATUSES.PROTOCOL_ERROR, "Failed to authenticate user."
-                )
-                return None
-            else:
-                return user
-
-        def got_user_error(failure):
+        try:
+            user = yield deferToDatabase(self.getUserFromSessionId, session_id)
+        except Exception as error:
             self.loseConnection(
-                STATUSES.PROTOCOL_ERROR,
-                "Error authenticating user: %s" % failure.getErrorMessage(),
+                STATUSES.PROTOCOL_ERROR, f"Error authenticating user: {error}"
             )
-            return None
-
-        d = deferToDatabase(self.getUserFromSessionId, session_id)
-        d.addCallbacks(got_user, got_user_error)
+            returnValue(None)
+            return
 
-        return d
+        if user is None or user.id is None:
+            self.loseConnection(
+                STATUSES.PROTOCOL_ERROR, "Failed to authenticate user."
+            )
+            returnValue(None)
+        else:
+            returnValue(user)
 
     def dataReceived(self, data):
         """Received message from client and queue up the message."""
@@ -265,27 +248,32 @@ class WebSocketProtocol(Protocol):
                 return handledMessages
         return handledMessages
 
+    @inlineCallbacks
     def handleRequest(self, message, msg_type=MSG_TYPE.REQUEST):
         """Handle the request message."""
         # Get the required request_id.
         request_id = self.getMessageField(message, "request_id")
         if request_id is None:
-            return None
+            returnValue(None)
+            return
 
         if msg_type == MSG_TYPE.PING:
             self.sequence_number += 1
-            return defer.succeed(
+            yield defer.succeed(
                 self.sendResult(
                     request_id=request_id,
                     result=self.sequence_number,
                     msg_type=MSG_TYPE.PING_REPLY,
                 )
             )
+            returnValue(None)
+            return
 
         # Decode the method to be called.
         msg_method = self.getMessageField(message, "method")
         if msg_method is None:
-            return None
+            returnValue(None)
+            return
         try:
             handler_name, method = msg_method.split(".", 1)
         except ValueError:
@@ -293,7 +281,8 @@ class WebSocketProtocol(Protocol):
             self.loseConnection(
                 STATUSES.PROTOCOL_ERROR, "Invalid method formatting."
             )
-            return None
+            returnValue(None)
+            return
 
         # Create the handler for the call.
         handler_class = self.factory.getHandler(handler_name)
@@ -302,15 +291,19 @@ class WebSocketProtocol(Protocol):
                 STATUSES.PROTOCOL_ERROR,
                 "Handler %s does not exist." % handler_name,
             )
-            return None
+            returnValue(None)
+            return
 
         handler = self.buildHandler(handler_class)
-        d = handler.execute(method, message.get("params", {}))
-        d.addCallbacks(
-            partial(self.sendResult, request_id),
-            partial(self.sendError, request_id, handler, method),
-        )
-        return d
+        try:
+            result = yield handler.execute(method, message.get("params", {}))
+        except Exception as error:
+            self.sendError(request_id, handler, method, error)
+            returnValue(None)
+            return
+
+        self.sendResult(request_id, result)
+        returnValue(None)
 
     def _json_encode(self, obj):
         """Allow byte strings embedded in the 'result' object passed to
@@ -334,25 +327,22 @@ class WebSocketProtocol(Protocol):
         )
         return result
 
-    def sendError(self, request_id, handler, method, failure):
+    def sendError(self, request_id, handler, method, error):
         """Log and send error to client."""
-        if isinstance(failure.value, ValidationError):
+        if isinstance(error, ValidationError):
             try:
                 # When the error is a validation issue, send the error as a
                 # JSON object. The client will use this to JSON to render the
                 # error messages for the correct fields.
-                error = json.dumps(failure.value.message_dict)
+                error = json.dumps(error.message_dict)
             except AttributeError:
-                error = failure.value.message
+                error = error.message
         else:
-            error = failure.getErrorMessage()
-        why = "Error on request ({}) {}.{}: {}".format(
-            request_id,
-            handler._meta.handler_name,
-            method,
+            error = str(error)
+        log.err(
             error,
+            f"Error on request ({request_id}) {handler._meta.handler_name}.{method}: {error}",
         )
-        log.err(failure, why)
 
         error_msg = {
             "type": MSG_TYPE.RESPONSE,
@@ -363,7 +353,6 @@ class WebSocketProtocol(Protocol):
         self.transport.write(
             json.dumps(error_msg, default=self._json_encode).encode("ascii")
         )
-        return None
 
     def sendNotify(self, name, action, data):
         """Send the notify message with data."""
diff --git a/src/maasserver/websockets/tests/test_protocol.py b/src/maasserver/websockets/tests/test_protocol.py
index 6c591af..5657f1b 100644
--- a/src/maasserver/websockets/tests/test_protocol.py
+++ b/src/maasserver/websockets/tests/test_protocol.py
@@ -1,8 +1,6 @@
-# Copyright 2015-2018 Canonical Ltd.  This software is licensed under the
+# Copyright 2015-2023 Canonical Ltd.  This software is licensed under the
 # GNU Affero General Public License version 3 (see the file LICENSE).
 
-"""Tests for `maasserver.websockets.protocol`"""
-
 
 from collections import deque
 import json
@@ -90,6 +88,10 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
             url += "?csrftoken=%s" % csrftoken
         return ascii_url(url)
 
+    @transactional
+    def make_user(self):
+        return maas_factory.make_User()
+
     def get_written_transport_message(self, protocol):
         call = protocol.transport.write.call_args_list.pop()
         return json.loads(call[0][0].decode("ascii"))
@@ -213,24 +215,23 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         )
 
     def test_loseConnection_calls_loseConnection_with_status_and_reason(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         status = random.randint(1000, 1010)
         reason = maas_factory.make_name("reason")
         protocol.loseConnection(status, reason)
-        self.assertThat(
-            protocol.transport._receiver._transport.loseConnection,
-            MockCalledOnceWith(status, reason.encode("utf-8")),
+        protocol.transport.loseConnection.assert_called_once_with(
+            status, reason.encode("utf-8")
         )
 
     def test_getMessageField_returns_value_in_message(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         key = maas_factory.make_name("key")
         value = maas_factory.make_name("value")
         message = {key: value}
         self.assertEqual(value, protocol.getMessageField(message, key))
 
     def test_getMessageField_calls_loseConnection_if_key_missing(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         key = maas_factory.make_name("key")
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
         self.expectThat(protocol.getMessageField({}, key), Is(None))
@@ -254,7 +255,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
     @inlineCallbacks
     def test_getUserFromSessionId_returns_User(self):
         user, session_id = yield deferToDatabase(self.get_user_and_session_id)
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         protocol_user = yield deferToDatabase(
             lambda: protocol.getUserFromSessionId(session_id)
         )
@@ -263,7 +264,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
     def test_getUserFromSessionId_returns_None_for_invalid_key(self):
         self.client.login(user=maas_factory.make_User())
         session_id = maas_factory.make_name("sessionid")
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         self.assertIsNone(protocol.getUserFromSessionId(session_id))
 
     @wait_for_reactor
@@ -271,7 +272,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
     def test_authenticate_calls_loseConnection_if_user_is_None(self):
         csrftoken = maas_factory.make_name("csrftoken")
         uri = self.make_ws_uri(csrftoken)
-        protocol, factory = self.make_protocol(
+        protocol, _ = self.make_protocol(
             patch_authenticate=False, transport_uri=uri
         )
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
@@ -295,7 +296,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
     def test_authenticate_calls_loseConnection_if_error_getting_user(self):
         csrftoken = maas_factory.make_name("csrftoken")
         uri = self.make_ws_uri(csrftoken)
-        protocol, factory = self.make_protocol(
+        protocol, _ = self.make_protocol(
             patch_authenticate=False, transport_uri=uri
         )
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
@@ -323,7 +324,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         user, session_id = yield deferToDatabase(self.get_user_and_session_id)
         csrftoken = maas_factory.make_name("csrftoken")
         uri = self.make_ws_uri(csrftoken)
-        protocol, factory = self.make_protocol(
+        protocol, _ = self.make_protocol(
             patch_authenticate=False, transport_uri=uri
         )
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
@@ -342,7 +343,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
     def test_authenticate_calls_loseConnection_if_csrftoken_is_missing(self):
         user, session_id = yield deferToDatabase(self.get_user_and_session_id)
         uri = self.make_ws_uri(csrftoken=None)
-        protocol, factory = self.make_protocol(
+        protocol, _ = self.make_protocol(
             patch_authenticate=False, transport_uri=uri
         )
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
@@ -357,7 +358,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         )
 
     def test_dataReceived_calls_loseConnection_if_json_error(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
         self.expectThat(protocol.dataReceived(b"{{{{"), Is(""))
         self.expectThat(
@@ -368,7 +369,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         )
 
     def test_dataReceived_adds_message_to_queue(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         self.patch_autospec(protocol, "processMessages")
         message = {"type": MSG_TYPE.REQUEST}
         self.expectThat(
@@ -378,7 +379,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         self.expectThat(protocol.messages, Equals(deque([message])))
 
     def test_dataReceived_calls_processMessages(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         mock_processMessages = self.patch_autospec(protocol, "processMessages")
         message = {"type": MSG_TYPE.REQUEST}
         self.expectThat(
@@ -398,7 +399,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         self.assertEqual([], protocol.processMessages())
 
     def test_processMessages_process_all_messages_in_the_queue(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         protocol.user = maas_factory.make_User()
         self.patch_autospec(
             protocol, "handleRequest"
@@ -411,7 +412,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         self.assertEqual(messages, protocol.processMessages())
 
     def test_processMessages_calls_loseConnection_if_missing_type_field(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         protocol.user = maas_factory.make_User()
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
         self.patch_autospec(
@@ -432,7 +433,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         )
 
     def test_processMessages_calls_loseConnection_if_type_not_request(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         protocol.user = maas_factory.make_User()
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
         self.patch_autospec(
@@ -452,7 +453,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         )
 
     def test_processMessages_stops_processing_msgs_handleRequest_fails(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         protocol.user = maas_factory.make_User()
         self.patch_autospec(protocol, "handleRequest").return_value = None
         messages = [
@@ -463,7 +464,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         self.expectThat([messages[0]], Equals(protocol.processMessages()))
 
     def test_processMessages_calls_handleRequest_with_message(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         protocol.user = maas_factory.make_User()
         mock_handleRequest = self.patch_autospec(protocol, "handleRequest")
         mock_handleRequest.return_value = NOT_DONE_YET
@@ -474,66 +475,66 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
             mock_handleRequest, MockCalledOnceWith(message, MSG_TYPE.REQUEST)
         )
 
+    @wait_for_reactor
+    @inlineCallbacks
     def test_handleRequest_calls_loseConnection_if_missing_request_id(self):
-        protocol, factory = self.make_protocol()
-        protocol.user = maas_factory.make_User()
+        protocol, _ = self.make_protocol()
+        protocol.user = yield deferToDatabase(self.make_user)
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
         message = {"type": MSG_TYPE.REQUEST}
-        self.expectThat(protocol.handleRequest(message), Is(None))
-        self.expectThat(
-            mock_loseConnection,
-            MockCalledOnceWith(
-                STATUSES.PROTOCOL_ERROR,
-                "Missing request_id field in the received message.",
-            ),
+        result = yield protocol.handleRequest(message)
+        self.assertIsNone(result)
+        mock_loseConnection.assert_called_once_with(
+            STATUSES.PROTOCOL_ERROR,
+            "Missing request_id field in the received message.",
         )
 
+    @wait_for_reactor
+    @inlineCallbacks
     def test_handleRequest_calls_loseConnection_if_missing_method(self):
-        protocol, factory = self.make_protocol()
-        protocol.user = maas_factory.make_User()
+        protocol, _ = self.make_protocol()
+        protocol.user = yield deferToDatabase(self.make_user)
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
         message = {"type": MSG_TYPE.REQUEST, "request_id": 1}
-        self.expectThat(protocol.handleRequest(message), Is(None))
-        self.expectThat(
-            mock_loseConnection,
-            MockCalledOnceWith(
-                STATUSES.PROTOCOL_ERROR,
-                "Missing method field in the received message.",
-            ),
+        result = yield protocol.handleRequest(message)
+        self.assertIsNone(result)
+        mock_loseConnection.assert_called_once_with(
+            STATUSES.PROTOCOL_ERROR,
+            "Missing method field in the received message.",
         )
 
+    @wait_for_reactor
+    @inlineCallbacks
     def test_handleRequest_calls_loseConnection_if_bad_method(self):
-        protocol, factory = self.make_protocol()
-        protocol.user = maas_factory.make_User()
+        protocol, _ = self.make_protocol()
+        protocol.user = yield deferToDatabase(self.make_user)
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
         message = {
             "type": MSG_TYPE.REQUEST,
             "request_id": 1,
             "method": "nodes",
         }
-        self.expectThat(protocol.handleRequest(message), Is(None))
-        self.expectThat(
-            mock_loseConnection,
-            MockCalledOnceWith(
-                STATUSES.PROTOCOL_ERROR, "Invalid method formatting."
-            ),
+        result = yield protocol.handleRequest(message)
+        self.assertIsNone(result)
+        mock_loseConnection.assert_called_once_with(
+            STATUSES.PROTOCOL_ERROR, "Invalid method formatting."
         )
 
+    @wait_for_reactor
+    @inlineCallbacks
     def test_handleRequest_calls_loseConnection_if_unknown_handler(self):
-        protocol, factory = self.make_protocol()
-        protocol.user = maas_factory.make_User()
+        protocol, _ = self.make_protocol()
+        protocol.user = yield deferToDatabase(self.make_user)
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
         message = {
             "type": MSG_TYPE.REQUEST,
             "request_id": 1,
             "method": "unknown.list",
         }
-        self.expectThat(protocol.handleRequest(message), Is(None))
-        self.expectThat(
-            mock_loseConnection,
-            MockCalledOnceWith(
-                STATUSES.PROTOCOL_ERROR, "Handler unknown does not exist."
-            ),
+        result = yield protocol.handleRequest(message)
+        self.assertIsNone(result)
+        mock_loseConnection.assert_called_once_with(
+            STATUSES.PROTOCOL_ERROR, "Handler unknown does not exist."
         )
 
     @synchronous
@@ -609,7 +610,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         # Need to delete the node as the transaction is committed
         self.addCleanup(self.clean_node, node)
 
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         protocol.user = MagicMock()
         message = {
             "type": MSG_TYPE.REQUEST,
@@ -631,7 +632,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         node = yield deferToDatabase(self.make_node)
         # Need to delete the node as the transaction is committed
         self.addCleanup(self.clean_node, node)
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         protocol.user = MagicMock()
 
         error_dict = {"error": ["bad"]}
@@ -659,7 +660,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         node = yield deferToDatabase(self.make_node)
         # Need to delete the node as the transaction is committed
         self.addCleanup(self.clean_node, node)
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         protocol.user = MagicMock()
 
         self.patch(Handler, "execute").return_value = fail(
@@ -686,7 +687,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         node = yield deferToDatabase(self.make_node)
         # Need to delete the node as the transaction is committed
         self.addCleanup(self.clean_node, node)
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         protocol.user = MagicMock()
 
         self.patch(Handler, "execute").return_value = fail(
@@ -710,7 +711,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
     @wait_for_reactor
     @inlineCallbacks
     def test_handleRequest_sends_ping_reply_on_ping(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         protocol.user = MagicMock()
 
         request_id = random.choice([1, 3, 5, 7, 9])
@@ -726,7 +727,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         self.expectThat(sent_obj["result"], Equals(seq + 1))
 
     def test_sendNotify_sends_correct_json(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         name = maas_factory.make_name("name")
         action = maas_factory.make_name("action")
         data = maas_factory.make_name("data")
diff --git a/src/maasserver/websockets/websockets.py b/src/maasserver/websockets/websockets.py
index 65c7038..b08673b 100644
--- a/src/maasserver/websockets/websockets.py
+++ b/src/maasserver/websockets/websockets.py
@@ -533,12 +533,12 @@ class WebSocketsProtocolWrapper(WebSocketsProtocol):
         for chunk in data:
             self.write(chunk)
 
-    def loseConnection(self):
+    def loseConnection(self, *args, **kwargs):
         """
         Try to lose the connection gracefully when closing by sending a close
         frame.
         """
-        self._receiver._transport.loseConnection()
+        self._receiver._transport.loseConnection(*args, **kwargs)
 
     def __getattr__(self, name):
         """

Follow ups