diff --git a/overcooked_simulator/counter_factory.py b/overcooked_simulator/counter_factory.py index 9b1487c3706a70dde8b224283a1877c126c41b56..22f56cf9835e7f08beb8b31d798f719359550742 100644 --- a/overcooked_simulator/counter_factory.py +++ b/overcooked_simulator/counter_factory.py @@ -60,7 +60,7 @@ from overcooked_simulator.game_items import ( Item, ) from overcooked_simulator.hooks import Hooks -from overcooked_simulator.order import OrderAndScoreManager +from overcooked_simulator.order import OrderManager from overcooked_simulator.utils import get_closest T = TypeVar("T") @@ -119,7 +119,7 @@ class CounterFactory: item_info: dict[str, ItemInfo], serving_window_additional_kwargs: dict[str, Any], plate_config: PlateConfig, - order_and_score: OrderAndScoreManager, + order_manager: OrderManager, effect_manager_config: dict, hook: Hooks, random: Random, @@ -150,7 +150,7 @@ class CounterFactory: """The additional keyword arguments for the serving window.""" self.plate_config: PlateConfig = plate_config """The plate config from the `environment_config`""" - self.order_and_score: OrderAndScoreManager = order_and_score + self.order_manager: OrderManager = order_manager """The order and score manager to pass to `ServingWindow` and the `Tashcan` which can affect the scores.""" self.effect_manager_config = effect_manager_config """The effect manager config to setup the effect manager based on the defined effects in the item info.""" @@ -246,8 +246,8 @@ class CounterFactory: kwargs.update(self.serving_window_additional_kwargs) if issubclass(counter_class, (ServingWindow, Trashcan)): kwargs[ - "order_and_score" - ] = self.order_and_score # individual because for the later trash scorer + "order_manager" + ] = self.order_manager # individual because for the later trash scorer return counter_class(**kwargs) def can_map(self, char) -> bool: diff --git a/overcooked_simulator/counters.py b/overcooked_simulator/counters.py index a7f126d7c9c489c71363bf9451964118e23746b2..8ce066ab7e944d131dcc87255fc77b94a5680878 100644 --- a/overcooked_simulator/counters.py +++ b/overcooked_simulator/counters.py @@ -66,7 +66,7 @@ from overcooked_simulator.hooks import ( if TYPE_CHECKING: from overcooked_simulator.effect_manager import Effect from overcooked_simulator.overcooked_environment import ( - OrderAndScoreManager, + OrderManager, ) import numpy as np @@ -339,13 +339,13 @@ class ServingWindow(Counter): def __init__( self, - order_and_score: OrderAndScoreManager, + order_manager: OrderManager, meals: set[str], env_time_func: Callable[[], datetime], plate_dispenser: PlateDispenser = None, **kwargs, ): - self.order_and_score: OrderAndScoreManager = order_and_score + self.order_manager: OrderManager = order_manager """Reference to the OrderAndScoreManager class. It determines which meals can be served and it manages the score.""" self.plate_dispenser: PlateDispenser = plate_dispenser @@ -360,7 +360,7 @@ class ServingWindow(Counter): def drop_off(self, item) -> Item | None: env_time = self.env_time_func() self.hook(PRE_SERVING, counter=self, item=item, env_time=env_time) - if self.order_and_score.serve_meal(item=item, env_time=env_time): + if self.order_manager.serve_meal(item=item, env_time=env_time): if self.plate_dispenser is not None: self.plate_dispenser.update_plate_out_of_kitchen(env_time=env_time) self.hook(POST_SERVING, counter=self, item=item, env_time=env_time) @@ -464,6 +464,7 @@ class PlateConfig: return_dirty: bool = True """Specifies if plates are returned dirty or clean to the plate dispenser.""" + class PlateDispenser(Counter): """At the moment, one and only one plate dispenser must exist in an environment, because only at one place the dirty plates should arrive. @@ -523,7 +524,9 @@ class PlateDispenser(Counter): def add_dirty_plate(self): """Add a dirty plate after a timer is completed.""" - self.occupied_by.appendleft(self.create_item(clean=not self.plate_config.return_dirty)) + self.occupied_by.appendleft( + self.create_item(clean=not self.plate_config.return_dirty) + ) def update_plate_out_of_kitchen(self, env_time: datetime): """Is called from the serving window to add a plate out of kitchen.""" @@ -597,10 +600,8 @@ class Trashcan(Counter): The character `X` in the `layout` file represents the Trashcan. """ - def __init__(self, order_and_score: OrderAndScoreManager, **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) - self.order_and_score: OrderAndScoreManager = order_and_score - """Reference to the `OrderAndScoreManager`, because unnecessary removed items can affect the score.""" def pick_up(self, on_hands: bool = True) -> Item | None: pass @@ -613,14 +614,9 @@ class Trashcan(Counter): ): return item if isinstance(item, CookingEquipment): - penalty = self.order_and_score.apply_penalty_for_using_trash( - item.content_list - ) item.reset_content() return item - else: - penalty = self.order_and_score.apply_penalty_for_using_trash(item) - self.hook(TRASHCAN_USAGE, counter=self, item=item, penalty=penalty) + self.hook(TRASHCAN_USAGE, counter=self, item=item) return None def can_drop_off(self, item: Item) -> bool: diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml index 0b24c354f1e07beda8d5cd286f3cd7c5c8f40843..6d290b823d5cf65c5ca42d7b40b7091b523208b2 100644 --- a/overcooked_simulator/game_content/environment_config.yaml +++ b/overcooked_simulator/game_content/environment_config.yaml @@ -1,8 +1,7 @@ plates: - clean_plates: 0 + clean_plates: 1 dirty_plates: 2 plate_delay: [ 5, 10 ] - return_dirty: True # range of seconds until the dirty plate arrives. game: @@ -67,28 +66,14 @@ orders: b: 20 sample_on_serving: false # Sample the delay for the next order only after a meal was served. - score_calc_gen_func: !!python/name:overcooked_simulator.order.simple_score_calc_gen_func '' - score_calc_gen_kwargs: - # the kwargs for the score_calc_gen_func - other: 20 - scores: - Burger: 15 - OnionSoup: 10 - Salad: 5 - TomatoSoup: 10 - expired_penalty_func: !!python/name:overcooked_simulator.order.simple_expired_penalty '' - expired_penalty_kwargs: - default: -5 - serving_not_ordered_meals: !!python/name:overcooked_simulator.order.serving_not_ordered_meals_with_zero_score '' - # a func that calcs a store for not ordered but served meals. Input: meal - penalty_for_trash: !!python/name:overcooked_simulator.order.penalty_for_each_item '' - # a func that calcs the penalty for items that the player puts into the trashcan. + serving_not_ordered_meals: true + # can meals that are not ordered be served / dropped on the serving window player_config: radius: 0.4 player_speed_units_per_seconds: 6 interaction_range: 1.6 - restricted_view: False + restricted_view: false view_angle: 95 effect_manager: @@ -100,34 +85,71 @@ effect_manager: extra_setup_functions: + # # --------------- Scoring --------------- + orders: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ completed_order ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: 20 + score_on_specific_kwarg: meal_name + score_map: + Burger: 15 + OnionSoup: 10 + Salad: 5 + TomatoSoup: 10 + not_ordered_meals: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ serve_not_ordered_meal ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: 2 + trashcan_usages: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ trashcan_usage ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: -5 + expired_orders: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ order_expired ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: -10 + # # --------------- Recording --------------- # json_states: - # func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' + # func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' # kwargs: # hooks: [ json_state ] - # log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' - # log_class_kwargs: + # callback_class: !!python/name:overcooked_simulator.recording.FileRecorder '' + # callback_class_kwargs: # log_path: USER_LOG_DIR/ENV_NAME/json_states.jsonl actions: - func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' kwargs: hooks: [ pre_perform_action ] - log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' - log_class_kwargs: + callback_class: !!python/name:overcooked_simulator.recording.FileRecorder '' + callback_class_kwargs: log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl random_env_events: - func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' kwargs: hooks: [ order_duration_sample, plate_out_of_kitchen_time ] - log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' - log_class_kwargs: + callback_class: !!python/name:overcooked_simulator.recording.FileRecorder '' + callback_class_kwargs: log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl add_hook_ref: true env_configs: - func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' kwargs: hooks: [ env_initialized, item_info_config ] - log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' - log_class_kwargs: + callback_class: !!python/name:overcooked_simulator.recording.FileRecorder '' + callback_class_kwargs: log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl add_hook_ref: true + diff --git a/overcooked_simulator/game_content/environment_config_rl.yaml b/overcooked_simulator/game_content/environment_config_rl.yaml index 6235a971e1fb53a0569e82fa63602b9a2e8427c9..7702f835dd2a51c2eb75b7a370b77e087db17f22 100644 --- a/overcooked_simulator/game_content/environment_config_rl.yaml +++ b/overcooked_simulator/game_content/environment_config_rl.yaml @@ -71,7 +71,7 @@ orders: score_calc_gen_kwargs: # the kwargs for the score_calc_gen_func other: 0 - scores: [] + scores: [ ] expired_penalty_func: !!python/name:overcooked_simulator.order.simple_expired_penalty '' expired_penalty_kwargs: default: 0 @@ -87,7 +87,7 @@ player_config: restricted_view: False view_angle: 95 -effect_manager: {} +effect_manager: { } # FireManager: # class: !!python/name:overcooked_simulator.effect_manager.FireEffectManager '' # kwargs: @@ -96,6 +96,41 @@ effect_manager: {} extra_setup_functions: + # # --------------- Scoring --------------- + orders: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ completed_order ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: 20 + score_on_specific_kwarg: meal_name + score_map: + Burger: 15 + OnionSoup: 10 + Salad: 5 + TomatoSoup: 10 + not_ordered_meals: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ serve_not_ordered_meal ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: 2 + trashcan_usages: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ trashcan_usage ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: -5 + expired_orders: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ order_expired ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: -10 # json_states: # func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' # kwargs: diff --git a/overcooked_simulator/hooks.py b/overcooked_simulator/hooks.py index 285ed735fb58190183dce24f051d2c4f1b8bcd24..9d907fcd93f2c4ba6d08cbdbd034c118496ac6c1 100644 --- a/overcooked_simulator/hooks.py +++ b/overcooked_simulator/hooks.py @@ -1,6 +1,12 @@ +from __future__ import annotations + +from abc import abstractmethod from collections import defaultdict from functools import partial -from typing import Callable +from typing import Callable, Any, TYPE_CHECKING, Type + +if TYPE_CHECKING: + from overcooked_simulator.overcooked_environment import Environment # TODO add player_id as kwarg to all hooks -> pass player id to all methods @@ -90,6 +96,8 @@ INIT_ORDERS = "init_orders" NEW_ORDERS = "new_orders" +ORDER_EXPIRED = "order_expired" + ACTION_ON_NOT_REACHABLE_COUNTER = "action_on_not_reachable_counter" ACTION_PUT = "action_put" @@ -118,6 +126,28 @@ def print_hook_callback(text, env, **kwargs): print(env.env_time, text) +class HookCallbackClass: + def __init__(self, name: str, env: Environment, **kwargs): + self.name = name + self.env = env + + @abstractmethod + def __call__(self, hook_ref: str, env: Environment, **kwargs): + ... + + +def hooks_via_callback_class( + name: str, + env: Environment, + hooks: list[str], + callback_class: Type[HookCallbackClass], + callback_class_kwargs: dict[str, Any], +): + recorder = callback_class(name=name, env=env, **callback_class_kwargs) + for hook in hooks: + env.register_callback_for_hook(hook, recorder) + + def add_dummy_callbacks(env): env.register_callback_for_hook( SERVE_NOT_ORDERED_MEAL, diff --git a/overcooked_simulator/order.py b/overcooked_simulator/order.py index 990be494c37bc59747e286c4cd31dca3d9e7c866..7f60abb5d9b1719aeaebebdbbfc45dc7fb7aca41 100644 --- a/overcooked_simulator/order.py +++ b/overcooked_simulator/order.py @@ -25,19 +25,10 @@ This file defines the following classes: Further, it defines same implementations for the basic order generation based on random sampling: - `RandomOrderGeneration` -- `simple_score_calc_gen_func` -- `simple_expired_penalty` -- `zero` For an easier usage of the random orders, also some classes for type hints and dataclasses are defined: - `RandomOrderKwarg` - `RandomFuncConfig` -- `ScoreCalcFuncType` -- `ScoreCalcGenFuncType` -- `ExpiredPenaltyFuncType` - -For the scoring of using the trashcan the `penalty_for_each_item` example function is defined. You can set/replace it -in the `environment_config`. ## Code Documentation @@ -51,7 +42,7 @@ from abc import abstractmethod from collections import deque from datetime import datetime, timedelta from random import Random -from typing import Callable, Tuple, Any, Deque, Protocol, TypedDict, Type +from typing import Callable, Tuple, Any, Deque, TypedDict, Type from overcooked_simulator.game_items import Item, Plate, ItemInfo from overcooked_simulator.hooks import ( @@ -62,6 +53,7 @@ from overcooked_simulator.hooks import ( INIT_ORDERS, NEW_ORDERS, ORDER_DURATION_SAMPLE, + ORDER_EXPIRED, ) log = logging.getLogger(__name__) @@ -92,37 +84,10 @@ class Order: """The start time relative to the env_time. On which the order is returned from the get_orders func.""" max_duration: timedelta """The duration after which the order expires.""" - score_calc: ScoreCalcFuncType - """The function that calculates the score of the served meal/fulfilled order.""" - timed_penalties: list[ - Tuple[timedelta, float] | Tuple[timedelta, float, int, timedelta] - ] - """List of timed penalties when the order is not fulfilled.""" - expired_penalty: float - """The penalty to the score if the order expires""" uuid: str = dataclasses.field(default_factory=lambda: uuid.uuid4().hex) """The unique identifier for the order.""" finished_info: dict[str, Any] = dataclasses.field(default_factory=dict) """Is set after the order is completed.""" - _timed_penalties: list[Tuple[datetime, float]] = dataclasses.field( - default_factory=list - ) - """Converted penalties the env is working with from the `timed_penalties`""" - - def create_penalties(self, env_time: datetime): - """Create the general timed penalties list to check for penalties after some time the order is still not - fulfilled.""" - for penalty_info in self.timed_penalties: - match penalty_info: - case (offset, penalty): - self._timed_penalties.append((env_time + offset, penalty)) - case (duration, penalty, number_repeat, offset): - self._timed_penalties.extend( - [ - (env_time + offset + (duration * i), penalty) - for i in range(number_repeat) - ] - ) class OrderGeneration: @@ -138,7 +103,13 @@ class OrderGeneration: ``` """ - def __init__(self, available_meals: dict[str, ItemInfo], hook: Hooks, random: Random, **kwargs): + def __init__( + self, + available_meals: dict[str, ItemInfo], + hook: Hooks, + random: Random, + **kwargs, + ): self.available_meals: list[ItemInfo] = list(available_meals.values()) """Available meals restricted through the `environment_config.yml`.""" self.hook = hook @@ -163,7 +134,7 @@ class OrderGeneration: ... -class OrderAndScoreManager: +class OrderManager: """The Order and Score Manager that is called from the serving window.""" def __init__( @@ -175,8 +146,6 @@ class OrderAndScoreManager: ): self.random = random """Random instance.""" - self.score: float = 0.0 - """The current score of the environment.""" self.order_gen: OrderGeneration = order_config["order_gen_class"]( available_meals=available_meals, hook=hook, @@ -189,13 +158,6 @@ class OrderAndScoreManager: ] = order_config["serving_not_ordered_meals"] """Function that decides if not ordered meals can be served and what score it gives""" - self.penalty_for_trash: Callable[[Item | list[Item]], float] = ( - order_config["penalty_for_trash"] - if "penalty_for_trash" in order_config - else zero - ) - """Function that calculates the penalty for items which were put into the trashcan.""" - self.available_meals = available_meals """The meals for that orders can be sampled from.""" self.open_orders: Deque[Order] = deque() @@ -207,7 +169,7 @@ class OrderAndScoreManager: self.last_finished: list[Order] = [] """Cache last finished orders for `OrderGeneration.get_orders` call. From the served meals.""" self.next_relevant_time: datetime = datetime.max - """For reduced order checking. Store the next time when to create an order or check for penalties.""" + """For reduced order checking. Store the next time when to create an order.""" self.last_expired: list[Order] = [] """Cache last expired orders for `OrderGeneration.get_orders` call.""" @@ -221,8 +183,6 @@ class OrderAndScoreManager: next_relevant_time = min( next_relevant_time, order.start_time + order.max_duration ) - for penalty in order._timed_penalties: - next_relevant_time = min(next_relevant_time, penalty[0]) self.next_relevant_time = next_relevant_time def serve_meal(self, item: Item, env_time: datetime) -> bool: @@ -235,57 +195,40 @@ class OrderAndScoreManager: order = self.find_order_for_meal(meal) if order is None: if self.serving_not_ordered_meals: - accept, score = self.serving_not_ordered_meals(meal) self.hook( SERVE_NOT_ORDERED_MEAL, - accept=accept, - score=score, meal=meal, + meal_name=meal.name, ) - if accept: - log.info( - f"Serving meal without order {meal.name!r} with score {score}" - ) - self.increment_score(score) - self.served_meals.append((meal, env_time)) - return accept + log.info(f"Serving meal without order {meal.name!r}") + self.served_meals.append((meal, env_time)) + return True log.info( f"Do not serve meal {meal.name!r} because it is not ordered" ) return False order, index = order - score = order.score_calc( - relative_order_time=env_time - order.start_time, - order=order, - ) - self.increment_score(score) - order.finished_info = { - "end_time": env_time, - "score": score, - } - log.info( - f"Serving meal {meal.name!r} with order with score {score}" - ) + log.info(f"Serving meal {meal.name!r} with order") self.last_finished.append(order) del self.open_orders[index] self.served_meals.append((meal, env_time)) - self.hook(COMPLETED_ORDER, score=score, order=order, meal=meal) + self.hook( + COMPLETED_ORDER, + order=order, + meal=meal, + relative_order_time=env_time - order.start_time, + meal_name=meal.name, + ) return True else: self.hook(SERVE_WITHOUT_PLATE, item=item) log.info(f"Do not serve item {item}") return False - def increment_score(self, score: int | float): - """Add a value to the current score and log it.""" - self.score += score - log.debug(f"Score: {self.score}") - def create_init_orders(self, env_time): """Create the initial orders in an environment.""" init_orders = self.order_gen.init_orders(env_time) self.hook(INIT_ORDERS) - self.setup_penalties(new_orders=init_orders, env_time=env_time) self.open_orders.extend(init_orders) def progress(self, passed_time: timedelta, now: datetime): @@ -298,7 +241,6 @@ class OrderAndScoreManager: ) if new_orders: self.hook(NEW_ORDERS, new_orders=new_orders) - self.setup_penalties(new_orders=new_orders, env_time=now) self.open_orders.extend(new_orders) self.last_finished = [] self.last_expired = [] @@ -309,20 +251,9 @@ class OrderAndScoreManager: for index, order in enumerate(self.open_orders): if now >= order.start_time + order.max_duration: # orders expired - self.increment_score(order.expired_penalty) + self.hook(ORDER_EXPIRED, order=order) remove_orders.append(index) - continue # no penalties for expired orders - remove_penalties = [] - for i, (penalty_time, penalty) in enumerate(order.timed_penalties): - # check penalties - if penalty_time < now: - # TODO add hook - self.score -= penalty - remove_penalties.append(i) - - for i in reversed(remove_penalties): - # or del order.timed_penalties[index] - order.timed_penalties.pop(i) + continue expired_orders: list[Order] = [] for remove_order in reversed(remove_orders): @@ -339,12 +270,6 @@ class OrderAndScoreManager: if order.meal.name == meal.name: return order, index - @staticmethod - def setup_penalties(new_orders: list[Order], env_time: datetime): - """Call the `Order.create_penalties` method for new orders.""" - for order in new_orders: - order.create_penalties(env_time) - def order_state(self) -> list[dict]: """Similar to the `to_dict` in `Item` and `Counter`. Relevant for the state of the environment""" return [ @@ -358,80 +283,6 @@ class OrderAndScoreManager: for order in self.open_orders ] - def apply_penalty_for_using_trash(self, remove: Item | list[Item]) -> float: - """Is called if a item is put into the trashcan.""" - penalty = self.penalty_for_trash(remove) - self.increment_score(penalty) - return penalty - - -class ScoreCalcFuncType(Protocol): - """Typed kwargs of the expected `Order.score_calc` function. Which is also returned by the - `RandomOrderKwarg.score_calc_gen_func`. - - The function should calculate the score for the completed orders. - - Args: - relative_order_time: `timedelta` the duration how long the order was active. - order: `Order` the order that was completed. - - Returns: - `float`: the score for a completed order and duration of the order. - """ - - def __call__(self, relative_order_time: timedelta, order: Order) -> float: - ... - - -class ScoreCalcGenFuncType(Protocol): - """Typed kwargs of the expected function for the `RandomOrderKwarg.score_calc_gen_func`. - - Generate the ScoreCalcFunc for an order based on its meal, duration etc. - - Args: - meal: `ItemInfo` the type of meal the order orders. - duration: `timedelta` the duration after the order expires. - now: `datetime` the environment time the order is created. - kwargs: `dict` the static kwargs defined in the `environment_config.yml` - - Returns: - `ScoreCalcFuncType` a reference to a function that calculates the score for a completed meal. - """ - - def __call__( - self, - meal: ItemInfo, - duration: timedelta, - now: datetime, - kwargs: dict, - **other_kwargs, - ) -> ScoreCalcFuncType: - ... - - -class ExpiredPenaltyFuncType(Protocol): - """Typed kwargs of the expected function for the `RandomOrderKwarg.expired_penalty_func`. - - An example is the `zero` function. - - Args: - item: `ItemInfo` the meal of the order that expired. It is calculated before the order is active. - """ - - def __call__(self, item: ItemInfo, **kwargs) -> float: - ... - - -def zero(item: ItemInfo, **kwargs) -> float: - """Example and default for the `RandomOrderKwarg.expired_penalty_func` function. - - Just no penalty for expired orders. - - Returns: - zero / 0.0 - """ - return 0.0 - class RandomFuncConfig(TypedDict): """Types of the dict for sampling with different random functions from the [`random` library](https://docs.python.org/3/library/random.html). @@ -470,14 +321,6 @@ class RandomOrderKwarg: """How many orders can maximally be active at the same time.""" order_duration_random_func: RandomFuncConfig """How long the order is alive until it expires. If `sample_on_serving` is `true` all orders have no expire time.""" - score_calc_gen_func: ScoreCalcGenFuncType - """The function that generates the `Order.score_calc` for each order.""" - score_calc_gen_kwargs: dict - """The additional static kwargs for `score_calc_gen_func`.""" - expired_penalty_func: Callable[[ItemInfo], float] = zero - """The function that calculates the penalty for a meal that was not served.""" - expired_penalty_kwargs: dict = dataclasses.field(default_factory=dict) - """The additional static kwargs for the `expired_penalty_func`.""" class RandomOrderGeneration(OrderGeneration): @@ -492,42 +335,39 @@ class RandomOrderGeneration(OrderGeneration): ```yaml orders: order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration '' - kwargs: - order_duration_random_func: - # how long should the orders be alive - # 'random' library call with getattr, kwargs are passed to the function - func: uniform - kwargs: - a: 40 - b: 60 - max_orders: 6 - # maximum number of active orders at the same time - num_start_meals: 3 - # number of orders generated at the start of the environment - sample_on_dur_random_func: - # 'random' library call with getattr, kwargs are passed to the function - func: uniform - kwargs: - a: 10 - b: 20 - sample_on_serving: false - # Sample the delay for the next order only after a meal was served. - score_calc_gen_func: !!python/name:overcooked_simulator.order.simple_score_calc_gen_func '' - score_calc_gen_kwargs: - # the kwargs for the score_calc_gen_func - other: 0 - scores: - Burger: 15 - OnionSoup: 10 - Salad: 5 - TomatoSoup: 10 - expired_penalty_func: !!python/name:overcooked_simulator.order.simple_expired_penalty '' - expired_penalty_kwargs: - default: -5 + # the class to that receives the kwargs. Should be a child class of OrderGeneration in order.py + order_gen_kwargs: + order_duration_random_func: + # how long should the orders be alive + # 'random' library call with getattr, kwargs are passed to the function + func: uniform + kwargs: + a: 40 + b: 60 + max_orders: 6 + # maximum number of active orders at the same time + num_start_meals: 2 + # number of orders generated at the start of the environment + sample_on_dur_random_func: + # 'random' library call with getattr, kwargs are passed to the function + func: uniform + kwargs: + a: 10 + b: 20 + sample_on_serving: false + # Sample the delay for the next order only after a meal was served. + serving_not_ordered_meals: true + # can meals that are not ordered be served / dropped on the serving window ``` """ - def __init__(self, available_meals: dict[str, ItemInfo], hook: Hooks, random: Random, **kwargs): + def __init__( + self, + available_meals: dict[str, ItemInfo], + hook: Hooks, + random: Random, + **kwargs, + ): super().__init__(available_meals, hook, random, **kwargs) self.kwargs: RandomOrderKwarg = RandomOrderKwarg(**kwargs["kwargs"]) self.next_order_time: datetime | None = datetime.max @@ -604,16 +444,6 @@ class RandomOrderGeneration(OrderGeneration): meal=meal, start_time=now, max_duration=duration, - score_calc=self.kwargs.score_calc_gen_func( - meal=meal, - duration=duration, - now=now, - kwargs=self.kwargs.score_calc_gen_kwargs, - ), - timed_penalties=[], - expired_penalty=self.kwargs.expired_penalty_func( - meal, **self.kwargs.expired_penalty_kwargs - ), ) ) @@ -626,61 +456,3 @@ class RandomOrderGeneration(OrderGeneration): ) ) log.info(f"Next order in {self.next_order_time}") - - -def simple_score_calc_gen_func( - meal: Item, duration: timedelta, now: datetime, kwargs: dict, **other_kwargs -) -> Callable: - """An example for the `RandomOrderKwarg.score_calc_gen_func` that selects the score for an order based on its meal from a list. - - Example: - ```yaml - score_calc_gen_func: !!python/name:overcooked_simulator.order.simple_score_calc_gen_func '' - score_calc_gen_kwargs: - # the kwargs for the score_calc_gen_func - other: 0 - scores: - Burger: 15 - OnionSoup: 10 - Salad: 5 - TomatoSoup: 10 - ``` - """ - scores = kwargs["scores"] - other = kwargs["other"] - - def score_calc(relative_order_time: timedelta, order: Order) -> float: - if order.meal.name in scores: - return scores[order.meal.name] - return other - - return score_calc - - -def simple_expired_penalty(item: ItemInfo, default: float, **kwargs) -> float: - """Example for the `RandomOrderKwarg.expired_penalty_func` function. - - A static default. - - Example: - ```yaml - expired_penalty_func: !!python/name:overcooked_simulator.order.simple_expired_penalty '' - expired_penalty_kwargs: - default: -5 - ``` - """ - return default - - -def serving_not_ordered_meals_with_zero_score(meal: Item) -> Tuple[bool, float | int]: - """Not ordered meals are accepted but do not affect the score.""" - return True, 0 - -def serving_not_ordered_meals_with_five_score(meal: Item) -> Tuple[bool, float | int]: - """Not ordered meals are accepted but do not affect the score.""" - return True, 5 - -def penalty_for_each_item(remove: Item | list[Item]) -> float: - if isinstance(remove, list): - return -len(remove) * 5 - return -5 diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index d46cec89b65462156f3a2ba61e4463ff6c77f11e..6d8dd883f32c9c82da5134e119177bff926b179c 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -46,11 +46,10 @@ from overcooked_simulator.hooks import ( ITEM_INFO_CONFIG, ) from overcooked_simulator.order import ( - OrderAndScoreManager, + OrderManager, OrderConfig, ) from overcooked_simulator.player import Player, PlayerConfig -from overcooked_simulator.state_representation import StateRepresentation from overcooked_simulator.utils import create_init_env_time, get_closest log = logging.getLogger(__name__) @@ -147,6 +146,9 @@ class Environment: self.hook: Hooks = Hooks(self) """Hook manager. Register callbacks and create hook points with additional kwargs.""" + self.score: float = 0.0 + """The current score of the environment.""" + self.players: dict[str, Player] = {} """the player, keyed by their id/name.""" @@ -192,7 +194,7 @@ class Environment: """The allowed meals depend on the `environment_config.yml` configured behaviour. Either all meals that are possible or only a limited subset.""" - self.order_and_score = OrderAndScoreManager( + self.order_and_score = OrderManager( order_config=self.environment_config["orders"], available_meals={ item: info @@ -223,7 +225,7 @@ class Environment: else {} ) ), - order_and_score=self.order_and_score, + order_manager=self.order_and_score, effect_manager_config=self.environment_config["effect_manager"], hook=self.hook, random=self.random, @@ -754,7 +756,7 @@ class Environment: return { "players": self.players, "counters": self.counters, - "score": self.order_and_score.score, + "score": self.score, "orders": self.order_and_score.open_orders, "ended": self.game_ended, "env_time": self.env_time, @@ -768,7 +770,7 @@ class Environment: "players": [p.to_dict() for p in self.players.values()], "counters": [c.to_dict() for c in self.counters], "kitchen": {"width": self.kitchen_width, "height": self.kitchen_height}, - "score": self.order_and_score.score, + "score": self.score, "orders": self.order_and_score.order_state(), "ended": self.game_ended, "env_time": self.env_time.isoformat(), @@ -810,3 +812,8 @@ class Environment: function_def["func"]( name=function_name, env=self, **function_def["kwargs"] ) + + def increment_score(self, score: int | float, info: str = ""): + """Add a value to the current score and log it.""" + self.score += score + log.debug(f"Score: {self.score} ({score}) - {info}") diff --git a/overcooked_simulator/recording.py b/overcooked_simulator/recording.py index 17a65ce0fa233a4013b6872bde7afbc33ccb584c..930f2c1c3d593b4700f46befc9132db8d5adf713 100644 --- a/overcooked_simulator/recording.py +++ b/overcooked_simulator/recording.py @@ -3,37 +3,27 @@ import logging import os import traceback from pathlib import Path -from typing import Any import platformdirs from overcooked_simulator import ROOT_DIR +from overcooked_simulator.hooks import HookCallbackClass from overcooked_simulator.overcooked_environment import Environment from overcooked_simulator.utils import NumpyAndDataclassEncoder log = logging.getLogger(__name__) -def class_recording_with_hooks( - name: str, - env: Environment, - hooks: list[str], - log_class, - log_class_kwargs: dict[str, Any], -): - recorder = log_class(name=name, env=env, **log_class_kwargs) - for hook in hooks: - env.register_callback_for_hook(hook, recorder) - - -class LogRecorder: +class FileRecorder(HookCallbackClass): def __init__( self, name: str, env: Environment, log_path: str = "USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl", add_hook_ref: bool = False, + **kwargs, ): + super().__init__(name, env, **kwargs) self.add_hook_ref = add_hook_ref log_path = log_path.replace("ENV_NAME", env.env_name).replace( "LOG_RECORD_NAME", name diff --git a/overcooked_simulator/scores.py b/overcooked_simulator/scores.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e5799eee5ced60d9063329246fa92fd7ccf8fa --- /dev/null +++ b/overcooked_simulator/scores.py @@ -0,0 +1,43 @@ +from typing import Any + +from overcooked_simulator.hooks import HookCallbackClass +from overcooked_simulator.overcooked_environment import Environment + + +class ScoreViaHooks(HookCallbackClass): + def __init__( + self, + name: str, + env: Environment, + static_score: float = 0, + score_map: dict[str, float] = None, + score_on_specific_kwarg: str = None, + kwarg_filter: dict[str, Any] = None, + **kwargs, + ): + super().__init__(name, env, **kwargs) + self.score_map = score_map + self.static_score = static_score + self.kwarg_filter = kwarg_filter + self.score_on_specific_kwarg = score_on_specific_kwarg + + def __call__(self, hook_ref: str, env: Environment, **kwargs): + if self.score_on_specific_kwarg: + if kwargs[self.score_on_specific_kwarg] in self.score_map: + self.env.increment_score( + self.score_map[kwargs[self.score_on_specific_kwarg]], + info=f"{hook_ref} - {kwargs[self.score_on_specific_kwarg]}", + ) + else: + self.env.increment_score(self.static_score, info=hook_ref) + elif self.score_map and hook_ref in self.score_map: + if self.kwarg_filter: + if kwargs.items() <= self.kwarg_filter.items(): + self.env.increment_score( + self.score_map[hook_ref], + info=f"{hook_ref} - {self.kwarg_filter}", + ) + else: + self.env.increment_score(self.score_map[hook_ref], info=hook_ref) + else: + self.env.increment_score(self.static_score, info=hook_ref)