From da6c58e058d31e6c884b3fcfc1de55181ea2f347 Mon Sep 17 00:00:00 2001 From: Felix Stupp Date: Sat, 19 Nov 2022 23:23:47 +0100 Subject: [PATCH] generate_preference_list: Add list prefilter --- server/entertainment_decider/models.py | 51 ++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/server/entertainment_decider/models.py b/server/entertainment_decider/models.py index bec9662..925c37b 100644 --- a/server/entertainment_decider/models.py +++ b/server/entertainment_decider/models.py @@ -161,6 +161,7 @@ class TagRootElement: children: List[TagTreeElement] = dataclasses.field(default_factory=lambda: []) def share_score(self, points: float) -> PreferenceScoreAppender: + # influences PreferenceScore.max_score_increase if points == 0 or len(self.children) <= 0: return PreferenceScoreAppender() single_share = points / len(self.children) @@ -174,6 +175,7 @@ class TagTreeElement: children: List[TagTreeElement] = dataclasses.field(default_factory=lambda: []) def share_score(self, points: float) -> PreferenceScoreAppender: + # influences PreferenceScore.max_score_increase children = [elem for elem in self.children if elem.base.use_for_preferences] if len(children) <= 0: return PreferenceScoreAppender(PreferenceScore({self.base: points})) @@ -266,6 +268,7 @@ class Tagable: return used def share_score_flat(self, score: float) -> PreferenceScoreAppender: + # influences PreferenceScore.max_score_increase direct_tags = [tag for tag in self.direct_tags if tag.use_for_preferences] if len(direct_tags) <= 0: return PreferenceScoreAppender() @@ -295,6 +298,11 @@ class PreferenceScore: def __neg__(self) -> PreferenceScore: return self * -1 + @staticmethod + def max_score_increase(score: float, adapt_count: int) -> float: + # depends on Tag(Root|Tree)Element.share_score / Tagable.share_score_flat + return score * adapt_count + def adapt_score( self, tagable: Tagable, @@ -388,6 +396,8 @@ PreferenceScoreCompatible: TypeAlias = Union[ PreferenceScoreCompatibleSimple, Iterable[PreferenceScoreCompatibleSimple] ] +ScoreCalc: TypeAlias = Callable[["MediaElement"], float] + def generate_preference_list( object_gen: Callable[[], List[MediaElement]], @@ -475,6 +485,47 @@ def generate_preference_list( pref_score = preference.calculate_iter_score(all_tags(element)) return static_score + pref_score + # pre filter list + # - elements which have a too low current score may never possible appear + # TODO add test that this does not change end result + elem_count = len(element_list) + if limit is not None and limit < elem_count: + # cache pref score for this + gen_pre_score = cache(gen_score) + # biggest possible score increase by adaption + max_score_inc = preference.max_score_increase( + score=score_adapt, + adapt_count=limit, + ) + logging.debug(f"Max adaption possible: {max_score_inc}") + # differenciate adapted buffing and adapted nerfing + without_max_adapt: ScoreCalc = lambda elem: gen_pre_score(elem) + with_max_adapt: ScoreCalc = lambda elem: without_max_adapt(elem) + max_score_inc + is_nerfing = score_adapt >= 0 + if is_nerfing: + best_case = without_max_adapt + worst_case = with_max_adapt + else: # is buffing + best_case = with_max_adapt + worst_case = without_max_adapt + # (limit)ths best's score in the worst adaption for it + limitths_best_worst = sorted(worst_case(elem) for elem in element_list)[limit] + logging.debug(f"(limit)ths best's worst case score: {limitths_best_worst}") + # extract worst's element's score in best case as well + worsts_best = best_case(max(element_list, key=gen_pre_score)) + logging.debug(f"Worsts best case score is {worsts_best}") + # check if reducing element count is possible + if limitths_best_worst < worsts_best: + # throw away all elements which's best adaption is not better than the (limit)ths one + element_list = { + elem for elem in element_list if best_case(elem) < limitths_best_worst + } + logging.debug( + f"Prefilter reduced set from {elem_count} to {len(element_list)} elements" + ) + else: + logging.debug(f"Prefilter couldn't reduce the element count ({elem_count})") + # gen elements res_ids = list[int]() while 0 < len(element_list):