txaws-dev team mailing list archive
-
txaws-dev team
-
Mailing list archive
-
Message #00035
[Merge] lp:~free.ekanayaka/txaws/query-api into lp:txaws
Free Ekanayaka has proposed merging lp:~free.ekanayaka/txaws/query-api into lp:txaws.
Requested reviews:
txAWS Developers (txaws-dev)
Related bugs:
Bug #782546 in txAWS: "Add support for web resources using the Query API authorization model"
https://bugs.launchpad.net/txaws/+bug/782546
For more details, see:
https://code.launchpad.net/~free.ekanayaka/txaws/query-api/+merge/60985
This branch adds a new txaws.server package, sporting the following classes:
- resource.QueryAPI: a base class that can be used to implement EC2-like APIs
- schema.Schema: a schema class that can be used to specify and parse the parameters of an EC2-like HTTP request
--
https://code.launchpad.net/~free.ekanayaka/txaws/query-api/+merge/60985
Your team txAWS Developers is requested to review the proposed merge of lp:~free.ekanayaka/txaws/query-api into lp:txaws.
=== added directory 'txaws/server'
=== added file 'txaws/server/__init__.py'
=== added file 'txaws/server/call.py'
--- txaws/server/call.py 1970-01-01 00:00:00 +0000
+++ txaws/server/call.py 2011-05-14 08:15:55 +0000
@@ -0,0 +1,58 @@
+from uuid import uuid4
+
+from txaws.version import ec2_api as ec2_api_version
+from txaws.server.exception import APIError
+
+
+class Call(object):
+ """Hold information about a single API call initiated by an HTTP request.
+
+ @param params: The raw parameters for the action to be executed, the
+ format is a dictionary mapping parameter names to parameter
+ values, like C{{'ParamName': param_value}}.
+ @param user: The L{User} issuing this API L{Call}.
+ @param action: The action to be performed.
+
+ @ivar id: A unique identifier for the API call.
+ @ivar principal: The principal performing the call.
+ @ivar args: An L{Arguments} object holding parameters extracted from the
+ raw parameters according to a L{Schema}.
+ @ivar rest: Extra parameters not included in the given arguments schema,
+ see L{parse}.
+ @ivar version: The version of the API call. Defaults to 2008-12-01.
+ """
+
+ def __init__(self, raw_params=None, principal=None, action=None,
+ version=None, id=None):
+ if id is None:
+ id = str(uuid4())
+ self.id = id
+ self._raw_params = {}
+ if raw_params is not None:
+ self._raw_params.update(raw_params)
+ self.action = action
+ if version is None:
+ version = ec2_api_version
+ self.version = version
+ self.principal = principal
+
+ def parse(self, schema, strict=True):
+ """Update our C{args} parsing values from the raw request arguments.
+
+ @param schema: The L{Schema} the parameters must be extracted with.
+ @param strict: If C{True} an error is raised if parameters not included
+ in the schema are found, otherwise the extra parameters will be
+ saved in the C{rest} attribute.
+ """
+ self.args, self.rest = schema.extract(self._raw_params)
+ if strict and self.rest:
+ raise APIError(400, "UnknownParameter",
+ "The parameter %s is not "
+ "recognized" % self.rest.keys()[0])
+
+ def get_raw_params(self):
+ """Return a C{dict} holding the raw API call paramaters.
+
+ The format of the dictionary is C{{'ParamName': param_value}}.
+ """
+ return self._raw_params.copy()
=== added file 'txaws/server/exception.py'
--- txaws/server/exception.py 1970-01-01 00:00:00 +0000
+++ txaws/server/exception.py 2011-05-14 08:15:55 +0000
@@ -0,0 +1,25 @@
+class APIError(Exception):
+ """Raised while handling an API request.
+
+ @param status: The HTTP status code the response will be set to.
+ @param code: A machine-parsable textual code for the error.
+ @param message: A human-readable description of the error.
+ @param response: The full body of the response to be sent to the client,
+ if C{None} it will be generated from C{code} and C{message}. See
+ also L{API.dump_error}.
+ """
+
+ def __init__(self, status, code=None, message=None, response=None):
+ super(APIError, self).__init__(message)
+ self.status = int(status)
+ self.code = code
+ self.message = message
+ self.response = response
+ if self.response is None:
+ if self.code is None or self.message is None:
+ raise RuntimeError("If the response is not specified, code "
+ "and status must both be set.")
+ else:
+ if self.code is not None or self.message is not None:
+ raise RuntimeError("If the full response payload is passed, "
+ "code and message must not be set.")
=== added file 'txaws/server/resource.py'
--- txaws/server/resource.py 1970-01-01 00:00:00 +0000
+++ txaws/server/resource.py 2011-05-14 08:15:55 +0000
@@ -0,0 +1,219 @@
+from datetime import datetime, timedelta
+from uuid import uuid4
+from urlparse import urljoin
+from pytz import UTC
+
+from twisted.python import log
+from twisted.internet.defer import maybeDeferred
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+from txaws.ec2.client import Signature
+from txaws.service import AWSServiceEndpoint
+from txaws.credentials import AWSCredentials
+from txaws.server.schema import (
+ Schema, Unicode, Integer, Enum, RawStr, Date)
+from txaws.server.exception import APIError
+from txaws.server.call import Call
+
+
+class QueryAPI(Resource):
+ """Base class for EC2-like query APIs.
+
+ The following class variables must be defined by sub-classes:
+
+ @ivar actions: The actions that the API supports. The 'Action' field of
+ the request must contain one of these.
+ @ivar signature_versions: A list of allowed values for 'SignatureVersion'.
+ @cvar content_type: The content type to set the 'Content-Type' header to.
+ """
+ isLeaf = True
+ time_format = "%Y-%m-%dT%H:%M:%SZ"
+
+ schema = Schema(
+ Unicode("Action"),
+ RawStr("AWSAccessKeyId"),
+ Date("Timestamp", optional=True),
+ Date("Expires", optional=True),
+ Unicode("Version", optional=True),
+ Enum("SignatureMethod", {"HmacSHA256": "sha256", "HmacSHA1": "sha1"},
+ optional=True, default="HmacSHA256"),
+ Unicode("Signature"),
+ Integer("SignatureVersion", optional=True, default=2))
+
+ def get_principal(self, access_key):
+ """Return a principal object by access key.
+
+ The returned object must have C{access_key} and C{secret_key}
+ attributes and if the authentication succeeds, it will be
+ passed to the created L{Call}.
+ """
+ raise NotImplemented("Must be implemented by subclasses")
+
+ def handle(self, request):
+ """Handle an HTTP request for executing an API call.
+
+ This method authenticates the request checking its signature, and then
+ calls the C{execute} method, passing it a L{Call} object set with the
+ principal for the authenticated user and the generic parameters
+ extracted from the request.
+
+ @param request: The L{HTTPRequest} to handle.
+ """
+ request.id = str(uuid4())
+ deferred = maybeDeferred(self._validate, request)
+ deferred.addCallback(self.execute)
+
+ def write_response(response):
+ request.setHeader("Content-Length", str(len(response)))
+ request.setHeader("Content-Type", self.content_type)
+ request.write(response)
+ request.finish()
+ return response
+
+ def write_error(failure):
+ log.err(failure)
+ if failure.check(APIError):
+ status = failure.value.status
+ bytes = failure.value.response
+ if bytes is None:
+ bytes = self.dump_error(failure.value, request)
+ else:
+ bytes = str(failure.value)
+ status = 500
+ request.setResponseCode(status)
+ request.write(bytes)
+ request.finish()
+
+ deferred.addCallback(write_response)
+ deferred.addErrback(write_error)
+ return deferred
+
+ def dump_error(self, error, request):
+ """Serialize an error generating the response to send to the client.
+
+ @param error: The L{APIError} to format.
+ @param request: The request that generated the error.
+ """
+ raise NotImplementedError("Must be implemented by subclass.")
+
+ def execute(self, call):
+ """Execute an API L{Call}.
+
+ At this point the request has been authenticated and C{call.principal}
+ is set with the L{Principal} for the L{User} requesting the call.
+
+ @return: The response to write in the request for the given L{Call}.
+ @raises: An L{APIError} in case the execution fail, sporting an error
+ message the HTTP status code to return.
+ """
+ raise NotImplementedError()
+
+ def get_utc_time(self):
+ """Return a C{datetime} object with the current time in UTC."""
+ return datetime.now(UTC)
+
+ def _validate(self, request):
+ """Validate an L{HTTPRequest} before executing it.
+
+ The following conditions are checked:
+
+ - The request contains all the generic parameters.
+ - The action specified in the request is a supported one.
+ - The signature mechanism is a supported one.
+ - The provided signature matches the one calculated using the locally
+ stored secret access key for the user.
+ - The signature hasn't expired.
+
+ @return: The validated L{Call}, set with it default arguments and the
+ the L{Principal} of the accessing L{User}.
+ """
+ params = dict((k, v[-1]) for k, v in request.args.iteritems())
+ args, rest = self.schema.extract(params)
+
+ self._validate_generic_parameters(args, self.get_utc_time())
+
+ def create_call(principal):
+ self._validate_principal(principal, args)
+ self._validate_signature(request, principal, args, params)
+ return Call(raw_params=rest,
+ principal=principal,
+ action=args.Action,
+ version=args.Version,
+ id=request.id)
+
+ deferred = maybeDeferred(self.get_principal, args.AWSAccessKeyId)
+ deferred.addCallback(create_call)
+ return deferred
+
+ def _validate_generic_parameters(self, args, utc_now):
+ """Validate the generic request parameters.
+
+ @param args: Parsed schema arguments.
+ @param utc_now: The current UTC time in datetime format.
+ @raises APIError: In the following cases:
+ - Action is not included in C{self.actions}
+ - SignatureVersion is not included in C{self.signature_versions}
+ - Expires and Timestamp are present
+ - Expires is before the current time
+ - Timestamp is older than 15 minutes.
+ """
+ if not args.Action in self.actions:
+ raise APIError(400, "InvalidAction", "The action %s is not valid "
+ "for this web service." % args.Action)
+
+ if not args.SignatureVersion in self.signature_versions:
+ raise APIError(403, "InvalidSignature", "SignatureVersion '%s' "
+ "not supported" % args.SignatureVersion)
+
+ if args.Expires and args.Timestamp:
+ raise APIError(400, "InvalidParameterCombination",
+ "The parameter Timestamp cannot be used with "
+ "the parameter Expires")
+ if args.Expires and args.Expires < utc_now:
+ raise APIError(400,
+ "RequestExpired",
+ "Request has expired. Expires date is %s" % (
+ args.Expires.strftime(self.time_format)))
+ if args.Timestamp and args.Timestamp + timedelta(minutes=15) < utc_now:
+ raise APIError(400,
+ "RequestExpired",
+ "Request has expired. Timestamp date is %s" % (
+ args.Timestamp.strftime(self.time_format)))
+
+ def _validate_principal(self, principal, args):
+ """Validate the principal."""
+ if principal is None:
+ raise APIError(401, "AuthFailure",
+ "No user with access key '%s'" %
+ args.AWSAccessKeyId)
+
+ def _validate_signature(self, request, principal, args, params):
+ """Validate the signature."""
+ creds = AWSCredentials(principal.access_key, principal.secret_key)
+ endpoint = AWSServiceEndpoint()
+ endpoint.set_method(request.method)
+ endpoint.set_canonical_host(request.getHeader("Host"))
+ endpoint.set_path(request.path)
+ params.pop("Signature")
+ signature = Signature(creds, endpoint, params)
+ if signature.compute() != args.Signature:
+ raise APIError(403, "SignatureDoesNotMatch",
+ "The request signature we calculated does not "
+ "match the signature you provided. Check your "
+ "key and signing method.")
+
+ def get_status_text(self):
+ """Get the text to return when a status check is made."""
+ return "Query API Service"
+
+ def render_GET(self, request):
+ """Handle a GET request."""
+ if not request.args:
+ request.setHeader("Content-Type", "text/plain")
+ return self.get_status_text()
+ else:
+ self.handle(request)
+ return NOT_DONE_YET
+
+ render_POST = render_GET
=== added file 'txaws/server/schema.py'
--- txaws/server/schema.py 1970-01-01 00:00:00 +0000
+++ txaws/server/schema.py 2011-05-14 08:15:55 +0000
@@ -0,0 +1,490 @@
+from datetime import datetime
+from operator import itemgetter
+
+from pytz import UTC
+
+from zope.datetime import parse, SyntaxError
+
+from txaws.server.exception import APIError
+
+
+class SchemaError(APIError):
+ """Raised when failing to extract or bundle L{Parameter}s."""
+
+ def __init__(self, message):
+ code = self.__class__.__name__[:-len("Error")]
+ super(SchemaError, self).__init__(400, code=code, message=message)
+
+
+class MissingParameterError(SchemaError):
+ """Raised when a parameter is missing.
+
+ @param name: The name of the missing parameter.
+ """
+
+ def __init__(self, name):
+ message = "The request must contain the parameter %s" % name
+ super(MissingParameterError, self).__init__(message)
+
+
+class InvalidParameterValueError(SchemaError):
+ """Raised when the value of a parameter is invalid."""
+
+
+class InvalidParameterCombinationError(SchemaError):
+ """
+ Raised when there is more than one parameter with the same name,
+ when this isn't explicitly allowed for.
+
+ @param name: The name of the missing parameter.
+ """
+
+ def __init__(self, name):
+ message = "The parameter '%s' may only be specified once." % name
+ super(InvalidParameterCombinationError, self).__init__(message)
+
+
+class UnknownParameterError(SchemaError):
+ """Raised when a parameter to extract is unknown."""
+
+ def __init__(self, name):
+ message = "The parameter %s is not recognized" % name
+ super(UnknownParameterError, self).__init__(message)
+
+
+class Parameter(object):
+ """A single parameter in an HTTP request.
+
+ @param name: A name for the key of the parameter, as specified
+ in a request. For example, a single parameter would be specified
+ simply as 'GroupName'. If more than one group name was accepted,
+ it would be specified as 'GroupName.n'. A more complex example
+ is 'IpPermissions.n.Groups.m.GroupName'.
+ @param optional: If C{True} the parameter may not be present.
+ @param default: A default value for the parameter, if not present.
+ @param min: Minimum value for a parameter.
+ @param max: Maximum value for a parameter.
+ @param allow_none: Whether the parameter may be C{None}.
+ """
+
+ def __init__(self, name, optional=False, default=None,
+ min=None, max=None, allow_none=False):
+ self.name = name
+ self.optional = optional
+ self.default = default
+ self.min = min
+ self.max = max
+ self.allow_none = allow_none
+
+ def coerce(self, value):
+ """Coerce a single value according to this parameter's settings.
+
+ @param value: A L{str}, or L{None}. If L{None} is passed - meaning no
+ value is avalable at all, not even the empty string - and this
+ parameter is optional, L{self.default} will be returned.
+ """
+ if value is None:
+ if self.optional:
+ return self.default
+ else:
+ value = ""
+ if value == "":
+ if not self.allow_none:
+ raise MissingParameterError(self.name)
+ return self.default
+ self._check_range(value)
+ try:
+ return self.parse(value)
+ except ValueError:
+ raise InvalidParameterValueError("Invalid %s value %s" %
+ (self.kind, value))
+
+ def _check_range(self, value):
+ """Check that the given C{value} is in the expected range."""
+ if self.min is None and self.max is None:
+ return
+
+ measure = self.measure(value)
+ prefix = "Value (%s) for parameter %s is invalid. %s"
+
+ if self.min is not None and measure < self.min:
+ message = prefix % (value, self.name,
+ self.lower_than_min_template % self.min)
+ raise InvalidParameterValueError(message)
+
+ if self.max is not None and measure > self.max:
+ message = prefix % (value, self.name,
+ self.greater_than_max_template % self.max)
+ raise InvalidParameterValueError(message)
+
+ def parse(self, value):
+ """
+ Parse a single parameter value coverting it to the appropriate type.
+ """
+ raise NotImplementedError()
+
+ def format(self, value):
+ """
+ Format a single parameter value in a way suitable for an HTTP request.
+ """
+ raise NotImplementedError()
+
+ def measure(self, value):
+ """
+ Return an C{int} providing a measure for C{value}, used for C{range}.
+ """
+ raise NotImplementedError()
+
+
+class Unicode(Parameter):
+ """A parameter that must be a C{unicode}."""
+
+ kind = "unicode"
+
+ lower_than_min_template = "Length must be at least %s."
+ greater_than_max_template = "Length exceeds maximum of %s."
+
+ def parse(self, value):
+ return value.decode("utf-8")
+
+ def format(self, value):
+ return value.encode("utf-8")
+
+ def measure(self, value):
+ return len(value)
+
+
+class RawStr(Parameter):
+ """A parameter that must be a C{str}."""
+
+ kind = "raw string"
+
+ def parse(self, value):
+ return str(value)
+
+ def format(self, value):
+ return value
+
+
+class Integer(Parameter):
+ """A parameter that must be a positive C{int}."""
+
+ kind = "integer"
+
+ def parse(self, value):
+ number = int(value)
+ if number < 0:
+ raise ValueError()
+ return number
+
+ def format(self, value):
+ return str(value)
+
+
+class Bool(Parameter):
+ """A parameter that must be a C{bool}."""
+
+ kind = "boolean"
+
+ def parse(self, value):
+ if value == "true":
+ return True
+ if value == "false":
+ return False
+ raise ValueError()
+
+ def format(self, value):
+ if value:
+ return "true"
+ else:
+ return "false"
+
+
+class Enum(Parameter):
+ """A parameter with enumerated values.
+
+ @param name: The name of the parameter, as specified in a request.
+ @param optional: If C{True} the parameter may not be present.
+ @param default: A default value for the parameter, if not present.
+ @param mapping: A mapping of accepted values to the values that
+ will be returned by C{parse}.
+ """
+
+ kind = "enum"
+
+ def __init__(self, name, mapping, optional=False, default=None):
+ super(Enum, self).__init__(name, optional=optional, default=default)
+ self.mapping = mapping
+ self.reverse = dict((value, key) for key, value in mapping.iteritems())
+
+ def parse(self, value):
+ try:
+ return self.mapping[value]
+ except KeyError:
+ raise ValueError()
+
+ def format(self, value):
+ return self.reverse[value]
+
+
+class Date(Parameter):
+ """A parameter that must be a valid ISO 8601 formatted date."""
+
+ kind = "date"
+
+ def parse(self, value):
+ try:
+ return datetime(*parse(value, local=False)[:6], tzinfo=UTC)
+ except (TypeError, SyntaxError):
+ raise ValueError()
+
+ def format(self, value):
+ # Convert value to UTC.
+ tt = value.utctimetuple()
+ utc_value = datetime(
+ tt.tm_year, tt.tm_mon, tt.tm_mday, tt.tm_hour, tt.tm_min,
+ tt.tm_sec)
+ return datetime.strftime(utc_value, "%Y-%m-%dT%H:%M:%SZ")
+
+
+class Arguments(object):
+ """Arguments parsed from a request."""
+
+ def __init__(self, tree):
+ """Initialize a new L{Arguments} instance.
+
+ @param tree: The C{dict}-based structure of the L{Argument}instance
+ to create.
+ """
+ for key, value in tree.iteritems():
+ self.__dict__[key] = self._wrap(value)
+
+ def __iter__(self):
+ """Returns an iterator yielding C{(name, value)} tuples."""
+ return self.__dict__.iteritems()
+
+ def __getitem__(self, index):
+ """Return the argument value with the given L{index}."""
+ return self.__dict__[index]
+
+ def __len__(self):
+ """Return the number of arguments."""
+ return len(self.__dict__)
+
+ def _wrap(self, value):
+ """Wrap the given L{tree} with L{Arguments} as necessary.
+
+ @param tree: A {dict}, containing L{dict}s and/or leaf values, nested
+ arbitrarily deep.
+ """
+ if isinstance(value, dict):
+ if any(isinstance(name, int) for name in value.keys()):
+ if not all(isinstance(name, int) for name in value.keys()):
+ raise RuntimeError("Integer and non-integer keys: %r"
+ % value.keys())
+ items = sorted(value.iteritems(), key=itemgetter(0))
+ return [self._wrap(value) for (name, value) in items]
+ else:
+ return Arguments(value)
+ else:
+ return value
+
+
+class Schema(object):
+ """
+ The schema that the arguments of an HTTP request must be compliant with.
+ """
+
+ def __init__(self, *parameters):
+ """Initialize a new L{Schema} instance.
+
+ Any number of L{Parameter} instances can be passed. The parameter path
+ is used as the target in L{Schema.extract} and L{Schema.bundle}. For
+ example::
+
+ schema = Schema(Unicode('Name'))
+
+ means that the result of L{Schema.extract} would have a C{Name}
+ attribute. Similarly, L{Schema.bundle} would look for a C{Name}
+ attribute.
+
+ A more complex example::
+
+ schema = Schema(Unicode('Name.#'))
+
+ means that the result of L{Schema.extract} would have a C{Name}
+ attribute, which would itself contain a list of names. Similarly,
+ L{Schema.bundle} would look for a C{Name} attribute.
+ """
+ self._parameters = dict(
+ (self._get_template(parameter.name), parameter)
+ for parameter in parameters)
+
+ def extract(self, params):
+ """Extract parameters from a raw C{dict} according to this schema.
+
+ @param params: The raw parameters to parse.
+ @return: An L{Arguments} object holding the extracted arguments.
+
+ @raises UnknownParameterError: If C{params} contains keys that this
+ schema doesn't know about.
+ """
+ tree = {}
+ rest = {}
+
+ # Extract from the given arguments and parse according to the
+ # corresponding parameters.
+ for name, value in params.iteritems():
+ template = self._get_template(name)
+ parameter = self._parameters.get(template)
+
+ if template.endswith(".#") and parameter is None:
+ # If we were unable to find a direct match for a template that
+ # allows multiple values. Let's attempt to find it without the
+ # multiple value marker which Amazon allows. For example if the
+ # template is 'PublicIp', then a single key 'PublicIp.1' is
+ # allowed.
+ parameter = self._parameters.get(template[:-2])
+ if parameter is not None:
+ name = name[:-2]
+
+ # At this point, we have a template that doesn't have the .#
+ # marker to indicate multiple values. We don't allow multiple
+ # "single" values for the same element.
+ if name in tree.keys():
+ raise InvalidParameterCombinationError(name)
+
+ if parameter is None:
+ rest[name] = value
+ else:
+ self._set_value(tree, name, parameter.coerce(value))
+
+ # Ensure that the tree arguments are consistent with constraints
+ # defined in the schema.
+ for template, parameter in self._parameters.iteritems():
+ self._ensure_tree(tree, parameter, *template.split("."))
+
+ return Arguments(tree), rest
+
+ def bundle(self, *arguments, **extra):
+ """Bundle the given arguments in a C{dict} with EC2-style format.
+
+ @param arguments: L{Arguments} instances to bundle. Keys in
+ later objects will override those in earlier objects.
+ @param extra: Any number of additional parameters. These will override
+ similarly named arguments in L{arguments}.
+ """
+ params = {}
+
+ for argument in arguments:
+ self._flatten(params, argument)
+ self._flatten(params, extra)
+
+ for name, value in params.iteritems():
+ parameter = self._parameters.get(self._get_template(name))
+ if parameter is None:
+ raise RuntimeError("Parameter '%s' not in schema" % name)
+ else:
+ if value is None:
+ params[name] = ""
+ else:
+ params[name] = parameter.format(value)
+
+ return params
+
+ def _get_template(self, key):
+ """Return the canonical template for a given parameter key.
+
+ For example::
+
+ 'Child.1.Name.2'
+
+ becomes::
+
+ 'Child.#.Name.#'
+
+ """
+ parts = key.split(".")
+ for index, part in enumerate(parts[1::2]):
+ parts[index * 2 + 1] = "#"
+ return ".".join(parts)
+
+ def _set_value(self, tree, path, value):
+ """Set C{value} at C{path} in the given C{tree}.
+
+ For example::
+
+ tree = {}
+ _set_value(tree, 'foo.1.bar.2', True)
+
+ results in C{tree} becoming::
+
+ {'foo': {1: {'bar': {2: True}}}}
+
+ @param tree: A L{dict}.
+ @param path: A L{str}.
+ @param value: The value to set. Can be anything.
+ """
+ nodes = []
+ for index, node in enumerate(path.split(".")):
+ if index % 2:
+ # Nodes with odd indexes must be non-negative integers
+ try:
+ node = int(node)
+ except ValueError:
+ raise UnknownParameterError(path)
+ if node < 0:
+ raise UnknownParameterError(path)
+ nodes.append(node)
+ for node in nodes[:-1]:
+ tree = tree.setdefault(node, {})
+ tree[nodes[-1]] = value
+
+ def _ensure_tree(self, tree, parameter, node, *nodes):
+ """Check that C{node} exists in C{tree} and is followed by C{nodes}.
+
+ C{node} and C{nodes} should correspond to a template path (i.e. where
+ there are no absolute indexes, but C{#} instead).
+ """
+ if node == "#":
+ if len(nodes) == 0:
+ if len(tree.keys()) == 0 and not parameter.optional:
+ raise MissingParameterError(parameter.name)
+ else:
+ for subtree in tree.itervalues():
+ self._ensure_tree(subtree, parameter, *nodes)
+ else:
+ if len(nodes) == 0:
+ if node not in tree.keys():
+ # No value for this parameter is present, if it's not
+ # optional nor allow_none is set, the call below will
+ # raise a MissingParameterError
+ tree[node] = parameter.coerce(None)
+ else:
+ if node not in tree.keys():
+ tree[node] = {}
+ self._ensure_tree(tree[node], parameter, *nodes)
+
+ def _flatten(self, params, tree, path=""):
+ """
+ For every element in L{tree}, set C{path} to C{value} in the given
+ L{params} dictionary.
+
+ @param params: A L{dict} which will be populated.
+ @param tree: A structure made up of L{Argument}s, L{list}s, L{dict}s
+ and leaf values.
+ """
+ if isinstance(tree, Arguments):
+ for name, value in tree:
+ self._flatten(params, value, "%s.%s" % (path, name))
+ elif isinstance(tree, dict):
+ for name, value in tree.iteritems():
+ self._flatten(params, value, "%s.%s" % (path, name))
+ elif isinstance(tree, list):
+ for index, value in enumerate(tree):
+ self._flatten(params, value, "%s.%d" % (path, index + 1))
+ elif tree is not None:
+ params[path.lstrip(".")] = tree
+ else:
+ # None is discarded.
+ pass
=== added directory 'txaws/server/tests'
=== added file 'txaws/server/tests/__init__.py'
=== added file 'txaws/server/tests/test_call.py'
--- txaws/server/tests/test_call.py 1970-01-01 00:00:00 +0000
+++ txaws/server/tests/test_call.py 2011-05-14 08:15:55 +0000
@@ -0,0 +1,14 @@
+from twisted.trial.unittest import TestCase
+
+from txaws.server.call import Call
+
+
+class CallTest(TestCase):
+
+ def test_default_version(self):
+ """
+ If no version is explicitly requested, C{version} is set to
+ 2008-12-01, which is the earliest version we support.
+ """
+ call = Call()
+ self.assertEqual(call.version, "2008-12-01")
=== added file 'txaws/server/tests/test_exception.py'
--- txaws/server/tests/test_exception.py 1970-01-01 00:00:00 +0000
+++ txaws/server/tests/test_exception.py 2011-05-14 08:15:55 +0000
@@ -0,0 +1,51 @@
+from unittest import TestCase
+
+from txaws.server.exception import APIError
+
+
+class APIErrorTest(TestCase):
+
+ def test_with_no_parameters(self):
+ """
+ The L{APIError} constructor must be passed either a code/message pair
+ or a full response payload.
+ """
+ self.assertRaises(RuntimeError, APIError, 400)
+
+ def test_with_response_and_code(self):
+ """
+ If the L{APIError} constructor is passed a full response payload, it
+ can't be passed an error code.
+ """
+ self.assertRaises(RuntimeError, APIError, 400, code="FooBar",
+ response="foo bar")
+
+ def test_with_response_and_message(self):
+ """
+ If the L{APIError} constructor is passed a full response payload, it
+ can't be passed an error code.
+ """
+ self.assertRaises(RuntimeError, APIError, 400, message="Foo Bar",
+ response="foo bar")
+
+ def test_with_code_and_no_message(self):
+ """
+ If the L{APIError} constructor is passed an error code, it must be
+ passed an error message as well.
+ """
+ self.assertRaises(RuntimeError, APIError, 400, code="FooBar")
+
+ def test_with_message_and_no_code(self):
+ """
+ If the L{APIError} constructor is passed an error message, it must be
+ passed an error code as well.
+ """
+ self.assertRaises(RuntimeError, APIError, 400, message="Foo Bar")
+
+ def test_with_string_status(self):
+ """
+ The L{APIError} constructor can be passed a C{str} as status code, and
+ it will be converted to C{intp}.
+ """
+ error = APIError("200", response="noes")
+ self.assertEqual(200, error.status)
=== added file 'txaws/server/tests/test_resource.py'
--- txaws/server/tests/test_resource.py 1970-01-01 00:00:00 +0000
+++ txaws/server/tests/test_resource.py 2011-05-14 08:15:55 +0000
@@ -0,0 +1,428 @@
+from pytz import UTC
+from cStringIO import StringIO
+from datetime import datetime
+
+from twisted.trial.unittest import TestCase
+
+from txaws.credentials import AWSCredentials
+from txaws.service import AWSServiceEndpoint
+from txaws.ec2.client import Query
+from txaws.server.resource import QueryAPI
+
+
+class FakeRequest(object):
+
+ def __init__(self, params, endpoint):
+ self.params = params
+ self.endpoint = endpoint
+ self.written = StringIO()
+ self.finished = False
+ self.code = None
+ self.headers = {"Host": endpoint.get_canonical_host()}
+
+ @property
+ def args(self):
+ return dict((key, [value]) for key, value in self.params.iteritems())
+
+ @property
+ def method(self):
+ return self.endpoint.method
+
+ @property
+ def path(self):
+ return self.endpoint.path
+
+ def write(self, content):
+ assert isinstance(content, str), "Only strings should be written"
+ self.written.write(content)
+
+ def finish(self):
+ if self.code is None:
+ self.code = 200
+ self.finished = True
+
+ def setResponseCode(self, code):
+ self.code = code
+
+ def setHeader(self, key, value):
+ self.headers[key] = value
+
+ def getHeader(self, key):
+ return self.headers.get(key)
+
+ @property
+ def response(self):
+ return self.written.getvalue()
+
+
+class TestPrincipal(object):
+
+ def __init__(self, creds):
+ self.creds = creds
+
+ @property
+ def access_key(self):
+ return self.creds.access_key
+
+ @property
+ def secret_key(self):
+ return self.creds.secret_key
+
+
+class TestQueryAPI(QueryAPI):
+
+ actions = ["SomeAction"]
+ signature_versions = (1, 2)
+ content_type = "text/plain"
+
+ def __init__(self, *args, **kwargs):
+ QueryAPI.__init__(self, *args, **kwargs)
+ self.principal = None
+
+ def execute(self, call):
+ return "data"
+
+ def get_principal(self, access_key):
+ if self.principal and self.principal.access_key == access_key:
+ return self.principal
+
+ def dump_error(self, error, request):
+ return str("%s - %s" % (error.code, error.message))
+
+
+class QueryAPITest(TestCase):
+
+ def setUp(self):
+ super(QueryAPITest, self).setUp()
+ self.api = TestQueryAPI()
+
+ def test_handle(self):
+ """
+ L{QueryAPI.handle} forwards valid requests to L{QueryAPI.execute}.
+ """
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint)
+ query.sign()
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignored):
+ self.assertTrue(request.finished)
+ self.assertEqual("data", request.response)
+ self.assertEqual("4", request.headers["Content-Length"])
+ self.assertEqual("text/plain", request.headers["Content-Type"])
+ self.assertEqual(200, request.code)
+
+ self.api.principal = TestPrincipal(creds)
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_pass_params_to_call(self):
+ """
+ L{QueryAPI.handle} creates a L{Call} object with the correct
+ parameters.
+ """
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint,
+ other_params={"Foo": "bar", "Version": "1.2.3"})
+ query.sign()
+ request = FakeRequest(query.params, endpoint)
+
+ def execute(call):
+ self.assertEqual({"Foo": "bar"}, call.get_raw_params())
+ self.assertIdentical(self.api.principal, call.principal)
+ self.assertEqual("SomeAction", call.action)
+ self.assertEqual("1.2.3", call.version)
+ self.assertEqual(request.id, call.id)
+ return "ok"
+
+ def check(ignored):
+ self.assertEqual("ok", request.response)
+ self.assertEqual(200, request.code)
+
+ self.api.execute = execute
+ self.api.principal = TestPrincipal(creds)
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_empty_request(self):
+ """
+ If an empty request is received a message describing the API is
+ returned.
+ """
+ endpoint = AWSServiceEndpoint("http://uri")
+ request = FakeRequest({}, endpoint)
+ self.assertEqual("Query API Service", self.api.render(request))
+ self.assertEqual("text/plain", request.headers["Content-Type"])
+ self.assertEqual(None, request.code)
+
+ def test_handle_with_signature_version_1(self):
+ """SignatureVersion 1 is supported as well."""
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint,
+ other_params={"SignatureVersion": "1"})
+ query.sign()
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignore):
+ self.assertEqual("data", request.response)
+ self.assertEqual(200, request.code)
+
+ self.api.principal = TestPrincipal(creds)
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_with_signature_sha1(self):
+ """
+ The C{HmacSHA1} signature method is supported, in which case the
+ signing using sha1 instead.
+ """
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint)
+ query.sign(hash_type="sha1")
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignore):
+ self.assertEqual("data", request.response)
+ self.assertEqual(200, request.code)
+
+ self.api.principal = TestPrincipal(creds)
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_with_unsupported_version(self):
+ """If signature versions is not supported an error is raised."""
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint)
+ query.sign()
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignored):
+ self.flushLoggedErrors()
+ self.assertEqual("InvalidSignature - SignatureVersion '2' "
+ "not supported", request.response)
+ self.assertEqual(403, request.code)
+
+ self.api.signature_versions = (1,)
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_with_internal_error(self):
+ """
+ If an unknown error occurs while handling the request,
+ L{QueryAPI.handle} responds with HTTP status 500.
+ """
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint)
+ query.sign()
+ request = FakeRequest(query.params, endpoint)
+
+ self.api.execute = lambda call: 1 / 0
+
+ def check(ignored):
+ self.flushLoggedErrors()
+ self.assertTrue(request.finished)
+ self.assertEqual("integer division or modulo by zero",
+ request.response)
+ self.assertEqual(500, request.code)
+
+ self.api.principal = TestPrincipal(creds)
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_with_parameter_error(self):
+ """
+ If an error occurs while parsing the parameters, L{QueryAPI.handle}
+ responds with HTTP status 400.
+ """
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint)
+ query.sign()
+ query.params.pop("Action")
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignored):
+ self.flushLoggedErrors()
+ self.assertEqual("MissingParameter - The request must contain "
+ "the parameter Action", request.response)
+ self.assertEqual(400, request.code)
+
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_with_unsupported_action(self):
+ """Only actions listed in L{QueryAPI.actions} are supported."""
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri")
+ query = Query(action="FooBar", creds=creds, endpoint=endpoint)
+ query.sign()
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignored):
+ self.flushLoggedErrors()
+ self.assertEqual("InvalidAction - The action FooBar is not valid"
+ " for this web service.", request.response)
+ self.assertEqual(400, request.code)
+
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_with_non_existing_user(self):
+ """
+ If no L{Principal} can be found with the given access key ID,
+ L{QueryAPI.handle} responds with HTTP status 400.
+ """
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint)
+ query.sign()
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignored):
+ self.flushLoggedErrors()
+ self.assertEqual("AuthFailure - No user with access key 'access'",
+ request.response)
+ self.assertEqual(401, request.code)
+
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_with_wrong_signature(self):
+ """
+ If the signature in the request doesn't match the one calculated with
+ the locally stored secret access key, and error is returned.
+ """
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint)
+ query.sign()
+ query.params["Signature"] = "wrong"
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignored):
+ self.flushLoggedErrors()
+ self.assertEqual("SignatureDoesNotMatch - The request signature "
+ "we calculated does not match the signature you "
+ "provided. Check your key and signing method.",
+ request.response)
+ self.assertEqual(403, request.code)
+
+ self.api.principal = TestPrincipal(creds)
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_with_timestamp_and_expires(self):
+ """
+ If the request contains both Expires and Timestamp parameters,
+ an error is returned.
+ """
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint,
+ other_params={"Timestamp": "2010-01-01T12:00:00Z",
+ "Expires": "2010-01-01T12:00:00Z"})
+ query.sign()
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignored):
+ self.flushLoggedErrors()
+ self.assertEqual(
+ "InvalidParameterCombination - The parameter Timestamp"
+ " cannot be used with the parameter Expires",
+ request.response)
+ self.assertEqual(400, request.code)
+
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_with_non_expired_signature(self):
+ """
+ If the request contains an Expires parameter with a time that is before
+ the current time, everything is fine.
+ """
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint,
+ other_params={"Expires": "2010-01-01T12:00:00Z"})
+ query.sign()
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignored):
+ self.assertEqual("data", request.response)
+ self.assertEqual(200, request.code)
+
+ now = datetime(2009, 12, 31, tzinfo=UTC)
+ self.api.get_utc_time = lambda: now
+ self.api.principal = TestPrincipal(creds)
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_with_expired_signature(self):
+ """
+ If the request contains an Expires parameter with a time that is before
+ the current time, an error is returned.
+ """
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint,
+ other_params={"Expires": "2010-01-01T12:00:00Z"})
+ query.sign()
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignored):
+ self.flushLoggedErrors()
+ self.assertEqual(
+ "RequestExpired - Request has expired. Expires date is"
+ " 2010-01-01T12:00:00Z", request.response)
+ self.assertEqual(400, request.code)
+
+ now = datetime(2010, 1, 1, 12, 0, 1, tzinfo=UTC)
+ self.api.get_utc_time = lambda: now
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_with_post_method(self):
+ """
+ L{QueryAPI.handle} forwards valid requests using the HTTP POST method
+ to L{QueryAPI.execute}.
+ """
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://uri", method="POST")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint)
+ query.sign()
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignored):
+ self.assertEqual("data", request.response)
+ self.assertEqual(200, request.code)
+
+ self.api.principal = TestPrincipal(creds)
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_with_port_number(self):
+ """
+ If the request Host header includes a port number, it's included
+ in the text that get signed when checking the signature.
+ """
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://endpoint:1234")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint)
+ query.sign()
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignored):
+ self.assertEqual("data", request.response)
+ self.assertEqual(200, request.code)
+
+ self.api.principal = TestPrincipal(creds)
+ return self.api.handle(request).addCallback(check)
+
+ def test_handle_with_endpoint_with_terminating_slash(self):
+ """
+ Check signature should handle a urs with a terminating slash.
+ """
+ creds = AWSCredentials("access", "secret")
+ endpoint = AWSServiceEndpoint("http://endpoint/")
+ query = Query(action="SomeAction", creds=creds, endpoint=endpoint)
+ query.sign()
+ request = FakeRequest(query.params, endpoint)
+
+ def check(ignored):
+ self.assertEqual("data", request.response)
+ self.assertEqual(200, request.code)
+
+ self.api.principal = TestPrincipal(creds)
+ return self.api.handle(request).addCallback(check)
=== added file 'txaws/server/tests/test_schema.py'
--- txaws/server/tests/test_schema.py 1970-01-01 00:00:00 +0000
+++ txaws/server/tests/test_schema.py 2011-05-14 08:15:55 +0000
@@ -0,0 +1,493 @@
+from datetime import datetime
+
+from pytz import UTC, FixedOffset
+
+from twisted.trial.unittest import TestCase
+
+from txaws.server.exception import APIError
+from txaws.server.schema import (
+ Arguments, Bool, Date, Enum, Integer, Parameter, RawStr, Schema, Unicode)
+
+
+class ArgumentsTest(TestCase):
+
+ def test_instantiate_empty(self):
+ """Creating an L{Arguments} object."""
+ arguments = Arguments({})
+ self.assertEqual({}, arguments.__dict__)
+
+ def test_instantiate_non_empty(self):
+ """Creating an L{Arguments} object with some arguments."""
+ arguments = Arguments({"foo": 123, "bar": 456})
+ self.assertEqual(123, arguments.foo)
+ self.assertEqual(456, arguments.bar)
+
+ def test_iterate(self):
+ """L{Arguments} returns an iterator with both keys and values."""
+ arguments = Arguments({"foo": 123, "bar": 456})
+ self.assertEqual([("foo", 123), ("bar", 456)], list(arguments))
+
+ def test_getitem(self):
+ """Values can be looked up using C{[index]} notation."""
+ arguments = Arguments({1: "a", 2: "b", "foo": "bar"})
+ self.assertEqual("b", arguments[2])
+ self.assertEqual("bar", arguments["foo"])
+
+ def test_getitem_error(self):
+ """L{KeyError} is raised when the argument is not found."""
+ arguments = Arguments({})
+ self.assertRaises(KeyError, arguments.__getitem__, 1)
+
+ def test_len(self):
+ """C{len()} can be used with an L{Arguments} instance."""
+ self.assertEqual(0, len(Arguments({})))
+ self.assertEqual(1, len(Arguments({1: 2})))
+
+ def test_nested_data(self):
+ """L{Arguments} can cope fine with nested data structures."""
+ arguments = Arguments({"foo": Arguments({"bar": "egg"})})
+ self.assertEqual("egg", arguments.foo.bar)
+
+ def test_nested_data_with_numbers(self):
+ """L{Arguments} can cope fine with list items."""
+ arguments = Arguments({"foo": {1: "egg"}})
+ self.assertEqual("egg", arguments.foo[0])
+
+
+class ParameterTest(TestCase):
+
+ def test_coerce(self):
+ """
+ L{Parameter.coerce} coerces a request argument with a single value.
+ """
+ parameter = Parameter("Test")
+ parameter.parse = lambda value: value
+ self.assertEqual("foo", parameter.coerce("foo"))
+
+ def test_coerce_with_optional(self):
+ """L{Parameter.coerce} returns C{None} if the parameter is optional."""
+ parameter = Parameter("Test", optional=True)
+ self.assertEqual(None, parameter.coerce(None))
+
+ def test_coerce_with_required(self):
+ """
+ L{Parameter.coerce} raises an L{APIError} if the parameter is
+ required but not present in the request.
+ """
+ parameter = Parameter("Test")
+ error = self.assertRaises(APIError, parameter.coerce, None)
+ self.assertEqual(400, error.status)
+ self.assertEqual("MissingParameter", error.code)
+ self.assertEqual("The request must contain the parameter Test",
+ error.message)
+
+ def test_coerce_with_default(self):
+ """
+ L{Parameter.coerce} returns F{Parameter.default} if the parameter is
+ optional and not present in the request.
+ """
+ parameter = Parameter("Test", optional=True, default=123)
+ self.assertEqual(123, parameter.coerce(None))
+
+ def test_coerce_with_parameter_error(self):
+ """
+ L{Parameter.coerce} raises an L{APIError} if an invalid value is
+ passed as request argument.
+ """
+ parameter = Parameter("Test")
+ parameter.parse = lambda value: int(value)
+ parameter.kind = "integer"
+ error = self.assertRaises(APIError, parameter.coerce, "foo")
+ self.assertEqual(400, error.status)
+ self.assertEqual("InvalidParameterValue", error.code)
+ self.assertEqual("Invalid integer value foo", error.message)
+
+ def test_coerce_with_empty_strings(self):
+ """
+ L{Parameter.coerce} returns C{None} if the value is an empty string and
+ C{allow_none} is C{True}.
+ """
+ parameter = Parameter("Test", allow_none=True)
+ self.assertEqual(None, parameter.coerce(""))
+
+ def test_coerce_with_empty_strings_error(self):
+ """
+ L{Parameter.coerce} raises an error if the value is an empty string and
+ C{allow_none} is not C{True}.
+ """
+ parameter = Parameter("Test")
+ error = self.assertRaises(APIError, parameter.coerce, "")
+ self.assertEqual(400, error.status)
+ self.assertEqual("MissingParameter", error.code)
+ self.assertEqual("The request must contain the parameter Test",
+ error.message)
+
+ def test_coerce_with_min(self):
+ """
+ L{Parameter.coerce} raises an error if the given value is lower than
+ the lower bound.
+ """
+ parameter = Parameter("Test", min=50)
+ parameter.measure = lambda value: int(value)
+ parameter.lower_than_min_template = "Please give me at least %s"
+ error = self.assertRaises(APIError, parameter.coerce, "4")
+ self.assertEqual(400, error.status)
+ self.assertEqual("InvalidParameterValue", error.code)
+ self.assertEqual("Value (4) for parameter Test is invalid. "
+ "Please give me at least 50", error.message)
+
+ def test_coerce_with_max(self):
+ """
+ L{Parameter.coerce} raises an error if the given value is greater than
+ the upper bound.
+ """
+ parameter = Parameter("Test", max=3)
+ parameter.measure = lambda value: len(value)
+ parameter.greater_than_max_template = "%s should be enough for anybody"
+ error = self.assertRaises(APIError, parameter.coerce, "longish")
+ self.assertEqual(400, error.status)
+ self.assertEqual("InvalidParameterValue", error.code)
+ self.assertEqual("Value (longish) for parameter Test is invalid. "
+ "3 should be enough for anybody", error.message)
+
+
+class UnicodeTest(TestCase):
+
+ def test_parse(self):
+ """L{Unicode.parse} converts the given raw C{value} to C{unicode}."""
+ parameter = Unicode("Test")
+ self.assertEqual(u"foo", parameter.parse("foo"))
+
+ def test_format(self):
+ """L{Unicode.format} encodes the given C{unicode} with utf-8."""
+ parameter = Unicode("Test")
+ value = parameter.format(u"fo\N{TAGBANWA LETTER SA}")
+ self.assertEqual("fo\xe1\x9d\xb0", value)
+ self.assertTrue(isinstance(value, str))
+
+ def test_min_and_max(self):
+ """The L{Unicode} parameter properly supports ranges."""
+ parameter = Unicode("Test", min=2, max=4)
+
+ error = self.assertRaises(APIError, parameter.coerce, "a")
+ self.assertEqual(400, error.status)
+ self.assertEqual("InvalidParameterValue", error.code)
+ self.assertIn("Length must be at least 2.", error.message)
+
+ error = self.assertRaises(APIError, parameter.coerce, "abcde")
+ self.assertIn("Length exceeds maximum of 4.", error.message)
+ self.assertEqual(400, error.status)
+ self.assertEqual("InvalidParameterValue", error.code)
+
+
+class RawStrTest(TestCase):
+
+ def test_parse(self):
+ """L{RawStr.parse checks that the given raw C{value} is a string."""
+ parameter = RawStr("Test")
+ self.assertEqual("foo", parameter.parse("foo"))
+
+ def test_format(self):
+ """L{RawStr.format} simply returns the given string."""
+ parameter = RawStr("Test")
+ value = parameter.format("foo")
+ self.assertEqual("foo", value)
+ self.assertTrue(isinstance(value, str))
+
+
+class IntegerTest(TestCase):
+
+ def test_parse(self):
+ """L{Integer.parse} converts the given raw C{value} to C{int}."""
+ parameter = Integer("Test")
+ self.assertEqual(123, parameter.parse("123"))
+
+ def test_parse_wiith_negative(self):
+ """L{Integer.parse} converts the given raw C{value} to C{int}."""
+ parameter = Integer("Test")
+ self.assertRaises(ValueError, parameter.parse, "-1")
+
+ def test_format(self):
+ """L{Integer.format} converts the given integer to a string."""
+ parameter = Integer("Test")
+ self.assertEqual("123", parameter.format(123))
+
+
+class BoolTest(TestCase):
+
+ def test_parse(self):
+ """L{Bool.parse} converts 'true' to C{True}."""
+ parameter = Bool("Test")
+ self.assertEqual(True, parameter.parse("true"))
+
+ def test_parse_with_false(self):
+ """L{Bool.parse} converts 'false' to C{False}."""
+ parameter = Bool("Test")
+ self.assertEqual(False, parameter.parse("false"))
+
+ def test_parse_with_error(self):
+ """
+ L{Bool.parse} raises C{ValueError} if the given value is neither 'true'
+ or 'false'.
+ """
+ parameter = Bool("Test")
+ self.assertRaises(ValueError, parameter.parse, "0")
+
+ def test_format(self):
+ """L{Bool.format} converts the given boolean to either '0' or '1'."""
+ parameter = Bool("Test")
+ self.assertEqual("true", parameter.format(True))
+ self.assertEqual("false", parameter.format(False))
+
+
+class EnumTest(TestCase):
+
+ def test_parse(self):
+ """L{Enum.parse} accepts a map for translating values."""
+ parameter = Enum("Test", {"foo": "bar"})
+ self.assertEqual("bar", parameter.parse("foo"))
+
+ def test_parse_with_error(self):
+ """
+ L{Bool.parse} raises C{ValueError} if the given value is not
+ present in the mapping.
+ """
+ parameter = Enum("Test", {})
+ self.assertRaises(ValueError, parameter.parse, "bar")
+
+ def test_format(self):
+ """L{Enum.format} converts back the given value to the original map."""
+ parameter = Enum("Test", {"foo": "bar"})
+ self.assertEqual("foo", parameter.format("bar"))
+
+
+class DateTest(TestCase):
+
+ def test_parse(self):
+ """L{Date.parse checks that the given raw C{value} is a date/time."""
+ parameter = Date("Test")
+ date = datetime(2010, 9, 15, 23, 59, 59, tzinfo=UTC)
+ self.assertEqual(date, parameter.parse("2010-09-15T23:59:59Z"))
+
+ def test_format(self):
+ """
+ L{Date.format} returns a string representation of the given datetime
+ instance.
+ """
+ parameter = Date("Test")
+ date = datetime(2010, 9, 15, 23, 59, 59,
+ tzinfo=FixedOffset(120))
+ self.assertEqual("2010-09-15T21:59:59Z", parameter.format(date))
+
+
+class SchemaTest(TestCase):
+
+ def test_extract(self):
+ """
+ L{Schema.extract} returns an L{Argument} object whose attributes are
+ the arguments extracted from the given C{request}, as specified.
+ """
+ schema = Schema(Unicode("name"))
+ arguments, _ = schema.extract({"name": "value"})
+ self.assertEqual("value", arguments.name)
+
+ def test_extract_with_rest(self):
+ """
+ L{Schema.extract} stores unknown parameters in the 'rest' return
+ dictionary.
+ """
+ schema = Schema()
+ _, rest = schema.extract({"name": "value"})
+ self.assertEqual(rest, {"name": "value"})
+
+ def test_extract_with_many_arguments(self):
+ """L{Schema.extract} can handle multiple parameters."""
+ schema = Schema(Unicode("name"), Integer("count"))
+ arguments, _ = schema.extract({"name": "value", "count": "123"})
+ self.assertEqual(u"value", arguments.name)
+ self.assertEqual(123, arguments.count)
+
+ def test_extract_with_optional(self):
+ """L{Schema.extract} can handle optional parameters."""
+ schema = Schema(Unicode("name"), Integer("count", optional=True))
+ arguments, _ = schema.extract({"name": "value"})
+ self.assertEqual(u"value", arguments.name)
+ self.assertEqual(None, arguments.count)
+
+ def test_extract_with_numbered(self):
+ """
+ L{Schema.extract} can handle parameters with numbered values.
+ """
+ schema = Schema(Unicode("name.n"))
+ arguments, _ = schema.extract({"name.0": "Joe", "name.1": "Tom"})
+ self.assertEqual("Joe", arguments.name[0])
+ self.assertEqual("Tom", arguments.name[1])
+
+ def test_extract_with_single_numbered(self):
+ """
+ L{Schema.extract} can handle a single parameter with a numbered value.
+ """
+ schema = Schema(Unicode("name.n"))
+ arguments, _ = schema.extract({"name.0": "Joe"})
+ self.assertEqual("Joe", arguments.name[0])
+
+ def test_extract_complex(self):
+ """L{Schema} can cope with complex schemas."""
+ schema = Schema(
+ Unicode("GroupName"),
+ RawStr("IpPermissions.n.IpProtocol"),
+ Integer("IpPermissions.n.FromPort"),
+ Integer("IpPermissions.n.ToPort"),
+ Unicode("IpPermissions.n.Groups.m.UserId", optional=True),
+ Unicode("IpPermissions.n.Groups.m.GroupName", optional=True))
+
+ arguments, _ = schema.extract(
+ {"GroupName": "Foo",
+ "IpPermissions.1.IpProtocol": "tcp",
+ "IpPermissions.1.FromPort": "1234",
+ "IpPermissions.1.ToPort": "5678",
+ "IpPermissions.1.Groups.1.GroupName": "Bar",
+ "IpPermissions.1.Groups.2.GroupName": "Egg"})
+
+ self.assertEqual(u"Foo", arguments.GroupName)
+ self.assertEqual(1, len(arguments.IpPermissions))
+ self.assertEqual(1234, arguments.IpPermissions[0].FromPort)
+ self.assertEqual(5678, arguments.IpPermissions[0].ToPort)
+ self.assertEqual(2, len(arguments.IpPermissions[0].Groups))
+ self.assertEqual("Bar", arguments.IpPermissions[0].Groups[0].GroupName)
+ self.assertEqual("Egg", arguments.IpPermissions[0].Groups[1].GroupName)
+
+ def test_extract_with_multiple_parameters_in_singular_schema(self):
+ """
+ If multiple parameters are passed in to a Schema element that is not
+ flagged as supporting multiple values then we should throw an
+ C{APIError}.
+ """
+ schema = Schema(Unicode("name"))
+ params = {"name.1": "value", "name.2": "value2"}
+ error = self.assertRaises(APIError, schema.extract, params)
+ self.assertEqual(400, error.status)
+ self.assertEqual("InvalidParameterCombination", error.code)
+ self.assertEqual("The parameter 'name' may only be specified once.",
+ error.message)
+
+ def test_extract_with_mixed(self):
+ """
+ L{Schema.extract} stores in the rest result all numbered parameters
+ given without an index.
+ """
+ schema = Schema(Unicode("name.n"))
+ _, rest = schema.extract({"name": "foo", "name.1": "bar"})
+ self.assertEqual(rest, {"name": "foo"})
+
+ def test_extract_with_non_numbered_template(self):
+ """
+ L{Schema.extract} accepts a single numbered argument even if the
+ associated template is not numbered.
+ """
+ schema = Schema(Unicode("name"))
+ arguments, _ = schema.extract({"name.1": "foo"})
+ self.assertEqual("foo", arguments.name)
+
+ def test_extract_with_non_integer_index(self):
+ """
+ L{Schema.extract} raises an error when trying to pass a numbered
+ parameter with a non-integer index.
+ """
+ schema = Schema(Unicode("name.n"))
+ params = {"name.one": "foo"}
+ error = self.assertRaises(APIError, schema.extract, params)
+ self.assertEqual(400, error.status)
+ self.assertEqual("UnknownParameter", error.code)
+ self.assertEqual("The parameter name.one is not recognized",
+ error.message)
+
+ def test_extract_with_negative_index(self):
+ """
+ L{Schema.extract} raises an error when trying to pass a numbered
+ parameter with a negative index.
+ """
+ schema = Schema(Unicode("name.n"))
+ params = {"name.-1": "foo"}
+ error = self.assertRaises(APIError, schema.extract, params)
+ self.assertEqual(400, error.status)
+ self.assertEqual("UnknownParameter", error.code)
+ self.assertEqual("The parameter name.-1 is not recognized",
+ error.message)
+
+ def test_bundle(self):
+ """
+ L{Schema.bundle} returns a dictionary of raw parameters that
+ can be used for an EC2-style query.
+ """
+ schema = Schema(Unicode("name"))
+ params = schema.bundle(name="foo")
+ self.assertEqual({"name": "foo"}, params)
+
+ def test_bundle_with_numbered(self):
+ """
+ L{Schema.bundle} correctly handles numbered arguments.
+ """
+ schema = Schema(Unicode("name.n"))
+ params = schema.bundle(name=["foo", "bar"])
+ self.assertEqual({"name.1": "foo", "name.2": "bar"}, params)
+
+ def test_bundle_with_none(self):
+ """L{None} values are discarded in L{Schema.bundle}."""
+ schema = Schema(Unicode("name.n", optional=True))
+ params = schema.bundle(name=None)
+ self.assertEqual({}, params)
+
+ def test_bundle_with_empty_numbered(self):
+ """
+ L{Schema.bundle} correctly handles an empty numbered arguments list.
+ """
+ schema = Schema(Unicode("name.n"))
+ params = schema.bundle(names=[])
+ self.assertEqual({}, params)
+
+ def test_bundle_with_numbered_not_supplied(self):
+ """
+ L{Schema.bundle} ignores parameters that are not present.
+ """
+ schema = Schema(Unicode("name.n"))
+ params = schema.bundle()
+ self.assertEqual({}, params)
+
+ def test_bundle_with_multiple(self):
+ """
+ L{Schema.bundle} correctly handles multiple arguments.
+ """
+ schema = Schema(Unicode("name.n"), Integer("count"))
+ params = schema.bundle(name=["Foo", "Bar"], count=123)
+ self.assertEqual({"name.1": "Foo", "name.2": "Bar", "count": "123"},
+ params)
+
+ def test_bundle_with_arguments(self):
+ """L{Schema.bundle} can bundle L{Arguments} too."""
+ schema = Schema(Unicode("name.n"), Integer("count"))
+ arguments = Arguments({"name": Arguments({1: "Foo", 7: "Bar"}),
+ "count": 123})
+ params = schema.bundle(arguments)
+ self.assertEqual({"name.1": "Foo", "name.7": "Bar", "count": "123"},
+ params)
+
+ def test_bundle_with_arguments_and_extra(self):
+ """
+ L{Schema.bundle} can bundle L{Arguments} with keyword arguments too.
+
+ Keyword arguments take precedence.
+ """
+ schema = Schema(Unicode("name.n"), Integer("count"))
+ arguments = Arguments({"name": {1: "Foo", 7: "Bar"}, "count": 321})
+ params = schema.bundle(arguments, count=123)
+ self.assertEqual({"name.1": "Foo", "name.2": "Bar", "count": "123"},
+ params)
+
+ def test_bundle_with_missing_parameter(self):
+ """
+ L{Schema.bundle} raises an exception one of the given parameters
+ doesn't exist in the schema.
+ """
+ schema = Schema(Integer("count"))
+ self.assertRaises(RuntimeError, schema.bundle, name="foo")