diff --git a/overcooked_simulator/__init__.py b/overcooked_simulator/__init__.py index fb941f2fadd719d7f44a3c1977aac8ee5f416538..9a2ff77f21ee63b168b6e4513f95c953e12be649 100644 --- a/overcooked_simulator/__init__.py +++ b/overcooked_simulator/__init__.py @@ -4,7 +4,7 @@ This is the documentation of the Overcooked Simulator. # About the package -The package contains of an environment for cooperation between players/agents. A PyGameGUI visualizes the game to +The package contains an environment for cooperation between players/agents. A PyGameGUI visualizes the game to human or visual agents in 2D. A 3D web-enabled version (for example for online studies, currently under development) can be found [here](https://gitlab.ub.uni-bielefeld.de/scs/cocosy/godot-overcooked-3d-visualization) diff --git a/overcooked_simulator/counter_factory.py b/overcooked_simulator/counter_factory.py index 5653552c1cbd9d38971338f59f12c660799f2212..46a96a17a973c00e6e5820f83ad886fb0d3a17bc 100644 --- a/overcooked_simulator/counter_factory.py +++ b/overcooked_simulator/counter_factory.py @@ -1,7 +1,8 @@ import inspect import sys -from typing import Any +from typing import Any, Type, TypeVar +import numpy as np import numpy.typing as npt from overcooked_simulator.counters import ( @@ -13,8 +14,13 @@ from overcooked_simulator.counters import ( PlateDispenser, Sink, PlateConfig, + SinkAddon, ) -from overcooked_simulator.game_items import ItemInfo, ItemType, CookingEquipment +from overcooked_simulator.game_items import ItemInfo, ItemType, CookingEquipment, Plate +from overcooked_simulator.order import OrderAndScoreManager +from overcooked_simulator.utils import get_closest + +T = TypeVar("T") def convert_words_to_chars(layout_chars_config: dict[str, str]) -> dict[str, str]: @@ -72,6 +78,7 @@ class CounterFactory: item_info: dict[str, ItemInfo], serving_window_additional_kwargs: dict[str, Any], plate_config: PlateConfig, + order_and_score: OrderAndScoreManager, ) -> None: """Constructor for the `CounterFactory` class. Set up the attributes necessary to instantiate the counters. @@ -99,6 +106,7 @@ class CounterFactory: self.item_info = item_info self.serving_window_additional_kwargs = serving_window_additional_kwargs self.plate_config = plate_config + self.order_and_score = order_and_score self.no_counter_chars = set( c @@ -162,12 +170,14 @@ class CounterFactory: by_item_type=ItemType.Meal ), "plate_config": self.plate_config, - "dispensing": self.item_info["Plate"], + "dispensing": self.item_info[Plate.__name__], } ) elif issubclass(counter_class, ServingWindow): kwargs.update(self.serving_window_additional_kwargs) - + kwargs[ + "order_and_score" + ] = self.order_and_score # individual because for the later trash scorer return counter_class(**kwargs) def can_map(self, char) -> bool: @@ -209,3 +219,36 @@ class CounterFactory: and info.equipment.name == by_equipment_name } return self.item_info + + def post_counter_setup(self, counters: list[Counter]): + """Initialize the counters in the environment. + + Connect the `ServingWindow`(s) with the `PlateDispenser`. + Find and connect the `SinkAddon`s with the `Sink`s + + Args: + counters: list of counters to perform the post setup on. + """ + plate_dispenser = self.get_counter_of_type(PlateDispenser, counters) + assert len(plate_dispenser) > 0, "No Plate Dispenser in the environment" + + sink_addons = self.get_counter_of_type(SinkAddon, counters) + + for counter in counters: + match counter: + case ServingWindow(): + counter: ServingWindow # Pycharm type checker does now work for match statements? + counter.add_plate_dispenser(plate_dispenser[0]) + case Sink(pos=pos): + counter: Sink # Pycharm type checker does now work for match statements? + assert len(sink_addons) > 0, "No SinkAddon but normal Sink" + closest_addon = get_closest(pos, sink_addons) + assert 1 - (1 * 0.05) <= np.linalg.norm( + closest_addon.pos - pos + ), f"No SinkAddon connected to Sink at pos {pos}" + counter.set_addon(closest_addon) + + @staticmethod + def get_counter_of_type(counter_type: Type[T], counters: list[Counter]) -> list[T]: + """Filter all counters in the environment for a counter type.""" + return list(filter(lambda counter: isinstance(counter, counter_type), counters)) diff --git a/overcooked_simulator/counters.py b/overcooked_simulator/counters.py index fa278bb2fcba792a1d7eadc9c431819c8e51479c..8fd1e658d173e158ec4ed840832d9fa0b70b2bba 100644 --- a/overcooked_simulator/counters.py +++ b/overcooked_simulator/counters.py @@ -240,7 +240,7 @@ class CuttingBoard(Counter): class ServingWindow(Counter): """The orders and scores are updated based on completed and dropped off meals. The plate dispenser is pinged for - the info about a plate outside of the kitchen. + the info about a plate outside the kitchen. All items in the `item_info.yml` with the type meal are considered to be servable, if they are ordered. Not ordered meals can also be served, if a `serving_not_ordered_meals` function is set in the `environment_config.yml`. diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml index 3cfabc29d50b32492b265afa3b4e2c33f986132b..d8d2d5f69ab32f03d74b67498b72a0227277687d 100644 --- a/overcooked_simulator/game_content/environment_config.yaml +++ b/overcooked_simulator/game_content/environment_config.yaml @@ -67,7 +67,7 @@ orders: score_calc_gen_func: !!python/name:overcooked_simulator.order.simple_score_calc_gen_func '' score_calc_gen_kwargs: # the kwargs for the score_calc_gen_func - other: 0 + other: 20 scores: Burger: 15 OnionSoup: 10 diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index dd2698d96a57037a83fddc48b81a90c023535c11..d3c27dc5b59482fbdd2098215b7163e720bfe434 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -14,15 +14,10 @@ from typing import Literal import numpy as np import numpy.typing as npt import yaml -from scipy.spatial import distance_matrix from overcooked_simulator.counter_factory import CounterFactory from overcooked_simulator.counters import ( Counter, - ServingWindow, - Sink, - PlateDispenser, - SinkAddon, PlateConfig, ) from overcooked_simulator.game_items import ( @@ -32,7 +27,7 @@ from overcooked_simulator.game_items import ( from overcooked_simulator.order import OrderAndScoreManager from overcooked_simulator.player import Player, PlayerConfig from overcooked_simulator.state_representation import StateRepresentation -from overcooked_simulator.utils import create_init_env_time +from overcooked_simulator.utils import create_init_env_time, get_closest log = logging.getLogger(__name__) @@ -148,7 +143,6 @@ class Environment: layout_chars_config=self.environment_config["layout_chars"], item_info=self.item_info, serving_window_additional_kwargs={ - "order_and_score": self.order_and_score, "meals": self.allowed_meal_names, "env_time_func": self.get_env_time, }, @@ -159,6 +153,7 @@ class Environment: else {} ) ), + order_and_score=self.order_and_score, ) ( @@ -185,8 +180,6 @@ class Environment: ) """Counters that needs to be called in the step function via the `progress` method.""" - self.post_counter_setup() - self.env_time: datetime = create_init_env_time() """the internal time of the environment. An environment starts always with the time from `create_init_env_time`.""" @@ -329,6 +322,8 @@ class Environment: self.kitchen_width -= 0.5 + self.counter_factory.post_counter_setup(counters) + return counters, designated_player_positions, free_positions def perform_action(self, action: Action): @@ -361,21 +356,6 @@ class Environment: if player.last_interacted_counter: player.perform_interact_hold_stop(player.last_interacted_counter) - def get_closest_counter(self, point: np.ndarray): - """Determines the closest counter for a given 2d-coordinate point in the env. - - Args: - point: The point in the env for which to find the closest counter - - Returns: The closest counter for the given point. - """ - counter_distances = distance_matrix( - [point], [counter.pos for counter in self.counters] - )[0] - - closest_counter_idx = np.argmin(counter_distances) - return self.counters[closest_counter_idx] - def get_facing_counter(self, player: Player): """Determines the counter which the player is looking at. Adds a multiple of the player facing direction onto the player position and finds the closest @@ -387,7 +367,7 @@ class Environment: Returns: """ - facing_counter = self.get_closest_counter(player.facing_point) + facing_counter = get_closest(player.facing_point, self.counters) return facing_counter def perform_movement(self, player: Player, duration: timedelta): @@ -637,47 +617,6 @@ class Environment: assert StateRepresentation.model_validate_json(json_data=json_data) return json_data - def post_counter_setup(self): - """Initialize the counters in the environment. - - Connect the `ServingWindow`(s) with the `PlateDispenser`. - Find and connect the `SinkAddon`s with the `Sink`s - """ - plate_dispenser = self.get_counter_of_type(PlateDispenser) - assert len(plate_dispenser) > 0, "No Plate Dispenser in the environment" - - sink_addons = self.get_counter_of_type(SinkAddon) - - for counter in self.counters: - match counter: - case ServingWindow(): - counter: ServingWindow # Pycharm type checker does now work for match statements? - counter.add_plate_dispenser(plate_dispenser[0]) - case Sink(pos=pos): - counter: Sink # Pycharm type checker does now work for match statements? - assert len(sink_addons) > 0, "No SinkAddon but normal Sink" - closest_addon = self.get_closest(pos, sink_addons) - assert 1 - (1 * 0.05) <= np.linalg.norm( - closest_addon.pos - pos - ), f"No SinkAddon connected to Sink at pos {pos}" - counter.set_addon(closest_addon) - - @staticmethod - def get_closest(pos: npt.NDArray[float], counters: list[Counter]): - """Find the closest counter for a position - - Args: - pos: the position to find the closest one from. Needs to be the same shape as the `Counter.pos` array. - counters: target to find the closest one. - """ - return min(counters, key=lambda c: np.linalg.norm(c.pos - pos)) - - def get_counter_of_type(self, counter_type) -> list[Counter]: - """Filter all counters in the environment for a counter type.""" - return list( - filter(lambda counter: isinstance(counter, counter_type), self.counters) - ) - def reset_env_time(self): """Reset the env time to the initial time, defined by `create_init_env_time`.""" self.env_time = create_init_env_time() diff --git a/overcooked_simulator/player.py b/overcooked_simulator/player.py index 0f5f25424210d595d8290c61231d641b16d934a7..4c7257ddb13a1d7e2ca4f28ca8358659ce10643b 100644 --- a/overcooked_simulator/player.py +++ b/overcooked_simulator/player.py @@ -78,7 +78,7 @@ class Player: """A point on the "circle" of the players border in the `facing_direction` with which the closest counter is calculated with.""" - self.current_movement: npt.NDArray[2] = np.zeros(2, float) + self.current_movement: npt.NDArray[float] = np.zeros(2, float) self.movement_until: datetime.datetime = datetime.datetime.min def set_movement(self, move_vector, move_until): diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py index ecfb5958a982c084ce46db849fa9568483fea3e5..766abc3faeaf41582544c570d970ca416984e265 100644 --- a/overcooked_simulator/utils.py +++ b/overcooked_simulator/utils.py @@ -4,7 +4,12 @@ import sys from datetime import datetime from enum import Enum +import numpy as np +import numpy.typing as npt +from scipy.spatial import distance_matrix + from overcooked_simulator import ROOT_DIR +from overcooked_simulator.counters import Counter def create_init_env_time(): @@ -14,6 +19,21 @@ def create_init_env_time(): ) +def get_closest(point: npt.NDArray[float], counters: list[Counter]): + """Determines the closest counter for a given 2d-coordinate point in the env. + + Args: + point: The point in the env for which to find the closest counter + counters: List of objects with a `pos` attribute to compare to. + + Returns: The closest counter for the given point. + """ + + return counters[ + np.argmin(distance_matrix([point], [counter.pos for counter in counters])[0]) + ] + + def custom_asdict_factory(data): def convert_value(obj): if isinstance(obj, Enum):