Enabled Embed rule inheritance from parent templates and to included templates with...
authorStephen Burrows <stephen.r.burrows@gmail.com>
Fri, 19 Nov 2010 20:06:19 +0000 (15:06 -0500)
committerStephen Burrows <stephen.r.burrows@gmail.com>
Fri, 19 Nov 2010 20:06:19 +0000 (15:06 -0500)
templatetags/embed.py
tests.py

index 8fb240d..d7a8466 100644 (file)
 from django import template
 from django.contrib.contenttypes.models import ContentType
 from django.conf import settings
 from django import template
 from django.contrib.contenttypes.models import ContentType
 from django.conf import settings
+from django.template.loader_tags import ExtendsNode, BlockContext, BLOCK_CONTEXT_KEY, TextNode, BlockNode
 from philo.utils import LOADED_TEMPLATE_ATTR
 
 
 register = template.Library()
 from philo.utils import LOADED_TEMPLATE_ATTR
 
 
 register = template.Library()
+EMBED_CONTEXT_KEY = 'embed_context'
+
+
+class EmbedContext(object):
+       "Inspired by django.template.loader_tags.BlockContext."
+       def __init__(self):
+               self.embeds = {}
+               self.rendered = []
+       
+       def add_embeds(self, embeds):
+               for content_type, embed_list in embeds.iteritems():
+                       if content_type in self.embeds:
+                               self.embeds[content_type] = embed_list + self.embeds[content_type]
+                       else:
+                               self.embeds[content_type] = embed_list
+       
+       def get_embed_template(self, embed, context):
+               """To return a template for an embed node, find the node's position in the stack
+               and then progress up the stack until a template-defining node is found
+               """
+               embeds = self.embeds[embed.content_type]
+               embeds = embeds[:embeds.index(embed)][::-1]
+               for e in embeds:
+                       template = e.get_template(context)
+                       if template:
+                               return template
+               
+               # No template was found in the current render_context - but perhaps one level up? Or more?
+               # We may be in an inclusion tag.
+               self_found = False
+               for context_dict in context.render_context.dicts[::-1]:
+                       if not self_found:
+                               if self in context_dict.values():
+                                       self_found = True
+                                       continue
+                       elif EMBED_CONTEXT_KEY not in context_dict:
+                               continue
+                       else:
+                               embed_context = context_dict[EMBED_CONTEXT_KEY]
+                               # We can tell where we are in the list of embeds by which have already been rendered.
+                               embeds = embed_context.embeds[embed.content_type][:len(embed_context.rendered)][::-1]
+                               for e in embeds:
+                                       template = e.get_template(context)
+                                       if template:
+                                               return template
+               
+               raise IndexError
+
+
+# Override ExtendsNode render method to have it handle EmbedNodes
+# similarly to BlockNodes.
+old_extends_node_init = ExtendsNode.__init__
+
+
+def get_embed_dict(nodelist):
+       embeds = {}
+       for n in nodelist.get_nodes_by_type(ConstantEmbedNode):
+               if n.content_type not in embeds:
+                       embeds[n.content_type] = [n]
+               else:
+                       embeds[n.content_type].append(n)
+       return embeds
+
+
+def extends_node_init(self, nodelist, *args, **kwargs):
+       self.embeds = get_embed_dict(nodelist)
+       old_extends_node_init(self, nodelist, *args, **kwargs)
+
+
+def render_extends_node(self, context):
+       compiled_parent = self.get_parent(context)
+       
+       if BLOCK_CONTEXT_KEY not in context.render_context:
+               context.render_context[BLOCK_CONTEXT_KEY] = BlockContext()
+       block_context = context.render_context[BLOCK_CONTEXT_KEY]
+       
+       if EMBED_CONTEXT_KEY not in context.render_context:
+               context.render_context[EMBED_CONTEXT_KEY] = EmbedContext()
+       embed_context = context.render_context[EMBED_CONTEXT_KEY]
+       
+       # Add the block nodes from this node to the block context
+       # Do the equivalent for embed nodes
+       block_context.add_blocks(self.blocks)
+       embed_context.add_embeds(self.embeds)
+       
+       # If this block's parent doesn't have an extends node it is the root,
+       # and its block nodes also need to be added to the block context.
+       for node in compiled_parent.nodelist:
+               # The ExtendsNode has to be the first non-text node.
+               if not isinstance(node, TextNode):
+                       if not isinstance(node, ExtendsNode):
+                               blocks = dict([(n.name, n) for n in compiled_parent.nodelist.get_nodes_by_type(BlockNode)])
+                               block_context.add_blocks(blocks)
+                               embeds = get_embed_dict(compiled_parent.nodelist)
+                               embed_context.add_embeds(embeds)
+                       break
+
+       # Call Template._render explicitly so the parser context stays
+       # the same.
+       return compiled_parent._render(context)
+
+
+ExtendsNode.__init__ = extends_node_init
+ExtendsNode.render = render_extends_node
 
 
 class ConstantEmbedNode(template.Node):
        """Analogous to the ConstantIncludeNode, this node precompiles several variables necessary for correct rendering - namely the referenced instance or the included template."""
 
 
 class ConstantEmbedNode(template.Node):
        """Analogous to the ConstantIncludeNode, this node precompiles several variables necessary for correct rendering - namely the referenced instance or the included template."""
-       def __init__(self, content_type, varname, object_pk=None, template_name=None, kwargs=None):
+       def __init__(self, content_type, object_pk=None, template_name=None, kwargs=None):
                assert template_name is not None or object_pk is not None
                self.content_type = content_type
                assert template_name is not None or object_pk is not None
                self.content_type = content_type
-               self.varname = varname
                
                kwargs = kwargs or {}
                for k, v in kwargs.items():
                
                kwargs = kwargs or {}
                for k, v in kwargs.items():
-                       kwargs[k] = template.Variable(v)
+                       kwargs[k] = v
                self.kwargs = kwargs
                
                if object_pk is not None:
                self.kwargs = kwargs
                
                if object_pk is not None:
-                       self.compile_instance(object_pk)
+                       self.instance = self.compile_instance(object_pk)
                else:
                        self.instance = None
                
                if template_name is not None:
                else:
                        self.instance = None
                
                if template_name is not None:
-                       self.compile_template(template_name[1:-1])
+                       self.template = self.compile_template(template_name[1:-1])
                else:
                        self.template = None
        
                else:
                        self.template = None
        
-       def compile_instance(self, object_pk):
+       def compile_instance(self, object_pk, context=None):
                self.object_pk = object_pk
                model = self.content_type.model_class()
                try:
                self.object_pk = object_pk
                model = self.content_type.model_class()
                try:
-                       self.instance = model.objects.get(pk=object_pk)
+                       return model.objects.get(pk=object_pk)
                except model.DoesNotExist:
                        if not hasattr(self, 'object_pk') and settings.TEMPLATE_DEBUG:
                                # Then it's a constant node.
                                raise
                except model.DoesNotExist:
                        if not hasattr(self, 'object_pk') and settings.TEMPLATE_DEBUG:
                                # Then it's a constant node.
                                raise
-                       self.instance = False
+                       return False
+       
+       def get_instance(self, context):
+               return self.instance
        
        
-       def compile_template(self, template_name):
+       def compile_template(self, template_name, context=None):
                try:
                try:
-                       self.template = template.loader.get_template(template_name)
+                       return template.loader.get_template(template_name)
                except template.TemplateDoesNotExist:
                        if not hasattr(self, 'template_name') and settings.TEMPLATE_DEBUG:
                                # Then it's a constant node.
                                raise
                except template.TemplateDoesNotExist:
                        if not hasattr(self, 'template_name') and settings.TEMPLATE_DEBUG:
                                # Then it's a constant node.
                                raise
-                       self.template = False
+                       return False
+       
+       def get_template(self, context):
+               return self.template
+       
+       def check_context(self, context):
+               if EMBED_CONTEXT_KEY not in context.render_context:
+                       context.render_context[EMBED_CONTEXT_KEY] = EmbedContext()
+               embed_context = context.render_context[EMBED_CONTEXT_KEY]
+               
+               
+               if self.content_type not in embed_context.embeds:
+                       embed_context.embeds[self.content_type] = [self]
+               elif self not in embed_context.embeds[self.content_type]:
+                       embed_context.embeds[self.content_type].append(self)
+       
+       def mark_rendered(self, context):
+               context.render_context[EMBED_CONTEXT_KEY].rendered.append(self)
        
        def render(self, context):
        
        def render(self, context):
+               self.check_context(context)
+               
                if self.template is not None:
                        if self.template is False:
                                return settings.TEMPLATE_STRING_IF_INVALID
                if self.template is not None:
                        if self.template is False:
                                return settings.TEMPLATE_STRING_IF_INVALID
-                       
-                       if self.varname not in context:
-                               context[self.varname] = {}
-                       context[self.varname][self.content_type] = self.template
-                       
+                       self.mark_rendered(context)
                        return ''
                
                # Otherwise self.instance should be set. Render the instance with the appropriate template!
                if self.instance is None or self.instance is False:
                        return ''
                
                # Otherwise self.instance should be set. Render the instance with the appropriate template!
                if self.instance is None or self.instance is False:
+                       self.mark_rendered(context)
                        return settings.TEMPLATE_STRING_IF_INVALID
                
                        return settings.TEMPLATE_STRING_IF_INVALID
                
-               return self.render_template(context, self.instance)
+               return self.render_instance(context, self.instance)
        
        
-       def render_template(self, context, instance):
+       def render_instance(self, context, instance):
                try:
                try:
-                       t = context[self.varname][self.content_type]
-               except KeyError:
+                       t = context.render_context[EMBED_CONTEXT_KEY].get_embed_template(self, context)
+               except (KeyError, IndexError):
+                       if settings.TEMPLATE_DEBUG:
+                               raise
                        return settings.TEMPLATE_STRING_IF_INVALID
                
                context.push()
                        return settings.TEMPLATE_STRING_IF_INVALID
                
                context.push()
@@ -80,42 +205,54 @@ class ConstantEmbedNode(template.Node):
                context.update(kwargs)
                t_rendered = t.render(context)
                context.pop()
                context.update(kwargs)
                t_rendered = t.render(context)
                context.pop()
+               self.mark_rendered(context)
                return t_rendered
 
 
 class EmbedNode(ConstantEmbedNode):
                return t_rendered
 
 
 class EmbedNode(ConstantEmbedNode):
-       def __init__(self, content_type, varname, object_pk=None, template_name=None, kwargs=None):
+       def __init__(self, content_type, object_pk=None, template_name=None, kwargs=None):
                assert template_name is not None or object_pk is not None
                self.content_type = content_type
                assert template_name is not None or object_pk is not None
                self.content_type = content_type
-               self.varname = varname
                
                kwargs = kwargs or {}
                for k, v in kwargs.items():
                
                kwargs = kwargs or {}
                for k, v in kwargs.items():
-                       kwargs[k] = template.Variable(v)
+                       kwargs[k] = v
                self.kwargs = kwargs
                
                if object_pk is not None:
                self.kwargs = kwargs
                
                if object_pk is not None:
-                       self.object_pk = template.Variable(object_pk)
+                       self.object_pk = object_pk
                else:
                        self.object_pk = None
                        self.instance = None
                
                if template_name is not None:
                else:
                        self.object_pk = None
                        self.instance = None
                
                if template_name is not None:
-                       self.template_name = template.Variable(template_name)
+                       self.template_name = template_name
                else:
                        self.template_name = None
                        self.template = None
        
                else:
                        self.template_name = None
                        self.template = None
        
+       def get_instance(self, context):
+               return self.compile_instance(self.object_pk, context)
+       
+       def get_template(self, context):
+               return self.compile_template(self.template_name, context)
+       
        def render(self, context):
        def render(self, context):
+               self.check_context(context)
+               
                if self.template_name is not None:
                if self.template_name is not None:
-                       template_name = self.template_name.resolve(context)
-                       self.compile_template(template_name)
+                       self.mark_rendered(context)
+                       return ''
                
                
-               if self.object_pk is not None:
-                       object_pk = self.object_pk.resolve(context)
-                       self.compile_instance(object_pk)
+               if self.object_pk is None:
+                       if settings.TEMPLATE_DEBUG:
+                               raise ValueError("NoneType is not a valid object_pk value")
+                       self.mark_rendered(context)
+                       return settings.TEMPLATE_STRING_IF_INVALID
                
                
-               return super(EmbedNode, self).render(context)
+               instance = self.compile_instance(self.object_pk.resolve(context))
+               
+               return self.render_instance(context, instance)
 
 
 def get_embedded(self):
 
 
 def get_embedded(self):
@@ -127,8 +264,7 @@ setattr(ConstantEmbedNode, LOADED_TEMPLATE_ATTR, property(get_embedded))
 
 def do_embed(parser, token):
        """
 
 def do_embed(parser, token):
        """
-       The {% embed %} tag can be used in three ways:
-       {% embed as <varname> %} :: This sets which variable will be used to track embedding template names for the current context. Default: "embed"
+       The {% embed %} tag can be used in two ways:
        {% embed <app_label>.<model_name> with <template> %} :: Sets which template will be used to render a particular model.
        {% embed <app_label>.<model_name> <object_pk> [<argname>=<value> ...]%} :: Embeds the instance specified by the given parameters in the document with the previously-specified template. Any kwargs provided will be passed into the context of the template.
        """
        {% embed <app_label>.<model_name> with <template> %} :: Sets which template will be used to render a particular model.
        {% embed <app_label>.<model_name> <object_pk> [<argname>=<value> ...]%} :: Embeds the instance specified by the given parameters in the document with the previously-specified template. Any kwargs provided will be passed into the context of the template.
        """
@@ -137,9 +273,6 @@ def do_embed(parser, token):
        
        if len(args) < 2:
                raise template.TemplateSyntaxError('"%s" template tag must have at least three arguments.' % tag)
        
        if len(args) < 2:
                raise template.TemplateSyntaxError('"%s" template tag must have at least three arguments.' % tag)
-       elif len(args) == 3 and args[1] == "as":
-               parser._embedNodeVarName = args[2]
-               return template.defaulttags.CommentNode()
        else:
                if '.' not in args[1]:
                        raise template.TemplateSyntaxError('"%s" template tag expects the first argument to be of the form app_label.model' % tag)
        else:
                if '.' not in args[1]:
                        raise template.TemplateSyntaxError('"%s" template tag expects the first argument to be of the form app_label.model' % tag)
@@ -150,16 +283,14 @@ def do_embed(parser, token):
                except ContentType.DoesNotExist:
                        raise template.TemplateSyntaxError('"%s" template tag option "references" requires an argument of the form app_label.model which refers to an installed content type (see django.contrib.contenttypes)' % tag)
                
                except ContentType.DoesNotExist:
                        raise template.TemplateSyntaxError('"%s" template tag option "references" requires an argument of the form app_label.model which refers to an installed content type (see django.contrib.contenttypes)' % tag)
                
-               varname = getattr(parser, '_embedNodeVarName', 'embed')
-               
                if args[2] == "with":
                        if len(args) > 4:
                                raise template.TemplateSyntaxError('"%s" template tag may have no more than four arguments.' % tag)
                        
                        if args[3][0] in ['"', "'"] and args[3][0] == args[3][-1]:
                if args[2] == "with":
                        if len(args) > 4:
                                raise template.TemplateSyntaxError('"%s" template tag may have no more than four arguments.' % tag)
                        
                        if args[3][0] in ['"', "'"] and args[3][0] == args[3][-1]:
-                               return ConstantEmbedNode(ct, template_name=args[3], varname=varname)
+                               return ConstantEmbedNode(ct, template_name=args[3])
                        
                        
-                       return EmbedNode(ct, template_name=args[3], varname=varname)
+                       return EmbedNode(ct, template_name=args[3])
                
                object_pk = args[2]
                remaining_args = args[3:]
                
                object_pk = args[2]
                remaining_args = args[3:]
@@ -168,9 +299,14 @@ def do_embed(parser, token):
                        if '=' not in arg:
                                raise template.TemplateSyntaxError("Invalid keyword argument for '%s' template tag: %s" % (tag, arg))
                        k, v = arg.split('=')
                        if '=' not in arg:
                                raise template.TemplateSyntaxError("Invalid keyword argument for '%s' template tag: %s" % (tag, arg))
                        k, v = arg.split('=')
-                       kwargs[k] = v
+                       kwargs[k] = parser.compile_filter(v)
                
                
-               return EmbedNode(ct, object_pk=object_pk, varname=varname, kwargs=kwargs)
+               try:
+                       int(object_pk)
+               except ValueError:
+                       return EmbedNode(ct, object_pk=parser.compile_filter(object_pk), kwargs=kwargs)
+               else:
+                       return ConstantEmbedNode(ct, object_pk=object_pk, kwargs=kwargs)
 
 
 register.tag('embed', do_embed)
\ No newline at end of file
 
 
 register.tag('embed', do_embed)
\ No newline at end of file
index 874f62f..b79534f 100644 (file)
--- a/tests.py
+++ b/tests.py
@@ -1,9 +1,100 @@
 from django.test import TestCase
 from django import template
 from django.conf import settings
 from django.test import TestCase
 from django import template
 from django.conf import settings
+from django.template import loader
+from django.template.loaders import cached
 from philo.exceptions import AncestorDoesNotExist
 from philo.models import Node, Page, Template
 from philo.contrib.penfield.models import Blog, BlogView, BlogEntry
 from philo.exceptions import AncestorDoesNotExist
 from philo.models import Node, Page, Template
 from philo.contrib.penfield.models import Blog, BlogView, BlogEntry
+import sys, traceback
+
+
+class TemplateTestCase(TestCase):
+       fixtures = ['test_fixtures.json']
+       
+       def test_templates(self):
+               "Tests to make sure that embed behaves with complex includes and extends"
+               template_tests = self.get_template_tests()
+               
+               # Register our custom template loader. Shamelessly cribbed from django core regressiontests.
+               def test_template_loader(template_name, template_dirs=None):
+                       "A custom template loader that loads the unit-test templates."
+                       try:
+                               return (template_tests[template_name][0] , "test:%s" % template_name)
+                       except KeyError:
+                               raise template.TemplateDoesNotExist, template_name
+               
+               cache_loader = cached.Loader(('test_template_loader',))
+               cache_loader._cached_loaders = (test_template_loader,)
+               
+               old_template_loaders = loader.template_source_loaders
+               loader.template_source_loaders = [cache_loader]
+               
+               # Turn TEMPLATE_DEBUG off, because tests assume that.
+               old_td, settings.TEMPLATE_DEBUG = settings.TEMPLATE_DEBUG, False
+               
+               # Set TEMPLATE_STRING_IF_INVALID to a known string.
+               old_invalid = settings.TEMPLATE_STRING_IF_INVALID
+               expected_invalid_str = 'INVALID'
+               
+               failures = []
+               
+               # Run tests
+               for name, vals in template_tests.items():
+                       xx, context, result = vals
+                       try:
+                               test_template = loader.get_template(name)
+                               output = test_template.render(template.Context(context))
+                       except Exception:
+                               exc_type, exc_value, exc_tb = sys.exc_info()
+                               if exc_type != result:
+                                       tb = '\n'.join(traceback.format_exception(exc_type, exc_value, exc_tb))
+                                       failures.append("Template test %s -- FAILED. Got %s, exception: %s\n%s" % (name, exc_type, exc_value, tb))
+                               continue
+                       if output != result:
+                               failures.append("Template test %s -- FAILED. Expected %r, got %r" % (name, result, output))
+               
+               # Cleanup
+               settings.TEMPLATE_DEBUG = old_td
+               settings.TEMPLATE_STRING_IF_INVALID = old_invalid
+               loader.template_source_loaders = old_template_loaders
+               
+               self.assertEqual(failures, [], "Tests failed:\n%s\n%s" % ('-'*70, ("\n%s\n" % ('-'*70)).join(failures)))
+       
+       
+       def get_template_tests(self):
+               # SYNTAX --
+               # 'template_name': ('template contents', 'context dict', 'expected string output' or Exception class)
+               blog = Blog.objects.all()[0]
+               return {
+                       # EMBED INCLUSION HANDLING
+                       
+                       'embed01': ('{{ embedded.title|safe }}', {'embedded': blog}, blog.title),
+                       'embed02': ('{{ embedded.title|safe }}{{ var1 }}{{ var2 }}', {'embedded': blog}, blog.title),
+                       'embed03': ('{{ embedded.title|safe }} is a lie!', {'embedded': blog}, '%s is a lie!' % blog.title),
+                       
+                       # Simple template structure with embed
+                       'simple01': ('{% embed penfield.blog with "embed01" %}{% embed penfield.blog 1 %}Simple{% block one %}{% endblock %}', {'blog': blog}, '%sSimple' % blog.title),
+                       'simple02': ('{% extends "simple01" %}', {}, '%sSimple' % blog.title),
+                       'simple03': ('{% embed penfield.blog with "embed000" %}', {}, settings.TEMPLATE_STRING_IF_INVALID),
+                       'simple04': ('{% embed penfield.blog 1 %}', {}, settings.TEMPLATE_STRING_IF_INVALID),
+                       
+                       # Kwargs
+                       'kwargs01': ('{% embed penfield.blog with "embed02" %}{% embed penfield.blog 1 var1="hi" var2=lo %}', {'lo': 'lo'}, '%shilo' % blog.title),
+                       
+                       # Filters/variables
+                       'filters01': ('{% embed penfield.blog with "embed02" %}{% embed penfield.blog 1 var1=hi|first var2=lo|slice:"3" %}', {'hi': ["These", "words"], 'lo': 'lower'}, '%sTheselow' % blog.title),
+                       'filters02': ('{% embed penfield.blog with "embed01" %}{% embed penfield.blog entry %}', {'entry': 1}, blog.title),
+                       
+                       # Blocky structure
+                       'block01': ('{% block one %}Hello{% endblock %}', {}, 'Hello'),
+                       'block02': ('{% extends "simple01" %}{% block one %}{% embed penfield.blog 1 %}{% endblock %}', {}, "%sSimple%s" % (blog.title, blog.title)),
+                       'block03': ('{% extends "simple01" %}{% embed penfield.blog with "embed03" %}{% block one %}{% embed penfield.blog 1 %}{% endblock %}', {}, "%sSimple%s is a lie!" % (blog.title, blog.title)),
+                       
+                       # Blocks and includes
+                       'block-include01': ('{% extends "simple01" %}{% embed penfield.blog with "embed03" %}{% block one %}{% include "simple01" %}{% embed penfield.blog 1 %}{% endblock %}', {}, "%sSimple%sSimple%s is a lie!" % (blog.title, blog.title, blog.title)),
+                       'block-include02': ('{% extends "simple01" %}{% block one %}{% include "simple04" %}{% embed penfield.blog with "embed03" %}{% include "simple04" %}{% embed penfield.blog 1 %}{% endblock %}', {}, "%sSimple%s%s is a lie!%s is a lie!" % (blog.title, blog.title, blog.title, blog.title)),
+               }
 
 
 class NodeURLTestCase(TestCase):
 
 
 class NodeURLTestCase(TestCase):