diff --git a/server/entertainment_decider/common.py b/server/entertainment_decider/common.py index 327cad7..1965d33 100644 --- a/server/entertainment_decider/common.py +++ b/server/entertainment_decider/common.py @@ -5,12 +5,15 @@ import sys from typing import ( IO, Iterable, + Iterator, List, Literal, Optional, Sequence, + Tuple, TypeVar, Union, + overload, ) @@ -78,5 +81,75 @@ def limit_iter(iter: Iterable[T], limit: int) -> List[T]: return list(itertools.islice(iter, limit)) +class _IterFixer(Iterator[T]): + __it: Iterator[T] + + def __init__(self, it: Iterator[T]) -> None: + super().__init__() + self.__it = it + + def __iter__(self) -> Iterator[T]: + return self + + def __next__(self) -> T: + return next(self.__it) + + +def fix_iter(iterable: Iterable[T]) -> Iterable[T]: + return _IterFixer(iter(iterable)) + + +@overload +def iter_lookahead( + iterable: Iterable[T], + get_first: Literal[False] = False, + get_last: Literal[False] = False, +) -> Iterable[Tuple[T, T]]: + ... + + +@overload +def iter_lookahead( + iterable: Iterable[T], + get_first: Literal[True], + get_last: Literal[False] = False, +) -> Iterable[Tuple[None, T] | Tuple[T, T]]: + ... + + +@overload +def iter_lookahead( + iterable: Iterable[T], + get_first: Literal[False] = False, + get_last: Literal[True] = True, # <- default only to satisfy python +) -> Iterable[Tuple[T, T] | Tuple[T, None]]: + ... + + +@overload +def iter_lookahead( + iterable: Iterable[T], + get_first: Literal[True], + get_last: Literal[True], +) -> Iterable[Tuple[None, T] | Tuple[T, T] | Tuple[T, None]]: + ... + + +def iter_lookahead( + iterable: Iterable[T], + get_first: bool = False, + get_last: bool = False, +) -> Iterable[Tuple[None, T] | Tuple[T, T] | Tuple[T, None]]: + it = iter(iterable) + last = next(it) + if get_first: + yield None, last + for cur in it: + yield last, cur + last = cur + if get_last: + yield last, None + + def date_to_datetime(d: date) -> datetime: return datetime(d.year, d.month, d.day)