diff --git a/server/entertainment_decider/models.py b/server/entertainment_decider/models.py index 51058c7..765c85c 100644 --- a/server/entertainment_decider/models.py +++ b/server/entertainment_decider/models.py @@ -24,6 +24,7 @@ from typing import ( ) import magic +import numpy import requests from pony import orm from pony.orm.core import Query as PonyQuery @@ -276,6 +277,60 @@ PreferenceScoreCompatible = Union[ ] +class ConsideredMediaGenerator: + media_list: Dict[int, MediaElement] + dependencies: numpy.ndarray + + def __init__(self, media_list: Iterable[MediaElement]) -> None: + self.media_list = {elem.id: elem for elem in media_list} + max_id = orm.max(elem.id for elem in MediaElement) + 1 + self.dependencies = numpy.zeros((max_id, max_id), dtype=numpy.bool_) + for left_id, right_id in self.__get_data(): + self.dependencies[left_id, right_id] = True + + def __get_data(self) -> Iterable[Tuple[int, int]]: + return db.select( + """ + SELECT left_elem.id AS left_id, right_elem.id AS right_id + FROM mediaelement left_elem + INNER JOIN mediacollectionlink left_link + on left_elem.id = left_link.element + INNER JOIN mediacollection coll + on left_link.collection = coll.id and coll.watch_in_order + INNER JOIN mediacollectionlink right_link + on left_link.collection = right_link.collection + INNER JOIN mediaelement right_elem + on right_link.element = right_elem.id + WHERE NOT (left_elem.watched OR left_elem.ignored) AND NOT (right_elem.watched OR right_elem.ignored) + AND (left_link.season, left_link.episode, left_elem.release_date, left_elem.id) > + (right_link.season, right_link.episode, right_elem.release_date, right_elem.id) + """.strip() + ) + + def __len__(self) -> int: + return len(self.media_list) + + def iter_media(self) -> Iterable[MediaElement]: + return [ + elem + for elem_id, elem in self.media_list.items() + if not self.dependencies[elem_id].any() + ] + + def is_considered_id(self, elem_id: int) -> bool: + return elem_id in self.media_list and not self.dependencies[elem_id].any() + + def is_considered(self, element: MediaElement) -> bool: + return self.is_considered_id(element.id) + + def mark_as_watched_id(self, elem_id: id) -> None: + del self.media_list[elem_id] + self.dependencies[:, elem_id] = False + + def mark_as_watched(self, element: MediaElement) -> None: + return self.mark_as_watched_id(element.id) + + def generate_preference_list( base: PreferenceScore, object_gen: Callable[[], List[MediaElement]],