← Back to team overview

yellow team mailing list archive

[Merge] lp:~frankban/python-shell-toolbox/helpers into lp:python-shell-toolbox

 

Francesco Banconi has proposed merging lp:~frankban/python-shell-toolbox/helpers into lp:python-shell-toolbox.

Requested reviews:
  Launchpad Yellow Squad (yellow)

For more details, see:
https://code.launchpad.net/~frankban/python-shell-toolbox/helpers/+merge/96180

== Changes ==

- Added several helper functions:
  - bzr_whois
  - file_append
  - file_prepend
  - generate_ssh_keys
  - get_su_command
  - join_command
  - mkdirs
  - ssh
  - user_exists

- Added `environ` context manager and updated `su` to use it.

- Changed `apt_get_install` helper: now the function uses environ to run dpkg
  in non-interactive mode.

- Changed `get_user_home` to handle non-existent users
  (returning a default home directory)

- Updated the `install_extra_repositories` helper to accept distribution
  placeholders.

- Updated `cd` context manager: yield in a try/finally block.

- Fixed `grep` helper.

== Tests ==

python tests.py
.....................................................
----------------------------------------------------------------------
Ran 53 tests in 0.124s

OK
-- 
https://code.launchpad.net/~frankban/python-shell-toolbox/helpers/+merge/96180
Your team Launchpad Yellow Squad is requested to review the proposed merge of lp:~frankban/python-shell-toolbox/helpers into lp:python-shell-toolbox.
=== modified file 'setup.py'
--- setup.py	2012-03-01 20:49:48 +0000
+++ setup.py	2012-03-06 16:50:24 +0000
@@ -9,7 +9,7 @@
 ez_setup.use_setuptools()
 
 
-__version__ = '0.1.1'
+__version__ = '0.2.0'
 
 from setuptools import setup
 

=== modified file 'shelltoolbox/__init__.py'
--- shelltoolbox/__init__.py	2012-03-02 14:59:46 +0000
+++ shelltoolbox/__init__.py	2012-03-06 16:50:24 +0000
@@ -19,25 +19,39 @@
 __metaclass__ = type
 __all__ = [
     'apt_get_install',
+    'bzr_whois',
     'cd',
     'command',
     'DictDiffer',
+    'environ',
+    'file_append',
+    'file_prepend',
+    'generate_ssh_keys',
+    'get_su_command',
+    'get_user_home',
     'get_user_ids',
-    'get_user_home',
     'get_value_from_line',
     'grep',
     'install_extra_repositories',
+    'join_command',
+    'mkdirs',
     'run',
     'Serializer',
     'script_name',
+    'ssh',
     'su',
+    'user_exists',
+    'wait_for_page_contents',
     ]
 
 from collections import namedtuple
 from contextlib import contextmanager
+from email.Utils import parseaddr
+import errno
 import json
 import operator
 import os
+import pipes
 import pwd
 import re
 import subprocess
@@ -50,26 +64,50 @@
 Env = namedtuple('Env', 'uid gid home')
 
 
-def run(*args, **kwargs):
-    """Run the command with the given arguments.
-
-    The first argument is the path to the command to run.
-    Subsequent arguments are command-line arguments to be passed.
-
-    This function accepts all optional keyword arguments accepted by
-    `subprocess.Popen`.
-    """
-    args = [i for i in args if i is not None]
-    pipe = subprocess.PIPE
-    process = subprocess.Popen(
-        args, stdout=kwargs.pop('stdout', pipe),
-        stderr=kwargs.pop('stderr', pipe),
-        close_fds=kwargs.pop('close_fds', True), **kwargs)
-    stdout, stderr = process.communicate()
-    if process.returncode:
-        raise subprocess.CalledProcessError(
-            process.returncode, repr(args), output=stdout+stderr)
-    return stdout
+def apt_get_install(*args, **kwargs):
+    """Install given packages using apt.
+
+    It is possible to pass environment variables to be set during install
+    using keyword arguments.
+
+    :raises: subprocess.CalledProcessError
+    """
+    debian_frontend = kwargs.pop('DEBIAN_FRONTEND', 'noninteractive')
+    with environ(DEBIAN_FRONTEND=debian_frontend, **kwargs):
+        cmd = ('apt-get', '-y', 'install') + args
+        return run(*cmd)
+
+
+def bzr_whois(user):
+    """Return full name and email of bzr `user`.
+
+    Return None if the given `user` does not have a bzr user id.
+    """
+    with su(user):
+        try:
+            whoami = run('bzr', 'whoami')
+        except (subprocess.CalledProcessError, OSError):
+            return None
+    return parseaddr(whoami)
+
+
+@contextmanager
+def cd(directory):
+    """A context manager to temporarily change current working dir, e.g.::
+
+        >>> import os
+        >>> os.chdir('/tmp')
+        >>> with cd('/bin'): print os.getcwd()
+        /bin
+        >>> print os.getcwd()
+        /tmp
+    """
+    cwd = os.getcwd()
+    os.chdir(directory)
+    try:
+        yield
+    finally:
+        os.chdir(cwd)
 
 
 def command(*base_args):
@@ -97,39 +135,392 @@
     return callable_command
 
 
-apt_get_install = command('apt-get', 'install', '-y', '--force-yes')
+@contextmanager
+def environ(**kwargs):
+    """A context manager to temporarily change environment variables.
+
+    If an existing environment variable is changed, it is restored during
+    context cleanup::
+
+        >>> import os
+        >>> os.environ['MY_VARIABLE'] = 'foo'
+        >>> with environ(MY_VARIABLE='bar'): print os.getenv('MY_VARIABLE')
+        bar
+        >>> print os.getenv('MY_VARIABLE')
+        foo
+        >>> del os.environ['MY_VARIABLE']
+
+    If we are adding environment variables, they are removed during context
+    cleanup::
+
+        >>> import os
+        >>> with environ(MY_VAR1='foo', MY_VAR2='bar'):
+        ...     print os.getenv('MY_VAR1'), os.getenv('MY_VAR2')
+        foo bar
+        >>> os.getenv('MY_VAR1') == os.getenv('MY_VAR2') == None
+        True
+    """
+    backup = {}
+    for key, value in kwargs.items():
+        backup[key] = os.getenv(key)
+        os.environ[key] = value
+    try:
+        yield
+    finally:
+        for key, value in backup.items():
+            if value is None:
+                del os.environ[key]
+            else:
+                os.environ[key] = value
+
+
+def file_append(filename, line):
+    r"""Append given `line`, if not present, at the end of `filename`.
+
+    Usage example::
+
+        >>> import tempfile
+        >>> f = tempfile.NamedTemporaryFile('w', delete=False)
+        >>> f.write('line1\n')
+        >>> f.close()
+        >>> file_append(f.name, 'new line\n')
+        >>> open(f.name).read()
+        'line1\nnew line\n'
+
+    Nothing happens if the file already contains the given `line`::
+
+        >>> file_append(f.name, 'new line\n')
+        >>> open(f.name).read()
+        'line1\nnew line\n'
+
+    A new line is automatically added before the given `line` if it is not
+    present at the end of current file content::
+
+        >>> import tempfile
+        >>> f = tempfile.NamedTemporaryFile('w', delete=False)
+        >>> f.write('line1')
+        >>> f.close()
+        >>> file_append(f.name, 'new line\n')
+        >>> open(f.name).read()
+        'line1\nnew line\n'
+
+    The file is created if it does not exist::
+
+        >>> import tempfile
+        >>> filename = tempfile.mktemp()
+        >>> file_append(filename, 'line1\n')
+        >>> open(filename).read()
+        'line1\n'
+    """
+    with open(filename, 'a+') as f:
+        content = f.read()
+        if line not in content:
+            if content.endswith('\n') or not content:
+                f.write(line)
+            else:
+                f.write('\n' + line)
+
+
+def file_prepend(filename, line):
+    r"""Insert given `line`, if not present, at the beginning of `filename`.
+
+    Usage example::
+
+        >>> import tempfile
+        >>> f = tempfile.NamedTemporaryFile('w', delete=False)
+        >>> f.write('line1\n')
+        >>> f.close()
+        >>> file_prepend(f.name, 'line0\n')
+        >>> open(f.name).read()
+        'line0\nline1\n'
+
+    If the file starts with the given `line`, nothing happens::
+
+        >>> file_prepend(f.name, 'line0\n')
+        >>> open(f.name).read()
+        'line0\nline1\n'
+
+    If the file contains the given `line`, but not at the beginning,
+    the line is moved on top::
+
+        >>> file_prepend(f.name, 'line1\n')
+        >>> open(f.name).read()
+        'line1\nline0\n'
+    """
+    with open(filename, 'r+') as f:
+        lines = f.readlines()
+        if lines[0] != line:
+            if line in lines:
+                lines.remove(line)
+            lines.insert(0, line)
+            f.seek(0)
+            f.writelines(lines)
+
+
+def generate_ssh_keys(path, passphrase=''):
+    """Generate ssh key pair, saving them inside the given `directory`.
+
+        >>> generate_ssh_keys('/tmp/id_rsa')
+        0
+        >>> open('/tmp/id_rsa').readlines()[0].strip()
+        '-----BEGIN RSA PRIVATE KEY-----'
+        >>> open('/tmp/id_rsa.pub').read().startswith('ssh-rsa')
+        True
+        >>> os.remove('/tmp/id_rsa')
+        >>> os.remove('/tmp/id_rsa.pub')
+    """
+    return subprocess.call([
+        'ssh-keygen', '-q', '-t', 'rsa', '-N', passphrase, '-f', path])
+
+
+def get_su_command(user, args):
+    """Return a command line as a sequence, prepending "su" if necessary.
+
+    This can be used together with `run` when the `su` context manager is not
+    enough (e.g. an external program uses uid rather than euid).
+
+        >>> import getpass
+        >>> current_user = getpass.getuser()
+
+    If the su is requested as current user, the arguments are returned as
+    given::
+
+        >>> get_su_command(current_user, ('ls', '-l'))
+        ('ls', '-l')
+
+    Otherwise, "su" is prepended::
+
+        >>> get_su_command('nobody', ('ls', '-l', 'my file'))
+        ('su', 'nobody', '-c', "ls -l 'my file'")
+    """
+    if get_user_ids(user)[0] != os.getuid():
+        args = [i for i in args if i is not None]
+        return ('su', user, '-c', join_command(args))
+    return args
+
+
+def get_user_home(user):
+    """Return the home directory of the given `user`.
+
+        >>> get_user_home('root')
+        '/root'
+
+    If the user does not exist, return a default /home/[username] home::
+
+        >>> get_user_home('_this_user_does_not_exist_')
+        '/home/_this_user_does_not_exist_'
+    """
+    try:
+        return pwd.getpwnam(user).pw_dir
+    except KeyError:
+        return os.path.join(os.path.sep, 'home', user)
+
+
+def get_user_ids(user):
+    """Return the uid and gid of given `user`, e.g.::
+
+        >>> get_user_ids('root')
+        (0, 0)
+    """
+    userdata = pwd.getpwnam(user)
+    return userdata.pw_uid, userdata.pw_gid
+
+
+def get_value_from_line(line):
+    """Return the value from a line representing a Python assignment."""
+    return line.split('=')[1].strip('"\' ')
+
+
+def grep(content, filename):
+    """Grep `filename` using `content` regular expression."""
+    with open(filename) as f:
+        for line in f:
+            if re.search(content, line):
+                return line.strip()
 
 
 def install_extra_repositories(*repositories):
     """Install all of the extra repositories and update apt.
 
+    Given repositories can contain a "{distribution}" placeholder, that will
+    be replaced by current distribution codename.
+
     :raises: subprocess.CalledProcessError
     """
     distribution = run('lsb_release', '-cs').strip()
-    # Starting from Oneiric, the `apt-add-repository` is interactive by
+    # Starting from Oneiric, `apt-add-repository` is interactive by
     # default, and requires a "-y" flag to be set.
     assume_yes = None if distribution == 'lucid' else '-y'
     for repo in repositories:
-        run('apt-add-repository', assume_yes, repo)
+        repository = repo.format(distribution=distribution)
+        run('apt-add-repository', assume_yes, repository)
     run('apt-get', 'clean')
     run('apt-get', 'update')
 
 
-def grep(content, filename):
-    with open(filename) as f:
-        for line in f:
-            if re.match(content, line):
-                return line.strip()
-
-
-def get_value_from_line(line):
-    return line.split('=')[1].strip('"\' ')
+def join_command(args):
+    """Return a valid Unix command line from `args`.
+
+        >>> join_command(['ls', '-l'])
+        'ls -l'
+
+    Arguments containing spaces and empty args are correctly quoted::
+
+        >>> join_command(['command', 'arg1', 'arg containing spaces', ''])
+        "command arg1 'arg containing spaces' ''"
+    """
+    return ' '.join(pipes.quote(arg) for arg in args)
+
+
+def mkdirs(*args):
+    """Create leaf directories (given as `args`) and all intermediate ones.
+
+        >>> import tempfile
+        >>> base_dir = tempfile.mktemp(suffix='/')
+        >>> dir1 = tempfile.mktemp(prefix=base_dir)
+        >>> dir2 = tempfile.mktemp(prefix=base_dir)
+        >>> mkdirs(dir1, dir2)
+        >>> os.path.isdir(dir1)
+        True
+        >>> os.path.isdir(dir2)
+        True
+
+    If the leaf directory already exists the function returns without errors::
+
+        >>> mkdirs(dir1)
+
+    An `OSError` is raised if the leaf path exists and it is a file::
+
+        >>> f = tempfile.NamedTemporaryFile(
+        ...     'w', delete=False, prefix=base_dir)
+        >>> f.close()
+        >>> mkdirs(f.name) # doctest: +ELLIPSIS
+        Traceback (most recent call last):
+        OSError: ...
+    """
+    for directory in args:
+        try:
+            os.makedirs(directory)
+        except OSError as err:
+            if err.errno != errno.EEXIST or os.path.isfile(directory):
+                raise
+
+
+def run(*args, **kwargs):
+    """Run the command with the given arguments.
+
+    The first argument is the path to the command to run.
+    Subsequent arguments are command-line arguments to be passed.
+
+    This function accepts all optional keyword arguments accepted by
+    `subprocess.Popen`.
+    """
+    args = [i for i in args if i is not None]
+    pipe = subprocess.PIPE
+    process = subprocess.Popen(
+        args, stdout=kwargs.pop('stdout', pipe),
+        stderr=kwargs.pop('stderr', pipe),
+        close_fds=kwargs.pop('close_fds', True), **kwargs)
+    stdout, stderr = process.communicate()
+    if process.returncode:
+        raise subprocess.CalledProcessError(
+            process.returncode, repr(args), output=stdout+stderr)
+    return stdout
 
 
 def script_name():
+    """Return the name of this script."""
     return os.path.basename(sys.argv[0])
 
 
+def ssh(location, user=None, key=None, caller=subprocess.call):
+    """Return a callable that can be used to run ssh shell commands.
+
+    The ssh `location` and, optionally, `user` must be given.
+    If the user is None then the current user is used for the connection.
+
+    The callable internally uses the given `caller`::
+
+        >>> def caller(cmd):
+        ...     print tuple(cmd)
+        >>> sshcall = ssh('example.com', 'myuser', caller=caller)
+        >>> root_sshcall = ssh('example.com', caller=caller)
+        >>> sshcall('ls -l') # doctest: +ELLIPSIS
+        ('ssh', '-t', ..., 'myuser@xxxxxxxxxxx', '--', 'ls -l')
+        >>> root_sshcall('ls -l') # doctest: +ELLIPSIS
+        ('ssh', '-t', ..., 'example.com', '--', 'ls -l')
+
+    The ssh key path can be optionally provided::
+
+        >>> root_sshcall = ssh('example.com', key='/tmp/foo', caller=caller)
+        >>> root_sshcall('ls -l') # doctest: +ELLIPSIS
+        ('ssh', '-t', ..., '-i', '/tmp/foo', 'example.com', '--', 'ls -l')
+
+    If the ssh command exits with an error code,
+    a `subprocess.CalledProcessError` is raised::
+
+        >>> ssh('loc', caller=lambda cmd: 1)('ls -l') # doctest: +ELLIPSIS
+        Traceback (most recent call last):
+        CalledProcessError: ...
+
+    If ignore_errors is set to True when executing the command, no error
+    will be raised, even if the command itself returns an error code.
+
+        >>> sshcall = ssh('loc', caller=lambda cmd: 1)
+        >>> sshcall('ls -l', ignore_errors=True)
+    """
+    sshcmd = [
+        'ssh',
+        '-t',
+        '-t',  # Yes, this second -t is deliberate. See `man ssh`.
+        '-o', 'StrictHostKeyChecking=no',
+        '-o', 'UserKnownHostsFile=/dev/null',
+        ]
+    if key is not None:
+        sshcmd.extend(['-i', key])
+    if user is not None:
+        location = '{}@{}'.format(user, location)
+    sshcmd.extend([location, '--'])
+
+    def _sshcall(cmd, ignore_errors=False):
+        command = sshcmd + [cmd]
+        retcode = caller(command)
+        if retcode and not ignore_errors:
+            raise subprocess.CalledProcessError(retcode, ' '.join(command))
+
+    return _sshcall
+
+
+@contextmanager
+def su(user):
+    """A context manager to temporarily run the script as a different user."""
+    uid, gid = get_user_ids(user)
+    os.setegid(gid)
+    os.seteuid(uid)
+    home = get_user_home(user)
+    with environ(HOME=home):
+        try:
+            yield Env(uid, gid, home)
+        finally:
+            os.setegid(os.getgid())
+            os.seteuid(os.getuid())
+
+
+def user_exists(username):
+    """Return True if given `username` exists, e.g.::
+
+        >>> user_exists('root')
+        True
+        >>> user_exists('_this_user_does_not_exist_')
+        False
+    """
+    try:
+        pwd.getpwnam(username)
+    except KeyError:
+        return False
+    return True
+
+
 def wait_for_page_contents(url, contents, timeout=120, validate=None):
     if validate is None:
         validate = operator.contains
@@ -148,83 +539,6 @@
         time.sleep(0.1)
 
 
-class Serializer:
-    """Handle JSON (de)serialization."""
-
-    def __init__(self, path, default=None, serialize=None, deserialize=None):
-        self.path = path
-        self.default = default or {}
-        self.serialize = serialize or json.dump
-        self.deserialize = deserialize or json.load
-
-    def exists(self):
-        return os.path.exists(self.path)
-
-    def get(self):
-        if self.exists():
-            with open(self.path) as f:
-                return self.deserialize(f)
-        return self.default
-
-    def set(self, data):
-        with open(self.path, 'w') as f:
-            self.serialize(data, f)
-
-
-def get_user_ids(user):
-    """Return the uid and gid of given `user`, e.g.::
-
-        >>> get_user_ids('root')
-        (0, 0)
-    """
-    userdata = pwd.getpwnam(user)
-    return userdata.pw_uid, userdata.pw_gid
-
-
-def get_user_home(user):
-    """Return the home directory of the given `user`.
-
-        >>> get_user_home('root')
-        '/root'
-    """
-    return pwd.getpwnam(user).pw_dir
-
-
-@contextmanager
-def cd(directory):
-    """A context manager to temporary change current working dir, e.g.::
-
-        >>> import os
-        >>> os.chdir('/tmp')
-        >>> with cd('/bin'): print os.getcwd()
-        /bin
-        >>> os.getcwd()
-        '/tmp'
-    """
-    cwd = os.getcwd()
-    os.chdir(directory)
-    yield
-    os.chdir(cwd)
-
-
-@contextmanager
-def su(user):
-    """A context manager to temporary run the script as a different user."""
-    uid, gid = get_user_ids(user)
-    os.setegid(gid)
-    os.seteuid(uid)
-    current_home = os.getenv('HOME')
-    home = get_user_home(user)
-    os.environ['HOME'] = home
-    try:
-        yield Env(uid, gid, home)
-    finally:
-        os.setegid(os.getgid())
-        os.seteuid(os.getuid())
-        if current_home is not None:
-            os.environ['HOME'] = current_home
-
-
 class DictDiffer:
     """
     Calculate the difference between two dictionaries as:

=== modified file 'tests.py'
--- tests.py	2012-03-02 13:58:36 +0000
+++ tests.py	2012-03-06 16:50:24 +0000
@@ -6,44 +6,43 @@
 __metaclass__ = type
 
 
+import getpass
 import os
 from subprocess import CalledProcessError
+import tempfile
 import unittest
 
 from shelltoolbox import (
     cd,
     command,
     DictDiffer,
+    environ,
+    file_append,
+    file_prepend,
+    generate_ssh_keys,
+    get_su_command,
+    get_user_home,
+    get_user_ids,
+    get_value_from_line,
+    grep,
+    join_command,
+    mkdirs,
     run,
+    Serializer,
+    ssh,
     su,
+    user_exists,
     )
 
 
-class TestRun(unittest.TestCase):
-
-    def testSimpleCommand(self):
-        # Running a simple command (ls) works and running the command
-        # produces a string.
-        self.assertIsInstance(run('/bin/ls'), str)
-
-    def testStdoutReturned(self):
-        # Running a simple command (ls) works and running the command
-        # produces a string.
-        self.assertIn('Usage:', run('/bin/ls', '--help'))
-
-    def testCalledProcessErrorRaised(self):
-        # If an error occurs a CalledProcessError is raised with the return
-        # code, command executed, and the output of the command.
-        with self.assertRaises(CalledProcessError) as info:
-            run('ls', '--not a valid switch')
-        exception = info.exception
-        self.assertEqual(2, exception.returncode)
-        self.assertEqual("['ls', '--not a valid switch']", exception.cmd)
-        self.assertIn('unrecognized option', exception.output)
-
-    def testNoneArguments(self):
-        # Ensure None is ignored when passed as positional argument.
-        self.assertIn('Usage:', run('/bin/ls', None, '--help', None))
+class TestCdContextManager(unittest.TestCase):
+
+    def test_cd(self):
+        curdir = os.getcwd()
+        self.assertNotEqual('/var', curdir)
+        with cd('/var'):
+            self.assertEqual('/var', os.getcwd())
+        self.assertEqual(curdir, os.getcwd())
 
 
 class TestCommand(unittest.TestCase):
@@ -112,6 +111,336 @@
         self.assertEquals(expected, diff.added_or_changed)
 
 
+class TestEnviron(unittest.TestCase):
+
+    def test_existing(self):
+        # If an existing environment variable is changed, it is
+        # restored during context cleanup.
+        os.environ['MY_VARIABLE'] = 'foo'
+        with environ(MY_VARIABLE='bar'):
+            self.assertEqual('bar', os.getenv('MY_VARIABLE'))
+        self.assertEqual('foo', os.getenv('MY_VARIABLE'))
+        del os.environ['MY_VARIABLE']
+
+    def test_new(self):
+        # If a new environment variable is added, it is removed during
+        # context cleanup.
+        with environ(MY_VAR1='foo', MY_VAR2='bar'):
+            self.assertEqual('foo', os.getenv('MY_VAR1'))
+            self.assertEqual('bar', os.getenv('MY_VAR2'))
+        self.assertIsNone(os.getenv('MY_VAR1'))
+        self.assertIsNone(os.getenv('MY_VAR2'))
+
+
+class BaseCreateFile(object):
+
+    def create_file(self, content):
+        f = tempfile.NamedTemporaryFile('w', delete=False)
+        f.write(content)
+        f.close()
+        return f
+
+
+class BaseTestFile(BaseCreateFile):
+
+    base_content = 'line1\n'
+    new_content = 'new line\n'
+
+    def check_file_content(self, content, filename):
+        self.assertEqual(content, open(filename).read())
+
+
+class TestFileAppend(unittest.TestCase, BaseTestFile):
+
+    def test_append(self):
+        # Ensure the new content is correctly added at the end of the file.
+        f = self.create_file(self.base_content)
+        file_append(f.name, self.new_content)
+        self.check_file_content(self.base_content + self.new_content, f.name)
+
+    def test_existing_content(self):
+        # Ensure nothing happens if the file already contains the given
+        # content.
+        content = self.base_content + self.new_content
+        f = self.create_file(content)
+        file_append(f.name, self.new_content)
+        self.check_file_content(content, f.name)
+
+    def test_new_line(self):
+        # A new line is automatically added before the given content if it
+        # is not present at the end of current file.
+        f = self.create_file(self.base_content.strip())
+        file_append(f.name, self.new_content)
+        self.check_file_content(self.base_content + self.new_content, f.name)
+
+    def test_non_existent_file(self):
+        # Ensure the file is created if it does not exist.
+        filename = tempfile.mktemp()
+        file_append(filename, self.base_content)
+        self.check_file_content(self.base_content, filename)
+
+
+class TestFilePrepend(unittest.TestCase, BaseTestFile):
+
+    def test_prpend(self):
+        # Ensure the new content is correctly prepended at the beginning of
+        # the file.
+        f = self.create_file(self.base_content)
+        file_prepend(f.name, self.new_content)
+        self.check_file_content(self.new_content + self.base_content, f.name)
+
+    def test_existing_content(self):
+        # Ensure nothing happens if the file already starts with the given
+        # content.
+        content = self.base_content + self.new_content
+        f = self.create_file(content)
+        file_prepend(f.name, self.base_content)
+        self.check_file_content(content, f.name)
+
+    def test_move_content(self):
+        # If the file contains the given content, but not at the beginning,
+        # the content is moved on top.
+        f = self.create_file(self.base_content + self.new_content)
+        file_prepend(f.name, self.new_content)
+        self.check_file_content(self.new_content + self.base_content, f.name)
+
+
+class TestGenerateSSHKeys(unittest.TestCase):
+
+    def test_generation(self):
+        # Ensure ssh keys are correctly generated.
+        filename = tempfile.mktemp()
+        generate_ssh_keys(filename)
+        first_line = open(filename).readlines()[0].strip()
+        self.assertEqual('-----BEGIN RSA PRIVATE KEY-----', first_line)
+        pub_content = open(filename + '.pub').read()
+        self.assertTrue(pub_content.startswith('ssh-rsa'))
+
+
+class TestGetSuCommand(unittest.TestCase):
+
+    def test_current_user(self):
+        # If the su is requested as current user, the arguments are
+        # returned as given.
+        cmd = ('ls', '-l')
+        command = get_su_command(getpass.getuser(), cmd)
+        self.assertSequenceEqual(cmd, command)
+
+    def test_another_user(self):
+        # Ensure "su" is prepended and arguments are correctly quoted.
+        command = get_su_command('nobody', ('ls', '-l', 'my file'))
+        self.assertSequenceEqual(
+            ('su', 'nobody', '-c', "ls -l 'my file'"), command)
+
+
+class TestGetUserHome(unittest.TestCase):
+
+    def test_existent(self):
+        # Ensure the real home directory is returned for existing users.
+        self.assertEqual('/root', get_user_home('root'))
+
+    def test_non_existent(self):
+        # If the user does not exist, return a default /home/[username] home.
+        user = '_this_user_does_not_exist_'
+        self.assertEqual('/home/' + user, get_user_home(user))
+
+
+class TestGetUserIds(unittest.TestCase):
+
+    def test_get_user_ids(self):
+        # Ensure the correct uid and gid are returned.
+        uid, gid = get_user_ids('root')
+        self.assertEqual(0, uid)
+        self.assertEqual(0, gid)
+
+
+class TestGetValueFromLine(unittest.TestCase):
+
+    def test_get_value_from_line(self):
+        # Ensure the correct value is returned.
+        self.assertEqual('value', get_value_from_line("name = 'value'"))
+        self.assertEqual('47', get_value_from_line('name = 47'))
+
+
+class TestGrep(unittest.TestCase, BaseCreateFile):
+
+    def setUp(self):
+        self.filename = self.create_file('content1\ncontent2\n').name
+
+    def tearDown(self):
+        os.remove(self.filename)
+
+    def test_grep(self):
+        # Ensure plain text is correctly matched.
+        self.assertEqual('content2', grep('ent2', self.filename))
+        self.assertEqual('content1', grep('content', self.filename))
+
+    def test_no_match(self):
+        # Ensure the function does not return false positives.
+        self.assertIsNone(grep('no_match', self.filename))
+
+    def test_regexp(self):
+        # Ensure the function works with regular expressions.
+        self.assertEqual('content2', grep('\w2', self.filename))
+
+
+class TestJoinCommand(unittest.TestCase):
+
+    def test_normal(self):
+        # Ensure a normal command is correctly parsed.
+        command = 'ls -l'
+        self.assertEqual(command, join_command(command.split()))
+
+    def test_containing_spaces(self):
+        # Ensure args containing spaces are correctly quoted.
+        args = ('command', 'arg containig spaces')
+        self.assertEqual("command 'arg containig spaces'", join_command(args))
+
+    def test_empty(self):
+        # Ensure empty args are correctly quoted.
+        args = ('command', '')
+        self.assertEqual("command ''", join_command(args))
+
+
+class TestMkdirs(unittest.TestCase):
+
+    def test_intermediate_dirs(self):
+        # Ensure the leaf directory and all intermediate ones are created.
+        base_dir = tempfile.mktemp(suffix='/')
+        dir1 = tempfile.mktemp(prefix=base_dir)
+        dir2 = tempfile.mktemp(prefix=base_dir)
+        mkdirs(dir1, dir2)
+        self.assertTrue(os.path.isdir(dir1))
+        self.assertTrue(os.path.isdir(dir2))
+
+    def test_existing_dir(self):
+        # If the leaf directory already exists the function returns
+        # without errors.
+        mkdirs('/tmp')
+
+    def test_existing_file(self):
+        # An `OSError` is raised if the leaf path exists and it is a file.
+        f = tempfile.NamedTemporaryFile('w', delete=False)
+        f.close()
+        with self.assertRaises(OSError):
+            mkdirs(f.name)
+
+
+class TestRun(unittest.TestCase):
+
+    def testSimpleCommand(self):
+        # Running a simple command (ls) works and running the command
+        # produces a string.
+        self.assertIsInstance(run('/bin/ls'), str)
+
+    def testStdoutReturned(self):
+        # Running a simple command (ls) works and running the command
+        # produces a string.
+        self.assertIn('Usage:', run('/bin/ls', '--help'))
+
+    def testCalledProcessErrorRaised(self):
+        # If an error occurs a CalledProcessError is raised with the return
+        # code, command executed, and the output of the command.
+        with self.assertRaises(CalledProcessError) as info:
+            run('ls', '--not a valid switch')
+        exception = info.exception
+        self.assertEqual(2, exception.returncode)
+        self.assertEqual("['ls', '--not a valid switch']", exception.cmd)
+        self.assertIn('unrecognized option', exception.output)
+
+    def testNoneArguments(self):
+        # Ensure None is ignored when passed as positional argument.
+        self.assertIn('Usage:', run('/bin/ls', None, '--help', None))
+
+
+class TestSerializer(unittest.TestCase):
+
+    def setUp(self):
+        self.path = tempfile.mktemp()
+        self.data = {'key': 'value'}
+
+    def tearDown(self):
+        if os.path.exists(self.path):
+            os.remove(self.path)
+
+    def test_serializer(self):
+        # Ensure data is correctly serializied and deserialized.
+        s = Serializer(self.path)
+        s.set(self.data)
+        self.assertEqual(self.data, s.get())
+
+    def test_existence(self):
+        # Ensure the file is created only when needed.
+        s = Serializer(self.path)
+        self.assertFalse(s.exists())
+        s.set(self.data)
+        self.assertTrue(s.exists())
+
+    def test_default_value(self):
+        # If the file does not exist, the serializer returns a default value.
+        s = Serializer(self.path)
+        self.assertEqual({}, s.get())
+        s = Serializer(self.path, default=47)
+        self.assertEqual(47, s.get())
+
+    def test_another_serializer(self):
+        # It is possible to use a custom serializer (e.g. pickle).
+        import pickle
+        s = Serializer(
+            self.path, serialize=pickle.dump, deserialize=pickle.load)
+        s.set(self.data)
+        self.assertEqual(self.data, s.get())
+
+
+class TestSSH(unittest.TestCase):
+
+    def setUp(self):
+        self.last_command = None
+
+    def remove_command_options(self, cmd):
+        cmd = list(cmd)
+        del cmd[1:7]
+        return cmd
+
+    def caller(self, cmd):
+        self.last_command = self.remove_command_options(cmd)
+
+    def check_last_command(self, expected):
+        self.assertSequenceEqual(expected, self.last_command)
+
+    def test_current_user(self):
+        # Ensure ssh command is correctly generated for current user.
+        sshcall = ssh('example.com', caller=self.caller)
+        sshcall('ls -l')
+        self.check_last_command(['ssh', 'example.com', '--', 'ls -l'])
+
+    def test_another_user(self):
+        # Ensure ssh command is correctly generated for a different user.
+        sshcall = ssh('example.com', 'myuser', caller=self.caller)
+        sshcall('ls -l')
+        self.check_last_command(['ssh', 'myuser@xxxxxxxxxxx', '--', 'ls -l'])
+
+    def test_ssh_key(self):
+        # The ssh key path can be optionally provided.
+        sshcall = ssh('example.com', key='/tmp/foo', caller=self.caller)
+        sshcall('ls -l')
+        self.check_last_command([
+            'ssh', '-i', '/tmp/foo', 'example.com', '--', 'ls -l'])
+
+    def test_error(self):
+        # If the ssh command exits with an error code, a
+        # `subprocess.CalledProcessError` is raised.
+        sshcall = ssh('example.com', caller=lambda cmd: 1)
+        with self.assertRaises(CalledProcessError):
+            sshcall('ls -l')
+
+    def test_ignore_errors(self):
+        # If ignore_errors is set to True when executing the command, no error
+        # will be raised, even if the command itself returns an error code.
+        sshcall = ssh('example.com', caller=lambda cmd: 1)
+        sshcall('ls -l', ignore_errors=True)
+
+
 current_euid = os.geteuid()
 current_egid = os.getegid()
 current_home = os.environ['HOME']
@@ -180,13 +509,11 @@
             self.assertEqual(current_home, os.environ['HOME'])
 
 
-class TestCdContextManager(unittest.TestCase):
-    def test_cd(self):
-        curdir = os.getcwd()
-        self.assertNotEqual('/var', curdir)
-        with cd('/var'):
-            self.assertEqual('/var', os.getcwd())
-        self.assertEqual(curdir, os.getcwd())
+class TestUserExists(unittest.TestCase):
+
+    def test_user_exists(self):
+        self.assertTrue(user_exists('root'))
+        self.assertFalse(user_exists('_this_user_does_not_exist_'))
 
 
 if __name__ == '__main__':


Follow ups