← Back to team overview

sts-sponsors team mailing list archive

Re: [Merge] ~ack/maas:websocket-check-expired-sessions into maas:master

 


Diff comments:

> diff --git a/src/maasserver/websockets/protocol.py b/src/maasserver/websockets/protocol.py
> index 6cdaad7..5eda482 100644
> --- a/src/maasserver/websockets/protocol.py
> +++ b/src/maasserver/websockets/protocol.py
> @@ -380,16 +382,18 @@ class WebSocketFactory(Factory):
>          self.handlers = {}
>          self.clients = []
>          self.listener = listener
> +        self.session_checker = LoopingCall(self._check_sessions)
> +        self.session_checker_done = None
>          self.cacheHandlers()
>          self.registerNotifiers()
>  
>      def startFactory(self):
> -        """Register for RPC events."""
>          self.registerRPCEvents()
> +        self.session_checker_done = self.session_checker.start(5, now=True)
>  
>      def stopFactory(self):
> -        """Unregister RPC events."""
>          self.unregisterRPCEvents()
> +        self.session_checker.stop()

what about error handling? how will we ensure that session checker stops if unregisterRPCEvents fails?

>  
>      def getSessionEngine(self):
>          """Returns the session engine being used by Django.
> @@ -491,3 +495,35 @@ class WebSocketFactory(Factory):
>              )
>          else:
>              return fail("Unable to get the 'controller' handler.")
> +
> +    @inlineCallbacks
> +    def _check_sessions(self):
> +        client_sessions = {
> +            client.session.session_key: client
> +            for client in self.clients
> +            if client.session is not None
> +        }
> +        client_session_keys = set(client_sessions)
> +
> +        def get_valid_sessions(session_keys):
> +            session_engine = self.getSessionEngine()
> +            Session = session_engine.SessionStore.get_model_class()
> +            return set(
> +                Session.objects.filter(
> +                    session_key__in=session_keys,
> +                    expire_date__gt=timezone.now(),
> +                ).values_list("session_key", flat=True)
> +            )
> +
> +        valid_session_keys = yield deferToDatabase(
> +            get_valid_sessions,
> +            client_session_keys,
> +        )
> +        # drop connections for expired sessions
> +        for session_key in client_session_keys - valid_session_keys:
> +            client = client_sessions.get(session_key)
> +
> +            if client:

walrus operator to one-line it?

> +                client.loseConnection(STATUSES.GOING_AWAY, "Session expired")
> +
> +        returnValue(None)
> diff --git a/src/maasserver/websockets/tests/test_protocol.py b/src/maasserver/websockets/tests/test_protocol.py
> index 92dda5a..3ab1e23 100644
> --- a/src/maasserver/websockets/tests/test_protocol.py
> +++ b/src/maasserver/websockets/tests/test_protocol.py
> @@ -995,3 +996,59 @@ class TestWebSocketFactoryTransactional(
>                  controller.system_id,
>              ),
>          )
> +
> +    @wait_for_reactor
> +    @inlineCallbacks
> +    def test_check_sessions(self):
> +        factory = self.make_factory()
> +
> +        session_engine = factory.getSessionEngine()
> +        Session = session_engine.SessionStore.get_model_class()
> +        key1 = maas_factory.make_string()
> +        key2 = maas_factory.make_string()
> +
> +        def make_sessions():
> +            now = datetime.utcnow()
> +            delta = timedelta(hours=1)
> +            # first session is expired, second one is valid
> +            return (
> +                Session.objects.create(
> +                    session_key=key1, expire_date=now - delta
> +                ),
> +                Session.objects.create(
> +                    session_key=key2, expire_date=now + delta
> +                ),
> +            )
> +
> +        session1, session2 = yield deferToDatabase(make_sessions)

give them better names e.g. expired_session, active_session

> +
> +        def make_protocol_with_session(session):
> +            protocol = factory.buildProtocol(None)
> +            protocol.transport = MagicMock()
> +            protocol.transport.cookies = b""
> +
> +            def authenticate(*args):
> +                protocol.session = session
> +                return defer.succeed(True)
> +
> +            self.patch(protocol, "authenticate", authenticate)
> +            self.patch(protocol, "loseConnection")
> +            return protocol
> +
> +        proto1 = make_protocol_with_session(session1)
> +        proto2 = make_protocol_with_session(session2)

same here - expired_protocol, active_protocol ?

> +
> +        yield proto1.connectionMade()
> +        self.addCleanup(lambda: proto1.connectionLost(""))

drop the lambda and just `self.addCleanup(proto1.connectionLost, "")

> +        yield proto2.connectionMade()
> +        self.addCleanup(lambda: proto2.connectionLost(""))
> +
> +        yield factory.startFactory()
> +        factory.stopFactory()
> +        # wait until it's stopped, sessions are checked
> +        yield factory.session_checker_done
> +        # the first client gets disconnected
> +        proto1.loseConnection.assert_called_once_with(
> +            STATUSES.GOING_AWAY, "Session expired"
> +        )
> +        proto2.loseConnection.assert_not_called()


-- 
https://code.launchpad.net/~ack/maas/+git/maas/+merge/438669
Your team MAAS Maintainers is requested to review the proposed merge of ~ack/maas:websocket-check-expired-sessions into maas:master.



References