← Back to team overview

dulwich team mailing list archive

[PATCH 1/3] Lazily read the contents of ShaFiles from disk.

 

From: Dave Borowitz <dborowitz@xxxxxxxxxx>

Previously, ShaFile.from_file read and inflated the entire contents of a
ShaFile, storing the inflated text in memory. Parsing that text was done
lazily, but this resulted in a confusing performance profile, since the
"lazy" part was in fact faster (memory/CPU-bound) than the eager part
(I/O-bound).

This change ensures that a file is only read fully from disk right
before it is about to be parsed. The first few bytes are still read
from disk to get the file type, but there is no need to slurp the
whole file. Moreover, we now maintain a distinction between ShaFiles
that are initialized from (possibly raw) strings and files initialized
with filenames. In order to parse the contents of a file, we need to
have one or the other. There is a third category of ShaFiles that
should never be parsed, where all the properties are intended to be set
explicitly (usually (but not always) during testing). These objects are
created as usual via __init__.

The tests all pass, but we do not currently have a way of testing the
point at which the implementation reads the whole file from disk.

Also, note that a call to foo.id still results in reading the whole file
and calculating its SHA-1; this behavior may change in the future.

Change-Id: Ib9f67ead0d2a812a2e43cffa54c37a1c4a219841
---
 dulwich/object_store.py       |    2 +-
 dulwich/objects.py            |  187 +++++++++++++++++++++++++----------------
 dulwich/tests/test_objects.py |   61 +++++++-------
 3 files changed, 148 insertions(+), 102 deletions(-)

diff --git a/dulwich/object_store.py b/dulwich/object_store.py
index 6456f6b..6eaa179 100644
--- a/dulwich/object_store.py
+++ b/dulwich/object_store.py
@@ -378,7 +378,7 @@ class DiskObjectStore(PackBasedObjectStore):
         path = self._get_shafile_path(sha)
         try:
             return ShaFile.from_file(path)
-        except OSError, e:
+        except (OSError, IOError), e:
             if e.errno == errno.ENOENT:
                 return None
             raise
diff --git a/dulwich/objects.py b/dulwich/objects.py
index c3d296b..350eda3 100644
--- a/dulwich/objects.py
+++ b/dulwich/objects.py
@@ -99,9 +99,10 @@ def object_class(type):
     """Get the object class corresponding to the given type.
 
     :param type: Either a type name string or a numeric type.
-    :return: The ShaFile subclass corresponding to the given type.
+    :return: The ShaFile subclass corresponding to the given type, or None if
+        type is not a valid type name/number.
     """
-    return _TYPE_MAP[type]
+    return _TYPE_MAP.get(type, None)
 
 
 def check_hexsha(hex, error_msg):
@@ -124,32 +125,40 @@ def check_identity(identity, error_msg):
 class ShaFile(object):
     """A git SHA file."""
 
-    @classmethod
-    def _parse_legacy_object(cls, map):
-        """Parse a legacy object, creating it and setting object._text"""
-        text = _decompress(map)
-        object = None
-        for cls in OBJECT_CLASSES:
-            if text.startswith(cls.type_name):
-                object = cls()
-                text = text[len(cls.type_name):]
-                break
-        assert object is not None, "%s is not a known object type" % text[:9]
-        assert text[0] == ' ', "%s is not a space" % text[0]
-        text = text[1:]
-        size = 0
-        i = 0
-        while text[0] >= '0' and text[0] <= '9':
-            if i > 0 and size == 0:
-                raise AssertionError("Size is not in canonical format")
-            size = (size * 10) + int(text[0])
-            text = text[1:]
-            i += 1
-        object._size = size
-        assert text[0] == "\0", "Size not followed by null"
-        text = text[1:]
-        object.set_raw_string(text)
-        return object
+    @staticmethod
+    def _parse_legacy_object_header(magic, f):
+        """Parse a legacy object, creating it but not reading the file."""
+        bufsize = 1024
+        decomp = zlib.decompressobj()
+        header = decomp.decompress(magic)
+        start = 0
+        end = -1
+        while end < 0:
+            header += decomp.decompress(f.read(bufsize))
+            end = header.find("\0", start)
+            start = len(header)
+        header = header[:end]
+        type_name, size = header.split(" ", 1)
+        size = int(size)  # sanity check
+        obj_class = object_class(type_name)
+        if not obj_class:
+            raise ObjectFormatException("Not a known type: %s" % type_name)
+        obj = obj_class()
+        obj._filename = f.name
+        return obj
+
+    def _parse_legacy_object(self, f):
+        """Parse a legacy object, setting the raw string."""
+        size = os.path.getsize(f.name)
+        map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
+        try:
+            text = _decompress(map)
+        finally:
+            map.close()
+        header_end = text.find('\0')
+        if header_end < 0:
+            raise ObjectFormatException("Invalid object header")
+        self.set_raw_string(text[header_end+1:])
 
     def as_legacy_object_chunks(self):
         compobj = zlib.compressobj()
@@ -162,9 +171,10 @@ class ShaFile(object):
         return "".join(self.as_legacy_object_chunks())
 
     def as_raw_chunks(self):
-        if self._needs_serialization:
+        if self._needs_parsing:
+            self._ensure_parsed()
+        else:
             self._chunked_text = self._serialize()
-            self._needs_serialization = False
         return self._chunked_text
 
     def as_raw_string(self):
@@ -181,6 +191,9 @@ class ShaFile(object):
 
     def _ensure_parsed(self):
         if self._needs_parsing:
+            if not self._chunked_text:
+                assert self._filename, "ShaFile needs either text or filename"
+                self._parse_file()
             self._deserialize(self._chunked_text)
             self._needs_parsing = False
 
@@ -195,35 +208,55 @@ class ShaFile(object):
         self._needs_parsing = True
         self._needs_serialization = False
 
-    @classmethod
-    def _parse_object(cls, map):
-        """Parse a new style object , creating it and setting object._text"""
-        used = 0
-        byte = ord(map[used])
-        used += 1
-        type_num = (byte >> 4) & 7
+    @staticmethod
+    def _parse_object_header(magic, f):
+        """Parse a new style object, creating it but not reading the file."""
+        num_type = (ord(magic[0]) >> 4) & 7
+        obj_class = object_class(num_type)
+        if not obj_class:
+            raise ObjectFormatError("Not a known type: %d" % num_type)
+        obj = obj_class()
+        obj._filename = f.name
+        return obj
+
+    def _parse_object(self, f):
+        """Parse a new style object, setting self._text."""
+        size = os.path.getsize(f.name)
+        map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
         try:
-            object = object_class(type_num)()
-        except KeyError:
-            raise AssertionError("Not a known type: %d" % type_num)
-        while (byte & 0x80) != 0:
-            byte = ord(map[used])
-            used += 1
-        raw = map[used:]
-        object.set_raw_string(_decompress(raw))
-        return object
+            # skip type and size; type must have already been determined, and we
+            # trust zlib to fail if it's otherwise corrupted
+            byte = ord(map[0])
+            used = 1
+            while (byte & 0x80) != 0:
+                byte = ord(map[used])
+                used += 1
+            raw = map[used:]
+            self.set_raw_string(_decompress(raw))
+        finally:
+            map.close()
+
+    @classmethod
+    def _is_legacy_object(cls, magic):
+        b0, b1 = map(ord, magic)
+        word = (b0 << 8) + b1
+        return b0 == 0x78 and (word % 31) == 0
 
     @classmethod
-    def _parse_file(cls, map):
-        word = (ord(map[0]) << 8) + ord(map[1])
-        if ord(map[0]) == 0x78 and (word % 31) == 0:
-            return cls._parse_legacy_object(map)
+    def _parse_file_header(cls, f):
+        magic = f.read(2)
+        if cls._is_legacy_object(magic):
+            return cls._parse_legacy_object_header(magic, f)
         else:
-            return cls._parse_object(map)
+            return cls._parse_object_header(magic, f)
 
     def __init__(self):
         """Don't call this directly"""
         self._sha = None
+        self._filename = None
+        self._chunked_text = []
+        self._needs_parsing = False
+        self._needs_serialization = True
 
     def _deserialize(self, chunks):
         raise NotImplementedError(self._deserialize)
@@ -231,15 +264,29 @@ class ShaFile(object):
     def _serialize(self):
         raise NotImplementedError(self._serialize)
 
+    def _parse_file(self):
+        f = GitFile(self._filename, 'rb')
+        try:
+            magic = f.read(2)
+            if self._is_legacy_object(magic):
+                self._parse_legacy_object(f)
+            else:
+                self._parse_object(f)
+        finally:
+            f.close()
+
     @classmethod
     def from_file(cls, filename):
-        """Get the contents of a SHA file on disk"""
-        size = os.path.getsize(filename)
+        """Get the contents of a SHA file on disk."""
         f = GitFile(filename, 'rb')
         try:
-            map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
-            shafile = cls._parse_file(map)
-            return shafile
+            try:
+                obj = cls._parse_file_header(f)
+                obj._needs_parsing = True
+                obj._needs_serialization = True
+                return obj
+            except (IndexError, ValueError), e:
+                raise ObjectFormatException("invalid object header")
         finally:
             f.close()
 
@@ -267,7 +314,7 @@ class ShaFile(object):
 
     @classmethod
     def from_string(cls, string):
-        """Create a blob from a string."""
+        """Create a ShaFile from a string."""
         obj = cls()
         obj.set_raw_string(string)
         return obj
@@ -367,14 +414,24 @@ class Blob(ShaFile):
         self.set_raw_string(data)
 
     data = property(_get_data, _set_data,
-            "The text contained within the blob object.")
+                    "The text contained within the blob object.")
 
     def _get_chunked(self):
+        self._ensure_parsed()
         return self._chunked_text
 
     def _set_chunked(self, chunks):
         self._chunked_text = chunks
 
+    def _serialize(self):
+        if not self._chunked_text:
+            self._ensure_parsed()
+        self._needs_serialization = False
+        return self._chunked_text
+
+    def _deserialize(self, chunks):
+        return "".join(chunks)
+
     chunked = property(_get_chunked, _set_chunked,
         "The text within the blob object, as chunks (not necessarily lines).")
 
@@ -424,8 +481,6 @@ class Tag(ShaFile):
 
     def __init__(self):
         super(Tag, self).__init__()
-        self._needs_parsing = False
-        self._needs_serialization = True
         self._tag_timezone_neg_utc = False
 
     @classmethod
@@ -435,13 +490,6 @@ class Tag(ShaFile):
             raise NotTagError(filename)
         return tag
 
-    @classmethod
-    def from_string(cls, string):
-        """Create a blob from a string."""
-        shafile = cls()
-        shafile.set_raw_string(string)
-        return shafile
-
     def check(self):
         """Check this object for internal consistency.
 
@@ -600,8 +648,6 @@ class Tree(ShaFile):
     def __init__(self):
         super(Tree, self).__init__()
         self._entries = {}
-        self._needs_parsing = False
-        self._needs_serialization = True
 
     @classmethod
     def from_file(cls, filename):
@@ -668,7 +714,6 @@ class Tree(ShaFile):
         # TODO: list comprehension is for efficiency in the common (small) case;
         # if memory efficiency in the large case is a concern, use a genexp.
         self._entries = dict([(n, (m, s)) for n, m, s in parsed_entries])
-        self._needs_parsing = False
 
     def check(self):
         """Check this object for internal consistency.
@@ -746,8 +791,6 @@ class Commit(ShaFile):
         super(Commit, self).__init__()
         self._parents = []
         self._encoding = None
-        self._needs_parsing = False
-        self._needs_serialization = True
         self._extra = {}
         self._author_timezone_neg_utc = False
         self._commit_timezone_neg_utc = False
diff --git a/dulwich/tests/test_objects.py b/dulwich/tests/test_objects.py
index 8d8b007..7cf87a9 100644
--- a/dulwich/tests/test_objects.py
+++ b/dulwich/tests/test_objects.py
@@ -212,11 +212,13 @@ class BlobReadTests(unittest.TestCase):
 
 class ShaFileCheckTests(unittest.TestCase):
 
-    def assertCheckFails(self, obj, data):
+    def assertCheckFails(self, cls, data):
+        obj = cls()
         obj.set_raw_string(data)
         self.assertRaises(ObjectFormatException, obj.check)
 
-    def assertCheckSucceeds(self, obj, data):
+    def assertCheckSucceeds(self, cls, data):
+        obj = cls()
         obj.set_raw_string(data)
         try:
             obj.check()
@@ -343,22 +345,22 @@ class CommitParseTests(ShaFileCheckTests):
         self.assertEquals('UTF-8', c.encoding)
 
     def test_check(self):
-        self.assertCheckSucceeds(Commit(), self.make_commit_text())
-        self.assertCheckSucceeds(Commit(), self.make_commit_text(parents=None))
-        self.assertCheckSucceeds(Commit(),
+        self.assertCheckSucceeds(Commit, self.make_commit_text())
+        self.assertCheckSucceeds(Commit, self.make_commit_text(parents=None))
+        self.assertCheckSucceeds(Commit,
                                  self.make_commit_text(encoding='UTF-8'))
 
-        self.assertCheckFails(Commit(), self.make_commit_text(tree='xxx'))
-        self.assertCheckFails(Commit(), self.make_commit_text(
+        self.assertCheckFails(Commit, self.make_commit_text(tree='xxx'))
+        self.assertCheckFails(Commit, self.make_commit_text(
           parents=[a_sha, 'xxx']))
         bad_committer = "some guy without an email address 1174773719 +0000"
-        self.assertCheckFails(Commit(),
+        self.assertCheckFails(Commit,
                               self.make_commit_text(committer=bad_committer))
-        self.assertCheckFails(Commit(),
+        self.assertCheckFails(Commit,
                               self.make_commit_text(author=bad_committer))
-        self.assertCheckFails(Commit(), self.make_commit_text(author=None))
-        self.assertCheckFails(Commit(), self.make_commit_text(committer=None))
-        self.assertCheckFails(Commit(), self.make_commit_text(
+        self.assertCheckFails(Commit, self.make_commit_text(author=None))
+        self.assertCheckFails(Commit, self.make_commit_text(committer=None))
+        self.assertCheckFails(Commit, self.make_commit_text(
           author=None, committer=None))
 
     def test_check_duplicates(self):
@@ -369,9 +371,9 @@ class CommitParseTests(ShaFileCheckTests):
             text = '\n'.join(lines)
             if lines[i].startswith('parent'):
                 # duplicate parents are ok for now
-                self.assertCheckSucceeds(Commit(), text)
+                self.assertCheckSucceeds(Commit, text)
             else:
-                self.assertCheckFails(Commit(), text)
+                self.assertCheckFails(Commit, text)
 
     def test_check_order(self):
         lines = self.make_commit_lines(parents=[a_sha], encoding='UTF-8')
@@ -382,9 +384,9 @@ class CommitParseTests(ShaFileCheckTests):
             perm = list(perm)
             text = '\n'.join(perm + rest)
             if perm == headers:
-                self.assertCheckSucceeds(Commit(), text)
+                self.assertCheckSucceeds(Commit, text)
             else:
-                self.assertCheckFails(Commit(), text)
+                self.assertCheckFails(Commit, text)
 
 
 class TreeTests(ShaFileCheckTests):
@@ -406,6 +408,7 @@ class TreeTests(ShaFileCheckTests):
     def _do_test_parse_tree(self, parse_tree):
         o = Tree.from_file(os.path.join(os.path.dirname(__file__), 'data',
                                         'trees', tree_sha))
+        o._parse_file()
         self.assertEquals([('a', 0100644, a_sha), ('b', 0100644, b_sha)],
                           list(parse_tree(o.as_raw_string())))
 
@@ -418,7 +421,7 @@ class TreeTests(ShaFileCheckTests):
         self._do_test_parse_tree(parse_tree)
 
     def test_check(self):
-        t = Tree()
+        t = Tree
         sha = hex_to_sha(a_sha)
 
         # filenames
@@ -530,26 +533,26 @@ class TagParseTests(ShaFileCheckTests):
         self.assertEquals("v2.6.22-rc7", x.name)
 
     def test_check(self):
-        self.assertCheckSucceeds(Tag(), self.make_tag_text())
-        self.assertCheckFails(Tag(), self.make_tag_text(object_sha=None))
-        self.assertCheckFails(Tag(), self.make_tag_text(object_type_name=None))
-        self.assertCheckFails(Tag(), self.make_tag_text(name=None))
-        self.assertCheckFails(Tag(), self.make_tag_text(name=''))
-        self.assertCheckFails(Tag(), self.make_tag_text(
+        self.assertCheckSucceeds(Tag, self.make_tag_text())
+        self.assertCheckFails(Tag, self.make_tag_text(object_sha=None))
+        self.assertCheckFails(Tag, self.make_tag_text(object_type_name=None))
+        self.assertCheckFails(Tag, self.make_tag_text(name=None))
+        self.assertCheckFails(Tag, self.make_tag_text(name=''))
+        self.assertCheckFails(Tag, self.make_tag_text(
           object_type_name="foobar"))
-        self.assertCheckFails(Tag(), self.make_tag_text(
+        self.assertCheckFails(Tag, self.make_tag_text(
           tagger="some guy without an email address 1183319674 -0700"))
-        self.assertCheckFails(Tag(), self.make_tag_text(
+        self.assertCheckFails(Tag, self.make_tag_text(
           tagger=("Linus Torvalds <torvalds@xxxxxxxxxxxxxxxxxxxxxxxxxx> "
                   "Sun 7 Jul 2007 12:54:34 +0700")))
-        self.assertCheckFails(Tag(), self.make_tag_text(object_sha="xxx"))
+        self.assertCheckFails(Tag, self.make_tag_text(object_sha="xxx"))
 
     def test_check_duplicates(self):
         # duplicate each of the header fields
         for i in xrange(4):
             lines = self.make_tag_lines()
             lines.insert(i, lines[i])
-            self.assertCheckFails(Tag(), '\n'.join(lines))
+            self.assertCheckFails(Tag, '\n'.join(lines))
 
     def test_check_order(self):
         lines = self.make_tag_lines()
@@ -560,9 +563,9 @@ class TagParseTests(ShaFileCheckTests):
             perm = list(perm)
             text = '\n'.join(perm + rest)
             if perm == headers:
-                self.assertCheckSucceeds(Tag(), text)
+                self.assertCheckSucceeds(Tag, text)
             else:
-                self.assertCheckFails(Tag(), text)
+                self.assertCheckFails(Tag, text)
 
 
 class CheckTests(unittest.TestCase):
-- 
1.7.0.3.295.gd8fa2




Follow ups