Implemented encode_element_size and write_element_size, and added tests. This should...
authorJoseph Spiros <joseph.spiros@ithinksw.com>
Tue, 12 Apr 2011 03:27:38 +0000 (23:27 -0400)
committerJoseph Spiros <joseph.spiros@ithinksw.com>
Tue, 12 Apr 2011 03:27:38 +0000 (23:27 -0400)
ebml/core.py
ebml/tests/__init__.py [new file with mode: 0644]
ebml/tests/test_core.py [new file with mode: 0644]

index 6e8d338..d709268 100644 (file)
@@ -1,5 +1,6 @@
 import struct
 import datetime
 import struct
 import datetime
+from math import log
 from .exceptions import *
 
 
 from .exceptions import *
 
 
@@ -7,15 +8,15 @@ EBMLMaxSizeLength = 8
 EBMLMaxIDLength = 4
 
 
 EBMLMaxIDLength = 4
 
 
-def _read_vint_to_bytearray(stream, max_width=EBMLMaxSizeLength):
+def _read_vint_to_bytearray(stream, max_length=EBMLMaxSizeLength):
        """
        
        Reads a vint from stream and returns a bytearray containing all of the bytes without doing any decoding.
        
        :arg stream: the source of the bytes
        :type stream: a file-like object
        """
        
        Reads a vint from stream and returns a bytearray containing all of the bytes without doing any decoding.
        
        :arg stream: the source of the bytes
        :type stream: a file-like object
-       :arg max_width: the maximum length, in bytes, of the vint (defaults to :data:`EBMLMaxSizeLength`)
-       :type max_width: int
+       :arg max_length: the maximum length, in bytes, of the vint (defaults to :data:`EBMLMaxSizeLength`)
+       :type max_length: int
        :returns: bytearray
        
        """
        :returns: bytearray
        
        """
@@ -25,9 +26,10 @@ def _read_vint_to_bytearray(stream, max_width=EBMLMaxSizeLength):
        vint_len = -7
        while not marker_found:
                vint_len += 8
        vint_len = -7
        while not marker_found:
                vint_len += 8
-               if vint_len > max_width:
-                       raise ParseError('vint exceeds max_width (%(max_width)i)' % {
-                               'max_width': max_width
+               if vint_len > max_length:
+                       raise ParseError('vint length (%(vint_len)i) exceeds max_length (%(max_length)i)' % {
+                               'vint_len': vint_len,
+                               'max_length': max_length
                        })
                byte = ord(stream.read(1))
                vint_bytes.append(byte)
                        })
                byte = ord(stream.read(1))
                vint_bytes.append(byte)
@@ -43,7 +45,7 @@ def _read_vint_to_bytearray(stream, max_width=EBMLMaxSizeLength):
                vint_bytes.extend(ord(remaining_byte) for remaining_byte in stream.read(remaining_bytes_len))
        
        if len(vint_bytes) != vint_len:
                vint_bytes.extend(ord(remaining_byte) for remaining_byte in stream.read(remaining_bytes_len))
        
        if len(vint_bytes) != vint_len:
-               raise ParseError('Unable to read truncated vint of width %(vint_len)s from stream (%(vint_bytes)s bytes available)' % {
+               raise ParseError('Unable to read truncated vint of length %(vint_len)s from stream (%(vint_bytes)s bytes available)' % {
                        'vint_len': vint_len,
                        'vint_bytes': len(vint_bytes)
                })
                        'vint_len': vint_len,
                        'vint_bytes': len(vint_bytes)
                })
@@ -51,7 +53,7 @@ def _read_vint_to_bytearray(stream, max_width=EBMLMaxSizeLength):
        return vint_bytes
 
 
        return vint_bytes
 
 
-def read_element_size(stream, max_width=EBMLMaxSizeLength):
+def read_element_size(stream, max_length=EBMLMaxSizeLength):
        """
        
        Reads an EBML element size vint from stream and returns a tuple containing:
        """
        
        Reads an EBML element size vint from stream and returns a tuple containing:
@@ -61,13 +63,13 @@ def read_element_size(stream, max_width=EBMLMaxSizeLength):
        
        :arg stream: the source of the bytes
        :type stream: a file-like object
        
        :arg stream: the source of the bytes
        :type stream: a file-like object
-       :arg max_width: the maximum length, in bytes, of the vint storing the element size (defaults to :data:`EBMLMaxSizeLength`)
-       :type max_width: int
+       :arg max_length: the maximum length, in bytes, of the vint storing the element size (defaults to :data:`EBMLMaxSizeLength`)
+       :type max_length: int
        :returns: tuple
        
        """
        
        :returns: tuple
        
        """
        
-       vint_bytes = _read_vint_to_bytearray(stream, max_width)
+       vint_bytes = _read_vint_to_bytearray(stream, max_length)
        vint_len = len(vint_bytes)
        
        int_bytes = vint_bytes[((vint_len - 1) // 8):]
        vint_len = len(vint_bytes)
        
        int_bytes = vint_bytes[((vint_len - 1) // 8):]
@@ -90,7 +92,104 @@ def read_element_size(stream, max_width=EBMLMaxSizeLength):
        return value, vint_len
 
 
        return value, vint_len
 
 
-def read_element_id(stream, max_width=EBMLMaxIDLength):
+def encode_element_size(size, min_length=None, max_length=EBMLMaxSizeLength):
+       """
+       
+       Encode the size of an EBML element as a vint, optionally with a minimum length.
+       
+       :arg size: the element size, or None if undefined
+       :type size: int or None
+       :arg min_length: the minimum length, in bytes, of the resultant vint
+       :type min_length: int
+       :arg max_length: the maximum length, in bytes, of the vint storing the element size (defaults to :data:`EBMLMaxSizeLength`)
+       :type max_length: int
+       :returns: bytearray
+       
+       """
+       
+       
+       if size is not None:
+               size_bits = bin(size).lstrip('-0b')
+               size_bit_length = len(size_bits)
+               length_required = (abs(size_bit_length - 1) // 7) + 1
+               if size_bit_length % 7 == 0 and '1' in size_bits and '0' not in size_bits:
+                       length_required += 1
+               length = max(length_required, min_length)
+               
+               alignment_bit_length = 0
+               while ((length + alignment_bit_length + size_bit_length) // 8) < length:
+                       alignment_bit_length += 1
+       else:
+               length = min_length or 1
+               required_bits = (length * 8) - length
+               size_bit_length = required_bits
+               size = (2**required_bits) - 1
+               alignment_bit_length = 0
+       
+       if length > max_length:
+               raise ValueError('Unable to encode size (%i) with length %i (longer than limit of %i)' % (size, length, max_length))
+       
+       data = bytearray(length)
+       bytes_written = 0
+       marker_written = False
+       while bytes_written < length:
+               index = (length - bytes_written) - 1
+               if size:
+                       data[index] = size & 0b11111111
+                       size = size >> 8
+                       if not size and not size_bit_length % 8 == 0:
+                               if alignment_bit_length < (8 - (size_bit_length % 8)):
+                                       mask = 0b10000000 >> ((length - 1) % 8)
+                                       data[index] = data[index] | mask
+                                       alignment_bit_length = 0
+                                       marker_written = True
+                               else:
+                                       alignment_bit_length -= (8 - (size_bit_length % 8))
+                       bytes_written += 1
+               else:
+                       if alignment_bit_length:
+                               if alignment_bit_length < 8:
+                                       data[index] = 0b10000000 >> ((length - 1) % 8)
+                                       alignment_bit_length = 0
+                                       bytes_written += 1
+                                       marker_written = True
+                               else:
+                                       data[index] = 0b00000000
+                                       alignment_bit_length -= 8
+                                       bytes_written += 1
+                       else:
+                               remaining_bytes = length - bytes_written
+                               if not marker_written:
+                                       data[(remaining_bytes - 1)] = 0b00000001
+                                       zero_range = range(0, (remaining_bytes - 1))
+                               else:
+                                       zero_range = range(0, remaining_bytes)
+                               for index in zero_range:
+                                       data[index] = 0b00000000
+                               bytes_written += remaining_bytes
+       
+       return data
+
+
+def write_element_size(size, stream, min_length=None, max_length=EBMLMaxSizeLength):
+       """
+       
+       Write the size of an EBML element to stream, optionally with a minimum length.
+       
+       :arg size: the element size, or None if undefined
+       :type size: int or None
+       :arg min_length: the minimum length, in bytes, to write
+       :type min_length: int
+       :arg max_length: the maximum length, in bytes, to write (defaults to :data:`EBMLMaxSizeLength`)
+       :type max_length: int
+       :returns: None
+       
+       """
+       
+       stream.write(encode_element_size(size, min_length, max_length))
+
+
+def read_element_id(stream, max_length=EBMLMaxIDLength):
        """
        
        Reads an EBML element ID vint from stream and returns a tuple containing:
        """
        
        Reads an EBML element ID vint from stream and returns a tuple containing:
@@ -100,34 +199,44 @@ def read_element_id(stream, max_width=EBMLMaxIDLength):
        
        :arg stream: the source of the bytes
        :type stream: a file-like object
        
        :arg stream: the source of the bytes
        :type stream: a file-like object
-       :arg max_width: the maximum length, in bytes, of the vint storing the element ID (defaults to :data:`EBMLMaxIDLength`)
-       :type max_width: int
+       :arg max_length: the maximum length, in bytes, of the vint storing the element ID (defaults to :data:`EBMLMaxIDLength`)
+       :type max_length: int
        :returns: tuple
        
        """
        
        :returns: tuple
        
        """
        
-       vint_bytes = _read_vint_to_bytearray(stream, max_width)
+       vint_bytes = _read_vint_to_bytearray(stream, max_length)
        vint_len = len(vint_bytes)
        
        value = 0
        vint_len = len(vint_bytes)
        
        value = 0
-       max_bytes = 0
-       min_bytes = 0
        
        for vint_byte in vint_bytes:
        
        for vint_byte in vint_bytes:
-               if vint_byte == 0b11111111:
-                       max_bytes += 1
-               elif vint_byte == 0:
-                       min_bytes += 1
                value = (value << 8) | vint_byte
        
                value = (value << 8) | vint_byte
        
-       if max_bytes == vint_len:
-               raise ReservedElementIDError('All value bits set to 1')
-       elif min_bytes == vint_len:
-               raise ReservedElementIDError('All value bits set to 0')
-       
        return value, vint_len
 
 
        return value, vint_len
 
 
+# def encode_element_id(class_id, max_length=EBMLMaxIDLength):
+#      length = int(((log(class_id, 2) - 1) // 7) + 1)
+#      
+#      if length > max_length:
+#              raise ValueError('Unable to encode ID (%x) with length %i (longer than limit of %i)' % (class_id, length, max_length))
+#      
+#      data = bytearray(length)
+#      
+#      bytes_written = 0
+#      while bytes_written < length:
+#              data[(length - bytes_written) - 1] = class_id & 0b11111111
+#              class_id >> 8
+#              bytes_written += 1
+#      
+#      return data
+# 
+# 
+# def write_element_id(class_id, stream, max_length=EBMLMaxIDLength):
+#      stream.write(encode_element_id(class_id, max_length))
+
+
 def read_int(stream, size):
        value = 0
        if size > 0:
 def read_int(stream, size):
        value = 0
        if size > 0:
diff --git a/ebml/tests/__init__.py b/ebml/tests/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/ebml/tests/test_core.py b/ebml/tests/test_core.py
new file mode 100644 (file)
index 0000000..bed71f3
--- /dev/null
@@ -0,0 +1,63 @@
+import unittest
+try:
+       from cStringIO import StringIO
+except ImportError:
+       from StringIO import StringIO
+from random import randint
+from ..core import *
+
+
+class ElementSizeTestsBase(object):
+       def assert_roundtrip(self, value, min_length=1, max_length=EBMLMaxSizeLength):
+               encoded = encode_element_size(value, min_length=min_length, max_length=max_length)
+               encoded_stream = StringIO(encoded)
+               self.assertEqual(value, read_element_size(encoded_stream, max_length=max_length)[0])
+
+
+class ElementSizeTests(unittest.TestCase, ElementSizeTestsBase):
+       def test_undefined(self):
+               for length in xrange(1, 9):
+                       self.assert_roundtrip(None, min_length=length)
+       
+       def test_base_10(self):
+               for value in (10**exp for exp in xrange(1, 16)):
+                       self.assert_roundtrip(value)
+       
+       def test_base_2(self):
+               for value in (2**exp for exp in xrange(1, 56)):
+                       self.assert_roundtrip(value)
+       
+       def test_max_base_2(self):
+               for value in ((2**exp) - 2 for exp in xrange(1, 57)):
+                       self.assert_roundtrip(value)
+               
+       def test_random(self):
+               maximum = 2**56 - 2
+               for value in (randint(0, maximum) for i in xrange(0, 200)):
+                       self.assert_roundtrip(value)
+
+
+class LargeElementSizeTests(unittest.TestCase, ElementSizeTestsBase): # tests values that WILL be longer than 8 bytes (EBMLMaxSizeLength)
+       def test_base_10(self):
+               for value in (10**exp for exp in xrange(17, 300)):
+                       self.assert_roundtrip(value, max_length=1024)
+       
+       def test_base_2(self):
+               for value in (2**exp for exp in xrange(56, 1024)):
+                       self.assert_roundtrip(value, max_length=1024)
+       
+       def test_max_base_2(self):
+               for value in ((2**exp) - 2 for exp in xrange(57, 1024)):
+                       self.assert_roundtrip(value, max_length=1024)
+       
+       def test_random(self):
+               for value in (randint(2**56 - 1, 2**10240) for i in xrange(0, 200)):
+                       self.assert_roundtrip(value, max_length=10240)
+
+
+# class ElementIDTests(unittest.TestCase):
+#      def assert_roundtrip(self, value, max_length=EBMLMaxIDLength):
+#              encoded = encode_element_id(value, max_length=max_length)
+
+if __name__ == '__main__':
+       unittest.main()
\ No newline at end of file