sts-sponsors team mailing list archive
-
sts-sponsors team
-
Mailing list archive
-
Message #05804
[Merge] ~ack/maas:more-ws-cleanups into maas:master
Alberto Donato has proposed merging ~ack/maas:more-ws-cleanups into maas:master.
Commit message:
some more cleanups to websocket code
Requested reviews:
MAAS Maintainers (maas-maintainers)
For more details, see:
https://code.launchpad.net/~ack/maas/+git/maas/+merge/438604
--
Your team MAAS Maintainers is requested to review the proposed merge of ~ack/maas:more-ws-cleanups into maas:master.
diff --git a/src/maasserver/websockets/protocol.py b/src/maasserver/websockets/protocol.py
index bd6add8..6cdaad7 100644
--- a/src/maasserver/websockets/protocol.py
+++ b/src/maasserver/websockets/protocol.py
@@ -75,6 +75,7 @@ class WebSocketProtocol(Protocol):
def __init__(self):
self.messages = deque()
self.user = None
+ self.session = None
self.request = None
self.cache = {}
self.sequence_number = 0
@@ -88,13 +89,12 @@ class WebSocketProtocol(Protocol):
# from an authenticated user.
cookies = self.transport.cookies.decode("ascii")
- user = yield self.authenticate(
+ authenticated = yield self.authenticate(
get_cookie(cookies, "sessionid"), get_cookie(cookies, "csrftoken")
)
- if not user:
+ if not authenticated:
return
- self.user = user
# XXX newell 2018-10-17 bug=1798479:
# Check that 'SERVER_NAME' and 'SERVER_PORT' are set.
# 'SERVER_NAME' and 'SERVER_PORT' are required so
@@ -109,6 +109,7 @@ class WebSocketProtocol(Protocol):
# Create the request for the handlers for this connection.
self.request = HttpRequest()
self.request.user = self.user
+ self.request.session = self.session
self.request.META.update(
{
"HTTP_USER_AGENT": self.transport.user_agent,
@@ -152,17 +153,17 @@ class WebSocketProtocol(Protocol):
@synchronous
@transactional
- def getUserFromSessionId(self, session_id):
- """Return the user from `session_id`."""
+ def get_user_and_session(self, session_id):
+ """Return the user and its session from the session ID."""
session_engine = self.factory.getSessionEngine()
- session_wrapper = session_engine.SessionStore(session_key=session_id)
- user_id = session_wrapper.get(SESSION_KEY)
- backend = session_wrapper.get(BACKEND_SESSION_KEY)
+ session = session_engine.SessionStore(session_key=session_id)
+ backend = session.get(BACKEND_SESSION_KEY)
if backend is None:
return None
auth_backend = load_backend(backend)
+ user_id = session.get(SESSION_KEY)
if user_id is not None and auth_backend is not None:
- return auth_backend.get_user(user_id)
+ return auth_backend.get_user(user_id), session
return None
@@ -173,8 +174,7 @@ class WebSocketProtocol(Protocol):
- Check that the CSRF token is valid.
- Authenticate the user using the session id.
- This returns the authenticated user or ``None``. The latter means that
- the connection is being dropped, and that processing should cease.
+ It returns whether authentication succeeded.
"""
# Check the CSRF token.
tokens = parse_qs(urlparse(self.transport.uri).query).get(b"csrftoken")
@@ -185,24 +185,31 @@ class WebSocketProtocol(Protocol):
if tokens is None or csrftoken not in tokens:
# No csrftoken in the request or the token does not match.
self.loseConnection(STATUSES.PROTOCOL_ERROR, "Invalid CSRF token.")
- return None
+ returnValue(False)
+ return
try:
- user = yield deferToDatabase(self.getUserFromSessionId, session_id)
+ result = yield deferToDatabase(
+ self.get_user_and_session, session_id
+ )
except Exception as error:
self.loseConnection(
STATUSES.PROTOCOL_ERROR, f"Error authenticating user: {error}"
)
- returnValue(None)
+ returnValue(False)
return
+ if result:
+ self.user, self.session = result
+ else:
+ self.user = self.session = None
- if user is None or user.id is None:
+ if self.user is None or self.user.id is None:
self.loseConnection(
STATUSES.PROTOCOL_ERROR, "Failed to authenticate user."
)
- returnValue(None)
+ returnValue(False)
else:
- returnValue(user)
+ returnValue(True)
def dataReceived(self, data):
"""Received message from client and queue up the message."""
diff --git a/src/maasserver/websockets/tests/test_protocol.py b/src/maasserver/websockets/tests/test_protocol.py
index 02348c7..92dda5a 100644
--- a/src/maasserver/websockets/tests/test_protocol.py
+++ b/src/maasserver/websockets/tests/test_protocol.py
@@ -133,15 +133,14 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
self.assertEqual(protocol.request.META["SERVER_NAME"], "localhost")
self.assertEqual(protocol.request.META["SERVER_PORT"], 5248)
- def test_connectionMade_sets_user_and_processes_messages(self):
+ def test_connectionMade_processes_messages(self):
protocol, factory = self.make_protocol(patch_authenticate=False)
self.patch_autospec(protocol, "authenticate")
self.patch_autospec(protocol, "processMessages")
- protocol.authenticate.return_value = defer.succeed(sentinel.user)
+ protocol.authenticate.return_value = defer.succeed(True)
protocol.connectionMade()
self.addCleanup(protocol.connectionLost, "")
- self.assertIs(protocol.user, sentinel.user)
- self.assertThat(protocol.processMessages, MockCalledOnceWith())
+ protocol.processMessages.assert_called_once_with()
def test_connectionMade_adds_self_to_factory_if_auth_succeeds(self):
protocol, factory = self.make_protocol()
@@ -249,19 +248,20 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
@wait_for_reactor
@inlineCallbacks
- def test_getUserFromSessionId_returns_User(self):
+ def test_get_user_and_session_returns_user_and_session(self):
user, session_id = yield deferToDatabase(self.get_user_and_session_id)
protocol, _ = self.make_protocol()
- protocol_user = yield deferToDatabase(
- lambda: protocol.getUserFromSessionId(session_id)
+ protocol_user, session = yield deferToDatabase(
+ lambda: protocol.get_user_and_session(session_id)
)
self.assertEqual(user, protocol_user)
+ self.assertEqual(session.session_key, session_id)
- def test_getUserFromSessionId_returns_None_for_invalid_key(self):
+ def test_get_user_and_session_returns_None_for_invalid_key(self):
self.client.login(user=maas_factory.make_User())
session_id = maas_factory.make_name("sessionid")
protocol, _ = self.make_protocol()
- self.assertIsNone(protocol.getUserFromSessionId(session_id))
+ self.assertIsNone(protocol.get_user_and_session(session_id))
@wait_for_reactor
@inlineCallbacks
@@ -272,19 +272,16 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
patch_authenticate=False, transport_uri=uri
)
mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
- mock_getUserFromSessionId = self.patch_autospec(
- protocol, "getUserFromSessionId"
+ mock_get_user_and_session = self.patch_autospec(
+ protocol, "get_user_and_session"
)
- mock_getUserFromSessionId.return_value = None
+ mock_get_user_and_session.return_value = None
yield protocol.authenticate(
maas_factory.make_name("sessionid"), csrftoken
)
- self.expectThat(
- mock_loseConnection,
- MockCalledOnceWith(
- STATUSES.PROTOCOL_ERROR, "Failed to authenticate user."
- ),
+ mock_loseConnection.assert_called_once_with(
+ STATUSES.PROTOCOL_ERROR, "Failed to authenticate user."
)
@wait_for_reactor
@@ -296,28 +293,25 @@ class TestWebSocketProtocol(MAASTransactionServerTestCase):
patch_authenticate=False, transport_uri=uri
)
mock_loseConnection = self.patch_autospec(protocol, "loseConnection")
- mock_getUserFromSessionId = self.patch_autospec(
- protocol, "getUserFromSessionId"
+ mock_get_user_and_session = self.patch_autospec(
+ protocol, "get_user_and_session"
)
- mock_getUserFromSessionId.side_effect = maas_factory.make_exception(
+ mock_get_user_and_session.side_effect = maas_factory.make_exception(
"unknown reason"
)
yield protocol.authenticate(
maas_factory.make_name("sessionid"), csrftoken
)
- self.expectThat(
- mock_loseConnection,
- MockCalledOnceWith(
- STATUSES.PROTOCOL_ERROR,
- "Error authenticating user: unknown reason",
- ),
+ mock_loseConnection.assert_called_once_with(
+ STATUSES.PROTOCOL_ERROR,
+ "Error authenticating user: unknown reason",
)
@wait_for_reactor
@inlineCallbacks
def test_authenticate_calls_loseConnection_if_invalid_csrftoken(self):
- user, session_id = yield deferToDatabase(self.get_user_and_session_id)
+ _, session_id = yield deferToDatabase(self.get_user_and_session_id)
csrftoken = maas_factory.make_name("csrftoken")
uri = self.make_ws_uri(csrftoken)
protocol, _ = self.make_protocol(
@@ -755,8 +749,12 @@ class MakeProtocolFactoryMixin:
protocol.transport.cookies = b""
if user is None:
user = maas_factory.make_User()
- mock_authenticate = self.patch(protocol, "authenticate")
- mock_authenticate.return_value = defer.succeed(user)
+
+ def authenticate(*args):
+ protocol.user = user
+ return defer.succeed(True)
+
+ self.patch(protocol, "authenticate", authenticate)
protocol.connectionMade()
self.addCleanup(lambda: protocol.connectionLost(""))
return protocol, factory
Follow ups