From a0b38d57f9794bde5804be94c50e9671bea25833 Mon Sep 17 00:00:00 2001 From: Joseph Spiros Date: Mon, 11 Apr 2011 23:27:38 -0400 Subject: [PATCH] Implemented encode_element_size and write_element_size, and added tests. This should support element size lengths greater than 8 bytes. --- ebml/core.py | 161 +++++++++++++++++++++++++++++++++------- ebml/tests/__init__.py | 0 ebml/tests/test_core.py | 63 ++++++++++++++++ 3 files changed, 198 insertions(+), 26 deletions(-) create mode 100644 ebml/tests/__init__.py create mode 100644 ebml/tests/test_core.py diff --git a/ebml/core.py b/ebml/core.py index 6e8d338..d709268 100644 --- a/ebml/core.py +++ b/ebml/core.py @@ -1,5 +1,6 @@ import struct import datetime +from math import log from .exceptions import * @@ -7,15 +8,15 @@ EBMLMaxSizeLength = 8 EBMLMaxIDLength = 4 -def _read_vint_to_bytearray(stream, max_width=EBMLMaxSizeLength): +def _read_vint_to_bytearray(stream, max_length=EBMLMaxSizeLength): """ Reads a vint from stream and returns a bytearray containing all of the bytes without doing any decoding. :arg stream: the source of the bytes :type stream: a file-like object - :arg max_width: the maximum length, in bytes, of the vint (defaults to :data:`EBMLMaxSizeLength`) - :type max_width: int + :arg max_length: the maximum length, in bytes, of the vint (defaults to :data:`EBMLMaxSizeLength`) + :type max_length: int :returns: bytearray """ @@ -25,9 +26,10 @@ def _read_vint_to_bytearray(stream, max_width=EBMLMaxSizeLength): vint_len = -7 while not marker_found: vint_len += 8 - if vint_len > max_width: - raise ParseError('vint exceeds max_width (%(max_width)i)' % { - 'max_width': max_width + if vint_len > max_length: + raise ParseError('vint length (%(vint_len)i) exceeds max_length (%(max_length)i)' % { + 'vint_len': vint_len, + 'max_length': max_length }) byte = ord(stream.read(1)) vint_bytes.append(byte) @@ -43,7 +45,7 @@ def _read_vint_to_bytearray(stream, max_width=EBMLMaxSizeLength): vint_bytes.extend(ord(remaining_byte) for remaining_byte in stream.read(remaining_bytes_len)) if len(vint_bytes) != vint_len: - raise ParseError('Unable to read truncated vint of width %(vint_len)s from stream (%(vint_bytes)s bytes available)' % { + raise ParseError('Unable to read truncated vint of length %(vint_len)s from stream (%(vint_bytes)s bytes available)' % { 'vint_len': vint_len, 'vint_bytes': len(vint_bytes) }) @@ -51,7 +53,7 @@ def _read_vint_to_bytearray(stream, max_width=EBMLMaxSizeLength): return vint_bytes -def read_element_size(stream, max_width=EBMLMaxSizeLength): +def read_element_size(stream, max_length=EBMLMaxSizeLength): """ Reads an EBML element size vint from stream and returns a tuple containing: @@ -61,13 +63,13 @@ def read_element_size(stream, max_width=EBMLMaxSizeLength): :arg stream: the source of the bytes :type stream: a file-like object - :arg max_width: the maximum length, in bytes, of the vint storing the element size (defaults to :data:`EBMLMaxSizeLength`) - :type max_width: int + :arg max_length: the maximum length, in bytes, of the vint storing the element size (defaults to :data:`EBMLMaxSizeLength`) + :type max_length: int :returns: tuple """ - vint_bytes = _read_vint_to_bytearray(stream, max_width) + vint_bytes = _read_vint_to_bytearray(stream, max_length) vint_len = len(vint_bytes) int_bytes = vint_bytes[((vint_len - 1) // 8):] @@ -90,7 +92,104 @@ def read_element_size(stream, max_width=EBMLMaxSizeLength): return value, vint_len -def read_element_id(stream, max_width=EBMLMaxIDLength): +def encode_element_size(size, min_length=None, max_length=EBMLMaxSizeLength): + """ + + Encode the size of an EBML element as a vint, optionally with a minimum length. + + :arg size: the element size, or None if undefined + :type size: int or None + :arg min_length: the minimum length, in bytes, of the resultant vint + :type min_length: int + :arg max_length: the maximum length, in bytes, of the vint storing the element size (defaults to :data:`EBMLMaxSizeLength`) + :type max_length: int + :returns: bytearray + + """ + + + if size is not None: + size_bits = bin(size).lstrip('-0b') + size_bit_length = len(size_bits) + length_required = (abs(size_bit_length - 1) // 7) + 1 + if size_bit_length % 7 == 0 and '1' in size_bits and '0' not in size_bits: + length_required += 1 + length = max(length_required, min_length) + + alignment_bit_length = 0 + while ((length + alignment_bit_length + size_bit_length) // 8) < length: + alignment_bit_length += 1 + else: + length = min_length or 1 + required_bits = (length * 8) - length + size_bit_length = required_bits + size = (2**required_bits) - 1 + alignment_bit_length = 0 + + if length > max_length: + raise ValueError('Unable to encode size (%i) with length %i (longer than limit of %i)' % (size, length, max_length)) + + data = bytearray(length) + bytes_written = 0 + marker_written = False + while bytes_written < length: + index = (length - bytes_written) - 1 + if size: + data[index] = size & 0b11111111 + size = size >> 8 + if not size and not size_bit_length % 8 == 0: + if alignment_bit_length < (8 - (size_bit_length % 8)): + mask = 0b10000000 >> ((length - 1) % 8) + data[index] = data[index] | mask + alignment_bit_length = 0 + marker_written = True + else: + alignment_bit_length -= (8 - (size_bit_length % 8)) + bytes_written += 1 + else: + if alignment_bit_length: + if alignment_bit_length < 8: + data[index] = 0b10000000 >> ((length - 1) % 8) + alignment_bit_length = 0 + bytes_written += 1 + marker_written = True + else: + data[index] = 0b00000000 + alignment_bit_length -= 8 + bytes_written += 1 + else: + remaining_bytes = length - bytes_written + if not marker_written: + data[(remaining_bytes - 1)] = 0b00000001 + zero_range = range(0, (remaining_bytes - 1)) + else: + zero_range = range(0, remaining_bytes) + for index in zero_range: + data[index] = 0b00000000 + bytes_written += remaining_bytes + + return data + + +def write_element_size(size, stream, min_length=None, max_length=EBMLMaxSizeLength): + """ + + Write the size of an EBML element to stream, optionally with a minimum length. + + :arg size: the element size, or None if undefined + :type size: int or None + :arg min_length: the minimum length, in bytes, to write + :type min_length: int + :arg max_length: the maximum length, in bytes, to write (defaults to :data:`EBMLMaxSizeLength`) + :type max_length: int + :returns: None + + """ + + stream.write(encode_element_size(size, min_length, max_length)) + + +def read_element_id(stream, max_length=EBMLMaxIDLength): """ Reads an EBML element ID vint from stream and returns a tuple containing: @@ -100,34 +199,44 @@ def read_element_id(stream, max_width=EBMLMaxIDLength): :arg stream: the source of the bytes :type stream: a file-like object - :arg max_width: the maximum length, in bytes, of the vint storing the element ID (defaults to :data:`EBMLMaxIDLength`) - :type max_width: int + :arg max_length: the maximum length, in bytes, of the vint storing the element ID (defaults to :data:`EBMLMaxIDLength`) + :type max_length: int :returns: tuple """ - vint_bytes = _read_vint_to_bytearray(stream, max_width) + vint_bytes = _read_vint_to_bytearray(stream, max_length) vint_len = len(vint_bytes) value = 0 - max_bytes = 0 - min_bytes = 0 for vint_byte in vint_bytes: - if vint_byte == 0b11111111: - max_bytes += 1 - elif vint_byte == 0: - min_bytes += 1 value = (value << 8) | vint_byte - if max_bytes == vint_len: - raise ReservedElementIDError('All value bits set to 1') - elif min_bytes == vint_len: - raise ReservedElementIDError('All value bits set to 0') - return value, vint_len +# def encode_element_id(class_id, max_length=EBMLMaxIDLength): +# length = int(((log(class_id, 2) - 1) // 7) + 1) +# +# if length > max_length: +# raise ValueError('Unable to encode ID (%x) with length %i (longer than limit of %i)' % (class_id, length, max_length)) +# +# data = bytearray(length) +# +# bytes_written = 0 +# while bytes_written < length: +# data[(length - bytes_written) - 1] = class_id & 0b11111111 +# class_id >> 8 +# bytes_written += 1 +# +# return data +# +# +# def write_element_id(class_id, stream, max_length=EBMLMaxIDLength): +# stream.write(encode_element_id(class_id, max_length)) + + def read_int(stream, size): value = 0 if size > 0: diff --git a/ebml/tests/__init__.py b/ebml/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ebml/tests/test_core.py b/ebml/tests/test_core.py new file mode 100644 index 0000000..bed71f3 --- /dev/null +++ b/ebml/tests/test_core.py @@ -0,0 +1,63 @@ +import unittest +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO +from random import randint +from ..core import * + + +class ElementSizeTestsBase(object): + def assert_roundtrip(self, value, min_length=1, max_length=EBMLMaxSizeLength): + encoded = encode_element_size(value, min_length=min_length, max_length=max_length) + encoded_stream = StringIO(encoded) + self.assertEqual(value, read_element_size(encoded_stream, max_length=max_length)[0]) + + +class ElementSizeTests(unittest.TestCase, ElementSizeTestsBase): + def test_undefined(self): + for length in xrange(1, 9): + self.assert_roundtrip(None, min_length=length) + + def test_base_10(self): + for value in (10**exp for exp in xrange(1, 16)): + self.assert_roundtrip(value) + + def test_base_2(self): + for value in (2**exp for exp in xrange(1, 56)): + self.assert_roundtrip(value) + + def test_max_base_2(self): + for value in ((2**exp) - 2 for exp in xrange(1, 57)): + self.assert_roundtrip(value) + + def test_random(self): + maximum = 2**56 - 2 + for value in (randint(0, maximum) for i in xrange(0, 200)): + self.assert_roundtrip(value) + + +class LargeElementSizeTests(unittest.TestCase, ElementSizeTestsBase): # tests values that WILL be longer than 8 bytes (EBMLMaxSizeLength) + def test_base_10(self): + for value in (10**exp for exp in xrange(17, 300)): + self.assert_roundtrip(value, max_length=1024) + + def test_base_2(self): + for value in (2**exp for exp in xrange(56, 1024)): + self.assert_roundtrip(value, max_length=1024) + + def test_max_base_2(self): + for value in ((2**exp) - 2 for exp in xrange(57, 1024)): + self.assert_roundtrip(value, max_length=1024) + + def test_random(self): + for value in (randint(2**56 - 1, 2**10240) for i in xrange(0, 200)): + self.assert_roundtrip(value, max_length=10240) + + +# class ElementIDTests(unittest.TestCase): +# def assert_roundtrip(self, value, max_length=EBMLMaxIDLength): +# encoded = encode_element_id(value, max_length=max_length) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file -- 2.20.1