← Back to team overview

dulwich-users team mailing list archive

[PATCH 5/6] serialize_graftpoints

 

https://git.wiki.kernel.org/index.php/GraftPoint

Serialize graftpoints to the line format:

  <commit sha1> <parent sha1> [<parent sha1>]*
---
 dulwich/object_store.py |   22 ++++++++++++-----
 dulwich/repo.py         |   57 +++++++++++++++++++++++++++++++++++++++++++---
 2 files changed, 68 insertions(+), 11 deletions(-)

diff --git a/dulwich/object_store.py b/dulwich/object_store.py
index d5dda89..84abb1e 100644
--- a/dulwich/object_store.py
+++ b/dulwich/object_store.py
@@ -27,6 +27,7 @@ import itertools
 import os
 import stat
 import tempfile
+import warnings
 
 from dulwich.diff_tree import (
     tree_changes,
@@ -113,8 +114,7 @@ class BaseObjectStore(object):
         """
         raise NotImplementedError(self.get_raw)
 
-    @property
-    def grafts(self):
+    def _get_grafts(self):
         """Graftpoints are commits with parents "rewritten"
 
         https://git.wiki.kernel.org/index.php/GraftPoint
@@ -123,12 +123,20 @@ class BaseObjectStore(object):
             self._grafts = {}
         return self._grafts
 
-    def add_grafts(self, grafts={}):
-        self.grafts.update(grafts)
+    def _set_grafts(self, value):
+        grafts = {}
+        for commit, parents in value.iteritems():
+            shas = [commit] + parents
+
+            if reduce(lambda x, y: x and y, [sha in self for sha in shas]):
+                grafts[commit] = parents
+            else:
+                warnings.warn(
+                    'object_store._set_grafts - Skipping invalid graft:'
+                    ' %s %s' % (commit, ' '.join(parents)))
+        self._grafts = grafts
 
-    def remove_grafts(self, shas=[]):
-        for sha in shas:
-            del self.grafts[sha]
+    grafts = property(_get_grafts, _set_grafts)
 
     def __getitem__(self, sha):
         """Obtain an object by SHA1."""
diff --git a/dulwich/repo.py b/dulwich/repo.py
index 3e754f2..d406729 100644
--- a/dulwich/repo.py
+++ b/dulwich/repo.py
@@ -798,6 +798,9 @@ def parse_graftpoints(graft_lines=[]):
     Each line is formatted as:
         <commit sha1> <parent sha1> [<parent sha1>]*
 
+    Resulting dictionary is:
+        <commit sha1>: [<parent sha1>*]
+
     https://git.wiki.kernel.org/index.php/GraftPoint
     """
     grafts = {}
@@ -813,6 +816,24 @@ def parse_graftpoints(graft_lines=[]):
     return grafts
 
 
+def serialize_graftpoints(grafts):
+    """Convert a dictionary of grafts into string
+
+    The graft dictionary is:
+        <commit sha1>: [<parent sha1>*]
+
+    Each line is formatted as:
+        <commit sha1> <parent sha1> [<parent sha1>]*
+
+    https://git.wiki.kernel.org/index.php/GraftPoint
+
+    """
+    graft_lines = ""
+    for commit, parents in grafts.iteritems():
+        graft_lines += "%s %s\n" % (commit, ' '.join(parents))
+    return graft_lines
+
+
 class BaseRepo(object):
     """Base class for a git repository.
 
@@ -834,6 +855,7 @@ class BaseRepo(object):
         self.object_store = object_store
         self.refs = refs
 
+        self.graftpoints = {}
         self.hooks = {}
 
     def _init_files(self, bare):
@@ -1183,6 +1205,30 @@ class BaseRepo(object):
             config.get(("user", ), "name"),
             config.get(("user", ), "email"))
 
+    def add_grafts(self, updated_grafts):
+        self.graftpoints.update(updated_grafts)
+        self.refresh_graftpoints()
+
+    def remove_grafts(self, to_remove=[]):
+        for sha in to_remove:
+            del self.graftpoints[sha]
+        self.refresh_graftpoints()
+
+    def serialize_grafts(self):
+        """Get the string representation of the graftpoints
+
+        This format is writable to a graftpoint file.
+        """
+        return serialize_graftpoints(self.graftpoints)
+
+    def refresh_graftpoints(self):
+        """Set all the known graftpoints in the object store
+
+        This method must be called before GraftedCommits can
+        be retrieved from the repo.
+        """
+        self.object_store.grafts = self.graftpoints
+
     def do_commit(self, message=None, committer=None,
                   author=None, commit_timestamp=None,
                   commit_timezone=None, author_timestamp=None,
@@ -1321,8 +1367,9 @@ class Repo(BaseRepo):
 
         graft_file = self.get_named_file(os.path.join("info", "grafts"))
         if graft_file:
-            grafts = parse_graftpoints(graft_file.read().splitlines())
-            self.object_store.add_grafts(grafts)
+            self.graftpoints = \
+                parse_graftpoints(graft_file.read().splitlines())
+            self.refresh_graftpoints()
 
         self.hooks['pre-commit'] = PreCommitShellHook(self.controldir())
         self.hooks['commit-msg'] = CommitMsgShellHook(self.controldir())
@@ -1347,7 +1394,8 @@ class Repo(BaseRepo):
 
         if path == os.path.join("info", "grafts"):
             grafts = parse_graftpoints(contents.splitlines())
-            self.object_store.add_grafts(grafts)
+            self.add_grafts(grafts)
+            self.refresh_graftpoints()
 
     def get_named_file(self, path):
         """Get a file from the control dir with a specific name.
@@ -1564,7 +1612,8 @@ class MemoryRepo(BaseRepo):
 
         if path == os.path.join("info", "grafts"):
             grafts = parse_graftpoints(contents.splitlines())
-            self.object_store.add_grafts(grafts)
+            self.add_grafts(grafts)
+            self.refresh_graftpoints()
 
     def get_named_file(self, path):
         """Get a file from the control dir with a specific name.
-- 
1.7.7.1.9.g13da8



References