diff --git a/overcooked_simulator/counter_factory.py b/overcooked_simulator/counter_factory.py index 46a96a17a973c00e6e5820f83ad886fb0d3a17bc..f1d02672ba836c235c957279dde7526f0dca2b48 100644 --- a/overcooked_simulator/counter_factory.py +++ b/overcooked_simulator/counter_factory.py @@ -15,6 +15,7 @@ from overcooked_simulator.counters import ( Sink, PlateConfig, SinkAddon, + Trashcan, ) from overcooked_simulator.game_items import ItemInfo, ItemType, CookingEquipment, Plate from overcooked_simulator.order import OrderAndScoreManager @@ -175,6 +176,7 @@ class CounterFactory: ) elif issubclass(counter_class, ServingWindow): kwargs.update(self.serving_window_additional_kwargs) + if issubclass(counter_class, (ServingWindow, Trashcan)): kwargs[ "order_and_score" ] = self.order_and_score # individual because for the later trash scorer diff --git a/overcooked_simulator/counters.py b/overcooked_simulator/counters.py index 8fd1e658d173e158ec4ed840832d9fa0b70b2bba..c3e33392d38b4331273061c8c21a1bdd5687eb1a 100644 --- a/overcooked_simulator/counters.py +++ b/overcooked_simulator/counters.py @@ -470,13 +470,22 @@ class Trashcan(Counter): The character `X` in the `layout` file represents the Trashcan. """ + def __init__( + self, order_and_score: OrderAndScoreManager, pos: npt.NDArray[float], **kwargs + ): + super().__init__(pos, **kwargs) + self.order_and_score = order_and_score + def pick_up(self, on_hands: bool = True) -> Item | None: pass def drop_off(self, item: Item) -> Item | None: if isinstance(item, CookingEquipment): + self.order_and_score.apply_penalty_for_using_trash(item.content_list) item.reset_content() return item + else: + self.order_and_score.apply_penalty_for_using_trash(item) return None def can_drop_off(self, item: Item) -> bool: diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml index d8d2d5f69ab32f03d74b67498b72a0227277687d..89a3ceb4b5571767463ec34b517aeac00dbed456 100644 --- a/overcooked_simulator/game_content/environment_config.yaml +++ b/overcooked_simulator/game_content/environment_config.yaml @@ -78,6 +78,8 @@ orders: default: -5 serving_not_ordered_meals: !!python/name:overcooked_simulator.order.serving_not_ordered_meals_with_zero_score '' # a func that calcs a store for not ordered but served meals. Input: meal + penalty_for_trash: !!python/name:overcooked_simulator.order.penalty_for_each_item '' + # a func that calcs the penalty for items that the player puts into the trashcan. player_config: radius: 0.4 diff --git a/overcooked_simulator/order.py b/overcooked_simulator/order.py index 68e5d311cf9cca0f30795486787a5597b5d82f93..b44d014d24a1df78cefcc1b5cf55b536f166b465 100644 --- a/overcooked_simulator/order.py +++ b/overcooked_simulator/order.py @@ -146,6 +146,14 @@ class OrderAndScoreManager: [Item], Tuple[bool, float] ] = order_config["serving_not_ordered_meals"] """Function that decides if not ordered meals can be served and what score it gives""" + + self.penalty_for_trash: Callable[[Item | list[Item]], float] = ( + order_config["penalty_for_trash"] + if "penalty_for_trash" in order_config + else zero + ) + """Function that calculates the penalty for items which were put into the trashcan.""" + self.available_meals = available_meals """The meals for that orders can be sampled from.""" self.open_orders: Deque[Order] = deque() @@ -285,6 +293,9 @@ class OrderAndScoreManager: for order in self.open_orders ] + def apply_penalty_for_using_trash(self, remove: Item | list[Item]): + self.increment_score(self.penalty_for_trash(remove)) + class ScoreCalcFuncType(Protocol): """Typed kwargs of the expected `Order.score_calc` function. Which is also returned by the @@ -592,3 +603,9 @@ def simple_expired_penalty(item: ItemInfo, default: float, **kwargs) -> float: def serving_not_ordered_meals_with_zero_score(meal: Item) -> Tuple[bool, float | int]: """Not ordered meals are accepted but do not affect the score.""" return True, 0 + + +def penalty_for_each_item(remove: Item | list[Item]) -> float: + if isinstance(remove, list): + return -len(remove) * 5 + return -5