Merge branch 'master' of git://github.com/melinath/philo
[philo.git] / models / base.py
index 20693b7..af1e880 100644 (file)
@@ -5,7 +5,7 @@ from django.contrib.contenttypes import generic
 from django.core.exceptions import ObjectDoesNotExist
 from django.core.validators import RegexValidator
 from django.utils import simplejson as json
-from django.utils.encoding import smart_str
+from django.utils.encoding import force_unicode
 from philo.exceptions import AncestorDoesNotExist
 from philo.models.fields import JSONField
 from philo.utils import ContentTypeRegistryLimiter, ContentTypeSubclassLimiter
@@ -45,6 +45,9 @@ def register_value_model(model):
        value_content_type_limiter.register_class(model)
 
 
+register_value_model(Tag)
+
+
 def unregister_value_model(model):
        value_content_type_limiter.unregister_class(model)
 
@@ -52,10 +55,6 @@ def unregister_value_model(model):
 class AttributeValue(models.Model):
        attribute_set = generic.GenericRelation('Attribute', content_type_field='value_content_type', object_id_field='value_object_id')
        
-       @property
-       def attribute(self):
-               return self.attribute_set.all()[0]
-       
        def set_value(self, value):
                raise NotImplementedError
        
@@ -78,10 +77,10 @@ attribute_value_limiter = ContentTypeSubclassLimiter(AttributeValue)
 
 
 class JSONValue(AttributeValue):
-       value = JSONField(verbose_name='Value (JSON)', help_text='This value must be valid JSON.', default='null')
+       value = JSONField(verbose_name='Value (JSON)', help_text='This value must be valid JSON.', default='null', db_index=True)
        
        def __unicode__(self):
-               return smart_str(self.value)
+               return force_unicode(self.value)
        
        def value_formfields(self):
                kwargs = {'initial': self.value_json}
@@ -101,7 +100,7 @@ class JSONValue(AttributeValue):
 
 class ForeignKeyValue(AttributeValue):
        content_type = models.ForeignKey(ContentType, limit_choices_to=value_content_type_limiter, verbose_name='Value type', null=True, blank=True)
-       object_id = models.PositiveIntegerField(verbose_name='Value ID', null=True, blank=True)
+       object_id = models.PositiveIntegerField(verbose_name='Value ID', null=True, blank=True, db_index=True)
        value = generic.GenericForeignKey()
        
        def value_formfields(self):
@@ -140,41 +139,48 @@ class ManyToManyValue(AttributeValue):
        content_type = models.ForeignKey(ContentType, limit_choices_to=value_content_type_limiter, verbose_name='Value type', null=True, blank=True)
        values = models.ManyToManyField(ForeignKeyValue, blank=True, null=True)
        
-       def get_object_id_list(self):
-               if not self.values.count():
-                       return []
-               else:
-                       return self.values.values_list('object_id', flat=True)
-       
-       def get_value(self):
-               if self.content_type is None:
-                       return None
-               
-               return self.content_type.model_class()._default_manager.filter(id__in=self.get_object_id_list())
+       def get_object_ids(self):
+               return self.values.values_list('object_id', flat=True)
+       object_ids = property(get_object_ids)
        
        def set_value(self, value):
-               # Value is probably a queryset - but allow any iterable.
+               # Value must be a queryset. Watch out for ModelMultipleChoiceField;
+               # it returns its value as a list if empty.
                
-               # These lines shouldn't be necessary; however, if value is an EmptyQuerySet,
-               # the code (specifically the object_id__in query) won't work without them. Unclear why...
-               if not value:
-                       value = []
+               self.content_type = ContentType.objects.get_for_model(value.model)
                
                # Before we can fiddle with the many-to-many to foreignkeyvalues, we need
                # a pk.
                if self.pk is None:
                        self.save()
                
-               if isinstance(value, models.query.QuerySet):
-                       value = value.values_list('id', flat=True)
+               object_ids = value.values_list('id', flat=True)
                
-               self.values.filter(~models.Q(object_id__in=value)).delete()
-               current = self.get_object_id_list()
+               # These lines shouldn't be necessary; however, if object_ids is an EmptyQuerySet,
+               # the code (specifically the object_id__in query) won't work without them. Unclear why...
+               # TODO: is this still the case?
+               if not object_ids:
+                       self.values.all().delete()
+               else:
+                       self.values.exclude(object_id__in=object_ids, content_type=self.content_type).delete()
+                       
+                       current_ids = self.object_ids
+                       
+                       for object_id in object_ids:
+                               if object_id in current_ids:
+                                       continue
+                               self.values.create(content_type=self.content_type, object_id=object_id)
+       
+       def get_value(self):
+               if self.content_type is None:
+                       return None
                
-               for v in value:
-                       if v in current:
-                               continue
-                       self.values.create(content_type=self.content_type, object_id=v)
+               # HACK to be safely explicit until http://code.djangoproject.com/ticket/15145 is resolved
+               object_ids = self.object_ids
+               manager = self.content_type.model_class()._default_manager
+               if not object_ids:
+                       return manager.none()
+               return manager.filter(id__in=self.object_ids)
        
        value = property(get_value, set_value)
        
@@ -184,7 +190,7 @@ class ManyToManyValue(AttributeValue):
                
                if self.content_type:
                        kwargs = {
-                               'initial': self.get_object_id_list(),
+                               'initial': self.object_ids,
                                'required': False,
                                'queryset': self.content_type.model_class()._default_manager.all()
                        }
@@ -198,7 +204,9 @@ class ManyToManyValue(AttributeValue):
                        self.values.clear()
                        self.content_type = ct
                else:
-                       value = kwargs.get('value', self.content_type.model_class()._default_manager.none())
+                       value = kwargs.get('value', None)
+                       if not value:
+                               value = self.content_type.model_class()._default_manager.none()
                        self.set_value(value)
        construct_instance.alters_data = True
        
@@ -208,14 +216,14 @@ class ManyToManyValue(AttributeValue):
 
 class Attribute(models.Model):
        entity_content_type = models.ForeignKey(ContentType, related_name='attribute_entity_set', verbose_name='Entity type')
-       entity_object_id = models.PositiveIntegerField(verbose_name='Entity ID')
+       entity_object_id = models.PositiveIntegerField(verbose_name='Entity ID', db_index=True)
        entity = generic.GenericForeignKey('entity_content_type', 'entity_object_id')
        
        value_content_type = models.ForeignKey(ContentType, related_name='attribute_value_set', limit_choices_to=attribute_value_limiter, verbose_name='Value type', null=True, blank=True)
-       value_object_id = models.PositiveIntegerField(verbose_name='Value ID', null=True, blank=True)
+       value_object_id = models.PositiveIntegerField(verbose_name='Value ID', null=True, blank=True, db_index=True)
        value = generic.GenericForeignKey('value_content_type', 'value_object_id')
        
-       key = models.CharField(max_length=255, validators=[RegexValidator("\w+")], help_text="Must contain one or more alphanumeric characters or underscores.")
+       key = models.CharField(max_length=255, validators=[RegexValidator("\w+")], help_text="Must contain one or more alphanumeric characters or underscores.", db_index=True)
        
        def __unicode__(self):
                return u'"%s": %s' % (self.key, self.value)
@@ -263,9 +271,9 @@ class EntityOptions(object):
 
 class EntityBase(models.base.ModelBase):
        def __new__(cls, name, bases, attrs):
+               entity_meta = attrs.pop('EntityMeta', None)
                new = super(EntityBase, cls).__new__(cls, name, bases, attrs)
-               entity_options = attrs.pop('EntityMeta', None)
-               setattr(new, '_entity_meta', EntityOptions(entity_options))
+               new.add_to_class('_entity_meta', EntityOptions(entity_meta))
                entity_class_prepared.send(sender=new)
                return new
 
@@ -303,11 +311,6 @@ class TreeManager(models.Manager):
                # tree structure won't be that deep.
                segments = path.split(pathsep)
                
-               # Check for a trailing pathsep so we can restore it later.
-               trailing_pathsep = False
-               if segments[-1] == '':
-                       trailing_pathsep = True
-               
                # Clean out blank segments. Handles multiple consecutive pathseps.
                while True:
                        try:
@@ -339,12 +342,6 @@ class TreeManager(models.Manager):
                        
                        return kwargs
                
-               def build_path(segments):
-                       path = pathsep.join(segments)
-                       if trailing_pathsep and segments and segments[-1] != '':
-                               path += pathsep
-                       return path
-               
                def find_obj(segments, depth, deepest_found=None):
                        if deepest_found is None:
                                deepest_level = 0
@@ -365,7 +362,7 @@ class TreeManager(models.Manager):
                                if deepest_level == depth:
                                        # This should happen if nothing is found with any part of the given path.
                                        if root is not None and deepest_found is None:
-                                               return root, build_path(segments)
+                                               return root, pathsep.join(segments)
                                        raise
                                
                                return find_obj(segments, depth, deepest_found)
@@ -378,7 +375,7 @@ class TreeManager(models.Manager):
                                
                                # Could there be a deeper one?
                                if obj.is_leaf_node():
-                                       return obj, build_path(segments[deepest_level:]) or None
+                                       return obj, pathsep.join(segments[deepest_level:]) or None
                                
                                depth += (len(segments) - depth)/2 or len(segments) - depth
                                
@@ -386,13 +383,13 @@ class TreeManager(models.Manager):
                                        depth = deepest_level + obj.get_descendant_count()
                                
                                if deepest_level == depth:
-                                       return obj, build_path(segments[deepest_level:]) or None
+                                       return obj, pathsep.join(segments[deepest_level:]) or None
                                
                                try:
                                        return find_obj(segments, depth, obj)
                                except self.model.DoesNotExist:
                                        # Then this was the deepest.
-                                       return obj, build_path(segments[deepest_level:])
+                                       return obj, pathsep.join(segments[deepest_level:])
                
                if absolute_result:
                        return self.get(**make_query_kwargs(segments, root))
@@ -416,12 +413,12 @@ class TreeModel(MPTTModel):
                if root is not None and not self.is_descendant_of(root):
                        raise AncestorDoesNotExist(root)
                
-               qs = self.get_ancestors()
+               qs = self.get_ancestors(include_self=True)
                
                if root is not None:
                        qs = qs.filter(**{'%s__gt' % self._mptt_meta.level_attr: root.get_level()})
                
-               return pathsep.join([getattr(parent, field, '?') for parent in list(qs) + [self]])
+               return pathsep.join([getattr(parent, field, '?') for parent in qs])
        path = property(get_path)
        
        def __unicode__(self):