import dataclasses import logging import random import uuid from abc import abstractmethod from collections import deque from datetime import datetime, timedelta from typing import Callable, Tuple, Any, Deque from overcooked_simulator.game_items import Item, Plate, ItemInfo log = logging.getLogger(__name__) ORDER_CATEGORY = "Order" @dataclasses.dataclass class Order: meal: ItemInfo start_time: datetime max_duration: timedelta score_calc: Callable[[timedelta, ...], float] timed_penalties: list[ Tuple[timedelta, float] | Tuple[timedelta, float, int, timedelta] ] expired_penalty: float uuid: str = dataclasses.field(default_factory=lambda: uuid.uuid4().hex) finished_info: dict[str, Any] = dataclasses.field(default_factory=dict) _timed_penalties: list[Tuple[datetime, float]] = dataclasses.field( default_factory=list ) def order_time(self, env_time: datetime) -> timedelta: return self.start_time - env_time def create_penalties(self, env_time: datetime): for penalty_info in self.timed_penalties: match penalty_info: case (offset, penalty): self._timed_penalties.append((env_time + offset, penalty)) case (duration, penalty, number_repeat, offset): self._timed_penalties.extend( [ (env_time + offset + (duration * i), penalty) for i in range(number_repeat) ] ) class OrderGeneration: def __init__(self, available_meals: dict[str, ItemInfo], **kwargs): self.available_meals: list[ItemInfo] = list(available_meals.values()) @abstractmethod def init_orders(self, now) -> list[Order]: ... @abstractmethod def get_orders( self, passed_time: timedelta, now: datetime, new_finished_orders: list[Order], expired_orders: list[Order], ) -> list[Order]: ... def zero(item: ItemInfo, **kwargs) -> float: return 0.0 @dataclasses.dataclass class RandomOrderKwarg: num_start_meals: int sample_on_serving: bool sample_on_dur: bool sample_on_dur_func: dict max_orders: int duration_sample: dict score_calc_gen_func: Callable[ [ItemInfo, timedelta, datetime, Any], Callable[[timedelta, Order], float] ] score_calc_gen_kwargs: dict expired_penalty_func: Callable[[ItemInfo], float] = zero expired_penalty_kwargs: dict = dataclasses.field(default_factory=dict) class RandomOrderGeneration(OrderGeneration): def __init__(self, available_meals: dict[str, ItemInfo], **kwargs): super().__init__(available_meals, **kwargs) self.kwargs: RandomOrderKwarg = RandomOrderKwarg(**kwargs["kwargs"]) self.next_order_time: datetime | None = datetime.max self.number_cur_orders = 0 self.needed_orders: int = 0 """For the sample on dur but when it was restricted due to max order number.""" def init_orders(self, now) -> list[Order]: self.number_cur_orders = self.kwargs.num_start_meals if self.kwargs.sample_on_dur: self.create_random_next_time_delta(now) return self.create_orders_for_meals( random.choices(self.available_meals, k=self.kwargs.num_start_meals), now, self.kwargs.sample_on_serving, ) def get_orders( self, passed_time: timedelta, now: datetime, new_finished_orders: list[Order], expired_orders: list[Order], ) -> list[Order]: self.number_cur_orders -= len(new_finished_orders) self.number_cur_orders -= len(expired_orders) if self.kwargs.sample_on_serving: if new_finished_orders: self.create_random_next_time_delta(now) return [] if self.needed_orders: self.needed_orders -= len(new_finished_orders) self.needed_orders = max(self.needed_orders, 0) self.number_cur_orders += len(new_finished_orders) return self.create_orders_for_meals( random.choices(self.available_meals, k=len(new_finished_orders)), now, ) if self.next_order_time <= now: if self.number_cur_orders >= self.kwargs.max_orders: self.needed_orders += 1 else: if self.kwargs.sample_on_dur: self.create_random_next_time_delta(now) else: self.next_order_time = datetime.max self.number_cur_orders += 1 return self.create_orders_for_meals( [random.choice(self.available_meals)], now, ) return [] def create_orders_for_meals( self, meals: list[ItemInfo], now: datetime, no_time_limit: bool = False ) -> list[Order]: orders = [] for meal in meals: if no_time_limit: duration = datetime.max - now else: duration = timedelta( seconds=getattr(random, self.kwargs.duration_sample["func"])( **self.kwargs.duration_sample["kwargs"] ) ) log.info(f"Create order for meal {meal} with duration {duration}") orders.append( Order( meal=meal, start_time=now, max_duration=duration, score_calc=self.kwargs.score_calc_gen_func( meal=meal, duration=duration, now=now, kwargs=self.kwargs.score_calc_gen_kwargs, ), timed_penalties=[], expired_penalty=self.kwargs.expired_penalty_func( meal, **self.kwargs.expired_penalty_kwargs ), ) ) return orders def create_random_next_time_delta(self, now: datetime): self.next_order_time = now + timedelta( seconds=getattr(random, self.kwargs.sample_on_dur_func["func"])( **self.kwargs.sample_on_dur_func["kwargs"] ) ) log.info(f"Next order in {self.next_order_time}") def simple_score_calc_gen_func( meal: Item, duration: timedelta, now: datetime, kwargs: dict, **other_kwargs ) -> Callable: scores = kwargs["scores"] other = kwargs["other"] def score_calc(relative_order_time: timedelta, order: Order) -> float: if order.meal.name in scores: return scores[order.meal.name] return other return score_calc def simple_expired_penalty(item: ItemInfo, default: float, **kwargs) -> float: return default class OrderAndScoreManager: def __init__(self, order_config, available_meals: dict[str, ItemInfo]): self.score = 0 self.order_gen: OrderGeneration = order_config["order_gen_class"]( available_meals=available_meals, kwargs=order_config["kwargs"] ) self.kwargs_for_func = order_config["kwargs"] self.serving_not_ordered_meals = order_config["serving_not_ordered_meals"] self.available_meals = available_meals self.open_orders: Deque[Order] = deque() # for logs or history in the future # TODO log who / which player served which meal -> for split scores self.served_meals: list[Tuple[Item, datetime]] = [] self.last_finished = [] self.next_relevant_time = datetime.max self.last_expired = [] def update_next_relevant_time(self): next_relevant_time = datetime.max for order in self.open_orders: next_relevant_time = min( next_relevant_time, order.start_time + order.max_duration ) for penalty in order._timed_penalties: next_relevant_time = min(next_relevant_time, penalty[0]) self.next_relevant_time = next_relevant_time def serve_meal(self, item: Item, env_time: datetime) -> bool: if isinstance(item, Plate): meal = item.get_potential_meal() if meal is not None: if meal.name in self.available_meals: order = self.find_order_for_meal(meal) if order is None: if self.serving_not_ordered_meals: accept, score = self.serving_not_ordered_meals(meal) if accept: log.info( f"Serving meal without order {meal.name} with score {score}" ) self.score += score self.served_meals.append((meal, env_time)) return accept log.info( f"Do not serve meal {meal.name} because it is not ordered" ) return False order, index = order score = order.score_calc( relative_order_time=env_time - order.start_time, order=order, ) self.score += score order.finished_info = { "end_time": env_time, "score": score, } log.info(f"Serving meal {meal.name} with order with score {score}") self.last_finished.append(order) del self.open_orders[index] self.served_meals.append((meal, env_time)) return True log.info(f"Do not serve item {item}") return False def increment_score(self, score: int): self.score += score log.debug(f"Score: {self.score}") def create_init_orders(self, env_time): init_orders = self.order_gen.init_orders(env_time) self.open_orders.extend(init_orders) def progress(self, passed_time: timedelta, now: datetime): new_orders = self.order_gen.get_orders( passed_time=passed_time, now=now, new_finished_orders=self.last_finished, expired_orders=self.last_expired, ) self.open_orders.extend(new_orders) self.last_finished = [] self.last_expired = [] if new_orders or self.next_relevant_time <= now: remove_orders = [] for index, order in enumerate(self.open_orders): if now >= order.start_time + order.max_duration: self.score += order.expired_penalty remove_orders.append(index) remove_penalties = [] for i, (penalty_time, penalty) in enumerate(order.timed_penalties): if penalty_time < now: self.score -= penalty remove_penalties.append(i) for i in reversed(remove_penalties): # or del order.timed_penalties[index] order.timed_penalties.pop(i) expired_orders = [] for remove_order in reversed(remove_orders): expired_orders.append(self.open_orders[remove_order]) del self.open_orders[remove_order] self.last_expired = expired_orders self.update_next_relevant_time() def find_order_for_meal(self, meal) -> Tuple[Order, int] | None: for index, order in enumerate(self.open_orders): if order.meal.name == meal.name: return order, index def setup_penalties(self, new_orders: list[Order], env_time: datetime): for order in new_orders: order.create_penalties(env_time) def order_state(self) -> list[dict]: return [ { "id": order.uuid, "category": ORDER_CATEGORY, "meal": order.meal.name, "start_time": order.start_time.isoformat(), "max_duration": order.max_duration.total_seconds(), } for order in self.open_orders ] if __name__ == "__main__": import yaml order_config = yaml.safe_load( """orders: kwargs: duration_sample: func: uniform kwargs: a: 30 b: 50 max_orders: 5 num_start_meals: 3 sample_on_dur: false sample_on_dur_func: func: uniform kwargs: a: 30 b: 50 sample_on_serving: true score_calc_gen_func: null score_calc_gen_kwargs: other: 0 scores: Burger: 15 OnionSoup: 10 Salad: 5 TomatoSoup: 10 score_calc_gen_func: ~'' order_gen_class: ~ serving_not_ordered_meals: null""" ) order_config["orders"]["order_gen_class"] = RandomOrderGeneration order_config["orders"]["kwargs"]["score_calc_gen_func"] = simple_score_calc_gen_func print(yaml.dump(order_config))