← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~ack/maas:websocket-check-expired-sessions into maas:master

 

Alberto Donato has proposed merging ~ack/maas:websocket-check-expired-sessions into maas:master.

Commit message:
periodically check and disconnect expired websocket sessions



Requested reviews:
  MAAS Maintainers (maas-maintainers)

For more details, see:
https://code.launchpad.net/~ack/maas/+git/maas/+merge/438669
-- 
Your team MAAS Committers is subscribed to branch maas:master.
diff --git a/src/maasserver/websockets/protocol.py b/src/maasserver/websockets/protocol.py
index 6cdaad7..5eda482 100644
--- a/src/maasserver/websockets/protocol.py
+++ b/src/maasserver/websockets/protocol.py
@@ -15,8 +15,10 @@ from django.conf import settings
 from django.contrib.auth import BACKEND_SESSION_KEY, load_backend, SESSION_KEY
 from django.core.exceptions import ValidationError
 from django.http import HttpRequest
+from django.utils import timezone
 from twisted.internet.defer import fail, inlineCallbacks, returnValue, succeed
 from twisted.internet.protocol import Factory, Protocol
+from twisted.internet.task import LoopingCall
 from twisted.python.modules import getModule
 from twisted.web.server import NOT_DONE_YET
 
@@ -380,16 +382,18 @@ class WebSocketFactory(Factory):
         self.handlers = {}
         self.clients = []
         self.listener = listener
+        self.session_checker = LoopingCall(self._check_sessions)
+        self.session_checker_done = None
         self.cacheHandlers()
         self.registerNotifiers()
 
     def startFactory(self):
-        """Register for RPC events."""
         self.registerRPCEvents()
+        self.session_checker_done = self.session_checker.start(5, now=True)
 
     def stopFactory(self):
-        """Unregister RPC events."""
         self.unregisterRPCEvents()
+        self.session_checker.stop()
 
     def getSessionEngine(self):
         """Returns the session engine being used by Django.
@@ -491,3 +495,35 @@ class WebSocketFactory(Factory):
             )
         else:
             return fail("Unable to get the 'controller' handler.")
+
+    @inlineCallbacks
+    def _check_sessions(self):
+        client_sessions = {
+            client.session.session_key: client
+            for client in self.clients
+            if client.session is not None
+        }
+        client_session_keys = set(client_sessions)
+
+        def get_valid_sessions(session_keys):
+            session_engine = self.getSessionEngine()
+            Session = session_engine.SessionStore.get_model_class()
+            return set(
+                Session.objects.filter(
+                    session_key__in=session_keys,
+                    expire_date__gt=timezone.now(),
+                ).values_list("session_key", flat=True)
+            )
+
+        valid_session_keys = yield deferToDatabase(
+            get_valid_sessions,
+            client_session_keys,
+        )
+        # drop connections for expired sessions
+        for session_key in client_session_keys - valid_session_keys:
+            client = client_sessions.get(session_key)
+
+            if client:
+                client.loseConnection(STATUSES.GOING_AWAY, "Session expired")
+
+        returnValue(None)
diff --git a/src/maasserver/websockets/tests/test_protocol.py b/src/maasserver/websockets/tests/test_protocol.py
index 92dda5a..3ab1e23 100644
--- a/src/maasserver/websockets/tests/test_protocol.py
+++ b/src/maasserver/websockets/tests/test_protocol.py
@@ -3,6 +3,7 @@
 
 
 from collections import deque
+from datetime import datetime, timedelta
 import json
 import random
 from unittest.mock import MagicMock, sentinel
@@ -93,7 +94,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         return json.loads(call[0][0].decode("ascii"))
 
     def test_connectionMade_sets_the_request(self):
-        protocol, factory = self.make_protocol(patch_authenticate=False)
+        protocol, _ = self.make_protocol(patch_authenticate=False)
         self.patch_autospec(protocol, "authenticate")
         # Be sure the request field is populated by the time that
         # processMessages() is called.
@@ -122,7 +123,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         )
 
     def test_connectionMade_sets_the_request_default_server_name_port(self):
-        protocol, factory = self.make_protocol(patch_authenticate=False)
+        protocol, _ = self.make_protocol(patch_authenticate=False)
         self.patch_autospec(protocol, "authenticate")
         self.patch_autospec(protocol, "processMessages")
         mock_splithost = self.patch_autospec(protocol_module, "splithost")
@@ -134,7 +135,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         self.assertEqual(protocol.request.META["SERVER_PORT"], 5248)
 
     def test_connectionMade_processes_messages(self):
-        protocol, factory = self.make_protocol(patch_authenticate=False)
+        protocol, _ = self.make_protocol(patch_authenticate=False)
         self.patch_autospec(protocol, "authenticate")
         self.patch_autospec(protocol, "processMessages")
         protocol.authenticate.return_value = defer.succeed(True)
@@ -161,7 +162,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         self.assertNotIn(protocol, factory.clients)
 
     def test_connectionMade_extracts_sessionid_and_csrftoken(self):
-        protocol, factory = self.make_protocol(patch_authenticate=False)
+        protocol, _ = self.make_protocol(patch_authenticate=False)
         sessionid = maas_factory.make_name("sessionid")
         csrftoken = maas_factory.make_name("csrftoken")
         cookies = {
@@ -194,7 +195,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
         self.assertEqual([], factory.clients)
 
     def test_loseConnection_writes_to_log(self):
-        protocol, factory = self.make_protocol()
+        protocol, _ = self.make_protocol()
         status = random.randint(1000, 1010)
         reason = maas_factory.make_name("reason")
         with TwistedLoggerFixture() as logger:
@@ -331,7 +332,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
     @wait_for_reactor
     @inlineCallbacks
     def test_authenticate_calls_loseConnection_if_csrftoken_is_missing(self):
-        user, session_id = yield deferToDatabase(self.get_user_and_session_id)
+        _, session_id = yield deferToDatabase(self.get_user_and_session_id)
         uri = self.make_ws_uri(csrftoken=None)
         protocol, _ = self.make_protocol(
             patch_authenticate=False, transport_uri=uri
@@ -981,7 +982,7 @@ class TestWebSocketFactoryTransactional(
         controller = yield deferToDatabase(
             transactional(maas_factory.make_RackController)
         )
-        protocol, factory = self.make_protocol_with_factory(user=user)
+        _, factory = self.make_protocol_with_factory(user=user)
         mock_onNotify = self.patch(factory, "onNotify")
         controller_handler = MagicMock()
         factory.handlers["controller"] = controller_handler
@@ -995,3 +996,59 @@ class TestWebSocketFactoryTransactional(
                 controller.system_id,
             ),
         )
+
+    @wait_for_reactor
+    @inlineCallbacks
+    def test_check_sessions(self):
+        factory = self.make_factory()
+
+        session_engine = factory.getSessionEngine()
+        Session = session_engine.SessionStore.get_model_class()
+        key1 = maas_factory.make_string()
+        key2 = maas_factory.make_string()
+
+        def make_sessions():
+            now = datetime.utcnow()
+            delta = timedelta(hours=1)
+            # first session is expired, second one is valid
+            return (
+                Session.objects.create(
+                    session_key=key1, expire_date=now - delta
+                ),
+                Session.objects.create(
+                    session_key=key2, expire_date=now + delta
+                ),
+            )
+
+        session1, session2 = yield deferToDatabase(make_sessions)
+
+        def make_protocol_with_session(session):
+            protocol = factory.buildProtocol(None)
+            protocol.transport = MagicMock()
+            protocol.transport.cookies = b""
+
+            def authenticate(*args):
+                protocol.session = session
+                return defer.succeed(True)
+
+            self.patch(protocol, "authenticate", authenticate)
+            self.patch(protocol, "loseConnection")
+            return protocol
+
+        proto1 = make_protocol_with_session(session1)
+        proto2 = make_protocol_with_session(session2)
+
+        yield proto1.connectionMade()
+        self.addCleanup(lambda: proto1.connectionLost(""))
+        yield proto2.connectionMade()
+        self.addCleanup(lambda: proto2.connectionLost(""))
+
+        yield factory.startFactory()
+        factory.stopFactory()
+        # wait until it's stopped, sessions are checked
+        yield factory.session_checker_done
+        # the first client gets disconnected
+        proto1.loseConnection.assert_called_once_with(
+            STATUSES.GOING_AWAY, "Session expired"
+        )
+        proto2.loseConnection.assert_not_called()

Follow ups