From 3c1fda71c3858f0d7a442929f5130960b2ca816a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20Schr=C3=B6der?= <fschroeder@techfak.uni-bielefeld.de> Date: Thu, 8 Feb 2024 11:56:47 +0100 Subject: [PATCH] Refactor scoring system and order management The scoring system and order management have been significantly updated. The score calculation has been moved out from 'OrderManager' to external 'ScoreViaHooks' class which works via hooks. The order management updates include separating scoring from orders and removing unused functions. Changes in the configuration file and other related files updated to accommodate these modifications. --- overcooked_simulator/counter_factory.py | 10 +- overcooked_simulator/counters.py | 24 +- .../game_content/environment_config.yaml | 84 +++-- .../game_content/environment_config_rl.yaml | 39 +- overcooked_simulator/hooks.py | 32 +- overcooked_simulator/order.py | 338 +++--------------- .../overcooked_environment.py | 19 +- overcooked_simulator/recording.py | 18 +- overcooked_simulator/scores.py | 43 +++ 9 files changed, 251 insertions(+), 356 deletions(-) create mode 100644 overcooked_simulator/scores.py diff --git a/overcooked_simulator/counter_factory.py b/overcooked_simulator/counter_factory.py index 9b1487c3..22f56cf9 100644 --- a/overcooked_simulator/counter_factory.py +++ b/overcooked_simulator/counter_factory.py @@ -60,7 +60,7 @@ from overcooked_simulator.game_items import ( Item, ) from overcooked_simulator.hooks import Hooks -from overcooked_simulator.order import OrderAndScoreManager +from overcooked_simulator.order import OrderManager from overcooked_simulator.utils import get_closest T = TypeVar("T") @@ -119,7 +119,7 @@ class CounterFactory: item_info: dict[str, ItemInfo], serving_window_additional_kwargs: dict[str, Any], plate_config: PlateConfig, - order_and_score: OrderAndScoreManager, + order_manager: OrderManager, effect_manager_config: dict, hook: Hooks, random: Random, @@ -150,7 +150,7 @@ class CounterFactory: """The additional keyword arguments for the serving window.""" self.plate_config: PlateConfig = plate_config """The plate config from the `environment_config`""" - self.order_and_score: OrderAndScoreManager = order_and_score + self.order_manager: OrderManager = order_manager """The order and score manager to pass to `ServingWindow` and the `Tashcan` which can affect the scores.""" self.effect_manager_config = effect_manager_config """The effect manager config to setup the effect manager based on the defined effects in the item info.""" @@ -246,8 +246,8 @@ class CounterFactory: kwargs.update(self.serving_window_additional_kwargs) if issubclass(counter_class, (ServingWindow, Trashcan)): kwargs[ - "order_and_score" - ] = self.order_and_score # individual because for the later trash scorer + "order_manager" + ] = self.order_manager # individual because for the later trash scorer return counter_class(**kwargs) def can_map(self, char) -> bool: diff --git a/overcooked_simulator/counters.py b/overcooked_simulator/counters.py index a7f126d7..8ce066ab 100644 --- a/overcooked_simulator/counters.py +++ b/overcooked_simulator/counters.py @@ -66,7 +66,7 @@ from overcooked_simulator.hooks import ( if TYPE_CHECKING: from overcooked_simulator.effect_manager import Effect from overcooked_simulator.overcooked_environment import ( - OrderAndScoreManager, + OrderManager, ) import numpy as np @@ -339,13 +339,13 @@ class ServingWindow(Counter): def __init__( self, - order_and_score: OrderAndScoreManager, + order_manager: OrderManager, meals: set[str], env_time_func: Callable[[], datetime], plate_dispenser: PlateDispenser = None, **kwargs, ): - self.order_and_score: OrderAndScoreManager = order_and_score + self.order_manager: OrderManager = order_manager """Reference to the OrderAndScoreManager class. It determines which meals can be served and it manages the score.""" self.plate_dispenser: PlateDispenser = plate_dispenser @@ -360,7 +360,7 @@ class ServingWindow(Counter): def drop_off(self, item) -> Item | None: env_time = self.env_time_func() self.hook(PRE_SERVING, counter=self, item=item, env_time=env_time) - if self.order_and_score.serve_meal(item=item, env_time=env_time): + if self.order_manager.serve_meal(item=item, env_time=env_time): if self.plate_dispenser is not None: self.plate_dispenser.update_plate_out_of_kitchen(env_time=env_time) self.hook(POST_SERVING, counter=self, item=item, env_time=env_time) @@ -464,6 +464,7 @@ class PlateConfig: return_dirty: bool = True """Specifies if plates are returned dirty or clean to the plate dispenser.""" + class PlateDispenser(Counter): """At the moment, one and only one plate dispenser must exist in an environment, because only at one place the dirty plates should arrive. @@ -523,7 +524,9 @@ class PlateDispenser(Counter): def add_dirty_plate(self): """Add a dirty plate after a timer is completed.""" - self.occupied_by.appendleft(self.create_item(clean=not self.plate_config.return_dirty)) + self.occupied_by.appendleft( + self.create_item(clean=not self.plate_config.return_dirty) + ) def update_plate_out_of_kitchen(self, env_time: datetime): """Is called from the serving window to add a plate out of kitchen.""" @@ -597,10 +600,8 @@ class Trashcan(Counter): The character `X` in the `layout` file represents the Trashcan. """ - def __init__(self, order_and_score: OrderAndScoreManager, **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) - self.order_and_score: OrderAndScoreManager = order_and_score - """Reference to the `OrderAndScoreManager`, because unnecessary removed items can affect the score.""" def pick_up(self, on_hands: bool = True) -> Item | None: pass @@ -613,14 +614,9 @@ class Trashcan(Counter): ): return item if isinstance(item, CookingEquipment): - penalty = self.order_and_score.apply_penalty_for_using_trash( - item.content_list - ) item.reset_content() return item - else: - penalty = self.order_and_score.apply_penalty_for_using_trash(item) - self.hook(TRASHCAN_USAGE, counter=self, item=item, penalty=penalty) + self.hook(TRASHCAN_USAGE, counter=self, item=item) return None def can_drop_off(self, item: Item) -> bool: diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml index 0b24c354..6d290b82 100644 --- a/overcooked_simulator/game_content/environment_config.yaml +++ b/overcooked_simulator/game_content/environment_config.yaml @@ -1,8 +1,7 @@ plates: - clean_plates: 0 + clean_plates: 1 dirty_plates: 2 plate_delay: [ 5, 10 ] - return_dirty: True # range of seconds until the dirty plate arrives. game: @@ -67,28 +66,14 @@ orders: b: 20 sample_on_serving: false # Sample the delay for the next order only after a meal was served. - score_calc_gen_func: !!python/name:overcooked_simulator.order.simple_score_calc_gen_func '' - score_calc_gen_kwargs: - # the kwargs for the score_calc_gen_func - other: 20 - scores: - Burger: 15 - OnionSoup: 10 - Salad: 5 - TomatoSoup: 10 - expired_penalty_func: !!python/name:overcooked_simulator.order.simple_expired_penalty '' - expired_penalty_kwargs: - default: -5 - serving_not_ordered_meals: !!python/name:overcooked_simulator.order.serving_not_ordered_meals_with_zero_score '' - # a func that calcs a store for not ordered but served meals. Input: meal - penalty_for_trash: !!python/name:overcooked_simulator.order.penalty_for_each_item '' - # a func that calcs the penalty for items that the player puts into the trashcan. + serving_not_ordered_meals: true + # can meals that are not ordered be served / dropped on the serving window player_config: radius: 0.4 player_speed_units_per_seconds: 6 interaction_range: 1.6 - restricted_view: False + restricted_view: false view_angle: 95 effect_manager: @@ -100,34 +85,71 @@ effect_manager: extra_setup_functions: + # # --------------- Scoring --------------- + orders: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ completed_order ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: 20 + score_on_specific_kwarg: meal_name + score_map: + Burger: 15 + OnionSoup: 10 + Salad: 5 + TomatoSoup: 10 + not_ordered_meals: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ serve_not_ordered_meal ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: 2 + trashcan_usages: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ trashcan_usage ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: -5 + expired_orders: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ order_expired ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: -10 + # # --------------- Recording --------------- # json_states: - # func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' + # func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' # kwargs: # hooks: [ json_state ] - # log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' - # log_class_kwargs: + # callback_class: !!python/name:overcooked_simulator.recording.FileRecorder '' + # callback_class_kwargs: # log_path: USER_LOG_DIR/ENV_NAME/json_states.jsonl actions: - func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' kwargs: hooks: [ pre_perform_action ] - log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' - log_class_kwargs: + callback_class: !!python/name:overcooked_simulator.recording.FileRecorder '' + callback_class_kwargs: log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl random_env_events: - func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' kwargs: hooks: [ order_duration_sample, plate_out_of_kitchen_time ] - log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' - log_class_kwargs: + callback_class: !!python/name:overcooked_simulator.recording.FileRecorder '' + callback_class_kwargs: log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl add_hook_ref: true env_configs: - func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' kwargs: hooks: [ env_initialized, item_info_config ] - log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' - log_class_kwargs: + callback_class: !!python/name:overcooked_simulator.recording.FileRecorder '' + callback_class_kwargs: log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl add_hook_ref: true + diff --git a/overcooked_simulator/game_content/environment_config_rl.yaml b/overcooked_simulator/game_content/environment_config_rl.yaml index 6235a971..7702f835 100644 --- a/overcooked_simulator/game_content/environment_config_rl.yaml +++ b/overcooked_simulator/game_content/environment_config_rl.yaml @@ -71,7 +71,7 @@ orders: score_calc_gen_kwargs: # the kwargs for the score_calc_gen_func other: 0 - scores: [] + scores: [ ] expired_penalty_func: !!python/name:overcooked_simulator.order.simple_expired_penalty '' expired_penalty_kwargs: default: 0 @@ -87,7 +87,7 @@ player_config: restricted_view: False view_angle: 95 -effect_manager: {} +effect_manager: { } # FireManager: # class: !!python/name:overcooked_simulator.effect_manager.FireEffectManager '' # kwargs: @@ -96,6 +96,41 @@ effect_manager: {} extra_setup_functions: + # # --------------- Scoring --------------- + orders: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ completed_order ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: 20 + score_on_specific_kwarg: meal_name + score_map: + Burger: 15 + OnionSoup: 10 + Salad: 5 + TomatoSoup: 10 + not_ordered_meals: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ serve_not_ordered_meal ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: 2 + trashcan_usages: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ trashcan_usage ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: -5 + expired_orders: + func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ order_expired ] + callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' + callback_class_kwargs: + static_score: -10 # json_states: # func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' # kwargs: diff --git a/overcooked_simulator/hooks.py b/overcooked_simulator/hooks.py index 285ed735..9d907fcd 100644 --- a/overcooked_simulator/hooks.py +++ b/overcooked_simulator/hooks.py @@ -1,6 +1,12 @@ +from __future__ import annotations + +from abc import abstractmethod from collections import defaultdict from functools import partial -from typing import Callable +from typing import Callable, Any, TYPE_CHECKING, Type + +if TYPE_CHECKING: + from overcooked_simulator.overcooked_environment import Environment # TODO add player_id as kwarg to all hooks -> pass player id to all methods @@ -90,6 +96,8 @@ INIT_ORDERS = "init_orders" NEW_ORDERS = "new_orders" +ORDER_EXPIRED = "order_expired" + ACTION_ON_NOT_REACHABLE_COUNTER = "action_on_not_reachable_counter" ACTION_PUT = "action_put" @@ -118,6 +126,28 @@ def print_hook_callback(text, env, **kwargs): print(env.env_time, text) +class HookCallbackClass: + def __init__(self, name: str, env: Environment, **kwargs): + self.name = name + self.env = env + + @abstractmethod + def __call__(self, hook_ref: str, env: Environment, **kwargs): + ... + + +def hooks_via_callback_class( + name: str, + env: Environment, + hooks: list[str], + callback_class: Type[HookCallbackClass], + callback_class_kwargs: dict[str, Any], +): + recorder = callback_class(name=name, env=env, **callback_class_kwargs) + for hook in hooks: + env.register_callback_for_hook(hook, recorder) + + def add_dummy_callbacks(env): env.register_callback_for_hook( SERVE_NOT_ORDERED_MEAL, diff --git a/overcooked_simulator/order.py b/overcooked_simulator/order.py index 990be494..7f60abb5 100644 --- a/overcooked_simulator/order.py +++ b/overcooked_simulator/order.py @@ -25,19 +25,10 @@ This file defines the following classes: Further, it defines same implementations for the basic order generation based on random sampling: - `RandomOrderGeneration` -- `simple_score_calc_gen_func` -- `simple_expired_penalty` -- `zero` For an easier usage of the random orders, also some classes for type hints and dataclasses are defined: - `RandomOrderKwarg` - `RandomFuncConfig` -- `ScoreCalcFuncType` -- `ScoreCalcGenFuncType` -- `ExpiredPenaltyFuncType` - -For the scoring of using the trashcan the `penalty_for_each_item` example function is defined. You can set/replace it -in the `environment_config`. ## Code Documentation @@ -51,7 +42,7 @@ from abc import abstractmethod from collections import deque from datetime import datetime, timedelta from random import Random -from typing import Callable, Tuple, Any, Deque, Protocol, TypedDict, Type +from typing import Callable, Tuple, Any, Deque, TypedDict, Type from overcooked_simulator.game_items import Item, Plate, ItemInfo from overcooked_simulator.hooks import ( @@ -62,6 +53,7 @@ from overcooked_simulator.hooks import ( INIT_ORDERS, NEW_ORDERS, ORDER_DURATION_SAMPLE, + ORDER_EXPIRED, ) log = logging.getLogger(__name__) @@ -92,37 +84,10 @@ class Order: """The start time relative to the env_time. On which the order is returned from the get_orders func.""" max_duration: timedelta """The duration after which the order expires.""" - score_calc: ScoreCalcFuncType - """The function that calculates the score of the served meal/fulfilled order.""" - timed_penalties: list[ - Tuple[timedelta, float] | Tuple[timedelta, float, int, timedelta] - ] - """List of timed penalties when the order is not fulfilled.""" - expired_penalty: float - """The penalty to the score if the order expires""" uuid: str = dataclasses.field(default_factory=lambda: uuid.uuid4().hex) """The unique identifier for the order.""" finished_info: dict[str, Any] = dataclasses.field(default_factory=dict) """Is set after the order is completed.""" - _timed_penalties: list[Tuple[datetime, float]] = dataclasses.field( - default_factory=list - ) - """Converted penalties the env is working with from the `timed_penalties`""" - - def create_penalties(self, env_time: datetime): - """Create the general timed penalties list to check for penalties after some time the order is still not - fulfilled.""" - for penalty_info in self.timed_penalties: - match penalty_info: - case (offset, penalty): - self._timed_penalties.append((env_time + offset, penalty)) - case (duration, penalty, number_repeat, offset): - self._timed_penalties.extend( - [ - (env_time + offset + (duration * i), penalty) - for i in range(number_repeat) - ] - ) class OrderGeneration: @@ -138,7 +103,13 @@ class OrderGeneration: ``` """ - def __init__(self, available_meals: dict[str, ItemInfo], hook: Hooks, random: Random, **kwargs): + def __init__( + self, + available_meals: dict[str, ItemInfo], + hook: Hooks, + random: Random, + **kwargs, + ): self.available_meals: list[ItemInfo] = list(available_meals.values()) """Available meals restricted through the `environment_config.yml`.""" self.hook = hook @@ -163,7 +134,7 @@ class OrderGeneration: ... -class OrderAndScoreManager: +class OrderManager: """The Order and Score Manager that is called from the serving window.""" def __init__( @@ -175,8 +146,6 @@ class OrderAndScoreManager: ): self.random = random """Random instance.""" - self.score: float = 0.0 - """The current score of the environment.""" self.order_gen: OrderGeneration = order_config["order_gen_class"]( available_meals=available_meals, hook=hook, @@ -189,13 +158,6 @@ class OrderAndScoreManager: ] = order_config["serving_not_ordered_meals"] """Function that decides if not ordered meals can be served and what score it gives""" - self.penalty_for_trash: Callable[[Item | list[Item]], float] = ( - order_config["penalty_for_trash"] - if "penalty_for_trash" in order_config - else zero - ) - """Function that calculates the penalty for items which were put into the trashcan.""" - self.available_meals = available_meals """The meals for that orders can be sampled from.""" self.open_orders: Deque[Order] = deque() @@ -207,7 +169,7 @@ class OrderAndScoreManager: self.last_finished: list[Order] = [] """Cache last finished orders for `OrderGeneration.get_orders` call. From the served meals.""" self.next_relevant_time: datetime = datetime.max - """For reduced order checking. Store the next time when to create an order or check for penalties.""" + """For reduced order checking. Store the next time when to create an order.""" self.last_expired: list[Order] = [] """Cache last expired orders for `OrderGeneration.get_orders` call.""" @@ -221,8 +183,6 @@ class OrderAndScoreManager: next_relevant_time = min( next_relevant_time, order.start_time + order.max_duration ) - for penalty in order._timed_penalties: - next_relevant_time = min(next_relevant_time, penalty[0]) self.next_relevant_time = next_relevant_time def serve_meal(self, item: Item, env_time: datetime) -> bool: @@ -235,57 +195,40 @@ class OrderAndScoreManager: order = self.find_order_for_meal(meal) if order is None: if self.serving_not_ordered_meals: - accept, score = self.serving_not_ordered_meals(meal) self.hook( SERVE_NOT_ORDERED_MEAL, - accept=accept, - score=score, meal=meal, + meal_name=meal.name, ) - if accept: - log.info( - f"Serving meal without order {meal.name!r} with score {score}" - ) - self.increment_score(score) - self.served_meals.append((meal, env_time)) - return accept + log.info(f"Serving meal without order {meal.name!r}") + self.served_meals.append((meal, env_time)) + return True log.info( f"Do not serve meal {meal.name!r} because it is not ordered" ) return False order, index = order - score = order.score_calc( - relative_order_time=env_time - order.start_time, - order=order, - ) - self.increment_score(score) - order.finished_info = { - "end_time": env_time, - "score": score, - } - log.info( - f"Serving meal {meal.name!r} with order with score {score}" - ) + log.info(f"Serving meal {meal.name!r} with order") self.last_finished.append(order) del self.open_orders[index] self.served_meals.append((meal, env_time)) - self.hook(COMPLETED_ORDER, score=score, order=order, meal=meal) + self.hook( + COMPLETED_ORDER, + order=order, + meal=meal, + relative_order_time=env_time - order.start_time, + meal_name=meal.name, + ) return True else: self.hook(SERVE_WITHOUT_PLATE, item=item) log.info(f"Do not serve item {item}") return False - def increment_score(self, score: int | float): - """Add a value to the current score and log it.""" - self.score += score - log.debug(f"Score: {self.score}") - def create_init_orders(self, env_time): """Create the initial orders in an environment.""" init_orders = self.order_gen.init_orders(env_time) self.hook(INIT_ORDERS) - self.setup_penalties(new_orders=init_orders, env_time=env_time) self.open_orders.extend(init_orders) def progress(self, passed_time: timedelta, now: datetime): @@ -298,7 +241,6 @@ class OrderAndScoreManager: ) if new_orders: self.hook(NEW_ORDERS, new_orders=new_orders) - self.setup_penalties(new_orders=new_orders, env_time=now) self.open_orders.extend(new_orders) self.last_finished = [] self.last_expired = [] @@ -309,20 +251,9 @@ class OrderAndScoreManager: for index, order in enumerate(self.open_orders): if now >= order.start_time + order.max_duration: # orders expired - self.increment_score(order.expired_penalty) + self.hook(ORDER_EXPIRED, order=order) remove_orders.append(index) - continue # no penalties for expired orders - remove_penalties = [] - for i, (penalty_time, penalty) in enumerate(order.timed_penalties): - # check penalties - if penalty_time < now: - # TODO add hook - self.score -= penalty - remove_penalties.append(i) - - for i in reversed(remove_penalties): - # or del order.timed_penalties[index] - order.timed_penalties.pop(i) + continue expired_orders: list[Order] = [] for remove_order in reversed(remove_orders): @@ -339,12 +270,6 @@ class OrderAndScoreManager: if order.meal.name == meal.name: return order, index - @staticmethod - def setup_penalties(new_orders: list[Order], env_time: datetime): - """Call the `Order.create_penalties` method for new orders.""" - for order in new_orders: - order.create_penalties(env_time) - def order_state(self) -> list[dict]: """Similar to the `to_dict` in `Item` and `Counter`. Relevant for the state of the environment""" return [ @@ -358,80 +283,6 @@ class OrderAndScoreManager: for order in self.open_orders ] - def apply_penalty_for_using_trash(self, remove: Item | list[Item]) -> float: - """Is called if a item is put into the trashcan.""" - penalty = self.penalty_for_trash(remove) - self.increment_score(penalty) - return penalty - - -class ScoreCalcFuncType(Protocol): - """Typed kwargs of the expected `Order.score_calc` function. Which is also returned by the - `RandomOrderKwarg.score_calc_gen_func`. - - The function should calculate the score for the completed orders. - - Args: - relative_order_time: `timedelta` the duration how long the order was active. - order: `Order` the order that was completed. - - Returns: - `float`: the score for a completed order and duration of the order. - """ - - def __call__(self, relative_order_time: timedelta, order: Order) -> float: - ... - - -class ScoreCalcGenFuncType(Protocol): - """Typed kwargs of the expected function for the `RandomOrderKwarg.score_calc_gen_func`. - - Generate the ScoreCalcFunc for an order based on its meal, duration etc. - - Args: - meal: `ItemInfo` the type of meal the order orders. - duration: `timedelta` the duration after the order expires. - now: `datetime` the environment time the order is created. - kwargs: `dict` the static kwargs defined in the `environment_config.yml` - - Returns: - `ScoreCalcFuncType` a reference to a function that calculates the score for a completed meal. - """ - - def __call__( - self, - meal: ItemInfo, - duration: timedelta, - now: datetime, - kwargs: dict, - **other_kwargs, - ) -> ScoreCalcFuncType: - ... - - -class ExpiredPenaltyFuncType(Protocol): - """Typed kwargs of the expected function for the `RandomOrderKwarg.expired_penalty_func`. - - An example is the `zero` function. - - Args: - item: `ItemInfo` the meal of the order that expired. It is calculated before the order is active. - """ - - def __call__(self, item: ItemInfo, **kwargs) -> float: - ... - - -def zero(item: ItemInfo, **kwargs) -> float: - """Example and default for the `RandomOrderKwarg.expired_penalty_func` function. - - Just no penalty for expired orders. - - Returns: - zero / 0.0 - """ - return 0.0 - class RandomFuncConfig(TypedDict): """Types of the dict for sampling with different random functions from the [`random` library](https://docs.python.org/3/library/random.html). @@ -470,14 +321,6 @@ class RandomOrderKwarg: """How many orders can maximally be active at the same time.""" order_duration_random_func: RandomFuncConfig """How long the order is alive until it expires. If `sample_on_serving` is `true` all orders have no expire time.""" - score_calc_gen_func: ScoreCalcGenFuncType - """The function that generates the `Order.score_calc` for each order.""" - score_calc_gen_kwargs: dict - """The additional static kwargs for `score_calc_gen_func`.""" - expired_penalty_func: Callable[[ItemInfo], float] = zero - """The function that calculates the penalty for a meal that was not served.""" - expired_penalty_kwargs: dict = dataclasses.field(default_factory=dict) - """The additional static kwargs for the `expired_penalty_func`.""" class RandomOrderGeneration(OrderGeneration): @@ -492,42 +335,39 @@ class RandomOrderGeneration(OrderGeneration): ```yaml orders: order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration '' - kwargs: - order_duration_random_func: - # how long should the orders be alive - # 'random' library call with getattr, kwargs are passed to the function - func: uniform - kwargs: - a: 40 - b: 60 - max_orders: 6 - # maximum number of active orders at the same time - num_start_meals: 3 - # number of orders generated at the start of the environment - sample_on_dur_random_func: - # 'random' library call with getattr, kwargs are passed to the function - func: uniform - kwargs: - a: 10 - b: 20 - sample_on_serving: false - # Sample the delay for the next order only after a meal was served. - score_calc_gen_func: !!python/name:overcooked_simulator.order.simple_score_calc_gen_func '' - score_calc_gen_kwargs: - # the kwargs for the score_calc_gen_func - other: 0 - scores: - Burger: 15 - OnionSoup: 10 - Salad: 5 - TomatoSoup: 10 - expired_penalty_func: !!python/name:overcooked_simulator.order.simple_expired_penalty '' - expired_penalty_kwargs: - default: -5 + # the class to that receives the kwargs. Should be a child class of OrderGeneration in order.py + order_gen_kwargs: + order_duration_random_func: + # how long should the orders be alive + # 'random' library call with getattr, kwargs are passed to the function + func: uniform + kwargs: + a: 40 + b: 60 + max_orders: 6 + # maximum number of active orders at the same time + num_start_meals: 2 + # number of orders generated at the start of the environment + sample_on_dur_random_func: + # 'random' library call with getattr, kwargs are passed to the function + func: uniform + kwargs: + a: 10 + b: 20 + sample_on_serving: false + # Sample the delay for the next order only after a meal was served. + serving_not_ordered_meals: true + # can meals that are not ordered be served / dropped on the serving window ``` """ - def __init__(self, available_meals: dict[str, ItemInfo], hook: Hooks, random: Random, **kwargs): + def __init__( + self, + available_meals: dict[str, ItemInfo], + hook: Hooks, + random: Random, + **kwargs, + ): super().__init__(available_meals, hook, random, **kwargs) self.kwargs: RandomOrderKwarg = RandomOrderKwarg(**kwargs["kwargs"]) self.next_order_time: datetime | None = datetime.max @@ -604,16 +444,6 @@ class RandomOrderGeneration(OrderGeneration): meal=meal, start_time=now, max_duration=duration, - score_calc=self.kwargs.score_calc_gen_func( - meal=meal, - duration=duration, - now=now, - kwargs=self.kwargs.score_calc_gen_kwargs, - ), - timed_penalties=[], - expired_penalty=self.kwargs.expired_penalty_func( - meal, **self.kwargs.expired_penalty_kwargs - ), ) ) @@ -626,61 +456,3 @@ class RandomOrderGeneration(OrderGeneration): ) ) log.info(f"Next order in {self.next_order_time}") - - -def simple_score_calc_gen_func( - meal: Item, duration: timedelta, now: datetime, kwargs: dict, **other_kwargs -) -> Callable: - """An example for the `RandomOrderKwarg.score_calc_gen_func` that selects the score for an order based on its meal from a list. - - Example: - ```yaml - score_calc_gen_func: !!python/name:overcooked_simulator.order.simple_score_calc_gen_func '' - score_calc_gen_kwargs: - # the kwargs for the score_calc_gen_func - other: 0 - scores: - Burger: 15 - OnionSoup: 10 - Salad: 5 - TomatoSoup: 10 - ``` - """ - scores = kwargs["scores"] - other = kwargs["other"] - - def score_calc(relative_order_time: timedelta, order: Order) -> float: - if order.meal.name in scores: - return scores[order.meal.name] - return other - - return score_calc - - -def simple_expired_penalty(item: ItemInfo, default: float, **kwargs) -> float: - """Example for the `RandomOrderKwarg.expired_penalty_func` function. - - A static default. - - Example: - ```yaml - expired_penalty_func: !!python/name:overcooked_simulator.order.simple_expired_penalty '' - expired_penalty_kwargs: - default: -5 - ``` - """ - return default - - -def serving_not_ordered_meals_with_zero_score(meal: Item) -> Tuple[bool, float | int]: - """Not ordered meals are accepted but do not affect the score.""" - return True, 0 - -def serving_not_ordered_meals_with_five_score(meal: Item) -> Tuple[bool, float | int]: - """Not ordered meals are accepted but do not affect the score.""" - return True, 5 - -def penalty_for_each_item(remove: Item | list[Item]) -> float: - if isinstance(remove, list): - return -len(remove) * 5 - return -5 diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index d46cec89..6d8dd883 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -46,11 +46,10 @@ from overcooked_simulator.hooks import ( ITEM_INFO_CONFIG, ) from overcooked_simulator.order import ( - OrderAndScoreManager, + OrderManager, OrderConfig, ) from overcooked_simulator.player import Player, PlayerConfig -from overcooked_simulator.state_representation import StateRepresentation from overcooked_simulator.utils import create_init_env_time, get_closest log = logging.getLogger(__name__) @@ -147,6 +146,9 @@ class Environment: self.hook: Hooks = Hooks(self) """Hook manager. Register callbacks and create hook points with additional kwargs.""" + self.score: float = 0.0 + """The current score of the environment.""" + self.players: dict[str, Player] = {} """the player, keyed by their id/name.""" @@ -192,7 +194,7 @@ class Environment: """The allowed meals depend on the `environment_config.yml` configured behaviour. Either all meals that are possible or only a limited subset.""" - self.order_and_score = OrderAndScoreManager( + self.order_and_score = OrderManager( order_config=self.environment_config["orders"], available_meals={ item: info @@ -223,7 +225,7 @@ class Environment: else {} ) ), - order_and_score=self.order_and_score, + order_manager=self.order_and_score, effect_manager_config=self.environment_config["effect_manager"], hook=self.hook, random=self.random, @@ -754,7 +756,7 @@ class Environment: return { "players": self.players, "counters": self.counters, - "score": self.order_and_score.score, + "score": self.score, "orders": self.order_and_score.open_orders, "ended": self.game_ended, "env_time": self.env_time, @@ -768,7 +770,7 @@ class Environment: "players": [p.to_dict() for p in self.players.values()], "counters": [c.to_dict() for c in self.counters], "kitchen": {"width": self.kitchen_width, "height": self.kitchen_height}, - "score": self.order_and_score.score, + "score": self.score, "orders": self.order_and_score.order_state(), "ended": self.game_ended, "env_time": self.env_time.isoformat(), @@ -810,3 +812,8 @@ class Environment: function_def["func"]( name=function_name, env=self, **function_def["kwargs"] ) + + def increment_score(self, score: int | float, info: str = ""): + """Add a value to the current score and log it.""" + self.score += score + log.debug(f"Score: {self.score} ({score}) - {info}") diff --git a/overcooked_simulator/recording.py b/overcooked_simulator/recording.py index 17a65ce0..930f2c1c 100644 --- a/overcooked_simulator/recording.py +++ b/overcooked_simulator/recording.py @@ -3,37 +3,27 @@ import logging import os import traceback from pathlib import Path -from typing import Any import platformdirs from overcooked_simulator import ROOT_DIR +from overcooked_simulator.hooks import HookCallbackClass from overcooked_simulator.overcooked_environment import Environment from overcooked_simulator.utils import NumpyAndDataclassEncoder log = logging.getLogger(__name__) -def class_recording_with_hooks( - name: str, - env: Environment, - hooks: list[str], - log_class, - log_class_kwargs: dict[str, Any], -): - recorder = log_class(name=name, env=env, **log_class_kwargs) - for hook in hooks: - env.register_callback_for_hook(hook, recorder) - - -class LogRecorder: +class FileRecorder(HookCallbackClass): def __init__( self, name: str, env: Environment, log_path: str = "USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl", add_hook_ref: bool = False, + **kwargs, ): + super().__init__(name, env, **kwargs) self.add_hook_ref = add_hook_ref log_path = log_path.replace("ENV_NAME", env.env_name).replace( "LOG_RECORD_NAME", name diff --git a/overcooked_simulator/scores.py b/overcooked_simulator/scores.py new file mode 100644 index 00000000..d0e5799e --- /dev/null +++ b/overcooked_simulator/scores.py @@ -0,0 +1,43 @@ +from typing import Any + +from overcooked_simulator.hooks import HookCallbackClass +from overcooked_simulator.overcooked_environment import Environment + + +class ScoreViaHooks(HookCallbackClass): + def __init__( + self, + name: str, + env: Environment, + static_score: float = 0, + score_map: dict[str, float] = None, + score_on_specific_kwarg: str = None, + kwarg_filter: dict[str, Any] = None, + **kwargs, + ): + super().__init__(name, env, **kwargs) + self.score_map = score_map + self.static_score = static_score + self.kwarg_filter = kwarg_filter + self.score_on_specific_kwarg = score_on_specific_kwarg + + def __call__(self, hook_ref: str, env: Environment, **kwargs): + if self.score_on_specific_kwarg: + if kwargs[self.score_on_specific_kwarg] in self.score_map: + self.env.increment_score( + self.score_map[kwargs[self.score_on_specific_kwarg]], + info=f"{hook_ref} - {kwargs[self.score_on_specific_kwarg]}", + ) + else: + self.env.increment_score(self.static_score, info=hook_ref) + elif self.score_map and hook_ref in self.score_map: + if self.kwarg_filter: + if kwargs.items() <= self.kwarg_filter.items(): + self.env.increment_score( + self.score_map[hook_ref], + info=f"{hook_ref} - {self.kwarg_filter}", + ) + else: + self.env.increment_score(self.score_map[hook_ref], info=hook_ref) + else: + self.env.increment_score(self.static_score, info=hook_ref) -- GitLab