Added methods and functions to support syncing embedded models in post-save.
authorStephen Burrows <stephen.r.burrows@gmail.com>
Mon, 27 Sep 2010 21:42:47 +0000 (17:42 -0400)
committerStephen Burrows <stephen.r.burrows@gmail.com>
Mon, 4 Oct 2010 14:46:06 +0000 (10:46 -0400)
contrib/penfield/embed.py
contrib/penfield/templatetags/embed.py

index b279fc7..ea0199f 100644 (file)
@@ -32,7 +32,48 @@ class TemplateField(models.TextField):
                return Template(value)
 
 
+class Embed(models.Model):
+       embedder_content_type = models.ForeignKey(ContentType, related_name="embedder_related")
+       embedder_object_id = models.PositiveIntegerField()
+       embedder = generic.GenericForeignKey("embedder_content_type", "embedder_object_id")
+       
+       embedded_content_type = models.ForeignKey(ContentType, related_name="embedded_related")
+       embedded_object_id = models.PositiveIntegerField()
+       embedded = generic.GenericForeignKey("embedded_content_type", "embedded_object_id")
+       
+       def delete(self):
+               # Unclear whether this would be called by a cascading deletion.
+               super(Embed, self).delete()
+               # Cycle through all the fields in the embedder and remove all references to the embedded object.
+       
+       def get_embed_tag(self):
+               """Convenience function to construct the embed tag that would create this instance."""
+               ct = self.embedded_content_type
+               return "{%% embed %s.%s %s %%}" % (ct.app_label, ct.model, self.embedded_object_id)
+       
+       class Meta:
+               app_label = 'penfield'
+
+
+def sync_embedded_objects(model_instance, embedded_instances):
+       # First, fetch all current embeds.
+       model_instance_ct = ContentType.objects.get_for_model(model_instance)
+       current_embeds = Embed.objects.filter()
+       
+       new_embed_pks = []
+       for embedded_instance in embedded_instances:
+               embedded_instance_ct = ContentType.objects.get_for_model(embedded_instance)
+               new_embed = Embed.objects.get_or_create(embedder_content_type=model_instance_ct, embedder_object_id=model_instance.id, embedded_content_type=embedded_instance_ct, embedded_object_id=embedded_instance.id)[0]
+               new_embed_pks.append(new_embed.pk)
+       
+       # Then, delete all embed objects related to this model instance which do not relate
+       # to one of the embedded instances.
+       Embed.objects.filter(embedder_content_type=model_instance_ct, embedder_object_id=model_instance.id).exclude(pk__in=new_embed_pks).delete()
+
+
 class EmbedField(TemplateField):
+       _embedded_instances = set()
+       
        def validate_template(self, template):
                """Check to be sure that the embedded instances and templates all exist."""
                for node in template.nodelist:
@@ -47,30 +88,15 @@ class EmbedField(TemplateField):
                                        embedded_template = loader.get_template(node.template_name)
                                        self.validate_template(embedded_template)
                                elif node.object_pk is not None:
-                                       embedded_instance = node.model.objects.get(pk=node.object_pk)
-       
-       def to_template(self, value):
-               return Template("{% load embed %}" + value)
-
-
-class Embed(models.Model):
-       embedder_embed_field = models.CharField(max_length=255)
-       
-       embedder_contenttype = models.ForeignKey(ContentType, related_name="embedder_related")
-       embedder_object_id = models.PositiveIntegerField()
-       embedder = generic.GenericForeignKey("embedder_contenttype", "embedder_object_id")
-       
-       embedded_contenttype = models.ForeignKey(ContentType, related_name="embedded_related")
-       embedded_object_id = models.PositiveIntegerField()
-       embedded = generic.GenericForeignKey("embedded_contenttype", "embedded_object_id")
+                                       self._embedded_instances.add(node.model.objects.get(pk=node.object_pk))
        
-       def delete(self):
-               # Unclear whether this would be called by a cascading deletion.
-               
-               super(Embed, self).delete()
+       #def to_template(self, value):
+       #       return Template("{% load embed %}" + value)
        
-       class Meta:
-               app_label = 'penfield'
+       def pre_save(self, model_instance, add):
+               if not hasattr(model_instance, '_embedded_instances'):
+                       model_instance._embedded_instances = set()
+               model_instance._embedded_instances |= self._embedded_instances
 
 
 class Test(models.Model):
index abdb1de..f8ab3f0 100644 (file)
@@ -9,7 +9,6 @@ register = template.Library()
 class EmbedNode(template.Node):
        def __init__(self, model, varname, object_pk=None, template_name=None):
                assert template_name is not None or object_pk is not None
-               
                app_label, model = model.split('.')
                self.model = ContentType.objects.get(app_label=app_label, model=model).model_class()
                self.varname = varname