Added support for recursive trees - i.e. recursion checks to prevent infinite loops...
authorStephen Burrows <stephen.r.burrows@gmail.com>
Wed, 10 Nov 2010 16:10:53 +0000 (11:10 -0500)
committerStephen Burrows <stephen.r.burrows@gmail.com>
Wed, 10 Nov 2010 16:10:53 +0000 (11:10 -0500)
fixtures/test_fixtures.json
models/base.py
tests.py

index 18f6962..14f5a27 100644 (file)
             ]
         }
     }, 
+    {
+        "pk": 4, 
+        "model": "philo.node", 
+        "fields": {
+            "view_object_id": 1, 
+            "slug": "more", 
+            "parent": 1, 
+            "view_content_type": [
+                "philo", 
+                "page"
+            ]
+        }
+    }, 
+    {
+        "pk": 5, 
+        "model": "philo.node", 
+        "fields": {
+            "view_object_id": 1, 
+            "slug": "second", 
+            "parent": 4, 
+            "view_content_type": [
+                "philo", 
+                "page"
+            ]
+        }
+    }, 
+    {
+        "pk": 6, 
+        "model": "philo.node", 
+        "fields": {
+            "view_object_id": 1, 
+            "slug": "third", 
+            "parent": 5, 
+            "view_content_type": [
+                "philo", 
+                "page"
+            ]
+        }
+    }, 
+    {
+        "pk": 7, 
+        "model": "philo.node", 
+        "fields": {
+            "view_object_id": 1, 
+            "slug": "recursive1", 
+            "parent": 9, 
+            "view_content_type": [
+                "philo", 
+                "page"
+            ]
+        }
+    }, 
+    {
+        "pk": 8, 
+        "model": "philo.node", 
+        "fields": {
+            "view_object_id": 1, 
+            "slug": "recursive2", 
+            "parent": 7, 
+            "view_content_type": [
+                "philo", 
+                "page"
+            ]
+        }
+    }, 
+    {
+        "pk": 9, 
+        "model": "philo.node", 
+        "fields": {
+            "view_object_id": 1, 
+            "slug": "recursive3", 
+            "parent": 8, 
+            "view_content_type": [
+                "philo", 
+                "page"
+            ]
+        }
+    }, 
+    {
+        "pk": 10, 
+        "model": "philo.node", 
+        "fields": {
+            "view_object_id": 1, 
+            "slug": "postrecursive1", 
+            "parent": 9, 
+            "view_content_type": [
+                "philo", 
+                "page"
+            ]
+        }
+    }, 
     {
         "pk": 1, 
         "model": "philo.redirect", 
index ae15d16..2338d72 100644 (file)
@@ -368,24 +368,47 @@ class TreeModel(models.Model):
        parent = models.ForeignKey('self', related_name='children', null=True, blank=True)
        slug = models.SlugField(max_length=255)
        
-       def has_ancestor(self, ancestor):
-               parent = self
+       def has_ancestor(self, ancestor, inclusive=False):
+               if inclusive:
+                       parent = self
+               else:
+                       parent = self.parent
+               
+               parents = []
+               
                while parent:
                        if parent == ancestor:
                                return True
+                       # If we've found this parent before, the path is recursive and ancestor wasn't on it.
+                       if parent in parents:
+                               return False
+                       parents.append(parent)
                        parent = parent.parent
+               # If ancestor is None, catch it here.
+               if parent == ancestor:
+                       return True
                return False
        
        def get_path(self, root=None, pathsep='/', field='slug'):
-               if root is not None and not self.has_ancestor(root):
-                       raise AncestorDoesNotExist(root)
-               
-               path = getattr(self, field, '?')
                parent = self.parent
+               parents = [self]
+               
+               def compile_path(parents):
+                       return pathsep.join([getattr(parent, field, '?') for parent in parents])
+               
                while parent and parent != root:
-                       path = getattr(parent, field, '?') + pathsep + path
+                       if parent in parents:
+                               if root is not None:
+                                       raise AncestorDoesNotExist(root)
+                               parents.append(parent)
+                               return u"\u2026%s%s" % (pathsep, compile_path(parents[::-1]))
+                       parents.append(parent)
                        parent = parent.parent
-               return path
+               
+               if root is not None and parent is None:
+                       raise AncestorDoesNotExist(root)
+               
+               return compile_path(parents[::-1])
        path = property(get_path)
        
        def __unicode__(self):
index d9f743f..874f62f 100644 (file)
--- a/tests.py
+++ b/tests.py
@@ -1,6 +1,7 @@
 from django.test import TestCase
 from django import template
 from django.conf import settings
+from philo.exceptions import AncestorDoesNotExist
 from philo.models import Node, Page, Template
 from philo.contrib.penfield.models import Blog, BlogView, BlogEntry
 
@@ -16,20 +17,18 @@ class NodeURLTestCase(TestCase):
                        command = Command()
                        command.handle(all_apps=True)
                
-               self.templates = [template.Template(string) for string in
-                       [
-                               "{% node_url %}", # 0
-                               "{% node_url for node2 %}", # 1
-                               "{% node_url as hello %}<p>{{ hello|slice:'1:' }}</p>", # 2
-                               "{% node_url for nodes|first %}", # 3
-                               "{% node_url with entry %}", # 4
-                               "{% node_url with entry for node2 %}", # 5
-                               "{% node_url with tag for node2 %}", # 6
-                               "{% node_url with date for node2 %}", # 7
-                               "{% node_url entries_by_day year=date|date:'Y' month=date|date:'m' day=date|date:'d' for node2 as goodbye %}<em>{{ goodbye|upper }}</em>", # 8
-                               "{% node_url entries_by_month year=date|date:'Y' month=date|date:'m' for node2 %}", # 9
-                               "{% node_url entries_by_year year=date|date:'Y' for node2 %}", # 10
-                       ]
+               self.templates = [
+                               ("{% node_url %}", "/root/never/"),
+                               ("{% node_url for node2 %}", "/root/blog/"),
+                               ("{% node_url as hello %}<p>{{ hello|slice:'1:' }}</p>", "<p>root/never/</p>"),
+                               ("{% node_url for nodes|first %}", "/root/never/"),
+                               ("{% node_url with entry %}", settings.TEMPLATE_STRING_IF_INVALID),
+                               ("{% node_url with entry for node2 %}", "/root/blog/2010/10/20/first-entry"),
+                               ("{% node_url with tag for node2 %}", "/root/blog/tags/test-tag/"),
+                               ("{% node_url with date for node2 %}", "/root/blog/2010/10/20"),
+                               ("{% node_url entries_by_day year=date|date:'Y' month=date|date:'m' day=date|date:'d' for node2 as goodbye %}<em>{{ goodbye|upper }}</em>", "<em>/ROOT/BLOG/2010/10/20</em>"),
+                               ("{% node_url entries_by_month year=date|date:'Y' month=date|date:'m' for node2 %}", "/root/blog/2010/10"),
+                               ("{% node_url entries_by_year year=date|date:'Y' for node2 %}", "/root/blog/2010/"),
                ]
                
                nodes = Node.objects.all()
@@ -45,30 +44,65 @@ class NodeURLTestCase(TestCase):
                })
        
        def test_nodeurl(self):
-               for i, template in enumerate(self.templates):
-                       t = template.render(self.context)
-                       
-                       if i == 0:
-                               self.assertEqual(t, "/root/never/")
-                       elif i == 1:
-                               self.assertEqual(t, "/root/blog/")
-                       elif i == 2:
-                               self.assertEqual(t, "<p>root/never/</p>")
-                       elif i == 3:
-                               self.assertEqual(t, "/root/never/")
-                       elif i == 4:
-                               self.assertEqual(t, settings.TEMPLATE_STRING_IF_INVALID)
-                       elif i == 5:
-                               self.assertEqual(t, "/root/blog/2010/10/20/first-entry")
-                       elif i == 6:
-                               self.assertEqual(t, "/root/blog/tags/test-tag/")
-                       elif i == 7:
-                               self.assertEqual(t, "/root/blog/2010/10/20")
-                       elif i == 8:
-                               self.assertEqual(t, "<em>/ROOT/BLOG/2010/10/20</em>")
-                       elif i == 9:
-                               self.assertEqual(t, "/root/blog/2010/10")
-                       elif i == 10:
-                               self.assertEqual(t, "/root/blog/2010/")
-                       else:
-                               print "Rendered as:\n%s\n\n" % t
\ No newline at end of file
+               for string, result in self.templates:
+                       self.assertEqual(template.Template(string).render(self.context), result)
+
+class TreePathTestCase(TestCase):
+       urls = 'philo.urls'
+       fixtures = ['test_fixtures.json']
+       
+       def setUp(self):
+               if 'south' in settings.INSTALLED_APPS:
+                       from south.management.commands.migrate import Command
+                       command = Command()
+                       command.handle(all_apps=True)
+       
+       def test_has_ancestor(self):
+               root = Node.objects.get(slug='root')
+               third = Node.objects.get(slug='third')
+               r1 = Node.objects.get(slug='recursive1')
+               r2 = Node.objects.get(slug='recursive2')
+               pr1 = Node.objects.get(slug='postrecursive1')
+               
+               # Simple case: straight path
+               self.assertEqual(third.has_ancestor(root), True)
+               self.assertEqual(root.has_ancestor(root), False)
+               self.assertEqual(root.has_ancestor(None), True)
+               self.assertEqual(third.has_ancestor(None), True)
+               self.assertEqual(root.has_ancestor(root, inclusive=True), True)
+               
+               # Recursive case
+               self.assertEqual(r1.has_ancestor(r1), True)
+               self.assertEqual(r1.has_ancestor(r2), True)
+               self.assertEqual(r2.has_ancestor(r1), True)
+               self.assertEqual(r2.has_ancestor(None), False)
+               
+               # Post-recursive case
+               self.assertEqual(pr1.has_ancestor(r1), True)
+               self.assertEqual(pr1.has_ancestor(pr1), False)
+               self.assertEqual(pr1.has_ancestor(pr1, inclusive=True), True)
+               self.assertEqual(pr1.has_ancestor(None), False)
+               self.assertEqual(pr1.has_ancestor(root), False)
+       
+       def test_get_path(self):
+               root = Node.objects.get(slug='root')
+               third = Node.objects.get(slug='third')
+               r1 = Node.objects.get(slug='recursive1')
+               r2 = Node.objects.get(slug='recursive2')
+               pr1 = Node.objects.get(slug='postrecursive1')
+               
+               # Simple case: straight path to None
+               self.assertEqual(root.get_path(), 'root')
+               self.assertEqual(third.get_path(), 'root/never/more/second/third')
+               
+               # Recursive case: Looped path to root None
+               self.assertEqual(r1.get_path(), u'\u2026/recursive1/recursive2/recursive3/recursive1')
+               self.assertEqual(pr1.get_path(), u'\u2026/recursive3/recursive1/recursive2/recursive3/postrecursive1')
+               
+               # Simple error case: straight invalid path
+               self.assertRaises(AncestorDoesNotExist, root.get_path, root=third)
+               self.assertRaises(AncestorDoesNotExist, third.get_path, root=pr1)
+               
+               # Recursive error case
+               self.assertRaises(AncestorDoesNotExist, r1.get_path, root=root)
+               self.assertRaises(AncestorDoesNotExist, pr1.get_path, root=third)
\ No newline at end of file