diff --git a/mygpo/administration/group.py b/mygpo/administration/group.py index 0adaf98b6..95729a1ef 100644 --- a/mygpo/administration/group.py +++ b/mygpo/administration/group.py @@ -29,15 +29,18 @@ def __get_episodes(self): def group(self, get_features): + """ Groups the episodes by features extracted using ``get_features`` + + get_features is a callable that expects an episode as parameter, and + returns a value representing the extracted feature(s). + """ episodes = self.__get_episodes() episode_groups = defaultdict(list) - episode_features = map(get_features, episodes.items()) - - for features, episode_id in episode_features: - episode = episodes[episode_id] + for episode in episodes.values(): + features = get_features(episode) episode_groups[features].append(episode) groups = sorted(episode_groups.values(), key=_SORT_KEY) diff --git a/mygpo/administration/templates/admin/merge-grouping.html b/mygpo/administration/templates/admin/merge-grouping.html index 99bbcc526..dc4354884 100644 --- a/mygpo/administration/templates/admin/merge-grouping.html +++ b/mygpo/administration/templates/admin/merge-grouping.html @@ -47,7 +47,7 @@

{% trans "Merge Podcasts and Episodes" %}

{% for episode in episodes %} {% if episode.podcast.get_id == podcast.get_id %} - + {% episode_link episode podcast %}
{% endif %} {% endfor %} diff --git a/mygpo/administration/views.py b/mygpo/administration/views.py index 2c42d83c5..27e93e257 100644 --- a/mygpo/administration/views.py +++ b/mygpo/administration/views.py @@ -140,11 +140,10 @@ def post(self, request): grouper = PodcastGrouper(podcasts) - get_features = lambda id_e: ((id_e[1].url, id_e[1].title), id_e[0]) + get_features = lambda episode: (episode.url, episode.title) num_groups = grouper.group(get_features) - except InvalidPodcast as ip: messages.error(request, _('No podcast with URL {url}').format(url=str(ip))) @@ -178,10 +177,10 @@ def post(self, request): for key, feature in request.POST.items(): m = self.RE_EPISODE.match(key) if m: - episode_id = m.group(1) + episode_id = uuid.UUID(m.group(1)) features[episode_id] = feature - get_features = lambda id_e: (features.get(id_e[0], id_e[0]), id_e[0]) + get_features = lambda episode: features[episode.id] num_groups = grouper.group(get_features) queue_id = request.POST.get('queue_id', '') diff --git a/mygpo/maintenance/merge.py b/mygpo/maintenance/merge.py index d8d950cef..7c435d95a 100644 --- a/mygpo/maintenance/merge.py +++ b/mygpo/maintenance/merge.py @@ -12,6 +12,7 @@ from mygpo.history.models import HistoryEntry, EpisodeHistoryEntry from mygpo.publisher.models import PublishedPodcast from mygpo.subscriptions.models import Subscription +from . import models import logging logger = logging.getLogger(__name__) @@ -68,7 +69,7 @@ def merge_episodes(self): # based on https://djangosnippets.org/snippets/2283/ @transaction.atomic -def merge_model_objects(primary_object, alias_objects=[], keep_old=False): +def merge_model_objects(primary_object, alias_objects, keep_old=False): """ Use this function to merge model objects (i.e. Users, Organizations, Polls, etc.) and migrate all of the related fields from the alias objects to the @@ -78,10 +79,8 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False): from django.contrib.auth.models import User primary_user = User.objects.get(email='good_email@example.com') duplicate_user = User.objects.get(email='good_email+duplicate@example.com') - merge_model_objects(primary_user, duplicate_user) + merge_model_objects(primary_user, [duplicate_user]) """ - if not isinstance(alias_objects, list): - alias_objects = [alias_objects] # check that all aliases are the same class as primary one and that # they are subclass of model @@ -105,11 +104,6 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False): for field_name, field in fields: generic_fields.append(field) - blank_local_fields = set( - [field.attname for field - in primary_object._meta.local_fields - if getattr(primary_object, field.attname) in [None, '']]) - # Loop through all alias objects and migrate their data to # the primary object. for alias_object in alias_objects: @@ -123,8 +117,9 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False): related_objects = getattr(alias_object, alias_varname) for obj in related_objects.all(): setattr(obj, obj_varname, primary_object) - reassigned(obj, primary_object) - obj.save() + deleted = reassigned(obj, primary_object) + if not deleted: + obj.save() # Migrate all many to many references from alias object to # primary object. @@ -143,8 +138,9 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False): obj_varname).all() for obj in related_many_objects.all(): getattr(obj, obj_varname).remove(alias_object) - reassigned(obj, primary_object) - getattr(obj, obj_varname).add(primary_object) + deleted = reassigned(obj, primary_object) + if not deleted: + getattr(obj, obj_varname).add(primary_object) # Migrate all generic foreign key references from alias # object to primary object. @@ -156,7 +152,10 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False): related = field.model.objects.filter(**filter_kwargs) for generic_related_object in related: setattr(generic_related_object, field.name, primary_object) - reassigned(generic_related_object, primary_object) + deleted = reassigned(generic_related_object, primary_object) + if deleted: + continue + try: # execute save in a savepoint, so we can resume in the # transaction @@ -166,20 +165,10 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False): if ie.__cause__.pgcode == PG_UNIQUE_VIOLATION: merge(generic_related_object, primary_object) - # Try to fill all missing values in primary object by - # values of duplicates - filled_up = set() - for field_name in blank_local_fields: - val = getattr(alias_object, field_name) - if val not in [None, '']: - setattr(primary_object, field_name, val) - filled_up.add(field_name) - blank_local_fields -= filled_up - if not keep_old: before_delete(alias_object, primary_object) alias_object.delete() - primary_object.save() + return primary_object @@ -199,6 +188,17 @@ def _get_all_related_many_to_many_objects(obj): def reassigned(obj, new): + """ handles changes necessary when reassigning `obj` to `new` + + Some objects have a dependent object (eg URL has a Podcast or Episode. + During merging, the object might be assigned from to a new Episode. + The re-assignment requires the "scope" field to be set to the value + of the new episode. In some cases it might require the existing object to + be deleted, to preserve uniqueness. + + Returns whether the object was deleted. + """ + if isinstance(obj, URL): # a URL has its parent's scope obj.scope = new.scope @@ -207,11 +207,27 @@ def reassigned(obj, new): max_order = max([-1] + [u.order for u in existing_urls]) obj.order = max_order+1 + elif isinstance(obj, Slug): + # a Slug has its parent's scope + obj.scope = new.scope + + existing_slugs = new.slugs.all() + max_order = max([-1] + [s.order for s in existing_slugs]) + obj.order = max_order+1 + elif isinstance(obj, Episode): # obj is an Episode, new is a podcast for url in obj.urls.all(): url.scope = new.as_scope - url.save() + try: + with transaction.atomic(): + url.save() + except IntegrityError as ie: + if 'podcasts_url_url_scope_key' in str(ie): + url.delete() + return True + else: + raise elif isinstance(obj, Subscription): pass @@ -222,10 +238,17 @@ def reassigned(obj, new): elif isinstance(obj, HistoryEntry): pass + elif isinstance(obj, models.MergeQueueEntry): + obj.delete() + return True + else: raise TypeError('unknown type for reassigning: {objtype}'.format( objtype=type(obj))) + # Object was not deleted + return False + def before_delete(old, new): diff --git a/mygpo/maintenance/models.py b/mygpo/maintenance/models.py index 877f96ed9..17c28b5a2 100644 --- a/mygpo/maintenance/models.py +++ b/mygpo/maintenance/models.py @@ -7,6 +7,14 @@ class MergeQueue(UUIDModel): """ A Group of podcasts that could be merged """ + @property + def podcasts(self): + """ Returns the podcasts of the queue, sorted by subscribers """ + podcasts = [entry.podcast for entry in self.entries.all()] + podcasts = sorted(podcasts, + key=lambda p: p.subscribers, reverse=True) + return podcasts + class MergeQueueEntry(UUIDModel): """ An entry in a MergeQueue """