sts-sponsors team mailing list archive
-
sts-sponsors team
-
Mailing list archive
-
Message #05825
[Merge] ~ack/maas:websocket-check-expired-sessions into maas:master
Alberto Donato has proposed merging ~ack/maas:websocket-check-expired-sessions into maas:master.
Commit message:
periodically check and disconnect expired websocket sessions
Requested reviews:
MAAS Maintainers (maas-maintainers)
For more details, see:
https://code.launchpad.net/~ack/maas/+git/maas/+merge/438669
--
Your team MAAS Committers is subscribed to branch maas:master.
diff --git a/src/maasserver/websockets/protocol.py b/src/maasserver/websockets/protocol.py
index 6cdaad7..5eda482 100644
--- a/src/maasserver/websockets/protocol.py
+++ b/src/maasserver/websockets/protocol.py
@@ -15,8 +15,10 @@ from django.conf import settings
from django.contrib.auth import BACKEND_SESSION_KEY, load_backend, SESSION_KEY
from django.core.exceptions import ValidationError
from django.http import HttpRequest
+from django.utils import timezone
from twisted.internet.defer import fail, inlineCallbacks, returnValue, succeed
from twisted.internet.protocol import Factory, Protocol
+from twisted.internet.task import LoopingCall
from twisted.python.modules import getModule
from twisted.web.server import NOT_DONE_YET
@@ -380,16 +382,18 @@ class WebSocketFactory(Factory):
self.handlers = {}
self.clients = []
self.listener = listener
+ self.session_checker = LoopingCall(self._check_sessions)
+ self.session_checker_done = None
self.cacheHandlers()
self.registerNotifiers()
def startFactory(self):
- """Register for RPC events."""
self.registerRPCEvents()
+ self.session_checker_done = self.session_checker.start(5, now=True)
def stopFactory(self):
- """Unregister RPC events."""
self.unregisterRPCEvents()
+ self.session_checker.stop()
def getSessionEngine(self):
"""Returns the session engine being used by Django.
@@ -491,3 +495,35 @@ class WebSocketFactory(Factory):
)
else:
return fail("Unable to get the 'controller' handler.")
+
+ @inlineCallbacks
+ def _check_sessions(self):
+ client_sessions = {
+ client.session.session_key: client
+ for client in self.clients
+ if client.session is not None
+ }
+ client_session_keys = set(client_sessions)
+
+ def get_valid_sessions(session_keys):
+ session_engine = self.getSessionEngine()
+ Session = session_engine.SessionStore.get_model_class()
+ return set(
+ Session.objects.filter(
+ session_key__in=session_keys,
+ expire_date__gt=timezone.now(),
+ ).values_list("session_key", flat=True)
+ )
+
+ valid_session_keys = yield deferToDatabase(
+ get_valid_sessions,
+ client_session_keys,
+ )
+ # drop connections for expired sessions
+ for session_key in client_session_keys - valid_session_keys:
+ client = client_sessions.get(session_key)
+
+ if client:
+ client.loseConnection(STATUSES.GOING_AWAY, "Session expired")
+
+ returnValue(None)
diff --git a/src/maasserver/websockets/tests/test_protocol.py b/src/maasserver/websockets/tests/test_protocol.py
index 92dda5a..3ab1e23 100644
--- a/src/maasserver/websockets/tests/test_protocol.py
+++ b/src/maasserver/websockets/tests/test_protocol.py
@@ -3,6 +3,7 @@
from collections import deque
+from datetime import datetime, timedelta
import json
import random
from unittest.mock import MagicMock, sentinel
@@ -93,7 +94,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
return json.loads(call[0][0].decode("ascii"))
def test_connectionMade_sets_the_request(self):
- protocol, factory = self.make_protocol(patch_authenticate=False)
+ protocol, _ = self.make_protocol(patch_authenticate=False)
self.patch_autospec(protocol, "authenticate")
# Be sure the request field is populated by the time that
# processMessages() is called.
@@ -122,7 +123,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
)
def test_connectionMade_sets_the_request_default_server_name_port(self):
- protocol, factory = self.make_protocol(patch_authenticate=False)
+ protocol, _ = self.make_protocol(patch_authenticate=False)
self.patch_autospec(protocol, "authenticate")
self.patch_autospec(protocol, "processMessages")
mock_splithost = self.patch_autospec(protocol_module, "splithost")
@@ -134,7 +135,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
self.assertEqual(protocol.request.META["SERVER_PORT"], 5248)
def test_connectionMade_processes_messages(self):
- protocol, factory = self.make_protocol(patch_authenticate=False)
+ protocol, _ = self.make_protocol(patch_authenticate=False)
self.patch_autospec(protocol, "authenticate")
self.patch_autospec(protocol, "processMessages")
protocol.authenticate.return_value = defer.succeed(True)
@@ -161,7 +162,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
self.assertNotIn(protocol, factory.clients)
def test_connectionMade_extracts_sessionid_and_csrftoken(self):
- protocol, factory = self.make_protocol(patch_authenticate=False)
+ protocol, _ = self.make_protocol(patch_authenticate=False)
sessionid = maas_factory.make_name("sessionid")
csrftoken = maas_factory.make_name("csrftoken")
cookies = {
@@ -194,7 +195,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
self.assertEqual([], factory.clients)
def test_loseConnection_writes_to_log(self):
- protocol, factory = self.make_protocol()
+ protocol, _ = self.make_protocol()
status = random.randint(1000, 1010)
reason = maas_factory.make_name("reason")
with TwistedLoggerFixture() as logger:
@@ -331,7 +332,7 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
@wait_for_reactor
@inlineCallbacks
def test_authenticate_calls_loseConnection_if_csrftoken_is_missing(self):
- user, session_id = yield deferToDatabase(self.get_user_and_session_id)
+ _, session_id = yield deferToDatabase(self.get_user_and_session_id)
uri = self.make_ws_uri(csrftoken=None)
protocol, _ = self.make_protocol(
patch_authenticate=False, transport_uri=uri
@@ -981,7 +982,7 @@ class TestWebSocketFactoryTransactional(
controller = yield deferToDatabase(
transactional(maas_factory.make_RackController)
)
- protocol, factory = self.make_protocol_with_factory(user=user)
+ _, factory = self.make_protocol_with_factory(user=user)
mock_onNotify = self.patch(factory, "onNotify")
controller_handler = MagicMock()
factory.handlers["controller"] = controller_handler
@@ -995,3 +996,59 @@ class TestWebSocketFactoryTransactional(
controller.system_id,
),
)
+
+ @wait_for_reactor
+ @inlineCallbacks
+ def test_check_sessions(self):
+ factory = self.make_factory()
+
+ session_engine = factory.getSessionEngine()
+ Session = session_engine.SessionStore.get_model_class()
+ key1 = maas_factory.make_string()
+ key2 = maas_factory.make_string()
+
+ def make_sessions():
+ now = datetime.utcnow()
+ delta = timedelta(hours=1)
+ # first session is expired, second one is valid
+ return (
+ Session.objects.create(
+ session_key=key1, expire_date=now - delta
+ ),
+ Session.objects.create(
+ session_key=key2, expire_date=now + delta
+ ),
+ )
+
+ session1, session2 = yield deferToDatabase(make_sessions)
+
+ def make_protocol_with_session(session):
+ protocol = factory.buildProtocol(None)
+ protocol.transport = MagicMock()
+ protocol.transport.cookies = b""
+
+ def authenticate(*args):
+ protocol.session = session
+ return defer.succeed(True)
+
+ self.patch(protocol, "authenticate", authenticate)
+ self.patch(protocol, "loseConnection")
+ return protocol
+
+ proto1 = make_protocol_with_session(session1)
+ proto2 = make_protocol_with_session(session2)
+
+ yield proto1.connectionMade()
+ self.addCleanup(lambda: proto1.connectionLost(""))
+ yield proto2.connectionMade()
+ self.addCleanup(lambda: proto2.connectionLost(""))
+
+ yield factory.startFactory()
+ factory.stopFactory()
+ # wait until it's stopped, sessions are checked
+ yield factory.session_checker_done
+ # the first client gets disconnected
+ proto1.loseConnection.assert_called_once_with(
+ STATUSES.GOING_AWAY, "Session expired"
+ )
+ proto2.loseConnection.assert_not_called()
Follow ups