Implemented proper recursive element definition parsing from specdata files.
[~jspiros/python-ebml.git] / ebml / tests / test_core.py
1 import unittest
2 try:
3         from cStringIO import StringIO
4 except ImportError:
5         from StringIO import StringIO
6 import random
7 import sys
8 from ..core import *
9
10
11 class ElementSizeTests(unittest.TestCase):
12         def assert_roundtrip(self, value, length=None):
13                 encoded = encode_element_size(value, length=length)
14                 if length is not None:
15                         self.assertEqual(length, len(encoded))
16                 encoded_stream = StringIO(encoded)
17                 self.assertEqual(value, read_element_size(encoded_stream)[0])
18         
19         def test_unknown(self):
20                 for length in xrange(1, 9):
21                         self.assert_roundtrip(None, length=length)
22         
23         def test_base_10(self):
24                 for value in (10**exp for exp in xrange(1, 16)):
25                         self.assert_roundtrip(value)
26         
27         def test_base_2(self):
28                 for value in (2**exp for exp in xrange(1, 56)):
29                         self.assert_roundtrip(value)
30         
31         def test_max_base_2(self):
32                 for value in ((2**exp) - 2 for exp in xrange(1, 57)):
33                         self.assert_roundtrip(value)
34         
35         def test_random(self):
36                 maximum = 2**56 - 2
37                 for value in (random.randint(0, maximum) for i in xrange(0, 10000)):
38                         self.assert_roundtrip(value)
39
40
41 class ElementIDTests(unittest.TestCase):
42         ebml_ids = (
43                 0x1a45dfa3,
44                 0x4286,
45                 0x42f7,
46                 0x42f2,
47                 0x42f3,
48                 0x4282,
49                 0x4287,
50                 0x4285,
51                 0xbf,
52                 0xec
53         )
54         
55         def assert_roundtrip(self, value):
56                 encoded = encode_element_id(value)
57                 encoded_stream = StringIO(encoded)
58                 self.assertEqual(value, read_element_id(encoded_stream)[0])
59         
60         def test_ebml_ids(self):
61                 for id_ in self.ebml_ids:
62                         self.assert_roundtrip(id_)
63
64
65 class ValueTestCase(unittest.TestCase):
66         encoder = None
67         reader = None
68         
69         def assert_roundtrip(self, value, length=None):
70                 if self.encoder is not None and self.reader is not None:
71                         encoded = self.encoder(value, length)
72                         if length is not None:
73                                 self.assertEqual(length, len(encoded))
74                         encoded_stream = StringIO(encoded)
75                         self.assertEqual(value, self.reader(encoded_stream, len(encoded)))
76                 else:
77                         raise NotImplementedError
78
79
80 class UnsignedIntegerTests(ValueTestCase):
81         encoder = staticmethod(encode_unsigned_integer)
82         reader = staticmethod(read_unsigned_integer)
83         maximum = 2**64 - 1
84         
85         def test_random(self):
86                 for value in (random.randint(0, self.maximum) for i in xrange(0, 10000)):
87                         self.assert_roundtrip(value)
88         
89         def test_random_longer(self):
90                 for value in (random.randint(0, (self.maximum / (2**32))) for i in xrange(0, 10000)):
91                         self.assert_roundtrip(value, length=8)
92         
93         def test_maximum(self):
94                 self.assert_roundtrip(self.maximum)
95
96
97 class SignedIntegerTests(ValueTestCase):
98         encoder = staticmethod(encode_signed_integer)
99         reader = staticmethod(read_signed_integer)
100         minimum = -(2**63)
101         maximum = (2**63) - 1
102         
103         def test_random(self):
104                 for value in (random.randint(self.minimum, self.maximum) for i in xrange(0, 10000)):
105                         self.assert_roundtrip(value)
106         
107         def test_random_longer(self):
108                 for value in (random.randint((self.minimum / (2**32)), (self.maximum / (2**32))) for i in xrange(0, 10000)):
109                         self.assert_roundtrip(value, length=8)
110         
111         def test_minimum(self):
112                 self.assert_roundtrip(self.minimum)
113         
114         def test_maximum(self):
115                 self.assert_roundtrip(self.maximum)
116
117
118 class FloatTests(ValueTestCase):
119         # Note:
120         # I'm not sure if this is a good idea, due to the potential for loss of precision.
121         # It seems that, at least with my installation of Python, floats are 64-bit IEEE, and so, for now, this works.
122         
123         encoder = staticmethod(encode_float)
124         reader = staticmethod(read_float)
125         
126         def test_random(self):
127                 for value in (random.uniform(1.0, float(random.randint(2, 2**10))) for i in xrange(0, 1000)):
128                         self.assert_roundtrip(value)
129
130
131 class StringTests(ValueTestCase):
132         encoder = staticmethod(encode_string)
133         reader = staticmethod(read_string)
134         letters = ''.join(chr(i) for i in xrange(1, 127))
135         
136         def test_random(self):
137                 for length in (random.randint(0, 2**10) for i in xrange(0, 1000)):
138                         astring = ''.join(random.sample(self.letters * ((length // len(self.letters)) + 1), length))
139                         self.assert_roundtrip(astring)
140                         self.assert_roundtrip(astring, length=length*2)
141
142
143 class UnicodeStringTests(ValueTestCase):
144         encoder = staticmethod(encode_unicode_string)
145         reader = staticmethod(read_unicode_string)
146         letters = u''.join(unichr(i) for i in xrange(1, sys.maxunicode + 1))
147         
148         def test_random(self):
149                 for length in (random.randint(0, 2**10) for i in xrange(0, 1000)):
150                         ustring = u''.join(random.sample(self.letters * ((length // len(self.letters)) + 1), length))
151                         ustring = ustring.encode('utf_8').decode('utf_8')
152                         self.assert_roundtrip(ustring)
153                         self.assert_roundtrip(ustring, length=length*5)
154
155
156 class DateTests(ValueTestCase):
157         encoder = staticmethod(encode_date)
158         reader = staticmethod(read_date)
159         
160         def test_random(self):
161                 pass
162
163
164 if __name__ == '__main__':
165         unittest.main()