From 28181e4d0556369aa673ae90f10e1c890797bde4 Mon Sep 17 00:00:00 2001 From: Stephen Burrows Date: Mon, 27 Sep 2010 17:42:47 -0400 Subject: [PATCH] Added methods and functions to support syncing embedded models in post-save. --- contrib/penfield/embed.py | 70 ++++++++++++++++++-------- contrib/penfield/templatetags/embed.py | 1 - 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/contrib/penfield/embed.py b/contrib/penfield/embed.py index b279fc7..ea0199f 100644 --- a/contrib/penfield/embed.py +++ b/contrib/penfield/embed.py @@ -32,7 +32,48 @@ class TemplateField(models.TextField): return Template(value) +class Embed(models.Model): + embedder_content_type = models.ForeignKey(ContentType, related_name="embedder_related") + embedder_object_id = models.PositiveIntegerField() + embedder = generic.GenericForeignKey("embedder_content_type", "embedder_object_id") + + embedded_content_type = models.ForeignKey(ContentType, related_name="embedded_related") + embedded_object_id = models.PositiveIntegerField() + embedded = generic.GenericForeignKey("embedded_content_type", "embedded_object_id") + + def delete(self): + # Unclear whether this would be called by a cascading deletion. + super(Embed, self).delete() + # Cycle through all the fields in the embedder and remove all references to the embedded object. + + def get_embed_tag(self): + """Convenience function to construct the embed tag that would create this instance.""" + ct = self.embedded_content_type + return "{%% embed %s.%s %s %%}" % (ct.app_label, ct.model, self.embedded_object_id) + + class Meta: + app_label = 'penfield' + + +def sync_embedded_objects(model_instance, embedded_instances): + # First, fetch all current embeds. + model_instance_ct = ContentType.objects.get_for_model(model_instance) + current_embeds = Embed.objects.filter() + + new_embed_pks = [] + for embedded_instance in embedded_instances: + embedded_instance_ct = ContentType.objects.get_for_model(embedded_instance) + new_embed = Embed.objects.get_or_create(embedder_content_type=model_instance_ct, embedder_object_id=model_instance.id, embedded_content_type=embedded_instance_ct, embedded_object_id=embedded_instance.id)[0] + new_embed_pks.append(new_embed.pk) + + # Then, delete all embed objects related to this model instance which do not relate + # to one of the embedded instances. + Embed.objects.filter(embedder_content_type=model_instance_ct, embedder_object_id=model_instance.id).exclude(pk__in=new_embed_pks).delete() + + class EmbedField(TemplateField): + _embedded_instances = set() + def validate_template(self, template): """Check to be sure that the embedded instances and templates all exist.""" for node in template.nodelist: @@ -47,30 +88,15 @@ class EmbedField(TemplateField): embedded_template = loader.get_template(node.template_name) self.validate_template(embedded_template) elif node.object_pk is not None: - embedded_instance = node.model.objects.get(pk=node.object_pk) - - def to_template(self, value): - return Template("{% load embed %}" + value) - - -class Embed(models.Model): - embedder_embed_field = models.CharField(max_length=255) - - embedder_contenttype = models.ForeignKey(ContentType, related_name="embedder_related") - embedder_object_id = models.PositiveIntegerField() - embedder = generic.GenericForeignKey("embedder_contenttype", "embedder_object_id") - - embedded_contenttype = models.ForeignKey(ContentType, related_name="embedded_related") - embedded_object_id = models.PositiveIntegerField() - embedded = generic.GenericForeignKey("embedded_contenttype", "embedded_object_id") + self._embedded_instances.add(node.model.objects.get(pk=node.object_pk)) - def delete(self): - # Unclear whether this would be called by a cascading deletion. - - super(Embed, self).delete() + #def to_template(self, value): + # return Template("{% load embed %}" + value) - class Meta: - app_label = 'penfield' + def pre_save(self, model_instance, add): + if not hasattr(model_instance, '_embedded_instances'): + model_instance._embedded_instances = set() + model_instance._embedded_instances |= self._embedded_instances class Test(models.Model): diff --git a/contrib/penfield/templatetags/embed.py b/contrib/penfield/templatetags/embed.py index abdb1de..f8ab3f0 100644 --- a/contrib/penfield/templatetags/embed.py +++ b/contrib/penfield/templatetags/embed.py @@ -9,7 +9,6 @@ register = template.Library() class EmbedNode(template.Node): def __init__(self, model, varname, object_pk=None, template_name=None): assert template_name is not None or object_pk is not None - app_label, model = model.split('.') self.model = ContentType.objects.get(app_label=app_label, model=model).model_class() self.varname = varname -- 2.20.1