← Back to team overview

launchpad-reviewers team mailing list archive

[Merge] ~pappacena/turnip:ssh-master-connection-fix into turnip:master

 

Thiago F. Pappacena has proposed merging ~pappacena/turnip:ssh-master-connection-fix into turnip:master.

Commit message:
Fix for closed backend connections when using parallel SSH connections with ControlMaster/ControlPath

Requested reviews:
  Launchpad code reviewers (launchpad-reviewers)

For more details, see:
https://code.launchpad.net/~pappacena/turnip/+git/turnip/+merge/391219
-- 
Your team Launchpad code reviewers is requested to review the proposed merge of ~pappacena/turnip:ssh-master-connection-fix into turnip:master.
diff --git a/turnip/pack/ssh.py b/turnip/pack/ssh.py
index 8e87aa8..ec0d54f 100644
--- a/turnip/pack/ssh.py
+++ b/turnip/pack/ssh.py
@@ -129,7 +129,15 @@ class SmartSSHSession(DoNothingSession):
 
     def __init__(self, *args, **kwargs):
         super(SmartSSHSession, self).__init__(*args, **kwargs)
-        self.pack_protocol = None
+        # There is usually only one pack_protocol call per connection,
+        # but when SSH's ControlMaster option is used, there will be only
+        # one SmartSSHSession holding several calls to the backend.
+        # So, in this set, we keep track of all pairs of
+        # (twisted.conch.ssh.session.SSHSessionProcessProtocol,
+        # SSHPackClientProtocol) to make sure we are closing backend
+        # connections when child connections are closed.
+        # See `SmartSSHSession.closed` method.
+        self.pack_protocols = set()
         self.env = {}
 
     def setEnv(self, name, value):
@@ -165,7 +173,8 @@ class SmartSSHSession(DoNothingSession):
         service = self.avatar.service
         conn = reactor.connectTCP(
             service.backend_host, service.backend_port, client_factory)
-        self.pack_protocol = yield d
+        pack_protocol = yield d
+        self.pack_protocols.add((ssh_protocol, pack_protocol))
         ssh_protocol.makeConnection(conn.transport)
 
     def execCommand(self, protocol, command):
@@ -193,13 +202,27 @@ class SmartSSHSession(DoNothingSession):
             self.errorWithMessage(protocol, str(e).encode("UTF-8"))
 
     def closed(self):
-        if self.pack_protocol is not None:
-            self.pack_protocol.transport.loseConnection()
+        """Called when a session closes its connection.
+
+        Please note that master connections (using ControlMaster/ControlPath)
+        causes this method to be called for each child connection on the
+        same session,  so we need to keep track of all the pack protocols
+        created and their corresponding ssh protocols.
+        """
+        pairs_to_remove = []
+        for proto_pair in self.pack_protocols.copy():
+            ssh_proto, pack_proto = proto_pair
+            if not ssh_proto.transport.connected:
+                pairs_to_remove.append(proto_pair)
+                pack_proto.transport.loseConnection()
+                continue
+        for pair in pairs_to_remove:
+            self.pack_protocols.remove(pair)
 
     def eofReceived(self):
-        if (self.pack_protocol is not None and
-                self.pack_protocol.transport.connected):
-            self.pack_protocol.transport.loseWriteConnection()
+        for ssh_proto, pack_proto in self.pack_protocols:
+            if pack_proto.transport.connected:
+                pack_proto.transport.loseWriteConnection()
 
 
 class SmartSSHAvatar(LaunchpadAvatar):
diff --git a/turnip/pack/tests/test_ssh.py b/turnip/pack/tests/test_ssh.py
index 12f5dfb..052b5a8 100644
--- a/turnip/pack/tests/test_ssh.py
+++ b/turnip/pack/tests/test_ssh.py
@@ -9,6 +9,7 @@ from __future__ import (
 
 from testtools import TestCase
 
+from turnip.tests.compat import mock
 from turnip.pack.ssh import SmartSSHSession
 
 
@@ -27,3 +28,38 @@ class TestSSHSessionProtocolVersion(TestCase):
         session = SmartSSHSession(None)
         session.setEnv('GIT_PROTOCOL', b'version=2')
         self.assertEqual(b'2', session.getProtocolVersion())
+
+
+class TestSmartSSHSessionMultipleConnectionsTracking(TestCase):
+    def test_closed_checks_all_connections(self):
+        # Still running.
+        running_pair = (
+            mock.Mock(transport=mock.Mock(connected=1)),
+            mock.Mock(transport=mock.Mock(connected=1)))
+        # Connected SSH, but no backend connection anymore.
+        active_ssh = (
+            mock.Mock(transport=mock.Mock(connected=1)),
+            mock.Mock(transport=mock.Mock(connected=0)))
+        # Connected to the backend, but SSH connection is gone.
+        active_backend = (
+            mock.Mock(transport=mock.Mock(connected=0)),
+            mock.Mock(transport=mock.Mock(connected=1)))
+        # All connections are done.
+        done_pair = (
+            mock.Mock(transport=mock.Mock(connected=0)),
+            mock.Mock(transport=mock.Mock(connected=0)))
+
+        avatar = mock.Mock()
+        session = SmartSSHSession(avatar)
+        session.pack_protocols = {
+            running_pair, active_ssh, active_backend, done_pair}
+        session.closed()
+        self.assertEqual({running_pair, active_ssh}, session.pack_protocols)
+        self.assertEqual(
+            0, running_pair[1].transport.loseConnection.call_count)
+        self.assertEqual(
+            0, active_ssh[1].transport.loseConnection.call_count)
+        self.assertEqual(
+            1, active_backend[1].transport.loseConnection.call_count)
+        self.assertEqual(
+            1, done_pair[1].transport.loseConnection.call_count)