← Back to team overview

dulwich-users team mailing list archive

[PATCH 4/4] Add shallow clone support to the server.

 

From: Dave Borowitz <dborowitz@xxxxxxxxxx>

This includes two categories of change:
 - There is an additional shallow discovery stage of the negotiation
   protocol.
 - MissingObjectFinder must understand the client's shallow state when
   determining which objects to pack.

Change-Id: I7fdf6886cfc44c9d7111633c19d97425f4f166b5
---
 NEWS                                 |    2 +
 dulwich/object_store.py              |   35 ++++++++--
 dulwich/repo.py                      |   15 +++-
 dulwich/server.py                    |   98 ++++++++++++++++++++++++---
 dulwich/tests/compat/server_utils.py |   61 ++++++++++++++++
 dulwich/tests/compat/test_web.py     |    7 ++-
 dulwich/tests/test_server.py         |  125 ++++++++++++++++++++++++++++++++++
 7 files changed, 323 insertions(+), 20 deletions(-)

diff --git a/NEWS b/NEWS
index 2405824..1473bbd 100644
--- a/NEWS
+++ b/NEWS
@@ -14,6 +14,8 @@
   * Correct short-circuiting operation for no-op fetches in the server.
     (Dave Borowitz)
 
+  * Add shallow clone support to the server. (Dave Borowitz)
+
  FEATURES
 
   * Use slots for core objects to save up on memory. (Jelmer Vernooij)
diff --git a/dulwich/object_store.py b/dulwich/object_store.py
index 162f102..8cb2aaa 100644
--- a/dulwich/object_store.py
+++ b/dulwich/object_store.py
@@ -198,7 +198,8 @@ class BaseObjectStore(object):
                     todo.append((entry_path, entry_mode, entry_hexsha))
 
     def find_missing_objects(self, haves, wants, progress=None,
-                             get_tagged=None):
+                             get_tagged=None,
+                             shallow=None, client_shallow=None):
         """Find the missing objects required for a set of revisions.
 
         :param haves: Iterable over SHAs already in common.
@@ -207,9 +208,15 @@ class BaseObjectStore(object):
             updated progress strings.
         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
             sha for including tags.
+        :param shallow: Set of shallow commit SHAs, commits whose parents
+            should not be included.
+        :param client_shallow: Set of commit SHAs that the client already has
+            marked as shallow.
         :return: Iterator over (sha, path) pairs.
         """
-        finder = MissingObjectFinder(self, haves, wants, progress, get_tagged)
+        finder = MissingObjectFinder(self, haves, wants, progress, get_tagged,
+                                     shallow=shallow,
+                                     client_shallow=client_shallow)
         return iter(finder.next, None)
 
     def find_common_revisions(self, graphwalker):
@@ -706,13 +713,15 @@ class MissingObjectFinder(object):
     :param progress: Optional function to report progress to.
     :param get_tagged: Function that returns a dict of pointed-to sha -> tag
         sha for including tags.
-    :param tagged: dict of pointed-to sha -> tag sha for including tags
+    :param shallow: Set of shallow commit SHAs, commits whose parents
+        should not be included.
+    :param client_shallow: Set of commit SHAs that the client already has
+        marked as shallow.
     """
 
     def __init__(self, object_store, haves, wants, progress=None,
-                 get_tagged=None):
-        haves = set(haves)
-        self.sha_done = haves
+                 get_tagged=None, shallow=None, client_shallow=None):
+        self.sha_done = set(haves)
         self.objects_to_send = set([(w, None, False) for w in wants
                                     if w not in haves])
         self.object_store = object_store
@@ -722,6 +731,17 @@ class MissingObjectFinder(object):
             self.progress = progress
         self._tagged = get_tagged and get_tagged() or {}
 
+        # Include everything between the client's shallow commits and our new
+        # shallow commits.
+        if client_shallow is not None:
+            for sha in client_shallow:
+                if sha not in shallow:
+                    self.add_todo([(p, None, False)
+                                   for p in self.object_store[sha].parents])
+        if shallow is None:
+            shallow = []
+        self._shallow = shallow
+
     def add_todo(self, entries):
         self.objects_to_send.update([e for e in entries
                                      if not e[0] in self.sha_done])
@@ -733,7 +753,8 @@ class MissingObjectFinder(object):
 
     def parse_commit(self, commit):
         self.add_todo([(commit.tree, "", False)])
-        self.add_todo([(p, None, False) for p in commit.parents])
+        if commit.id not in self._shallow:
+            self.add_todo([(p, None, False) for p in commit.parents])
 
     def parse_tag(self, tag):
         self.add_todo([(tag.object[1], None, False)])
diff --git a/dulwich/repo.py b/dulwich/repo.py
index 40c18bf..aa0946f 100644
--- a/dulwich/repo.py
+++ b/dulwich/repo.py
@@ -818,7 +818,8 @@ class BaseRepo(object):
         return self.get_refs()
 
     def fetch_objects(self, determine_wants, graph_walker, progress,
-                      get_tagged=None):
+                      get_tagged=None,
+                      shallow=None, client_shallow=None):
         """Fetch the missing objects required for a set of revisions.
 
         :param determine_wants: Function that takes a dictionary with heads
@@ -830,8 +831,15 @@ class BaseRepo(object):
             updated progress strings.
         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
             sha for including tags.
+        :param shallow: Set of shallow commit SHAs, commits whose parents
+            should not be included.
+        :param client_shallow: Set of commit SHAs that the client already has
+            marked as shallow.
         :return: iterator over objects, with __len__ implemented
         """
+        # TODO(dborowitz): Change this interface in a way that doesn't involve
+        # injecting arbitrary new arguments, e.g. by passing in a
+        # MissingObjectFinder or an iterator thereon.
         wants = determine_wants(self.get_refs())
         if wants is None:
             # TODO(dborowitz): find a way to short-circuit that doesn't change
@@ -839,8 +847,9 @@ class BaseRepo(object):
             return None
         haves = self.object_store.find_common_revisions(graph_walker)
         return self.object_store.iter_shas(
-          self.object_store.find_missing_objects(haves, wants, progress,
-                                                 get_tagged))
+          self.object_store.find_missing_objects(
+            haves, wants, progress=progress, get_tagged=get_tagged,
+            shallow=shallow, client_shallow=client_shallow))
 
     def get_graph_walker(self, heads=None):
         if heads is None:
diff --git a/dulwich/server.py b/dulwich/server.py
index a734628..b29813a 100644
--- a/dulwich/server.py
+++ b/dulwich/server.py
@@ -42,6 +42,7 @@ from dulwich.errors import (
 from dulwich import log_utils
 from dulwich.objects import (
     hex_to_sha,
+    Commit,
     )
 from dulwich.pack import (
     PackStreamReader,
@@ -215,7 +216,7 @@ class UploadPackHandler(Handler):
     @classmethod
     def capabilities(cls):
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
-                "ofs-delta", "no-progress", "include-tag")
+                "ofs-delta", "no-progress", "include-tag", "shallow")
 
     @classmethod
     def required_capabilities(cls):
@@ -262,7 +263,8 @@ class UploadPackHandler(Handler):
             self.repo.get_peeled)
         objects_iter = self.repo.fetch_objects(
           graph_walker.determine_wants, graph_walker, self.progress,
-          get_tagged=self.get_tagged)
+          get_tagged=self.get_tagged, shallow=graph_walker.shallow,
+          client_shallow=graph_walker.client_shallow)
 
         # Did the process short-circuit (e.g. in a stateless RPC call)? Note
         # that the client still expects a 0-object pack in most cases.
@@ -305,14 +307,55 @@ def _split_proto_line(line, allowed):
     try:
         if len(fields) == 1 and command in ('done', None):
             return (command, None)
-        elif len(fields) == 2 and command in ('want', 'have'):
-            hex_to_sha(fields[1])
-            return tuple(fields)
-    except (TypeError, AssertionError), e:
+        elif len(fields) == 2:
+            if command in ('want', 'have', 'shallow', 'unshallow'):
+                hex_to_sha(fields[1])
+                return tuple(fields)
+            elif command == 'deepen':
+                return command, int(fields[1])
+    except (TypeError, AssertionError, ValueError), e:
         raise GitProtocolError(e)
     raise GitProtocolError('Received invalid line from client: %s' % line)
 
 
+def _find_shallow(store, heads, depth):
+    """Find shallow commits according to a given depth.
+
+    :param store: An ObjectStore for looking up objects.
+    :param heads: Iterable of head SHAs to start walking from.
+    :param depth: The depth of ancestors to include.
+    :return: A tuple of (shallow, not_shallow), sets of SHAs that should be
+        considered shallow and unshallow according to the arguments. Note that
+        these sets may overlap if a commit is reachable along multiple paths.
+    """
+    parents = {}
+    def get_parents(sha):
+        result = parents.get(sha, None)
+        if not result:
+            result = store[sha].parents
+            parents[sha] = result
+        return result
+
+    todo = []  # stack of (sha, depth)
+    for head_sha in heads:
+        obj = store.peel_sha(head_sha)
+        if isinstance(obj, Commit):
+            todo.append((obj.id, 0))
+
+    not_shallow = set()
+    shallow = set()
+    while todo:
+        sha, cur_depth = todo.pop()
+        if cur_depth < depth:
+            not_shallow.add(sha)
+            new_depth = cur_depth + 1
+            todo.extend((p, new_depth) for p in get_parents(sha))
+        else:
+            shallow.add(sha)
+
+    return shallow, not_shallow
+
+
 class ProtocolGraphWalker(object):
     """A graph walker that knows the git protocol.
 
@@ -334,6 +377,8 @@ class ProtocolGraphWalker(object):
         self.stateless_rpc = handler.stateless_rpc
         self.advertise_refs = handler.advertise_refs
         self._wants = []
+        self.shallow = set()
+        self.client_shallow = set()
         self._cached = False
         self._cache = []
         self._cache_index = 0
@@ -347,6 +392,12 @@ class ProtocolGraphWalker(object):
         same regardless of ack type, and in fact is used to set the ack type of
         the ProtocolGraphWalker.
 
+        If the client has the 'shallow' capability, this method also reads and
+        responds to the 'shallow' and 'deepen' lines from the client. These are
+        not part of the wants per se, but they set up necessary state for
+        walking the graph. Additionally, later code depends on this method
+        consuming everything up to the first 'have' line.
+
         :param heads: a dict of refname->SHA1 to advertise
         :return: a list of SHA1s requested by the client
         """
@@ -377,11 +428,11 @@ class ProtocolGraphWalker(object):
         line, caps = extract_want_line_capabilities(want)
         self.handler.set_client_capabilities(caps)
         self.set_ack_type(ack_type(caps))
-        allowed = ('want', None)
+        allowed = ('want', 'shallow', 'deepen', None)
         command, sha = _split_proto_line(line, allowed)
 
         want_revs = []
-        while command != None:
+        while command == 'want':
             if sha not in values:
                 raise GitProtocolError(
                   'Client wants invalid object %s' % sha)
@@ -389,6 +440,9 @@ class ProtocolGraphWalker(object):
             command, sha = self.read_proto_line(allowed)
 
         self.set_wants(want_revs)
+        if command in ('shallow', 'deepen'):
+            self.unread_proto_line(command, sha)
+            self._handle_shallow_request(want_revs)
 
         if self.stateless_rpc and self.proto.eof():
             # The client may close the socket at this point, expecting a
@@ -398,6 +452,9 @@ class ProtocolGraphWalker(object):
 
         return want_revs
 
+    def unread_proto_line(self, command, value):
+        self.proto.unread_pkt_line('%s %s' % (command, value))
+
     def ack(self, have_ref):
         return self._impl.ack(have_ref)
 
@@ -420,10 +477,33 @@ class ProtocolGraphWalker(object):
 
         :param allowed: An iterable of command names that should be allowed.
         :return: A tuple of (command, value); see _split_proto_line.
-        :raise GitProtocolError: If an error occurred reading the line.
+        :raise UnexpectedCommandError: If an error occurred reading the line.
         """
         return _split_proto_line(self.proto.read_pkt_line(), allowed)
 
+    def _handle_shallow_request(self, wants):
+        while True:
+            command, val = self.read_proto_line(('deepen', 'shallow'))
+            if command == 'deepen':
+                depth = val
+                break
+            self.client_shallow.add(val)
+        self.read_proto_line((None,))  # consume client's flush-pkt
+
+        shallow, not_shallow = _find_shallow(self.store, wants, depth)
+
+        # Update self.shallow instead of reassigning it since we passed a
+        # reference to it before this method was called.
+        self.shallow.update(shallow - not_shallow)
+        new_shallow = self.shallow - self.client_shallow
+        unshallow = not_shallow & self.client_shallow
+
+        for sha in sorted(new_shallow):
+            self.proto.write_pkt_line('shallow %s' % sha)
+        for sha in sorted(unshallow):
+            self.proto.write_pkt_line('unshallow %s' % sha)
+        self.proto.write_pkt_line(None)
+
     def send_ack(self, sha, ack_type=''):
         if ack_type:
             ack_type = ' %s' % ack_type
diff --git a/dulwich/tests/compat/server_utils.py b/dulwich/tests/compat/server_utils.py
index 4a9da06..ef76197 100644
--- a/dulwich/tests/compat/server_utils.py
+++ b/dulwich/tests/compat/server_utils.py
@@ -20,10 +20,18 @@
 """Utilities for testing git server compatibility."""
 
 
+import os
 import select
 import socket
+import tempfile
 import threading
 
+from dulwich.objects import (
+    hex_to_sha,
+    )
+from dulwich.repo import (
+    Repo,
+    )
 from dulwich.server import (
     ReceivePackHandler,
     )
@@ -36,6 +44,32 @@ from utils import (
     )
 
 
+class _StubRepo(object):
+    """A stub repo that just contains a path to tear down."""
+
+    def __init__(self, name):
+        temp_dir = tempfile.mkdtemp()
+        self.path = os.path.join(temp_dir, name)
+        os.mkdir(self.path)
+
+
+def _get_shallow(repo):
+    shallow_file = repo.get_named_file('shallow')
+    if not shallow_file:
+        return []
+    shallows = []
+    try:
+        for line in shallow_file:
+            sha = line.strip()
+            if not sha:
+                continue
+            hex_to_sha(sha)
+            shallows.append(sha)
+    finally:
+        shallow_file.close()
+    return shallows
+
+
 class ServerTests(object):
     """Base tests for testing servers.
 
@@ -100,6 +134,33 @@ class ServerTests(object):
         self._old_repo.object_store._pack_cache = None
         self.assertReposEqual(self._old_repo, self._new_repo)
 
+    def test_shallow_clone_from_dulwich(self):
+        self._new_repo = import_repo('server_new.export')
+        self._old_repo = _StubRepo('shallow')
+        port = self._start_server(self._new_repo)
+
+        run_git_or_fail(['clone', '--mirror', '--depth=1', self.url(port),
+                         self._old_repo.path])
+        clone = self._old_repo = Repo(self._old_repo.path)
+        expected_shallow = ['94de09a530df27ac3bb613aaecdd539e0a0655e1',
+                            'da5cd81e1883c62a25bb37c4d1f8ad965b29bf8d']
+        self.assertEqual(expected_shallow, _get_shallow(clone))
+        self.assertReposNotEqual(clone, self._new_repo)
+
+        # Fetching at the same depth is a no-op.
+        run_git_or_fail(
+          ['fetch', '--depth=1', self.url(port)] + self.branch_args(),
+          cwd=self._old_repo.path)
+        self.assertEqual(expected_shallow, _get_shallow(clone))
+        self.assertReposNotEqual(clone, self._new_repo)
+
+        # The whole repo only has depth 3, so it should equal server_new.
+        run_git_or_fail(
+          ['fetch', '--depth=3', self.url(port)] + self.branch_args(),
+          cwd=self._old_repo.path)
+        self.assertEqual([], _get_shallow(clone))
+        self.assertReposEqual(clone, self._new_repo)
+
 
 class ShutdownServerMixIn:
     """Mixin that allows serve_forever to be shut down.
diff --git a/dulwich/tests/compat/test_web.py b/dulwich/tests/compat/test_web.py
index d1593db..ab5c60b 100644
--- a/dulwich/tests/compat/test_web.py
+++ b/dulwich/tests/compat/test_web.py
@@ -144,5 +144,10 @@ class DumbWebTestCase(WebTests, CompatTestCase):
         return HTTPGitApplication(backend, dumb=True)
 
     def test_push_to_dulwich(self):
-        # Note: remove this if dumb pushing is supported
+        # Note: remove this if dulwich implements dumb web pushing.
         raise TestSkipped('Dumb web pushing not supported.')
+
+    def test_shallow_clone_from_dulwich(self):
+        # Note: remove this if C git and dulwich implement dumb web shallow
+        # clones.
+        raise TestSkipped('Dumb web shallow cloning not supported.')
diff --git a/dulwich/tests/test_server.py b/dulwich/tests/test_server.py
index 5972349..ee1cd95 100644
--- a/dulwich/tests/test_server.py
+++ b/dulwich/tests/test_server.py
@@ -23,6 +23,13 @@ from dulwich.errors import (
     GitProtocolError,
     UnexpectedCommandError,
     )
+from dulwich.objects import (
+    Commit,
+    Tag,
+    )
+from dulwich.object_store import (
+    MemoryObjectStore,
+    )
 from dulwich.repo import (
     MemoryRepo,
     )
@@ -33,6 +40,7 @@ from dulwich.server import (
     MultiAckGraphWalkerImpl,
     MultiAckDetailedGraphWalkerImpl,
     _split_proto_line,
+    _find_shallow,
     ProtocolGraphWalker,
     SingleAckGraphWalkerImpl,
     UploadPackHandler,
@@ -40,6 +48,7 @@ from dulwich.server import (
 from dulwich.tests import TestCase
 from utils import (
     make_commit,
+    make_object,
     )
 
 
@@ -191,6 +200,81 @@ class UploadPackHandlerTestCase(TestCase):
         self.assertEquals({}, self._handler.get_tagged(refs, repo=self._repo))
 
 
+class FindShallowTests(TestCase):
+
+    def setUp(self):
+        self._store = MemoryObjectStore()
+
+    def make_commit(self, **attrs):
+        commit = make_commit(**attrs)
+        self._store.add_object(commit)
+        return commit
+
+    def make_linear_commits(self, n, message=''):
+        commits = []
+        parents = []
+        for _ in xrange(n):
+            commits.append(self.make_commit(parents=parents, message=message))
+            parents = [commits[-1].id]
+        return commits
+
+    def assertSameElements(self, expected, actual):
+        self.assertEqual(set(expected), set(actual))
+
+    def test_linear(self):
+        c1, c2, c3 = self.make_linear_commits(3)
+
+        self.assertEqual((set([c3.id]), set([])),
+                         _find_shallow(self._store, [c3.id], 0))
+        self.assertEqual((set([c2.id]), set([c3.id])),
+                         _find_shallow(self._store, [c3.id], 1))
+        self.assertEqual((set([c1.id]), set([c2.id, c3.id])),
+                         _find_shallow(self._store, [c3.id], 2))
+        self.assertEqual((set([]), set([c1.id, c2.id, c3.id])),
+                         _find_shallow(self._store, [c3.id], 3))
+
+    def test_multiple_independent(self):
+        a = self.make_linear_commits(2, message='a')
+        b = self.make_linear_commits(2, message='b')
+        c = self.make_linear_commits(2, message='c')
+        heads = [a[1].id, b[1].id, c[1].id]
+
+        self.assertEqual((set([a[0].id, b[0].id, c[0].id]), set(heads)),
+                         _find_shallow(self._store, heads, 1))
+
+    def test_multiple_overlapping(self):
+        # Create the following commit tree:
+        # 1--2
+        #  \
+        #   3--4
+        c1, c2 = self.make_linear_commits(2)
+        c3 = self.make_commit(parents=[c1.id])
+        c4 = self.make_commit(parents=[c3.id])
+
+        # 1 is shallow along the path from 4, but not along the path from 2.
+        self.assertEqual((set([c1.id]), set([c1.id, c2.id, c3.id, c4.id])),
+                         _find_shallow(self._store, [c2.id, c4.id], 2))
+
+    def test_merge(self):
+        c1 = self.make_commit()
+        c2 = self.make_commit()
+        c3 = self.make_commit(parents=[c1.id, c2.id])
+
+        self.assertEqual((set([c1.id, c2.id]), set([c3.id])),
+                         _find_shallow(self._store, [c3.id], 1))
+
+    def test_tag(self):
+        c1, c2 = self.make_linear_commits(2)
+        tag = make_object(Tag, name='tag', message='',
+                          tagger='Tagger <test@xxxxxxxxxxx>',
+                          tag_time=12345, tag_timezone=0,
+                          object=(Commit, c2.id))
+        self._store.add_object(tag)
+
+        self.assertEqual((set([c1.id]), set([c2.id])),
+                         _find_shallow(self._store, [tag.id], 1))
+
+
 class TestUploadPackHandler(UploadPackHandler):
     @classmethod
     def required_capabilities(self):
@@ -322,6 +406,47 @@ class ProtocolGraphWalkerTestCase(TestCase):
 
     # TODO: test commit time cutoff
 
+    def _handle_shallow_request(self, lines, heads):
+        self._walker.proto.set_output(lines)
+        self._walker._handle_shallow_request(heads)
+
+    def assertReceived(self, expected):
+        self.assertEquals(
+          expected, list(iter(self._walker.proto.get_received_line, None)))
+
+    def test_handle_shallow_request_no_client_shallows(self):
+        self._handle_shallow_request(['deepen 1\n'], [FOUR, FIVE])
+        self.assertEquals(set([TWO, THREE]), self._walker.shallow)
+        self.assertReceived([
+          'shallow %s' % TWO,
+          'shallow %s' % THREE,
+          'None',
+          ])
+
+    def test_handle_shallow_request_no_new_shallows(self):
+        lines = [
+          'shallow %s\n' % TWO,
+          'shallow %s\n' % THREE,
+          'deepen 1\n',
+          ]
+        self._handle_shallow_request(lines, [FOUR, FIVE])
+        self.assertEquals(set([TWO, THREE]), self._walker.shallow)
+        self.assertReceived(['None'])
+
+    def test_handle_shallow_request_unshallows(self):
+        lines = [
+          'shallow %s\n' % TWO,
+          'deepen 2\n',
+          ]
+        self._handle_shallow_request(lines, [FOUR, FIVE])
+        self.assertEquals(set([ONE]), self._walker.shallow)
+        self.assertReceived([
+          'shallow %s' % ONE,
+          'unshallow %s' % TWO,
+          # THREE is unshallow but was is not shallow in the client
+          'None',
+          ])
+
 
 class TestProtocolGraphWalker(object):
 
-- 
1.7.2




References