Fixed handling of null-termination in the string and unicode string readers.
[~jspiros/python-ebml.git] / ebml / core.py
1 import datetime
2 import struct
3
4
5 __all__ = (
6         'read_element_id',
7         'read_element_size',
8         'read_unsigned_integer',
9         'read_signed_integer',
10         'read_float',
11         'read_string',
12         'read_unicode_string',
13         'read_date',
14         'encode_element_id',
15         'encode_element_size',
16         'encode_unsigned_integer',
17         'encode_signed_integer',
18         'encode_float',
19         'encode_string',
20         'encode_unicode_string',
21         'encode_date',
22 )
23
24
25 MAXIMUM_ELEMENT_ID_LENGTH = 4
26 MAXIMUM_ELEMENT_SIZE_LENGTH = 8
27 MAXIMUM_UNSIGNED_INTEGER_LENGTH = 8
28 MAXIMUM_SIGNED_INTEGER_LENGTH = 8
29
30
31 def maximum_element_size_for_length(length):
32         """
33         
34         Returns the maximum element size representable in a given number of bytes.
35         
36         :arg length: the limit on the length of the encoded representation in bytes
37         :type length: int
38         :returns: the maximum element size representable
39         :rtype: int
40         
41         """
42         
43         return (2**(7*length)) - 2
44
45
46 def decode_vint_length(byte, mask=True):
47         length = None
48         value_mask = None
49         for n in xrange(1, 9):
50                 if byte & (2**8 - (2**(8 - n))) == 2**(8 - n):
51                         length = n
52                         value_mask = (2**(8 - n)) - 1
53                         break
54         if length is None:
55                 raise IOError('Cannot decode invalid varible-length integer.')
56         if mask:
57                 byte = byte & value_mask
58         return length, byte
59
60
61 def read_element_id(stream):
62         """
63         
64         Reads an element ID from a file-like object.
65         
66         :arg stream: the file-like object
67         :returns: the decoded element ID and its length in bytes
68         :rtype: tuple
69         
70         """
71         
72         byte = ord(stream.read(1))
73         length, id_ = decode_vint_length(byte, False)
74         if length > 4:
75                 raise IOError('Cannot decode element ID with length > 8.')
76         for i in xrange(0, length - 1):
77                 byte = ord(stream.read(1))
78                 id_ = (id_ * 2**8) + byte
79         return id_, length
80
81
82 def read_element_size(stream):
83         """
84         
85         Reads an element size from a file-like object.
86         
87         :arg stream: the file-like object
88         :returns: the decoded size (or None if unknown) and the length of the descriptor in bytes
89         :rtype: tuple
90         
91         """
92         
93         byte = ord(stream.read(1))
94         length, size = decode_vint_length(byte)
95         
96         for i in xrange(0, length - 1):
97                 byte = ord(stream.read(1))
98                 size = (size * 2**8) + byte
99         
100         if size == maximum_element_size_for_length(length) + 1:
101                 size = None
102         
103         return size, length
104
105
106 def read_unsigned_integer(stream, size):
107         """
108         
109         Reads an encoded unsigned integer value from a file-like object.
110         
111         :arg stream: the file-like object
112         :arg size: the number of bytes to read and decode
113         :type size: int
114         :returns: the decoded unsigned integer value
115         :rtype: int
116         
117         """
118         
119         value = 0
120         for i in xrange(0, size):
121                 byte = ord(stream.read(1))
122                 value = (value << 8) | byte
123         return value
124
125
126 def read_signed_integer(stream, size):
127         """
128         
129         Reads an encoded signed integer value from a file-like object.
130         
131         :arg stream: the file-like object
132         :arg size: the number of bytes to read and decode
133         :type size: int
134         :returns: the decoded signed integer value
135         :rtype: int
136         
137         """
138         
139         value = 0
140         if size > 0:
141                 first_byte = ord(stream.read(1))
142                 value = first_byte
143                 for i in xrange(1, size):
144                         byte = ord(stream.read(1))
145                         value = (value << 8) | byte
146                 if (first_byte & 0b10000000) == 0b10000000:
147                         value = -(2**(size*8) - value)
148         return value
149
150
151 def read_float(stream, size):
152         """
153         
154         Reads an encoded floating point value from a file-like object.
155         
156         :arg stream: the file-like object
157         :arg size: the number of bytes to read and decode (must be 0, 4, or 8)
158         :type size: int
159         :returns: the decoded floating point value
160         :rtype: float
161         
162         """
163         
164         if size not in (0, 4, 8):
165                 raise IOError('Cannot read floating point values with lengths other than 0, 4, or 8 bytes.')
166         value = 0.0
167         if size in (4, 8):
168                 data = stream.read(size)
169                 value = struct.unpack({
170                         4: '>f',
171                         8: '>d'
172                 }[size], data)[0]
173         return value
174
175
176 def read_string(stream, size):
177         """
178         
179         Reads an encoded ASCII string value from a file-like object.
180         
181         :arg stream: the file-like object
182         :arg size: the number of bytes to read and decode
183         :type size: int
184         :returns: the decoded ASCII string value
185         :rtype: str
186         
187         """
188         
189         value = ''
190         if size > 0:
191                 value = stream.read(size)
192                 value = value.partition(chr(0))[0]
193         return value
194
195
196 def read_unicode_string(stream, size):
197         """
198         
199         Reads an encoded unicode string value from a file-like object.
200         
201         :arg stream: the file-like object
202         :arg size: the number of bytes to read and decode
203         :type size: int
204         :returns: the decoded unicode string value
205         :rtype: unicode
206         
207         """
208         
209         value = u''
210         if size > 0:
211                 data = stream.read(size)
212                 data = data.partition(chr(0))[0]
213                 value = unicode(data, 'utf_8')
214         return value
215
216
217 def read_date(stream, size):
218         """
219         
220         Reads an encoded date (and time) value from a file-like object.
221         
222         :arg stream: the file-like object
223         :arg size: the number of bytes to read and decode (must be 8)
224         :type size: int
225         :returns: the decoded date (and time) value
226         :rtype: datetime
227         
228         """
229         
230         if size != 8:
231                 raise IOError('Cannot read date values with lengths other than 8 bytes.')
232         data = stream.read(size)
233         nanoseconds = struct.unpack('>q', data)[0]
234         delta = datetime.timedelta(microseconds=(nanoseconds // 1000))
235         return datetime.datetime(2001, 1, 1, tzinfo=None) + delta
236
237
238 def octet(n):
239         """
240         
241         Limits an integer or byte to 8 bits.
242         
243         """
244         
245         return n & 0b11111111
246
247
248 def vint_mask_for_length(length):
249         """
250         
251         Returns the bitmask for the first byte of a variable-length integer (used for element ID and size descriptors).
252         
253         :arg length: the length of the variable-length integer
254         :type length: int
255         :returns: the bitmask for the first byte of the variable-length integer
256         :rtype: int
257         
258         """
259         
260         return 0b10000000 >> (length - 1)
261
262
263 def encode_element_id(element_id):
264         """
265         
266         Encodes an element ID.
267         
268         :arg element_id: an element ID
269         :type element_id: int
270         :returns: the encoded representation bytes
271         :rtype: bytearray
272         
273         """
274         
275         length = MAXIMUM_ELEMENT_ID_LENGTH
276         while length and not (element_id & (vint_mask_for_length(length) << ((length - 1) * 8))):
277                 length -= 1
278         if not length:
279                 raise ValueError('Cannot encode invalid element ID %s.' % hex(element_id))
280         
281         data = bytearray(length)
282         for index in reversed(xrange(length)):
283                 data[index] = octet(element_id)
284                 element_id >>= 8
285         
286         return data
287
288
289 def encode_element_size(element_size, length=None):
290         """
291         
292         Encodes an element size. If element_size is None, the size will be encoded as unknown. If length is not None, the size will be encoded in that many bytes; otherwise, the size will be encoded in the minimum number of bytes required, or in 8 bytes if the size is unknown (element_size is None).
293         
294         :arg element_size: the element size, or None if unknown
295         :type element_size: int or None
296         :arg length: the length of the encoded representation, or None for the minimum length required (defaults to None)
297         :type length: int or None
298         :returns: the encoded representation bytes
299         :rtype: bytearray
300         
301         """
302         
303         if length is not None and (length < 1 or length > MAXIMUM_ELEMENT_SIZE_LENGTH):
304                 raise ValueError('Cannot encode element sizes into representations shorter than one byte long or longer than %i bytes long.' % MAXIMUM_ELEMENT_SIZE_LENGTH)
305         if element_size is not None:
306                 if element_size > maximum_element_size_for_length(MAXIMUM_ELEMENT_SIZE_LENGTH if length is None else length):
307                         raise ValueError('Cannot encode element size %i as it would have an encoded representation longer than %i bytes.' % (element_size, (MAXIMUM_ELEMENT_SIZE_LENGTH if length is None else length)))
308                 req_length = 1
309                 while (element_size >> ((req_length - 1) * 8)) >= (vint_mask_for_length(req_length) - 1) and req_length < MAXIMUM_ELEMENT_SIZE_LENGTH:
310                         req_length += 1
311                 if length is None:
312                         length = req_length
313         else:
314                 if length is None:
315                         length = 8 # other libraries do this, so unless another length is specified for the unknown size descriptor, do as they do to avoid compatibility issues.
316                 element_size = maximum_element_size_for_length(length) + 1
317         
318         data = bytearray(length)
319         for index in reversed(xrange(length)):
320                 data[index] = octet(element_size)
321                 element_size >>= 8
322                 if not index:
323                         data[index] = data[index] | vint_mask_for_length(length)
324         
325         return data
326
327
328 def encode_unsigned_integer(uint, length=None):
329         """
330         
331         Encodes an unsigned integer value. If length is not None, uint will be encoded in that many bytes; otherwise, uint will be encoded in the minimum number of bytes required. If uint is None or 0, the minimum number of bytes required is 0.
332         
333         :arg uint: the unsigned integer value
334         :type uint: int
335         :arg length: the length of the encoded representation, or None for the minimum length required (defaults to None)
336         :type length: int or None
337         :returns: the encoded representation bytes
338         :rtype: bytearray
339         
340         """
341         
342         if uint is None:
343                 uint = 0
344         if uint > ((2**((MAXIMUM_UNSIGNED_INTEGER_LENGTH if length is None else length) * 8)) - 1):
345                 raise ValueError('Cannot encode unsigned integer value %i as it would have an encoded representation longer than %i bytes.' % (uint, (MAXIMUM_UNSIGNED_INTEGER_LENGTH if length is None else length)))
346         elif uint == 0:
347                 req_length = 0
348         else:
349                 req_length = 1
350                 while uint >= (1 << (req_length * 8)) and req_length < MAXIMUM_UNSIGNED_INTEGER_LENGTH:
351                         req_length += 1
352         if length is None:
353                 length = req_length
354         
355         data = bytearray(length)
356         for index in reversed(xrange(length)):
357                 data[index] = octet(uint)
358                 uint >>= 8
359         
360         return data
361
362
363 def encode_signed_integer(sint, length=None):
364         """
365         
366         Encodes a signed integer value. If length is not None, sint will be encoded in that many bytes; otherwise, sint will be encoded in the minimum number of bytes required. If sint is None or 0, the minimum number of bytes required is 0.
367         
368         :arg sint: the signed integer value
369         :type sint: int
370         :arg length: the length of the encoded representation, or None for the minimum length required (defaults to None)
371         :type length: int or None
372         :returns: the encoded representation bytes
373         :rtype: bytearray
374         
375         """
376         
377         if sint is None:
378                 sint = 0
379         if not (-(2**(7+(8*((MAXIMUM_SIGNED_INTEGER_LENGTH if length is None else length)-1)))) <= sint <= (2**(7+(8*((MAXIMUM_SIGNED_INTEGER_LENGTH if length is None else length)-1))))-1):
380                 raise ValueError('Cannot encode signed integer value %i as it would have an encoded representation longer than %i bytes.' % (sint, (MAXIMUM_SIGNED_INTEGER_LENGTH if length is None else length)))
381         elif sint == 0:
382                 req_length = 0
383                 uint = 0
384                 if length is None:
385                         length = req_length
386         else:
387                 uint = ((-sint - 1) << 1) if sint < 0 else (sint << 1)
388                 req_length = 1
389                 while uint >= (1 << (req_length * 8)) and req_length < MAXIMUM_UNSIGNED_INTEGER_LENGTH:
390                         req_length += 1
391                 if length is None:
392                         length = req_length
393                 if sint >= 0:
394                         uint = sint
395                 else:
396                         uint = 2**(length*8) - abs(sint)
397         
398         data = bytearray(length)
399         for index in reversed(xrange(length)):
400                 data[index] = octet(uint)
401                 uint >>= 8
402         
403         return data
404
405
406 def encode_float(float_, length=None):
407         """
408         
409         Encodes a floating point value. If length is not None, float_ will be encoded in that many bytes; otherwise, float_ will be encoded in 0 bytes if float_ is None or 0, and 8 bytes in all other cases. If float_ is not None or 0 and length is 0, ValueError will be raised.
410         
411         :arg float_: the floating point value
412         :type float_: float
413         :arg length: the length of the encoded representation, or None (defaults to None)
414         :type length: int or None
415         :returns: the encoded representation bytes
416         :rtype: bytearray
417         
418         """
419         
420         if length not in (None, 0, 4, 8):
421                 raise ValueError('Cannot encode floating point values with lengths other than 0, 4, or 8 bytes.')
422         if float_ is None:
423                 float_ = 0.0
424         if float_ == 0.0:
425                 if length is None:
426                         length = 0
427         else:
428                 if length is None:
429                         length = 8
430                 elif length == 0:
431                         raise ValueError('Cannot encode floating point value %f as it would have an encoded representation longer than 0 bytes.' % float_)
432         
433         if length in (4, 8):
434                 data = bytearray(struct.pack({
435                         4: '>f',
436                         8: '>d'
437                 }[length], float_))
438         else:
439                 data = bytearray()
440         
441         return data
442
443
444 def encode_string(string, length=None):
445         """
446         
447         Encodes an ASCII string value. If length is not None, string will be encoded in that many bytes by padding with zero bytes at the end if necessary; otherwise, string will be encoded in the minimum number of bytes required. If string is None or empty, the minimum number of bytes required is 0.
448         
449         :arg string: the ASCII string value
450         :type string: str
451         :arg length: the length of the encoded representation, or None for the minimum length required (defaults to None)
452         :type length: int or None
453         :returns: the encoded representation bytes
454         :rtype: bytearray
455         
456         """
457         
458         if string is None:
459                 string = ''
460         if length is None:
461                 length = len(string)
462         else:
463                 if length < len(string):
464                         raise ValueError('Cannot encode ASCII string value \'%s\' as it would have an encoded representation longer than %i bytes.' % (string, length))
465                 elif length > len(string):
466                         for i in xrange(0, (length - len(string))):
467                                 string += chr(0)
468         
469         return bytearray(string)
470
471
472 def encode_unicode_string(string, length=None):
473         """
474         
475         Encodes a unicode string value. If length is not None, string will be encoded in that many bytes by padding with zero bytes at the end if necessary; otherwise, string will be encoded in the minimum number of bytes required. If string is None or empty, the minimum number of bytes required is 0.
476         
477         :arg string: the unicode string value
478         :type string: unicode
479         :arg length: the length of the encoded representation, or None for the minimum length required (defaults to None)
480         :type length: int or None
481         :returns: the encoded representation bytes
482         :rtype: bytearray
483         
484         """
485         
486         if string is None:
487                 string = u''
488         return encode_string(string.encode('utf_8'), length)
489
490
491 def encode_date(date, length=None):
492         """
493         
494         Encodes a date (and time) value. If length is not None, it must be 8. If date is None, the current date (and time) will be encoded.
495         
496         :arg date: the date (and time) value
497         :type date: datetime.datettime
498         :arg length: the length of the encoded representation (must be 8), or None
499         :type length: int or None
500         :returns: the encoded representation bytes
501         :rtype: bytearray
502         
503         """
504         
505         if date is None:
506                 date = datetime.datetime.utcnow()
507         else:
508                 date = (date - date.utcoffset()).replace(tzinfo=None)
509         if length is None:
510                 length = 8
511         elif length != 8:
512                 raise ValueError('Cannot encode date value %s with any length other than 8 bytes.')
513         
514         delta = date - datetime.datetime(2001, 1, 1, tzinfo=None)
515         nanoseconds = (delta.microseconds + ((delta.seconds + (delta.days * 24 * 60 * 60)) * 10**6)) * 10**3
516         return encode_signed_integer(nanoseconds, length)