Modified the get_start and get_end methods to return datetimes if times are specified.
[philo.git] / templatetags / embed.py
index ef2eeb2..eb4cd68 100644 (file)
@@ -61,23 +61,28 @@ class EmbedContext(object):
 old_extends_node_init = ExtendsNode.__init__
 
 
-def get_embed_dict(nodelist):
+def get_embed_dict(embed_list, context):
        embeds = {}
-       for n in nodelist.get_nodes_by_type(ConstantEmbedNode):
-               if n.content_type not in embeds:
-                       embeds[n.content_type] = [n]
+       for e in embed_list:
+               ct = e.get_content_type(context)
+               if ct is None:
+                       # Then the embed doesn't exist for this context.
+                       continue
+               if ct not in embeds:
+                       embeds[ct] = [e]
                else:
-                       embeds[n.content_type].append(n)
+                       embeds[ct].append(e)
        return embeds
 
 
 def extends_node_init(self, nodelist, *args, **kwargs):
-       self.embeds = get_embed_dict(nodelist)
+       self.embed_list = nodelist.get_nodes_by_type(ConstantEmbedNode)
        old_extends_node_init(self, nodelist, *args, **kwargs)
 
 
 def render_extends_node(self, context):
        compiled_parent = self.get_parent(context)
+       embeds = get_embed_dict(self.embed_list, context)
        
        if BLOCK_CONTEXT_KEY not in context.render_context:
                context.render_context[BLOCK_CONTEXT_KEY] = BlockContext()
@@ -90,7 +95,7 @@ def render_extends_node(self, context):
        # Add the block nodes from this node to the block context
        # Do the equivalent for embed nodes
        block_context.add_blocks(self.blocks)
-       embed_context.add_embeds(self.embeds)
+       embed_context.add_embeds(embeds)
        
        # If this block's parent doesn't have an extends node it is the root,
        # and its block nodes also need to be added to the block context.
@@ -100,10 +105,16 @@ def render_extends_node(self, context):
                        if not isinstance(node, ExtendsNode):
                                blocks = dict([(n.name, n) for n in compiled_parent.nodelist.get_nodes_by_type(BlockNode)])
                                block_context.add_blocks(blocks)
-                               embeds = get_embed_dict(compiled_parent.nodelist)
+                               embeds = get_embed_dict(compiled_parent.nodelist.get_nodes_by_type(ConstantEmbedNode), context)
                                embed_context.add_embeds(embeds)
                        break
-
+       
+       # Explicitly render all direct embed children of this node.
+       if self.embed_list:
+               for node in self.nodelist:
+                       if isinstance(node, ConstantEmbedNode):
+                               node.render(context)
+       
        # Call Template._render explicitly so the parser context stays
        # the same.
        return compiled_parent._render(context)
@@ -135,7 +146,6 @@ class ConstantEmbedNode(template.Node):
                        self.template = None
        
        def compile_instance(self, object_pk):
-               self.object_pk = object_pk
                model = self.content_type.model_class()
                try:
                        return model.objects.get(pk=object_pk)
@@ -152,7 +162,7 @@ class ConstantEmbedNode(template.Node):
                try:
                        return template.loader.get_template(template_name)
                except template.TemplateDoesNotExist:
-                       if not hasattr(self, 'template_name') and settings.TEMPLATE_DEBUG:
+                       if hasattr(self, 'template') and settings.TEMPLATE_DEBUG:
                                # Then it's a constant node.
                                raise
                        return False
@@ -199,16 +209,13 @@ class ConstantEmbedNode(template.Node):
                try:
                        t = context.render_context[EMBED_CONTEXT_KEY].get_embed_template(self, context)
                except (KeyError, IndexError):
-                       if settings.TEMPLATE_DEBUG:
-                               raise
+                       self.mark_rendered_for(context)
                        return settings.TEMPLATE_STRING_IF_INVALID
                
                context.push()
                context['embedded'] = instance
-               kwargs = {}
                for k, v in self.kwargs.items():
-                       kwargs[k] = v.resolve(context)
-               context.update(kwargs)
+                       context[k] = v.resolve(context)
                t_rendered = t.render(context)
                context.pop()
                self.mark_rendered_for(context)
@@ -225,13 +232,11 @@ class EmbedNode(ConstantEmbedNode):
                        self.object_pk = object_pk
                else:
                        self.object_pk = None
-                       self.instance = None
                
                if template_name is not None:
                        self.template_name = template_name
                else:
                        self.template_name = None
-                       self.template = None
        
        def get_instance(self, context):
                if self.object_pk is None:
@@ -256,7 +261,10 @@ class InstanceEmbedNode(EmbedNode):
                return self.instance.resolve(context)
        
        def get_content_type(self, context):
-               return ContentType.objects.get_for_model(self.get_instance(context))
+               instance = self.get_instance(context)
+               if not instance:
+                       return None
+               return ContentType.objects.get_for_model(instance)
 
 
 def get_embedded(self):
@@ -266,15 +274,15 @@ def get_embedded(self):
 setattr(ConstantEmbedNode, LOADED_TEMPLATE_ATTR, property(get_embedded))
 
 
-def get_content_type(bit):
+def parse_content_type(bit, tagname):
        try:
                app_label, model = bit.split('.')
        except ValueError:
-               raise template.TemplateSyntaxError('"%s" template tag expects the first argument to be of the form app_label.model' % tag)
+               raise template.TemplateSyntaxError('"%s" template tag expects the first argument to be of the form app_label.model' % tagname)
        try:
                ct = ContentType.objects.get(app_label=app_label, model=model)
        except ContentType.DoesNotExist:
-               raise template.TemplateSyntaxError('"%s" template tag requires an argument of the form app_label.model which refers to an installed content type (see django.contrib.contenttypes)' % tag)
+               raise template.TemplateSyntaxError('"%s" template tag requires an argument of the form app_label.model which refers to an installed content type (see django.contrib.contenttypes)' % tagname)
        return ct
 
 
@@ -291,7 +299,7 @@ def do_embed(parser, token):
                raise template.TemplateSyntaxError('"%s" template tag must have at least two arguments.' % tag)
        
        if len(bits) == 3 and bits[-2] == 'with':
-               ct = get_content_type(bits[0])
+               ct = parse_content_type(bits[0], tag)
                
                if bits[2][0] in ['"', "'"] and bits[2][0] == bits[2][-1]:
                        return ConstantEmbedNode(ct, template_name=bits[2])
@@ -314,7 +322,7 @@ def do_embed(parser, token):
                return InstanceEmbedNode(instance, kwargs)
        elif len(bits) > 2:
                raise template.TemplateSyntaxError('"%s" template tag expects at most 2 non-keyword arguments when embedding instances.')
-       ct = get_content_type(bits[0])
+       ct = parse_content_type(bits[0], tag)
        pk = bits[1]
        
        try: