← Back to team overview

dulwich-users team mailing list archive

[PATCH 1/4] Add eof() and unread_pkt_line() methods to Protocol.

 

From: Dave Borowitz <dborowitz@xxxxxxxxxx>

Change-Id: I2d64436952aee91aa4396a9dc2f028b98890343e
---
 NEWS                           |    2 ++
 dulwich/protocol.py            |   40 ++++++++++++++++++++++++++++++++++++++--
 dulwich/tests/test_protocol.py |   34 +++++++++++++++++++++++++++++++++-
 3 files changed, 73 insertions(+), 3 deletions(-)

diff --git a/NEWS b/NEWS
index 5dc4369..14cb030 100644
--- a/NEWS
+++ b/NEWS
@@ -28,6 +28,8 @@
 
   * Delegate SHA peeling to the object store.  (Dave Borowitz)
 
+  * Add eof() and unread_pkt_line() methods to Protocol. (Dave Borowitz)
+
  TESTS
 
   * Use GitFile when modifying packed-refs in tests. (Dave Borowitz)
diff --git a/dulwich/protocol.py b/dulwich/protocol.py
index af0113c..0b4e433 100644
--- a/dulwich/protocol.py
+++ b/dulwich/protocol.py
@@ -82,15 +82,24 @@ class Protocol(object):
         self.read = read
         self.write = write
         self.report_activity = report_activity
+        self._readahead = None
 
     def read_pkt_line(self):
         """Reads a pkt-line from the remote git process.
 
+        This method may read from the readahead buffer; see unread_pkt_line.
+
         :return: The next string from the stream, without the length prefix, or
             None for a flush-pkt ('0000').
         """
+        if self._readahead is None:
+            read = self.read
+        else:
+            read = self._readahead.read
+            self._readahead = None
+
         try:
-            sizestr = self.read(4)
+            sizestr = read(4)
             if not sizestr:
                 raise HangupException()
             size = int(sizestr, 16)
@@ -100,10 +109,37 @@ class Protocol(object):
                 return None
             if self.report_activity:
                 self.report_activity(size, 'read')
-            return self.read(size-4)
+            return read(size-4)
         except socket.error, e:
             raise GitProtocolError(e)
 
+    def eof(self):
+        """Test whether the protocol stream has reached EOF.
+
+        Note that this refers to the actual stream EOF and not just a flush-pkt.
+
+        :return: True if the stream is at EOF, False otherwise.
+        """
+        try:
+            next_line = self.read_pkt_line()
+        except HangupException:
+            return True
+        self.unread_pkt_line(next_line)
+        return False
+
+    def unread_pkt_line(self, data):
+        """Unread a single line of data into the readahead buffer.
+
+        This method can be used to unread a single pkt-line into a fixed
+        readahead buffer.
+
+        :param data: The data to unread, without the length prefix.
+        :raise ValueError: If more than one pkt-line is unread.
+        """
+        if self._readahead is not None:
+            raise ValueError('Attempted to unread multiple pkt-lines.')
+        self._readahead = StringIO(pkt_line(data))
+
     def read_pkt_seq(self):
         """Read a sequence of pkt-lines from the remote git process.
 
diff --git a/dulwich/tests/test_protocol.py b/dulwich/tests/test_protocol.py
index 78011e4..5c43f68 100644
--- a/dulwich/tests/test_protocol.py
+++ b/dulwich/tests/test_protocol.py
@@ -21,6 +21,9 @@
 
 from StringIO import StringIO
 
+from dulwich.errors import (
+    HangupException,
+    )
 from dulwich.protocol import (
     Protocol,
     ReceivableProtocol,
@@ -50,6 +53,24 @@ class BaseProtocolTests(object):
         self.rin.seek(0)
         self.assertEquals('cmd ', self.proto.read_pkt_line())
 
+    def test_eof(self):
+        self.rin.write('0000')
+        self.rin.seek(0)
+        self.assertFalse(self.proto.eof())
+        self.assertEquals(None, self.proto.read_pkt_line())
+        self.assertTrue(self.proto.eof())
+        self.assertRaises(HangupException, self.proto.read_pkt_line)
+
+    def test_unread_pkt_line(self):
+        self.rin.write('0007foo0000')
+        self.rin.seek(0)
+        self.assertEquals('foo', self.proto.read_pkt_line())
+        self.proto.unread_pkt_line('bar')
+        self.assertEquals('bar', self.proto.read_pkt_line())
+        self.assertEquals(None, self.proto.read_pkt_line())
+        self.proto.unread_pkt_line('baz1')
+        self.assertRaises(ValueError, self.proto.unread_pkt_line, 'baz2')
+
     def test_read_pkt_seq(self):
         self.rin.write('0008cmd 0005l0000')
         self.rin.seek(0)
@@ -91,10 +112,14 @@ class ProtocolTests(BaseProtocolTests, TestCase):
 class ReceivableStringIO(StringIO):
     """StringIO with socket-like recv semantics for testing."""
 
+    def __init__(self):
+        StringIO.__init__(self)
+        self.allow_read_past_eof = False
+
     def recv(self, size):
         # fail fast if no bytes are available; in a real socket, this would
         # block forever
-        if self.tell() == len(self.getvalue()):
+        if self.tell() == len(self.getvalue()) and not self.allow_read_past_eof:
             raise AssertionError('Blocking read past end of socket')
         if size == 1:
             return self.read(1)
@@ -111,6 +136,13 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
         self.proto._rbufsize = 8
 
+    def test_eof(self):
+        # Allow blocking reads past EOF just for this test. The only parts of
+        # the protocol that might check for EOF do not depend on the recv()
+        # semantics anyway.
+        self.rin.allow_read_past_eof = True
+        BaseProtocolTests.test_eof(self)
+
     def test_recv(self):
         all_data = '1234567' * 10  # not a multiple of bufsize
         self.rin.write(all_data)
-- 
1.7.2




References