← Back to team overview

txaws-dev team mailing list archive

[Merge] lp:~free.ekanayaka/txaws/use-port-in-signature into lp:txaws

 

Free Ekanayaka has proposed merging lp:~free.ekanayaka/txaws/use-port-in-signature into lp:txaws.

Requested reviews:
  txAWS Developers (txaws-dev)
Related bugs:
  Bug #781540 in txAWS: "The EC2 Query doesn't take the endpoint port into account when signing"
  https://bugs.launchpad.net/txaws/+bug/781540

For more details, see:
https://code.launchpad.net/~free.ekanayaka/txaws/use-port-in-signature/+merge/60739

This branch modifies txaws.ec2.client.Query so that:

- it considers the endpoint port number when signing the text
- it considers the endpoint port number when setting the Host HTTP header

-- 
https://code.launchpad.net/~free.ekanayaka/txaws/use-port-in-signature/+merge/60739
Your team txAWS Developers is requested to review the proposed merge of lp:~free.ekanayaka/txaws/use-port-in-signature into lp:txaws.
=== modified file 'txaws/ec2/client.py'
--- txaws/ec2/client.py	2011-04-26 16:28:34 +0000
+++ txaws/ec2/client.py	2011-05-12 08:29:15 +0000
@@ -922,7 +922,8 @@
 
     def signing_text(self):
         """Return the text to be signed when signing the query."""
-        result = "%s\n%s\n%s\n%s" % (self.endpoint.method, self.endpoint.host,
+        result = "%s\n%s\n%s\n%s" % (self.endpoint.method,
+                                     self.endpoint.get_canonical_host(),
                                      self.endpoint.path,
                                      self.get_canonical_query_params())
         return result
@@ -969,13 +970,17 @@
         url = self.endpoint.get_uri()
         method = self.endpoint.method
         params = self.get_canonical_query_params()
+        headers = {}
         kwargs = {"method": method}
         if method == "POST":
-            kwargs["headers"] = {
-                "Content-Type": "application/x-www-form-urlencoded"}
+            headers["Content-Type"] = "application/x-www-form-urlencoded"
             kwargs["postdata"] = params
         else:
             url += "?%s" % params
+        if self.endpoint.get_host() != self.endpoint.get_canonical_host():
+            headers["Host"] = self.endpoint.get_canonical_host()
+        if headers:
+            kwargs["headers"] = headers
         if self.timeout:
             kwargs["timeout"] = self.timeout
         d = self.get_page(url, **kwargs)

=== modified file 'txaws/ec2/tests/test_client.py'
--- txaws/ec2/tests/test_client.py	2011-04-21 21:16:37 +0000
+++ txaws/ec2/tests/test_client.py	2011-05-12 08:29:15 +0000
@@ -1608,6 +1608,25 @@
             "Timestamp=2007-11-12T13%3A14%3A15Z&Version=2008-12-01")
         self.assertEqual(signing_text, query.signing_text())
 
+    def test_signing_text_with_non_default_port(self):
+        """
+        The signing text uses the canonical host name, which includes
+        the port number, if it differs from the default one.
+        """
+        endpoint = AWSServiceEndpoint(uri="http://example.com:99/path";)
+        query = client.Query(
+            action="DescribeInstances", creds=self.creds, endpoint=endpoint,
+            time_tuple=(2007, 11, 12, 13, 14, 15, 0, 0, 0))
+        signing_text = ("GET\n"
+                        "example.com:99\n"
+                        "/path\n"
+                        "AWSAccessKeyId=foo&"
+                        "Action=DescribeInstances&"
+                        "SignatureVersion=2&"
+                        "Timestamp=2007-11-12T13%3A14%3A15Z&"
+                        "Version=2008-12-01")
+        self.assertEqual(signing_text, query.signing_text())
+
     def test_old_signing_text(self):
         query = client.Query(
             action="DescribeInstances", creds=self.creds,
@@ -1646,6 +1665,26 @@
             other_params={"SignatureVersion": "0"})
         self.assertRaises(RuntimeError, query.sign)
 
+    def test_submit_with_port(self):
+        """
+        If the endpoint port differs from the default one, the Host header
+        of the request will include it.
+        """
+        self.addCleanup(setattr, client.Query, "get_page",
+                        client.Query.get_page)
+
+        def get_page(query, url, **kwargs):
+            self.assertEqual("example.com:99", kwargs["headers"]["Host"])
+            return succeed(None)
+
+        client.Query.get_page = get_page
+        endpoint = AWSServiceEndpoint(uri="http://example.com:99/foo";)
+        query = client.Query(action="SomeQuery", creds=self.creds,
+                             endpoint=endpoint)
+
+        d = query.submit()
+        return d
+
     def test_submit_400(self):
         """A 4xx response status from EC2 should raise a txAWS EC2Error."""
         status = 400

=== modified file 'txaws/service.py'
--- txaws/service.py	2010-10-25 21:06:51 +0000
+++ txaws/service.py	2011-05-12 08:29:15 +0000
@@ -45,15 +45,24 @@
     def get_host(self):
         return self.host
 
+    def get_canonical_host(self):
+        """Return the canonical host as for the Host HTTP header specification.
+
+        If the port is different from the default one, it will be appended to
+        the host name.
+        """
+        host = self.host
+        if self.port and self.port != DEFAULT_PORT:
+            host += ":%s" % self.port
+        return host
+
     def set_path(self, path):
         self.path = path
 
     def get_uri(self):
         """Get a URL representation of the service."""
-        uri = "%s://%s" % (self.scheme, self.host)
-        if self.port and self.port != DEFAULT_PORT:
-            uri = "%s:%s" % (uri, self.port)
-        return uri + self.path
+        uri = "%s://%s%s" % (self.scheme, self.get_canonical_host(), self.path)
+        return uri
 
     def set_method(self, method):
         self.method = method

=== modified file 'txaws/tests/test_service.py'
--- txaws/tests/test_service.py	2011-03-26 10:48:22 +0000
+++ txaws/tests/test_service.py	2011-05-12 08:29:15 +0000
@@ -65,6 +65,32 @@
     def test_get_host(self):
         self.assertEquals(self.endpoint.host, self.endpoint.get_host())
 
+    def test_get_canonical_host(self):
+        """
+        If the port is not specified the canonical host is the same as
+        the host.
+        """
+        uri = "http://my.service/endpoint";
+        endpoint = AWSServiceEndpoint(uri=uri)
+        self.assertEquals("my.service", endpoint.get_canonical_host())
+
+    def test_get_canonical_host_with_default_port(self):
+        """
+        If the port is the default one, the canonical host is the same as
+        the host.
+        """
+        uri = "http://my.service:80/endpoint";
+        endpoint = AWSServiceEndpoint(uri=uri)
+        self.assertEquals("my.service", endpoint.get_canonical_host())
+
+    def test_get_canonical_host_with_non_default_port(self):
+        """
+        If the port is not the default, the canonical host includes it.
+        """
+        uri = "http://my.service:99/endpoint";
+        endpoint = AWSServiceEndpoint(uri=uri)
+        self.assertEquals("my.service:99", endpoint.get_canonical_host())
+
     def test_set_path(self):
         self.endpoint.set_path("/newpath")
         self.assertEquals(