Enabled Embed rule inheritance from parent templates and to included templates with...
authorStephen Burrows <stephen.r.burrows@gmail.com>
Fri, 19 Nov 2010 20:06:19 +0000 (15:06 -0500)
committerStephen Burrows <stephen.r.burrows@gmail.com>
Fri, 19 Nov 2010 20:06:19 +0000 (15:06 -0500)
templatetags/embed.py
tests.py

index 8fb240d..d7a8466 100644 (file)
 from django import template
 from django.contrib.contenttypes.models import ContentType
 from django.conf import settings
+from django.template.loader_tags import ExtendsNode, BlockContext, BLOCK_CONTEXT_KEY, TextNode, BlockNode
 from philo.utils import LOADED_TEMPLATE_ATTR
 
 
 register = template.Library()
+EMBED_CONTEXT_KEY = 'embed_context'
+
+
+class EmbedContext(object):
+       "Inspired by django.template.loader_tags.BlockContext."
+       def __init__(self):
+               self.embeds = {}
+               self.rendered = []
+       
+       def add_embeds(self, embeds):
+               for content_type, embed_list in embeds.iteritems():
+                       if content_type in self.embeds:
+                               self.embeds[content_type] = embed_list + self.embeds[content_type]
+                       else:
+                               self.embeds[content_type] = embed_list
+       
+       def get_embed_template(self, embed, context):
+               """To return a template for an embed node, find the node's position in the stack
+               and then progress up the stack until a template-defining node is found
+               """
+               embeds = self.embeds[embed.content_type]
+               embeds = embeds[:embeds.index(embed)][::-1]
+               for e in embeds:
+                       template = e.get_template(context)
+                       if template:
+                               return template
+               
+               # No template was found in the current render_context - but perhaps one level up? Or more?
+               # We may be in an inclusion tag.
+               self_found = False
+               for context_dict in context.render_context.dicts[::-1]:
+                       if not self_found:
+                               if self in context_dict.values():
+                                       self_found = True
+                                       continue
+                       elif EMBED_CONTEXT_KEY not in context_dict:
+                               continue
+                       else:
+                               embed_context = context_dict[EMBED_CONTEXT_KEY]
+                               # We can tell where we are in the list of embeds by which have already been rendered.
+                               embeds = embed_context.embeds[embed.content_type][:len(embed_context.rendered)][::-1]
+                               for e in embeds:
+                                       template = e.get_template(context)
+                                       if template:
+                                               return template
+               
+               raise IndexError
+
+
+# Override ExtendsNode render method to have it handle EmbedNodes
+# similarly to BlockNodes.
+old_extends_node_init = ExtendsNode.__init__
+
+
+def get_embed_dict(nodelist):
+       embeds = {}
+       for n in nodelist.get_nodes_by_type(ConstantEmbedNode):
+               if n.content_type not in embeds:
+                       embeds[n.content_type] = [n]
+               else:
+                       embeds[n.content_type].append(n)
+       return embeds
+
+
+def extends_node_init(self, nodelist, *args, **kwargs):
+       self.embeds = get_embed_dict(nodelist)
+       old_extends_node_init(self, nodelist, *args, **kwargs)
+
+
+def render_extends_node(self, context):
+       compiled_parent = self.get_parent(context)
+       
+       if BLOCK_CONTEXT_KEY not in context.render_context:
+               context.render_context[BLOCK_CONTEXT_KEY] = BlockContext()
+       block_context = context.render_context[BLOCK_CONTEXT_KEY]
+       
+       if EMBED_CONTEXT_KEY not in context.render_context:
+               context.render_context[EMBED_CONTEXT_KEY] = EmbedContext()
+       embed_context = context.render_context[EMBED_CONTEXT_KEY]
+       
+       # Add the block nodes from this node to the block context
+       # Do the equivalent for embed nodes
+       block_context.add_blocks(self.blocks)
+       embed_context.add_embeds(self.embeds)
+       
+       # If this block's parent doesn't have an extends node it is the root,
+       # and its block nodes also need to be added to the block context.
+       for node in compiled_parent.nodelist:
+               # The ExtendsNode has to be the first non-text node.
+               if not isinstance(node, TextNode):
+                       if not isinstance(node, ExtendsNode):
+                               blocks = dict([(n.name, n) for n in compiled_parent.nodelist.get_nodes_by_type(BlockNode)])
+                               block_context.add_blocks(blocks)
+                               embeds = get_embed_dict(compiled_parent.nodelist)
+                               embed_context.add_embeds(embeds)
+                       break
+
+       # Call Template._render explicitly so the parser context stays
+       # the same.
+       return compiled_parent._render(context)
+
+
+ExtendsNode.__init__ = extends_node_init
+ExtendsNode.render = render_extends_node
 
 
 class ConstantEmbedNode(template.Node):
        """Analogous to the ConstantIncludeNode, this node precompiles several variables necessary for correct rendering - namely the referenced instance or the included template."""
-       def __init__(self, content_type, varname, object_pk=None, template_name=None, kwargs=None):
+       def __init__(self, content_type, object_pk=None, template_name=None, kwargs=None):
                assert template_name is not None or object_pk is not None
                self.content_type = content_type
-               self.varname = varname
                
                kwargs = kwargs or {}
                for k, v in kwargs.items():
-                       kwargs[k] = template.Variable(v)
+                       kwargs[k] = v
                self.kwargs = kwargs
                
                if object_pk is not None:
-                       self.compile_instance(object_pk)
+                       self.instance = self.compile_instance(object_pk)
                else:
                        self.instance = None
                
                if template_name is not None:
-                       self.compile_template(template_name[1:-1])
+                       self.template = self.compile_template(template_name[1:-1])
                else:
                        self.template = None
        
-       def compile_instance(self, object_pk):
+       def compile_instance(self, object_pk, context=None):
                self.object_pk = object_pk
                model = self.content_type.model_class()
                try:
-                       self.instance = model.objects.get(pk=object_pk)
+                       return model.objects.get(pk=object_pk)
                except model.DoesNotExist:
                        if not hasattr(self, 'object_pk') and settings.TEMPLATE_DEBUG:
                                # Then it's a constant node.
                                raise
-                       self.instance = False
+                       return False
+       
+       def get_instance(self, context):
+               return self.instance
        
-       def compile_template(self, template_name):
+       def compile_template(self, template_name, context=None):
                try:
-                       self.template = template.loader.get_template(template_name)
+                       return template.loader.get_template(template_name)
                except template.TemplateDoesNotExist:
                        if not hasattr(self, 'template_name') and settings.TEMPLATE_DEBUG:
                                # Then it's a constant node.
                                raise
-                       self.template = False
+                       return False
+       
+       def get_template(self, context):
+               return self.template
+       
+       def check_context(self, context):
+               if EMBED_CONTEXT_KEY not in context.render_context:
+                       context.render_context[EMBED_CONTEXT_KEY] = EmbedContext()
+               embed_context = context.render_context[EMBED_CONTEXT_KEY]
+               
+               
+               if self.content_type not in embed_context.embeds:
+                       embed_context.embeds[self.content_type] = [self]
+               elif self not in embed_context.embeds[self.content_type]:
+                       embed_context.embeds[self.content_type].append(self)
+       
+       def mark_rendered(self, context):
+               context.render_context[EMBED_CONTEXT_KEY].rendered.append(self)
        
        def render(self, context):
+               self.check_context(context)
+               
                if self.template is not None:
                        if self.template is False:
                                return settings.TEMPLATE_STRING_IF_INVALID
-                       
-                       if self.varname not in context:
-                               context[self.varname] = {}
-                       context[self.varname][self.content_type] = self.template
-                       
+                       self.mark_rendered(context)
                        return ''
                
                # Otherwise self.instance should be set. Render the instance with the appropriate template!
                if self.instance is None or self.instance is False:
+                       self.mark_rendered(context)
                        return settings.TEMPLATE_STRING_IF_INVALID
                
-               return self.render_template(context, self.instance)
+               return self.render_instance(context, self.instance)
        
-       def render_template(self, context, instance):
+       def render_instance(self, context, instance):
                try:
-                       t = context[self.varname][self.content_type]
-               except KeyError:
+                       t = context.render_context[EMBED_CONTEXT_KEY].get_embed_template(self, context)
+               except (KeyError, IndexError):
+                       if settings.TEMPLATE_DEBUG:
+                               raise
                        return settings.TEMPLATE_STRING_IF_INVALID
                
                context.push()
@@ -80,42 +205,54 @@ class ConstantEmbedNode(template.Node):
                context.update(kwargs)
                t_rendered = t.render(context)
                context.pop()
+               self.mark_rendered(context)
                return t_rendered
 
 
 class EmbedNode(ConstantEmbedNode):
-       def __init__(self, content_type, varname, object_pk=None, template_name=None, kwargs=None):
+       def __init__(self, content_type, object_pk=None, template_name=None, kwargs=None):
                assert template_name is not None or object_pk is not None
                self.content_type = content_type
-               self.varname = varname
                
                kwargs = kwargs or {}
                for k, v in kwargs.items():
-                       kwargs[k] = template.Variable(v)
+                       kwargs[k] = v
                self.kwargs = kwargs
                
                if object_pk is not None:
-                       self.object_pk = template.Variable(object_pk)
+                       self.object_pk = object_pk
                else:
                        self.object_pk = None
                        self.instance = None
                
                if template_name is not None:
-                       self.template_name = template.Variable(template_name)
+                       self.template_name = template_name
                else:
                        self.template_name = None
                        self.template = None
        
+       def get_instance(self, context):
+               return self.compile_instance(self.object_pk, context)
+       
+       def get_template(self, context):
+               return self.compile_template(self.template_name, context)
+       
        def render(self, context):
+               self.check_context(context)
+               
                if self.template_name is not None:
-                       template_name = self.template_name.resolve(context)
-                       self.compile_template(template_name)
+                       self.mark_rendered(context)
+                       return ''
                
-               if self.object_pk is not None:
-                       object_pk = self.object_pk.resolve(context)
-                       self.compile_instance(object_pk)
+               if self.object_pk is None:
+                       if settings.TEMPLATE_DEBUG:
+                               raise ValueError("NoneType is not a valid object_pk value")
+                       self.mark_rendered(context)
+                       return settings.TEMPLATE_STRING_IF_INVALID
                
-               return super(EmbedNode, self).render(context)
+               instance = self.compile_instance(self.object_pk.resolve(context))
+               
+               return self.render_instance(context, instance)
 
 
 def get_embedded(self):
@@ -127,8 +264,7 @@ setattr(ConstantEmbedNode, LOADED_TEMPLATE_ATTR, property(get_embedded))
 
 def do_embed(parser, token):
        """
-       The {% embed %} tag can be used in three ways:
-       {% embed as <varname> %} :: This sets which variable will be used to track embedding template names for the current context. Default: "embed"
+       The {% embed %} tag can be used in two ways:
        {% embed <app_label>.<model_name> with <template> %} :: Sets which template will be used to render a particular model.
        {% embed <app_label>.<model_name> <object_pk> [<argname>=<value> ...]%} :: Embeds the instance specified by the given parameters in the document with the previously-specified template. Any kwargs provided will be passed into the context of the template.
        """
@@ -137,9 +273,6 @@ def do_embed(parser, token):
        
        if len(args) < 2:
                raise template.TemplateSyntaxError('"%s" template tag must have at least three arguments.' % tag)
-       elif len(args) == 3 and args[1] == "as":
-               parser._embedNodeVarName = args[2]
-               return template.defaulttags.CommentNode()
        else:
                if '.' not in args[1]:
                        raise template.TemplateSyntaxError('"%s" template tag expects the first argument to be of the form app_label.model' % tag)
@@ -150,16 +283,14 @@ def do_embed(parser, token):
                except ContentType.DoesNotExist:
                        raise template.TemplateSyntaxError('"%s" template tag option "references" requires an argument of the form app_label.model which refers to an installed content type (see django.contrib.contenttypes)' % tag)
                
-               varname = getattr(parser, '_embedNodeVarName', 'embed')
-               
                if args[2] == "with":
                        if len(args) > 4:
                                raise template.TemplateSyntaxError('"%s" template tag may have no more than four arguments.' % tag)
                        
                        if args[3][0] in ['"', "'"] and args[3][0] == args[3][-1]:
-                               return ConstantEmbedNode(ct, template_name=args[3], varname=varname)
+                               return ConstantEmbedNode(ct, template_name=args[3])
                        
-                       return EmbedNode(ct, template_name=args[3], varname=varname)
+                       return EmbedNode(ct, template_name=args[3])
                
                object_pk = args[2]
                remaining_args = args[3:]
@@ -168,9 +299,14 @@ def do_embed(parser, token):
                        if '=' not in arg:
                                raise template.TemplateSyntaxError("Invalid keyword argument for '%s' template tag: %s" % (tag, arg))
                        k, v = arg.split('=')
-                       kwargs[k] = v
+                       kwargs[k] = parser.compile_filter(v)
                
-               return EmbedNode(ct, object_pk=object_pk, varname=varname, kwargs=kwargs)
+               try:
+                       int(object_pk)
+               except ValueError:
+                       return EmbedNode(ct, object_pk=parser.compile_filter(object_pk), kwargs=kwargs)
+               else:
+                       return ConstantEmbedNode(ct, object_pk=object_pk, kwargs=kwargs)
 
 
 register.tag('embed', do_embed)
\ No newline at end of file
index 874f62f..b79534f 100644 (file)
--- a/tests.py
+++ b/tests.py
@@ -1,9 +1,100 @@
 from django.test import TestCase
 from django import template
 from django.conf import settings
+from django.template import loader
+from django.template.loaders import cached
 from philo.exceptions import AncestorDoesNotExist
 from philo.models import Node, Page, Template
 from philo.contrib.penfield.models import Blog, BlogView, BlogEntry
+import sys, traceback
+
+
+class TemplateTestCase(TestCase):
+       fixtures = ['test_fixtures.json']
+       
+       def test_templates(self):
+               "Tests to make sure that embed behaves with complex includes and extends"
+               template_tests = self.get_template_tests()
+               
+               # Register our custom template loader. Shamelessly cribbed from django core regressiontests.
+               def test_template_loader(template_name, template_dirs=None):
+                       "A custom template loader that loads the unit-test templates."
+                       try:
+                               return (template_tests[template_name][0] , "test:%s" % template_name)
+                       except KeyError:
+                               raise template.TemplateDoesNotExist, template_name
+               
+               cache_loader = cached.Loader(('test_template_loader',))
+               cache_loader._cached_loaders = (test_template_loader,)
+               
+               old_template_loaders = loader.template_source_loaders
+               loader.template_source_loaders = [cache_loader]
+               
+               # Turn TEMPLATE_DEBUG off, because tests assume that.
+               old_td, settings.TEMPLATE_DEBUG = settings.TEMPLATE_DEBUG, False
+               
+               # Set TEMPLATE_STRING_IF_INVALID to a known string.
+               old_invalid = settings.TEMPLATE_STRING_IF_INVALID
+               expected_invalid_str = 'INVALID'
+               
+               failures = []
+               
+               # Run tests
+               for name, vals in template_tests.items():
+                       xx, context, result = vals
+                       try:
+                               test_template = loader.get_template(name)
+                               output = test_template.render(template.Context(context))
+                       except Exception:
+                               exc_type, exc_value, exc_tb = sys.exc_info()
+                               if exc_type != result:
+                                       tb = '\n'.join(traceback.format_exception(exc_type, exc_value, exc_tb))
+                                       failures.append("Template test %s -- FAILED. Got %s, exception: %s\n%s" % (name, exc_type, exc_value, tb))
+                               continue
+                       if output != result:
+                               failures.append("Template test %s -- FAILED. Expected %r, got %r" % (name, result, output))
+               
+               # Cleanup
+               settings.TEMPLATE_DEBUG = old_td
+               settings.TEMPLATE_STRING_IF_INVALID = old_invalid
+               loader.template_source_loaders = old_template_loaders
+               
+               self.assertEqual(failures, [], "Tests failed:\n%s\n%s" % ('-'*70, ("\n%s\n" % ('-'*70)).join(failures)))
+       
+       
+       def get_template_tests(self):
+               # SYNTAX --
+               # 'template_name': ('template contents', 'context dict', 'expected string output' or Exception class)
+               blog = Blog.objects.all()[0]
+               return {
+                       # EMBED INCLUSION HANDLING
+                       
+                       'embed01': ('{{ embedded.title|safe }}', {'embedded': blog}, blog.title),
+                       'embed02': ('{{ embedded.title|safe }}{{ var1 }}{{ var2 }}', {'embedded': blog}, blog.title),
+                       'embed03': ('{{ embedded.title|safe }} is a lie!', {'embedded': blog}, '%s is a lie!' % blog.title),
+                       
+                       # Simple template structure with embed
+                       'simple01': ('{% embed penfield.blog with "embed01" %}{% embed penfield.blog 1 %}Simple{% block one %}{% endblock %}', {'blog': blog}, '%sSimple' % blog.title),
+                       'simple02': ('{% extends "simple01" %}', {}, '%sSimple' % blog.title),
+                       'simple03': ('{% embed penfield.blog with "embed000" %}', {}, settings.TEMPLATE_STRING_IF_INVALID),
+                       'simple04': ('{% embed penfield.blog 1 %}', {}, settings.TEMPLATE_STRING_IF_INVALID),
+                       
+                       # Kwargs
+                       'kwargs01': ('{% embed penfield.blog with "embed02" %}{% embed penfield.blog 1 var1="hi" var2=lo %}', {'lo': 'lo'}, '%shilo' % blog.title),
+                       
+                       # Filters/variables
+                       'filters01': ('{% embed penfield.blog with "embed02" %}{% embed penfield.blog 1 var1=hi|first var2=lo|slice:"3" %}', {'hi': ["These", "words"], 'lo': 'lower'}, '%sTheselow' % blog.title),
+                       'filters02': ('{% embed penfield.blog with "embed01" %}{% embed penfield.blog entry %}', {'entry': 1}, blog.title),
+                       
+                       # Blocky structure
+                       'block01': ('{% block one %}Hello{% endblock %}', {}, 'Hello'),
+                       'block02': ('{% extends "simple01" %}{% block one %}{% embed penfield.blog 1 %}{% endblock %}', {}, "%sSimple%s" % (blog.title, blog.title)),
+                       'block03': ('{% extends "simple01" %}{% embed penfield.blog with "embed03" %}{% block one %}{% embed penfield.blog 1 %}{% endblock %}', {}, "%sSimple%s is a lie!" % (blog.title, blog.title)),
+                       
+                       # Blocks and includes
+                       'block-include01': ('{% extends "simple01" %}{% embed penfield.blog with "embed03" %}{% block one %}{% include "simple01" %}{% embed penfield.blog 1 %}{% endblock %}', {}, "%sSimple%sSimple%s is a lie!" % (blog.title, blog.title, blog.title)),
+                       'block-include02': ('{% extends "simple01" %}{% block one %}{% include "simple04" %}{% embed penfield.blog with "embed03" %}{% include "simple04" %}{% embed penfield.blog 1 %}{% endblock %}', {}, "%sSimple%s%s is a lie!%s is a lie!" % (blog.title, blog.title, blog.title, blog.title)),
+               }
 
 
 class NodeURLTestCase(TestCase):