Further polished embedding system - allowed for context-dependent embed nodes to...
[philo.git] / templatetags / embed.py
index d7a8466..901e163 100644 (file)
@@ -26,7 +26,8 @@ class EmbedContext(object):
                """To return a template for an embed node, find the node's position in the stack
                and then progress up the stack until a template-defining node is found
                """
-               embeds = self.embeds[embed.content_type]
+               ct = embed.get_content_type(context)
+               embeds = self.embeds[ct]
                embeds = embeds[:embeds.index(embed)][::-1]
                for e in embeds:
                        template = e.get_template(context)
@@ -46,7 +47,7 @@ class EmbedContext(object):
                        else:
                                embed_context = context_dict[EMBED_CONTEXT_KEY]
                                # We can tell where we are in the list of embeds by which have already been rendered.
-                               embeds = embed_context.embeds[embed.content_type][:len(embed_context.rendered)][::-1]
+                               embeds = embed_context.embeds[ct][:len(embed_context.rendered)][::-1]
                                for e in embeds:
                                        template = e.get_template(context)
                                        if template:
@@ -60,23 +61,25 @@ class EmbedContext(object):
 old_extends_node_init = ExtendsNode.__init__
 
 
-def get_embed_dict(nodelist):
+def get_embed_dict(embed_list, context):
        embeds = {}
-       for n in nodelist.get_nodes_by_type(ConstantEmbedNode):
-               if n.content_type not in embeds:
-                       embeds[n.content_type] = [n]
+       for e in embed_list:
+               ct = e.get_content_type(context)
+               if ct not in embeds:
+                       embeds[ct] = [e]
                else:
-                       embeds[n.content_type].append(n)
+                       embeds[ct].append(e)
        return embeds
 
 
 def extends_node_init(self, nodelist, *args, **kwargs):
-       self.embeds = get_embed_dict(nodelist)
+       self.embed_list = nodelist.get_nodes_by_type(ConstantEmbedNode)
        old_extends_node_init(self, nodelist, *args, **kwargs)
 
 
 def render_extends_node(self, context):
        compiled_parent = self.get_parent(context)
+       embeds = get_embed_dict(self.embed_list, context)
        
        if BLOCK_CONTEXT_KEY not in context.render_context:
                context.render_context[BLOCK_CONTEXT_KEY] = BlockContext()
@@ -89,7 +92,7 @@ def render_extends_node(self, context):
        # Add the block nodes from this node to the block context
        # Do the equivalent for embed nodes
        block_context.add_blocks(self.blocks)
-       embed_context.add_embeds(self.embeds)
+       embed_context.add_embeds(embeds)
        
        # If this block's parent doesn't have an extends node it is the root,
        # and its block nodes also need to be added to the block context.
@@ -99,10 +102,16 @@ def render_extends_node(self, context):
                        if not isinstance(node, ExtendsNode):
                                blocks = dict([(n.name, n) for n in compiled_parent.nodelist.get_nodes_by_type(BlockNode)])
                                block_context.add_blocks(blocks)
-                               embeds = get_embed_dict(compiled_parent.nodelist)
+                               embeds = get_embed_dict(compiled_parent.nodelist.get_nodes_by_type(ConstantEmbedNode), context)
                                embed_context.add_embeds(embeds)
                        break
-
+       
+       # Explicitly render all direct embed children of this node.
+       if self.embed_list:
+               for node in self.nodelist:
+                       if isinstance(node, ConstantEmbedNode):
+                               node.render(context)
+       
        # Call Template._render explicitly so the parser context stays
        # the same.
        return compiled_parent._render(context)
@@ -133,7 +142,7 @@ class ConstantEmbedNode(template.Node):
                else:
                        self.template = None
        
-       def compile_instance(self, object_pk, context=None):
+       def compile_instance(self, object_pk):
                self.object_pk = object_pk
                model = self.content_type.model_class()
                try:
@@ -147,11 +156,11 @@ class ConstantEmbedNode(template.Node):
        def get_instance(self, context):
                return self.instance
        
-       def compile_template(self, template_name, context=None):
+       def compile_template(self, template_name):
                try:
                        return template.loader.get_template(template_name)
                except template.TemplateDoesNotExist:
-                       if not hasattr(self, 'template_name') and settings.TEMPLATE_DEBUG:
+                       if hasattr(self, 'template') and settings.TEMPLATE_DEBUG:
                                # Then it's a constant node.
                                raise
                        return False
@@ -159,35 +168,40 @@ class ConstantEmbedNode(template.Node):
        def get_template(self, context):
                return self.template
        
+       def get_content_type(self, context):
+               return self.content_type
+       
        def check_context(self, context):
                if EMBED_CONTEXT_KEY not in context.render_context:
                        context.render_context[EMBED_CONTEXT_KEY] = EmbedContext()
                embed_context = context.render_context[EMBED_CONTEXT_KEY]
                
-               
-               if self.content_type not in embed_context.embeds:
-                       embed_context.embeds[self.content_type] = [self]
-               elif self not in embed_context.embeds[self.content_type]:
-                       embed_context.embeds[self.content_type].append(self)
+               ct = self.get_content_type(context)
+               if ct not in embed_context.embeds:
+                       embed_context.embeds[ct] = [self]
+               elif self not in embed_context.embeds[ct]:
+                       embed_context.embeds[ct].append(self)
        
-       def mark_rendered(self, context):
+       def mark_rendered_for(self, context):
                context.render_context[EMBED_CONTEXT_KEY].rendered.append(self)
        
        def render(self, context):
                self.check_context(context)
                
-               if self.template is not None:
-                       if self.template is False:
+               template = self.get_template(context)
+               if template is not None:
+                       self.mark_rendered_for(context)
+                       if template is False:
                                return settings.TEMPLATE_STRING_IF_INVALID
-                       self.mark_rendered(context)
                        return ''
                
-               # Otherwise self.instance should be set. Render the instance with the appropriate template!
-               if self.instance is None or self.instance is False:
-                       self.mark_rendered(context)
+               # Otherwise an instance should be available. Render the instance with the appropriate template!
+               instance = self.get_instance(context)
+               if instance is None or instance is False:
+                       self.mark_rendered_for(context)
                        return settings.TEMPLATE_STRING_IF_INVALID
                
-               return self.render_instance(context, self.instance)
+               return self.render_instance(context, instance)
        
        def render_instance(self, context, instance):
                try:
@@ -195,6 +209,7 @@ class ConstantEmbedNode(template.Node):
                except (KeyError, IndexError):
                        if settings.TEMPLATE_DEBUG:
                                raise
+                       self.mark_rendered_for(context)
                        return settings.TEMPLATE_STRING_IF_INVALID
                
                context.push()
@@ -205,7 +220,7 @@ class ConstantEmbedNode(template.Node):
                context.update(kwargs)
                t_rendered = t.render(context)
                context.pop()
-               self.mark_rendered(context)
+               self.mark_rendered_for(context)
                return t_rendered
 
 
@@ -213,46 +228,42 @@ class EmbedNode(ConstantEmbedNode):
        def __init__(self, content_type, object_pk=None, template_name=None, kwargs=None):
                assert template_name is not None or object_pk is not None
                self.content_type = content_type
-               
-               kwargs = kwargs or {}
-               for k, v in kwargs.items():
-                       kwargs[k] = v
-               self.kwargs = kwargs
+               self.kwargs = kwargs or {}
                
                if object_pk is not None:
                        self.object_pk = object_pk
                else:
                        self.object_pk = None
-                       self.instance = None
                
                if template_name is not None:
                        self.template_name = template_name
                else:
                        self.template_name = None
-                       self.template = None
        
        def get_instance(self, context):
-               return self.compile_instance(self.object_pk, context)
+               if self.object_pk is None:
+                       return None
+               return self.compile_instance(self.object_pk.resolve(context))
        
        def get_template(self, context):
-               return self.compile_template(self.template_name, context)
+               if self.template_name is None:
+                       return None
+               return self.compile_template(self.template_name.resolve(context))
+
+
+class InstanceEmbedNode(EmbedNode):
+       def __init__(self, instance, kwargs=None):
+               self.instance = instance
+               self.kwargs = kwargs or {}
        
-       def render(self, context):
-               self.check_context(context)
-               
-               if self.template_name is not None:
-                       self.mark_rendered(context)
-                       return ''
-               
-               if self.object_pk is None:
-                       if settings.TEMPLATE_DEBUG:
-                               raise ValueError("NoneType is not a valid object_pk value")
-                       self.mark_rendered(context)
-                       return settings.TEMPLATE_STRING_IF_INVALID
-               
-               instance = self.compile_instance(self.object_pk.resolve(context))
-               
-               return self.render_instance(context, instance)
+       def get_template(self, context):
+               return None
+       
+       def get_instance(self, context):
+               return self.instance.resolve(context)
+       
+       def get_content_type(self, context):
+               return ContentType.objects.get_for_model(self.get_instance(context))
 
 
 def get_embedded(self):
@@ -262,51 +273,63 @@ def get_embedded(self):
 setattr(ConstantEmbedNode, LOADED_TEMPLATE_ATTR, property(get_embedded))
 
 
+def get_content_type(bit):
+       try:
+               app_label, model = bit.split('.')
+       except ValueError:
+               raise template.TemplateSyntaxError('"%s" template tag expects the first argument to be of the form app_label.model' % tag)
+       try:
+               ct = ContentType.objects.get(app_label=app_label, model=model)
+       except ContentType.DoesNotExist:
+               raise template.TemplateSyntaxError('"%s" template tag requires an argument of the form app_label.model which refers to an installed content type (see django.contrib.contenttypes)' % tag)
+       return ct
+
+
 def do_embed(parser, token):
        """
        The {% embed %} tag can be used in two ways:
        {% embed <app_label>.<model_name> with <template> %} :: Sets which template will be used to render a particular model.
-       {% embed <app_label>.<model_name> <object_pk> [<argname>=<value> ...]%} :: Embeds the instance specified by the given parameters in the document with the previously-specified template. Any kwargs provided will be passed into the context of the template.
+       {% embed (<app_label>.<model_name> <object_pk> || <instance>) [<argname>=<value> ...] %} :: Embeds the instance specified by the given parameters in the document with the previously-specified template. Any kwargs provided will be passed into the context of the template.
        """
-       args = token.split_contents()
-       tag = args[0]
+       bits = token.split_contents()
+       tag = bits.pop(0)
        
-       if len(args) < 2:
-               raise template.TemplateSyntaxError('"%s" template tag must have at least three arguments.' % tag)
-       else:
-               if '.' not in args[1]:
-                       raise template.TemplateSyntaxError('"%s" template tag expects the first argument to be of the form app_label.model' % tag)
-               
-               app_label, model = args[1].split('.')
-               try:
-                       ct = ContentType.objects.get(app_label=app_label, model=model)
-               except ContentType.DoesNotExist:
-                       raise template.TemplateSyntaxError('"%s" template tag option "references" requires an argument of the form app_label.model which refers to an installed content type (see django.contrib.contenttypes)' % tag)
-               
-               if args[2] == "with":
-                       if len(args) > 4:
-                               raise template.TemplateSyntaxError('"%s" template tag may have no more than four arguments.' % tag)
-                       
-                       if args[3][0] in ['"', "'"] and args[3][0] == args[3][-1]:
-                               return ConstantEmbedNode(ct, template_name=args[3])
-                       
-                       return EmbedNode(ct, template_name=args[3])
+       if len(bits) < 1:
+               raise template.TemplateSyntaxError('"%s" template tag must have at least two arguments.' % tag)
+       
+       if len(bits) == 3 and bits[-2] == 'with':
+               ct = get_content_type(bits[0])
                
-               object_pk = args[2]
-               remaining_args = args[3:]
-               kwargs = {}
-               for arg in remaining_args:
-                       if '=' not in arg:
-                               raise template.TemplateSyntaxError("Invalid keyword argument for '%s' template tag: %s" % (tag, arg))
-                       k, v = arg.split('=')
+               if bits[2][0] in ['"', "'"] and bits[2][0] == bits[2][-1]:
+                       return ConstantEmbedNode(ct, template_name=bits[2])
+               return EmbedNode(ct, template_name=bits[2])
+       
+       # Otherwise they're trying to embed a certain instance.
+       kwargs = {}
+       try:
+               bit = bits.pop()
+               while '=' in bit:
+                       k, v = bit.split('=')
                        kwargs[k] = parser.compile_filter(v)
-               
-               try:
-                       int(object_pk)
-               except ValueError:
-                       return EmbedNode(ct, object_pk=parser.compile_filter(object_pk), kwargs=kwargs)
-               else:
-                       return ConstantEmbedNode(ct, object_pk=object_pk, kwargs=kwargs)
+                       bit = bits.pop()
+               bits.append(bit)
+       except IndexError:
+               raise template.TemplateSyntaxError('"%s" template tag expects at least one non-keyword argument when embedding instances.')
+       
+       if len(bits) == 1:
+               instance = parser.compile_filter(bits[0])
+               return InstanceEmbedNode(instance, kwargs)
+       elif len(bits) > 2:
+               raise template.TemplateSyntaxError('"%s" template tag expects at most 2 non-keyword arguments when embedding instances.')
+       ct = get_content_type(bits[0])
+       pk = bits[1]
+       
+       try:
+               int(pk)
+       except ValueError:
+               return EmbedNode(ct, object_pk=parser.compile_filter(pk), kwargs=kwargs)
+       else:
+               return ConstantEmbedNode(ct, object_pk=pk, kwargs=kwargs)
 
 
 register.tag('embed', do_embed)
\ No newline at end of file