← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~ack/maas:more-ws-cleanups into maas:master

 

Alberto Donato has proposed merging ~ack/maas:more-ws-cleanups into maas:master.

Commit message:
some more cleanups to websocket code



Requested reviews:
  MAAS Maintainers (maas-maintainers)

For more details, see:
https://code.launchpad.net/~ack/maas/+git/maas/+merge/438604
-- 
Your team MAAS Maintainers is requested to review the proposed merge of ~ack/maas:more-ws-cleanups into maas:master.
diff --git a/src/maasserver/websockets/protocol.py b/src/maasserver/websockets/protocol.py
index bd6add8..6cdaad7 100644
--- a/src/maasserver/websockets/protocol.py
+++ b/src/maasserver/websockets/protocol.py
@@ -75,6 +75,7 @@ class WebSocketProtocol(Protocol):
     def __init__(self):
         self.messages = deque()
         self.user = None
+        self.session = None
         self.request = None
         self.cache = {}
         self.sequence_number = 0
@@ -88,13 +89,12 @@ class WebSocketProtocol(Protocol):
         # from an authenticated user.
 
         cookies = self.transport.cookies.decode("ascii")
-        user = yield self.authenticate(
+        authenticated = yield self.authenticate(
             get_cookie(cookies, "sessionid"), get_cookie(cookies, "csrftoken")
         )
-        if not user:
+        if not authenticated:
             return
 
-        self.user = user
         # XXX newell 2018-10-17 bug=1798479:
         # Check that 'SERVER_NAME' and 'SERVER_PORT' are set.
         # 'SERVER_NAME' and 'SERVER_PORT' are required so
@@ -109,6 +109,7 @@ class WebSocketProtocol(Protocol):
         # Create the request for the handlers for this connection.
         self.request = HttpRequest()
         self.request.user = self.user
+        self.request.session = self.session
         self.request.META.update(
             {
                 "HTTP_USER_AGENT": self.transport.user_agent,
@@ -152,17 +153,17 @@ class WebSocketProtocol(Protocol):
 
     @synchronous
     @transactional
-    def getUserFromSessionId(self, session_id):
-        """Return the user from `session_id`."""
+    def get_user_and_session(self, session_id):
+        """Return the user and its session from the session ID."""
         session_engine = self.factory.getSessionEngine()
-        session_wrapper = session_engine.SessionStore(session_key=session_id)
-        user_id = session_wrapper.get(SESSION_KEY)
-        backend = session_wrapper.get(BACKEND_SESSION_KEY)
+        session = session_engine.SessionStore(session_key=session_id)
+        backend = session.get(BACKEND_SESSION_KEY)
         if backend is None:
             return None
         auth_backend = load_backend(backend)
+        user_id = session.get(SESSION_KEY)
         if user_id is not None and auth_backend is not None:
-            return auth_backend.get_user(user_id)
+            return auth_backend.get_user(user_id), session
 
         return None
 
@@ -173,8 +174,7 @@ class WebSocketProtocol(Protocol):
         - Check that the CSRF token is valid.
         - Authenticate the user using the session id.
 
-        This returns the authenticated user or ``None``. The latter means that
-        the connection is being dropped, and that processing should cease.
+        It returns whether authentication succeeded.
         """
         # Check the CSRF token.
         tokens = parse_qs(urlparse(self.transport.uri).query).get(b"csrftoken")
@@ -185,24 +185,31 @@ class WebSocketProtocol(Protocol):
         if tokens is None or csrftoken not in tokens:
             # No csrftoken in the request or the token does not match.
             self.loseConnection(STATUSES.PROTOCOL_ERROR, "Invalid CSRF token.")
-            return None
+            returnValue(False)
+            return
 
         try:
-            user = yield deferToDatabase(self.getUserFromSessionId, session_id)
+            result = yield deferToDatabase(
+                self.get_user_and_session, session_id
+            )
         except Exception as error:
             self.loseConnection(
                 STATUSES.PROTOCOL_ERROR, f"Error authenticating user: {error}"
             )
-            returnValue(None)
+            returnValue(False)
             return
+        if result:
+            self.user, self.session = result
+        else:
+            self.user = self.session = None
 
-        if user is None or user.id is None:
+        if self.user is None or self.user.id is None:
             self.loseConnection(
                 STATUSES.PROTOCOL_ERROR, "Failed to authenticate user."
             )
-            returnValue(None)
+            returnValue(False)
         else:
-            returnValue(user)
+            returnValue(True)
 
     def dataReceived(self, data):
         """Received message from client and queue up the message."""
diff --git a/src/maasserver/websockets/tests/test_protocol.py b/src/maasserver/websockets/tests/test_protocol.py
index 02348c7..92dda5a 100644
--- a/src/maasserver/websockets/tests/test_protocol.py
+++ b/src/maasserver/websockets/tests/test_protocol.py
@@ -133,15 +133,14 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         self.assertEqual(protocol.request.META["SERVER_NAME"], "localhost")
         self.assertEqual(protocol.request.META["SERVER_PORT"], 5248)
 
-    def test_connectionMade_sets_user_and_processes_messages(self):
+    def test_connectionMade_processes_messages(self):
         protocol, factory = self.make_protocol(patch_authenticate=False)
         self.patch_autospec(protocol, "authenticate")
         self.patch_autospec(protocol, "processMessages")
-        protocol.authenticate.return_value = defer.succeed(sentinel.user)
+        protocol.authenticate.return_value = defer.succeed(True)
         protocol.connectionMade()
         self.addCleanup(protocol.connectionLost, "")
-        self.assertIs(protocol.user, sentinel.user)
-        self.assertThat(protocol.processMessages, MockCalledOnceWith())
+        protocol.processMessages.assert_called_once_with()
 
     def test_connectionMade_adds_self_to_factory_if_auth_succeeds(self):
         protocol, factory = self.make_protocol()
@@ -249,19 +248,20 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
 
     @wait_for_reactor
     @inlineCallbacks
-    def test_getUserFromSessionId_returns_User(self):
+    def test_get_user_and_session_returns_user_and_session(self):
         user, session_id = yield deferToDatabase(self.get_user_and_session_id)
         protocol, _ = self.make_protocol()
-        protocol_user = yield deferToDatabase(
-            lambda: protocol.getUserFromSessionId(session_id)
+        protocol_user, session = yield deferToDatabase(
+            lambda: protocol.get_user_and_session(session_id)
         )
         self.assertEqual(user, protocol_user)
+        self.assertEqual(session.session_key, session_id)
 
-    def test_getUserFromSessionId_returns_None_for_invalid_key(self):
+    def test_get_user_and_session_returns_None_for_invalid_key(self):
         self.client.login(user=maas_factory.make_User())
         session_id = maas_factory.make_name("sessionid")
         protocol, _ = self.make_protocol()
-        self.assertIsNone(protocol.getUserFromSessionId(session_id))
+        self.assertIsNone(protocol.get_user_and_session(session_id))
 
     @wait_for_reactor
     @inlineCallbacks
@@ -272,19 +272,16 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
             patch_authenticate=False, transport_uri=uri
         )
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
-        mock_getUserFromSessionId = self.patch_autospec(
-            protocol, "getUserFromSessionId"
+        mock_get_user_and_session = self.patch_autospec(
+            protocol, "get_user_and_session"
         )
-        mock_getUserFromSessionId.return_value = None
+        mock_get_user_and_session.return_value = None
 
         yield protocol.authenticate(
             maas_factory.make_name("sessionid"), csrftoken
         )
-        self.expectThat(
-            mock_loseConnection,
-            MockCalledOnceWith(
-                STATUSES.PROTOCOL_ERROR, "Failed to authenticate user."
-            ),
+        mock_loseConnection.assert_called_once_with(
+            STATUSES.PROTOCOL_ERROR, "Failed to authenticate user."
         )
 
     @wait_for_reactor
@@ -296,28 +293,25 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
             patch_authenticate=False, transport_uri=uri
         )
         mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
-        mock_getUserFromSessionId = self.patch_autospec(
-            protocol, "getUserFromSessionId"
+        mock_get_user_and_session = self.patch_autospec(
+            protocol, "get_user_and_session"
         )
-        mock_getUserFromSessionId.side_effect = maas_factory.make_exception(
+        mock_get_user_and_session.side_effect = maas_factory.make_exception(
             "unknown reason"
         )
 
         yield protocol.authenticate(
             maas_factory.make_name("sessionid"), csrftoken
         )
-        self.expectThat(
-            mock_loseConnection,
-            MockCalledOnceWith(
-                STATUSES.PROTOCOL_ERROR,
-                "Error authenticating user: unknown reason",
-            ),
+        mock_loseConnection.assert_called_once_with(
+            STATUSES.PROTOCOL_ERROR,
+            "Error authenticating user: unknown reason",
         )
 
     @wait_for_reactor
     @inlineCallbacks
     def test_authenticate_calls_loseConnection_if_invalid_csrftoken(self):
-        user, session_id = yield deferToDatabase(self.get_user_and_session_id)
+        _, session_id = yield deferToDatabase(self.get_user_and_session_id)
         csrftoken = maas_factory.make_name("csrftoken")
         uri = self.make_ws_uri(csrftoken)
         protocol, _ = self.make_protocol(
@@ -755,8 +749,12 @@ class MakeProtocolFactoryMixin:
         protocol.transport.cookies = b""
         if user is None:
             user = maas_factory.make_User()
-        mock_authenticate = self.patch(protocol, "authenticate")
-        mock_authenticate.return_value = defer.succeed(user)
+
+        def authenticate(*args):
+            protocol.user = user
+            return defer.succeed(True)
+
+        self.patch(protocol, "authenticate", authenticate)
         protocol.connectionMade()
         self.addCleanup(lambda: protocol.connectionLost(""))
         return protocol, factory

Follow ups