Updated page admin/add form to use the same method for redirection as the 1.3 user...
[philo.git] / tests.py
index 81417ff..a0e0184 100644 (file)
--- a/tests.py
+++ b/tests.py
@@ -1,12 +1,17 @@
-from django.test import TestCase
+import sys
+import traceback
+
 from django import template
 from django.conf import settings
+from django.db import connection
 from django.template import loader
 from django.template.loaders import cached
+from django.test import TestCase
+from django.test.utils import setup_test_template_loader
+
+from philo.contrib.penfield.models import Blog, BlogView, BlogEntry
 from philo.exceptions import AncestorDoesNotExist
 from philo.models import Node, Page, Template
-from philo.contrib.penfield.models import Blog, BlogView, BlogEntry
-import sys, traceback
 
 
 class TemplateTestCase(TestCase):
@@ -16,19 +21,15 @@ class TemplateTestCase(TestCase):
                "Tests to make sure that embed behaves with complex includes and extends"
                template_tests = self.get_template_tests()
                
-               # Register our custom template loader. Shamelessly cribbed from django core regressiontests.
-               def test_template_loader(template_name, template_dirs=None):
-                       "A custom template loader that loads the unit-test templates."
-                       try:
-                               return (template_tests[template_name][0] , "test:%s" % template_name)
-                       except KeyError:
-                               raise template.TemplateDoesNotExist, template_name
-               
-               cache_loader = cached.Loader(('test_template_loader',))
-               cache_loader._cached_loaders = (test_template_loader,)
+               # Register our custom template loader. Shamelessly cribbed from django/tests/regressiontests/templates/tests.py:384.
+               cache_loader = setup_test_template_loader(
+                       dict([(name, t[0]) for name, t in template_tests.iteritems()]),
+                       use_cached_loader=True,
+               )
                
-               old_template_loaders = loader.template_source_loaders
-               loader.template_source_loaders = [cache_loader]
+               failures = []
+               tests = template_tests.items()
+               tests.sort()
                
                # Turn TEMPLATE_DEBUG off, because tests assume that.
                old_td, settings.TEMPLATE_DEBUG = settings.TEMPLATE_DEBUG, False
@@ -37,10 +38,8 @@ class TemplateTestCase(TestCase):
                old_invalid = settings.TEMPLATE_STRING_IF_INVALID
                expected_invalid_str = 'INVALID'
                
-               failures = []
-               
                # Run tests
-               for name, vals in template_tests.items():
+               for name, vals in tests:
                        xx, context, result = vals
                        try:
                                test_template = loader.get_template(name)
@@ -95,6 +94,13 @@ class TemplateTestCase(TestCase):
                        # Blocks and includes
                        'block-include01': ('{% extends "simple01" %}{% embed penfield.blog with "embed03" %}{% block one %}{% include "simple01" %}{% embed penfield.blog 1 %}{% endblock %}', {}, "%sSimple%sSimple%s is a lie!" % (blog.title, blog.title, blog.title)),
                        'block-include02': ('{% extends "simple01" %}{% block one %}{% include "simple04" %}{% embed penfield.blog with "embed03" %}{% include "simple04" %}{% embed penfield.blog 1 %}{% endblock %}', {}, "%sSimple%s%s is a lie!%s is a lie!" % (blog.title, blog.title, blog.title, blog.title)),
+                       
+                       # Tests for more complex situations...
+                       'complex01': ('{% block one %}{% endblock %}complex{% block two %}{% endblock %}', {}, 'complex'),
+                       'complex02': ('{% extends "complex01" %}', {}, 'complex'),
+                       'complex03': ('{% extends "complex02" %}{% embed penfield.blog with "embed01" %}', {}, 'complex'),
+                       'complex04': ('{% extends "complex03" %}{% block one %}{% embed penfield.blog 1 %}{% endblock %}', {}, '%scomplex' % blog.title),
+                       'complex05': ('{% extends "complex03" %}{% block one %}{% include "simple04" %}{% endblock %}', {}, '%scomplex' % blog.title),
                }
 
 
@@ -110,25 +116,25 @@ class NodeURLTestCase(TestCase):
                        command.handle(all_apps=True)
                
                self.templates = [
-                               ("{% node_url %}", "/root/never/"),
-                               ("{% node_url for node2 %}", "/root/blog/"),
-                               ("{% node_url as hello %}<p>{{ hello|slice:'1:' }}</p>", "<p>root/never/</p>"),
-                               ("{% node_url for nodes|first %}", "/root/never/"),
+                               ("{% node_url %}", "/root/second/"),
+                               ("{% node_url for node2 %}", "/root/second2/"),
+                               ("{% node_url as hello %}<p>{{ hello|slice:'1:' }}</p>", "<p>root/second/</p>"),
+                               ("{% node_url for nodes|first %}", "/root/"),
                                ("{% node_url with entry %}", settings.TEMPLATE_STRING_IF_INVALID),
-                               ("{% node_url with entry for node2 %}", "/root/blog/2010/10/20/first-entry"),
-                               ("{% node_url with tag for node2 %}", "/root/blog/tags/test-tag/"),
-                               ("{% node_url with date for node2 %}", "/root/blog/2010/10/20"),
-                               ("{% node_url entries_by_day year=date|date:'Y' month=date|date:'m' day=date|date:'d' for node2 as goodbye %}<em>{{ goodbye|upper }}</em>", "<em>/ROOT/BLOG/2010/10/20</em>"),
-                               ("{% node_url entries_by_month year=date|date:'Y' month=date|date:'m' for node2 %}", "/root/blog/2010/10"),
-                               ("{% node_url entries_by_year year=date|date:'Y' for node2 %}", "/root/blog/2010/"),
+                               ("{% node_url with entry for node2 %}", "/root/second2/2010/10/20/first-entry"),
+                               ("{% node_url with tag for node2 %}", "/root/second2/tags/test-tag/"),
+                               ("{% node_url with date for node2 %}", "/root/second2/2010/10/20"),
+                               ("{% node_url entries_by_day year=date|date:'Y' month=date|date:'m' day=date|date:'d' for node2 as goodbye %}<em>{{ goodbye|upper }}</em>", "<em>/ROOT/SECOND2/2010/10/20</em>"),
+                               ("{% node_url entries_by_month year=date|date:'Y' month=date|date:'m' for node2 %}", "/root/second2/2010/10"),
+                               ("{% node_url entries_by_year year=date|date:'Y' for node2 %}", "/root/second2/2010/"),
                ]
                
                nodes = Node.objects.all()
                blog = Blog.objects.all()[0]
                
                self.context = template.Context({
-                       'node': nodes[0],
-                       'node2': nodes[1],
+                       'node': nodes.get(slug='second'),
+                       'node2': nodes.get(slug='second2'),
                        'nodes': nodes,
                        'entry': BlogEntry.objects.all()[0],
                        'tag': blog.entry_tags.all()[0],
@@ -149,52 +155,71 @@ class TreePathTestCase(TestCase):
                        command = Command()
                        command.handle(all_apps=True)
        
-       def test_has_ancestor(self):
+       def assertQueryLimit(self, max, expected_result, *args, **kwargs):
+               # As a rough measure of efficiency, limit the number of queries required for a given operation.
+               settings.DEBUG = True
+               call = kwargs.pop('callable', Node.objects.get_with_path)
+               try:
+                       queries = len(connection.queries)
+                       if isinstance(expected_result, type) and issubclass(expected_result, Exception):
+                               self.assertRaises(expected_result, call, *args, **kwargs)
+                       else:
+                               self.assertEqual(call(*args, **kwargs), expected_result)
+                       queries = len(connection.queries) - queries
+                       if queries > max:
+                               raise AssertionError('"%d" unexpectedly not less than or equal to "%s"' % (queries, max))
+               finally:
+                       settings.DEBUG = False
+       
+       def test_get_with_path(self):
                root = Node.objects.get(slug='root')
                third = Node.objects.get(slug='third')
-               r1 = Node.objects.get(slug='recursive1')
-               r2 = Node.objects.get(slug='recursive2')
-               pr1 = Node.objects.get(slug='postrecursive1')
-               
-               # Simple case: straight path
-               self.assertEqual(third.has_ancestor(root), True)
-               self.assertEqual(root.has_ancestor(root), False)
-               self.assertEqual(root.has_ancestor(None), True)
-               self.assertEqual(third.has_ancestor(None), True)
-               self.assertEqual(root.has_ancestor(root, inclusive=True), True)
-               
-               # Recursive case
-               self.assertEqual(r1.has_ancestor(r1), True)
-               self.assertEqual(r1.has_ancestor(r2), True)
-               self.assertEqual(r2.has_ancestor(r1), True)
-               self.assertEqual(r2.has_ancestor(None), False)
-               
-               # Post-recursive case
-               self.assertEqual(pr1.has_ancestor(r1), True)
-               self.assertEqual(pr1.has_ancestor(pr1), False)
-               self.assertEqual(pr1.has_ancestor(pr1, inclusive=True), True)
-               self.assertEqual(pr1.has_ancestor(None), False)
-               self.assertEqual(pr1.has_ancestor(root), False)
+               second2 = Node.objects.get(slug='second2')
+               fifth = Node.objects.get(slug='fifth')
+               e = Node.DoesNotExist
+               
+               # Empty segments
+               self.assertQueryLimit(0, root, '', root=root)
+               self.assertQueryLimit(0, e, '')
+               self.assertQueryLimit(0, (root, None), '', root=root, absolute_result=False)
+               
+               # Absolute result
+               self.assertQueryLimit(1, third, 'root/second/third')
+               self.assertQueryLimit(1, third, 'second/third', root=root)
+               self.assertQueryLimit(1, third, 'root//////second/third///')
+               
+               self.assertQueryLimit(1, e, 'root/secont/third')
+               self.assertQueryLimit(1, e, 'second/third')
+               
+               # Non-absolute result (binary search)
+               self.assertQueryLimit(2, (second2, 'sub/path/tail'), 'root/second2/sub/path/tail', absolute_result=False)
+               self.assertQueryLimit(3, (second2, 'sub/'), 'root/second2/sub/', absolute_result=False)
+               self.assertQueryLimit(2, e, 'invalid/path/1/2/3/4/5/6/7/8/9/1/2/3/4/5/6/7/8/9/0', absolute_result=False)
+               self.assertQueryLimit(1, (root, None), 'root', absolute_result=False)
+               self.assertQueryLimit(2, (second2, None), 'root/second2', absolute_result=False)
+               self.assertQueryLimit(3, (third, None), 'root/second/third', absolute_result=False)
+               
+               # with root != None
+               self.assertQueryLimit(1, (second2, None), 'second2', root=root, absolute_result=False)
+               self.assertQueryLimit(2, (third, None), 'second/third', root=root, absolute_result=False)
+               
+               # Preserve trailing slash
+               self.assertQueryLimit(2, (second2, 'sub/path/tail/'), 'root/second2/sub/path/tail/', absolute_result=False)
+               
+               # Speed increase for leaf nodes - should this be tested?
+               self.assertQueryLimit(1, (fifth, 'sub/path/tail/len/five'), 'root/second/third/fourth/fifth/sub/path/tail/len/five', absolute_result=False)
        
        def test_get_path(self):
                root = Node.objects.get(slug='root')
+               root2 = Node.objects.get(slug='root')
                third = Node.objects.get(slug='third')
-               r1 = Node.objects.get(slug='recursive1')
-               r2 = Node.objects.get(slug='recursive2')
-               pr1 = Node.objects.get(slug='postrecursive1')
-               
-               # Simple case: straight path to None
-               self.assertEqual(root.get_path(), 'root')
-               self.assertEqual(third.get_path(), 'root/never/more/second/third')
-               
-               # Recursive case: Looped path to root None
-               self.assertEqual(r1.get_path(), u'\u2026/recursive1/recursive2/recursive3/recursive1')
-               self.assertEqual(pr1.get_path(), u'\u2026/recursive3/recursive1/recursive2/recursive3/postrecursive1')
-               
-               # Simple error case: straight invalid path
-               self.assertRaises(AncestorDoesNotExist, root.get_path, root=third)
-               self.assertRaises(AncestorDoesNotExist, third.get_path, root=pr1)
-               
-               # Recursive error case
-               self.assertRaises(AncestorDoesNotExist, r1.get_path, root=root)
-               self.assertRaises(AncestorDoesNotExist, pr1.get_path, root=third)
\ No newline at end of file
+               second2 = Node.objects.get(slug='second2')
+               fifth = Node.objects.get(slug='fifth')
+               e = AncestorDoesNotExist
+               
+               self.assertQueryLimit(0, 'root', callable=root.get_path)
+               self.assertQueryLimit(0, '', root2, callable=root.get_path)
+               self.assertQueryLimit(1, 'root/second/third', callable=third.get_path)
+               self.assertQueryLimit(1, 'second/third', root, callable=third.get_path)
+               self.assertQueryLimit(1, e, third, callable=second2.get_path)
+               self.assertQueryLimit(1, '? - ?', root, ' - ', 'title', callable=third.get_path)