Implemented encode_element_size and write_element_size, and added tests. This should...
[~jspiros/python-ebml.git] / ebml / core.py
1 import struct
2 import datetime
3 from math import log
4 from .exceptions import *
5
6
7 EBMLMaxSizeLength = 8
8 EBMLMaxIDLength = 4
9
10
11 def _read_vint_to_bytearray(stream, max_length=EBMLMaxSizeLength):
12         """
13         
14         Reads a vint from stream and returns a bytearray containing all of the bytes without doing any decoding.
15         
16         :arg stream: the source of the bytes
17         :type stream: a file-like object
18         :arg max_length: the maximum length, in bytes, of the vint (defaults to :data:`EBMLMaxSizeLength`)
19         :type max_length: int
20         :returns: bytearray
21         
22         """
23         
24         marker_found = False
25         vint_bytes = bytearray()
26         vint_len = -7
27         while not marker_found:
28                 vint_len += 8
29                 if vint_len > max_length:
30                         raise ParseError('vint length (%(vint_len)i) exceeds max_length (%(max_length)i)' % {
31                                 'vint_len': vint_len,
32                                 'max_length': max_length
33                         })
34                 byte = ord(stream.read(1))
35                 vint_bytes.append(byte)
36                 for pos in range(0, 8):
37                         mask = 0b10000000 >> pos
38                         if byte & mask:
39                                 vint_len += pos
40                                 marker_found = True
41                                 break
42         
43         remaining_bytes_len = vint_len - len(vint_bytes)
44         if remaining_bytes_len > 0:
45                 vint_bytes.extend(ord(remaining_byte) for remaining_byte in stream.read(remaining_bytes_len))
46         
47         if len(vint_bytes) != vint_len:
48                 raise ParseError('Unable to read truncated vint of length %(vint_len)s from stream (%(vint_bytes)s bytes available)' % {
49                         'vint_len': vint_len,
50                         'vint_bytes': len(vint_bytes)
51                 })
52         
53         return vint_bytes
54
55
56 def read_element_size(stream, max_length=EBMLMaxSizeLength):
57         """
58         
59         Reads an EBML element size vint from stream and returns a tuple containing:
60         
61                 * the size as an integer, or None if the size is undefined
62                 * the length in bytes of the size descriptor (the vint) itself
63         
64         :arg stream: the source of the bytes
65         :type stream: a file-like object
66         :arg max_length: the maximum length, in bytes, of the vint storing the element size (defaults to :data:`EBMLMaxSizeLength`)
67         :type max_length: int
68         :returns: tuple
69         
70         """
71         
72         vint_bytes = _read_vint_to_bytearray(stream, max_length)
73         vint_len = len(vint_bytes)
74         
75         int_bytes = vint_bytes[((vint_len - 1) // 8):]
76         first_byte_mask = 0b10000000 >> ((vint_len - 1) % 8)
77         max_bytes = 0
78         
79         value = int_bytes[0] & (first_byte_mask - 1)
80         
81         if value == (first_byte_mask - 1):
82                 max_bytes += 1
83         
84         for int_byte in int_bytes[1:]:
85                 if int_byte == 0b11111111:
86                         max_bytes += 1
87                 value = (value << 8) | int_byte
88         
89         if max_bytes == len(int_bytes):
90                 value = None
91         
92         return value, vint_len
93
94
95 def encode_element_size(size, min_length=None, max_length=EBMLMaxSizeLength):
96         """
97         
98         Encode the size of an EBML element as a vint, optionally with a minimum length.
99         
100         :arg size: the element size, or None if undefined
101         :type size: int or None
102         :arg min_length: the minimum length, in bytes, of the resultant vint
103         :type min_length: int
104         :arg max_length: the maximum length, in bytes, of the vint storing the element size (defaults to :data:`EBMLMaxSizeLength`)
105         :type max_length: int
106         :returns: bytearray
107         
108         """
109         
110         
111         if size is not None:
112                 size_bits = bin(size).lstrip('-0b')
113                 size_bit_length = len(size_bits)
114                 length_required = (abs(size_bit_length - 1) // 7) + 1
115                 if size_bit_length % 7 == 0 and '1' in size_bits and '0' not in size_bits:
116                         length_required += 1
117                 length = max(length_required, min_length)
118                 
119                 alignment_bit_length = 0
120                 while ((length + alignment_bit_length + size_bit_length) // 8) < length:
121                         alignment_bit_length += 1
122         else:
123                 length = min_length or 1
124                 required_bits = (length * 8) - length
125                 size_bit_length = required_bits
126                 size = (2**required_bits) - 1
127                 alignment_bit_length = 0
128         
129         if length > max_length:
130                 raise ValueError('Unable to encode size (%i) with length %i (longer than limit of %i)' % (size, length, max_length))
131         
132         data = bytearray(length)
133         bytes_written = 0
134         marker_written = False
135         while bytes_written < length:
136                 index = (length - bytes_written) - 1
137                 if size:
138                         data[index] = size & 0b11111111
139                         size = size >> 8
140                         if not size and not size_bit_length % 8 == 0:
141                                 if alignment_bit_length < (8 - (size_bit_length % 8)):
142                                         mask = 0b10000000 >> ((length - 1) % 8)
143                                         data[index] = data[index] | mask
144                                         alignment_bit_length = 0
145                                         marker_written = True
146                                 else:
147                                         alignment_bit_length -= (8 - (size_bit_length % 8))
148                         bytes_written += 1
149                 else:
150                         if alignment_bit_length:
151                                 if alignment_bit_length < 8:
152                                         data[index] = 0b10000000 >> ((length - 1) % 8)
153                                         alignment_bit_length = 0
154                                         bytes_written += 1
155                                         marker_written = True
156                                 else:
157                                         data[index] = 0b00000000
158                                         alignment_bit_length -= 8
159                                         bytes_written += 1
160                         else:
161                                 remaining_bytes = length - bytes_written
162                                 if not marker_written:
163                                         data[(remaining_bytes - 1)] = 0b00000001
164                                         zero_range = range(0, (remaining_bytes - 1))
165                                 else:
166                                         zero_range = range(0, remaining_bytes)
167                                 for index in zero_range:
168                                         data[index] = 0b00000000
169                                 bytes_written += remaining_bytes
170         
171         return data
172
173
174 def write_element_size(size, stream, min_length=None, max_length=EBMLMaxSizeLength):
175         """
176         
177         Write the size of an EBML element to stream, optionally with a minimum length.
178         
179         :arg size: the element size, or None if undefined
180         :type size: int or None
181         :arg min_length: the minimum length, in bytes, to write
182         :type min_length: int
183         :arg max_length: the maximum length, in bytes, to write (defaults to :data:`EBMLMaxSizeLength`)
184         :type max_length: int
185         :returns: None
186         
187         """
188         
189         stream.write(encode_element_size(size, min_length, max_length))
190
191
192 def read_element_id(stream, max_length=EBMLMaxIDLength):
193         """
194         
195         Reads an EBML element ID vint from stream and returns a tuple containing:
196         
197                 * the ID as an integer
198                 * the length in bytes of the ID descriptor (the vint) itself
199         
200         :arg stream: the source of the bytes
201         :type stream: a file-like object
202         :arg max_length: the maximum length, in bytes, of the vint storing the element ID (defaults to :data:`EBMLMaxIDLength`)
203         :type max_length: int
204         :returns: tuple
205         
206         """
207         
208         vint_bytes = _read_vint_to_bytearray(stream, max_length)
209         vint_len = len(vint_bytes)
210         
211         value = 0
212         
213         for vint_byte in vint_bytes:
214                 value = (value << 8) | vint_byte
215         
216         return value, vint_len
217
218
219 # def encode_element_id(class_id, max_length=EBMLMaxIDLength):
220 #       length = int(((log(class_id, 2) - 1) // 7) + 1)
221 #       
222 #       if length > max_length:
223 #               raise ValueError('Unable to encode ID (%x) with length %i (longer than limit of %i)' % (class_id, length, max_length))
224 #       
225 #       data = bytearray(length)
226 #       
227 #       bytes_written = 0
228 #       while bytes_written < length:
229 #               data[(length - bytes_written) - 1] = class_id & 0b11111111
230 #               class_id >> 8
231 #               bytes_written += 1
232 #       
233 #       return data
234
235
236 # def write_element_id(class_id, stream, max_length=EBMLMaxIDLength):
237 #       stream.write(encode_element_id(class_id, max_length))
238
239
240 def read_int(stream, size):
241         value = 0
242         if size > 0:
243                 byte = ord(stream.read(1))
244                 if (byte & 0b10000000) == 0b10000000:
245                         value = -1 << 8
246                 value |= byte
247                 for i in range(1, size):
248                         byte = ord(stream.read(1))
249                         value = (value << 1) | byte
250         return value
251
252
253 def read_uint(stream, size):
254         value = 0
255         for i in range(0, size):
256                 byte = ord(stream.read(1))
257                 value = (value << 8) | byte
258         return value
259
260
261 def read_float(stream, size):
262         if size not in (0, 4, 8):
263                 # http://www.matroska.org/technical/specs/rfc/index.html allows for 10-byte floats.
264                 # http://www.matroska.org/technical/specs/index.html specifies 4-byte and 8-byte only.
265                 # I'm following the latter due to it being more up-to-date than the former, and because it's easier to implement.
266                 raise ValueError('floats must be 0, 4, or 8 bytes long')
267         value = 0.0
268         if size in (4, 8):
269                 data = stream.read(size)
270                 value = struct.unpack({
271                         4: '>f',
272                         8: '>d'
273                 }[size], data)[0]
274         return value
275
276
277 def read_string(stream, size):
278         value = ''
279         if size > 0:
280                 value = stream.read(size)
281         return value
282
283
284 def read_unicode(stream, size):
285         value = u''
286         if size > 0:
287                 data = stream.read(size)
288                 value = unicode(data, 'utf_8')
289         return value
290
291
292 def read_date(stream):
293         size = 8 # date is always an 8-byte signed integer
294         data = stream.read(size)
295         nanoseconds = struct.unpack('>q', data)[0]
296         delta = datetime.timedelta(microseconds=(nanoseconds // 1000))
297         return datetime.datetime(2001, 1, 1) + delta
298
299
300 def read_binary(stream, size):
301         return stream.read(size)