Implemented proper recursive element definition parsing from specdata files.
[~jspiros/python-ebml.git] / ebml / schema / specs.py
index 7320b6b..d6e7728 100644 (file)
@@ -1,5 +1,5 @@
 from xml.etree.ElementTree import parse as parse_xml
 from xml.etree.ElementTree import parse as parse_xml
-from .base import INT, UINT, FLOAT, STRING, UNICODE, DATE, BINARY, CONTAINER, Element, Schema
+from .base import INT, UINT, FLOAT, STRING, UNICODE, DATE, BINARY, CONTAINER, Element, Document
 
 
 SPECDATA_TYPES = {
 
 
 SPECDATA_TYPES = {
@@ -14,13 +14,13 @@ SPECDATA_TYPES = {
 }
 
 
 }
 
 
-def parse_specdata(source, schema_name):
+def parse_specdata(source, doc_name, doc_type, doc_version):
        """
        
        Reads a schema specification from a file (e.g., specdata.xml) or file-like object, and returns a tuple containing:
        
                * a mapping of class names to Element subclasses
        """
        
        Reads a schema specification from a file (e.g., specdata.xml) or file-like object, and returns a tuple containing:
        
                * a mapping of class names to Element subclasses
-               * a Schema subclass
+               * a Document subclass
        
        :arg source: the file or file-like object
        :type source: str or file-like object
        
        :arg source: the file or file-like object
        :type source: str or file-like object
@@ -32,43 +32,71 @@ def parse_specdata(source, schema_name):
        
        tree = parse_xml(source)
        elements = {}
        
        tree = parse_xml(source)
        elements = {}
-       parent_elements = []
+       globals = []
        
        
-       for element_element in tree.getiterator('element'):
-               raw_attrs = element_element.attrib
-               
-               element_name = '%sElement' % raw_attrs.get('cppname', raw_attrs.get('name'))
-               element_level = int(raw_attrs['level'])
-               element_attrs = {
-                       '__module__': None,
-                       'class_id': int(raw_attrs['id'], 0),
-                       'class_name': raw_attrs['name'],
-                       'data_type': SPECDATA_TYPES[raw_attrs['type']]
-               }
-               
-               while parent_elements and element_level <= parent_elements[-1][0]:
-                       parent_elements.pop()
-               
-               if element_level == -1:
-                       element_attrs['class_global'] = True
-                       parent_elements = []
-               elif element_level == 0:
-                       element_attrs['class_root'] = True
-                       parent_elements = []
-               else:
-                       if raw_attrs.get('recursive', '0') == '1':
-                               element_attrs['class_parents'] = (parent_elements[-1][1], 'self')
+       def child_elements(parent_level, element_list, upper_recursive=None):
+               children = []
+               while element_list:
+                       raw_element = element_list[0]
+                       raw_attrs = raw_element.attrib
+                       
+                       element_level = int(raw_attrs['level'])
+                       
+                       is_global = False
+                       if element_level == -1:
+                               is_global = True
+                       elif parent_level is not None and not element_level > parent_level:
+                               break
+                       element_list = element_list[1:]
+
+                       element_name = '%sElement' % raw_attrs.get('cppname', raw_attrs.get('name')).translate(None, '-')
+                       element_attrs = {
+                               '__module__': None,
+                               'id': int(raw_attrs['id'], 0),
+                               'name': raw_attrs['name'],
+                               'type': SPECDATA_TYPES[raw_attrs['type']],
+                               'mandatory': True if raw_attrs.get('mandatory', False) == '1' else False,
+                               'multiple': True if raw_attrs.get('multiple', False) == '1' else False
+                       }
+                       try:
+                               element_attrs['default'] = {
+                                       INT: lambda default: int(default),
+                                       UINT: lambda default: int(default),
+                                       FLOAT: lambda default: float(default),
+                                       STRING: lambda default: str(default),
+                                       UNICODE: lambda default: unicode(default)
+                               }.get(element_attrs['type'], lambda default: default)(raw_attrs['default'])
+                       except (KeyError, ValueError):
+                               element_attrs['default'] = None
+                       
+                       element = type(element_name, (Element,), element_attrs)
+                       elements[element_name] = element
+                       
+                       recursive = []
+                       if upper_recursive:
+                               recursive.extend(upper_recursive)
+                       if raw_attrs.get('recursive', False) == '1':
+                               recursive.append(element)
+                       
+                       element_children, element_list = child_elements(element_level if not is_global else 0, element_list, recursive)
+                       element_children += tuple(recursive)
+                       element.children = element_children
+                       
+                       if is_global:
+                               globals.append(element)
                        else:
                        else:
-                               element_attrs['class_parents'] = (parent_elements[-1][1],)
-               
-               element = type(element_name, (Element,), element_attrs)
-               elements[element_name] = element
-               parent_elements.append((element_level, element))
+                               children.append(element)
+               return tuple(children), element_list
+       
+       children = child_elements(None, tree.getroot().getchildren())[0]
        
        
-       schema_attrs = {
+       document_attrs = {
                '__module__': None,
                '__module__': None,
-               'elements': tuple(elements.values())
+               'type': doc_type,
+               'version': doc_version,
+               'children': children,
+               'globals': tuple(globals)
        }
        }
-       schema = type(schema_name, (Schema,), schema_attrs)
+       document = type(doc_name, (Document,), document_attrs)
        
        
-       return elements, schema
\ No newline at end of file
+       return elements, document
\ No newline at end of file