From 3606ecfd884ce7d7aee581fe2d6d74073f4301dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20Schr=C3=B6der?= <fschroeder@techfak.uni-bielefeld.de> Date: Fri, 26 Jan 2024 21:19:09 +0100 Subject: [PATCH] Refactor simulator environment and add counter factory This commit simplifies the overcooked environment by refactoring the way counters are created and managed. A new CounterFactory class is introduced, offloading logic from the environment class. In addition, the symbol to character mapping and other environment-related configurations are moved to a separate YAML file. The .gitignore file is also updated to ignore the 'playground' directory. Making these changes enhances code maintainability and readability. --- .gitignore | 2 + overcooked_simulator/counter_factory.py | 173 ++++++++++++++++ overcooked_simulator/counters.py | 14 +- .../game_content/environment_config.yaml | 25 +++ overcooked_simulator/game_server.py | 7 +- .../overcooked_environment.py | 185 ++++-------------- 6 files changed, 251 insertions(+), 155 deletions(-) create mode 100644 overcooked_simulator/counter_factory.py diff --git a/.gitignore b/.gitignore index 3c384f42..eb6bba84 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ # Created by https://www.toptal.com/developers/gitignore/api/python,intellij,visualstudiocode,pycharm,git,flask,django,docusaurus,ros,ros2,linux,macos,windows # Edit at https://www.toptal.com/developers/gitignore?templates=python,intellij,visualstudiocode,pycharm,git,flask,django,docusaurus,ros,ros2,linux,macos,windows +playground + ### Django ### *.log *.pot diff --git a/overcooked_simulator/counter_factory.py b/overcooked_simulator/counter_factory.py new file mode 100644 index 00000000..caa9fb23 --- /dev/null +++ b/overcooked_simulator/counter_factory.py @@ -0,0 +1,173 @@ +import inspect +import sys +from typing import Any + +import numpy.typing as npt + +from overcooked_simulator.counters import ( + Counter, + CookingCounter, + Dispenser, + ServingWindow, + CuttingBoard, + PlateDispenser, + Sink, + PlateConfig, +) +from overcooked_simulator.game_items import ItemInfo, ItemType, CookingEquipment + + +def convert_words_to_chars(layout_chars_config: dict[str, str]) -> dict[str, str]: + word_refs = { + "underscore": "_", + "hash": "#", + "space": " ", + "dot": ".", + "comma": ",", + "semicolon": ";", + "colon": ":", + "plus": "+", + "minus": "-", + "exclamation": "!", + "question": "?", + "dquote": '"', + "squote": "'", + "star": "*", + "dollar": "$", + "euro": "€", + "ampersand": "&", + "slash": "/", + "oparentheses": "(", + "cparentheses": ")", + "equal": "=", + "right": ">", + "left": "<", + "pipe": "|", + "at": "@", + "top": "^", + "tilde": "~", + } + return {word_refs.get(c, c): name for c, name in layout_chars_config.items()} + + +class CounterFactory: + additional_counter_names = {"Counter"} + + def __init__( + self, + layout_chars_config: dict[str, str], + item_info: dict[str, ItemInfo], + serving_window_additional_kwargs: dict[str, Any], + plate_config: PlateConfig, + ) -> None: + self.layout_chars_config = convert_words_to_chars(layout_chars_config) + self.item_info = item_info + self.serving_window_additional_kwargs = serving_window_additional_kwargs + self.plate_config = plate_config + + self.no_counter_chars = set( + c + for c, name in self.layout_chars_config.items() + if name in ["Agent", "Free"] + ) + + self.counter_classes = dict( + inspect.getmembers( + sys.modules["overcooked_simulator.counters"], inspect.isclass + ) + ) + + self.cooking_counter_equipments = { + cooking_counter: [ + equipment + for equipment, e_info in self.item_info.items() + if e_info.equipment and e_info.equipment.name == cooking_counter + ] + for cooking_counter, info in self.item_info.items() + if info.type == ItemType.Equipment and info.equipment is None + } + + def get_counter_object(self, c: str, pos: npt.NDArray[float]) -> Counter: + assert self.can_map(c), f"Can't map counter char {c}" + counter_class = None + if self.layout_chars_config[c] in self.item_info: + item_info = self.item_info[self.layout_chars_config[c]] + if item_info.type == ItemType.Equipment and item_info.equipment: + if item_info.equipment.name in self.counter_classes: + counter_class = self.counter_classes[item_info.equipment.name] + else: + return CookingCounter( + name=item_info.equipment.name, + cooking_counter_equipments=self.cooking_counter_equipments, + pos=pos, + occupied_by=CookingEquipment( + name=item_info.name, + item_info=item_info, + transitions=self.filter_item_info( + by_equipment_name=item_info.name + ), + ), + ) + elif item_info.type == ItemType.Ingredient: + return Dispenser(pos=pos, dispensing=item_info) + + if counter_class is None: + counter_class = self.counter_classes[self.layout_chars_config[c]] + kwargs = { + "pos": pos, + } + if counter_class.__name__ in [CuttingBoard.__name__, Sink.__name__]: + kwargs["transitions"] = self.filter_item_info( + by_equipment_name=counter_class.__name__ + ) + elif counter_class.__name__ == PlateDispenser.__name__: + kwargs.update( + { + "plate_transitions": self.filter_item_info( + by_item_type=ItemType.Meal + ), + "plate_config": self.plate_config, + "dispensing": self.item_info["Plate"], + } + ) + elif counter_class.__name__ == ServingWindow.__name__: + kwargs.update(self.serving_window_additional_kwargs) + + return counter_class(**kwargs) + + def can_map(self, char) -> bool: + return char in self.layout_chars_config and ( + not self.is_counter(char) + or self.layout_chars_config[char] in self.item_info + or self.layout_chars_config[char] in self.counter_classes + ) + + def is_counter(self, c: str) -> bool: + return c in self.layout_chars_config and c not in self.no_counter_chars + + def map_not_counter(self, c: str) -> str: + assert self.can_map(c) and not self.is_counter( + c + ), "Cannot map char {c} as a 'not counter'" + return self.layout_chars_config[c] + + def filter_item_info( + self, + by_item_type: ItemType = None, + by_equipment_name: str = None, + ) -> dict[str, ItemInfo]: + """Filter the item info dict by item type or equipment name""" + if by_item_type is not None: + return { + name: info + for name, info in self.item_info.items() + if info.type == by_item_type + } + if by_equipment_name is not None: + return { + name: info + for name, info in self.item_info.items() + if info.equipment is not None + and info.equipment.name == by_equipment_name + } + return self.item_info diff --git a/overcooked_simulator/counters.py b/overcooked_simulator/counters.py index 75152544..fa278bb2 100644 --- a/overcooked_simulator/counters.py +++ b/overcooked_simulator/counters.py @@ -73,6 +73,7 @@ class Counter: pos: npt.NDArray[float], occupied_by: Optional[Item] = None, uid: hex = None, + **kwargs, ): """Constructor setting the arguments as attributes. @@ -178,13 +179,13 @@ class CuttingBoard(Counter): The character `C` in the `layout` file represents the CuttingBoard. """ - def __init__(self, pos: np.ndarray, transitions: dict[str, ItemInfo]): + def __init__(self, pos: np.ndarray, transitions: dict[str, ItemInfo], **kwargs): self.progressing = False self.transitions = transitions self.inverted_transition_dict = { info.needs[0]: info for name, info in self.transitions.items() } - super().__init__(pos=pos) + super().__init__(pos=pos, **kwargs) def progress(self, passed_time: timedelta, now: datetime): """Called by environment step function for time progression. @@ -256,12 +257,13 @@ class ServingWindow(Counter): meals: set[str], env_time_func: Callable[[], datetime], plate_dispenser: PlateDispenser = None, + **kwargs, ): self.order_and_score = order_and_score self.plate_dispenser = plate_dispenser self.meals = meals self.env_time_func = env_time_func - super().__init__(pos=pos) + super().__init__(pos=pos, **kwargs) def drop_off(self, item) -> Item | None: env_time = self.env_time_func() @@ -303,11 +305,12 @@ class Dispenser(Counter): Which also is easier for the visualization of the dispenser. """ - def __init__(self, pos: npt.NDArray[float], dispensing: ItemInfo): + def __init__(self, pos: npt.NDArray[float], dispensing: ItemInfo, **kwargs): self.dispensing = dispensing super().__init__( pos=pos, occupied_by=self.create_item(), + **kwargs, ) def pick_up(self, on_hands: bool = True) -> Item | None: @@ -548,8 +551,9 @@ class Sink(Counter): pos: npt.NDArray[float], transitions: dict[str, ItemInfo], sink_addon: SinkAddon = None, + **kwargs, ): - super().__init__(pos=pos) + super().__init__(pos=pos, **kwargs) self.progressing = False self.sink_addon: SinkAddon = sink_addon """The connected sink addon which will receive the clean plates""" diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml index cad2888d..3cfabc29 100644 --- a/overcooked_simulator/game_content/environment_config.yaml +++ b/overcooked_simulator/game_content/environment_config.yaml @@ -16,6 +16,31 @@ meals: - OnionSoup - Salad +layout_chars: + underscore: Free + hash: Counter + A: Agent + P: PlateDispenser + C: CuttingBoard + X: Trashcan + W: ServingWindow + S: Sink + plus: SinkAddon + U: Pot # with Stove + Q: Pan # with Stove + O: Peel # with Oven + F: Basket # with DeepFryer + T: Tomato + N: Onion # oNioN + L: Lettuce + K: Potato # Kartoffel + I: Fish # fIIIsh + D: Dough + E: Cheese # chEEEse + G: Sausage # sausaGe + B: Bun + M: Meat + orders: order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration '' # the class to that receives the kwargs. Should be a child class of OrderGeneration in order.py diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py index 35db3885..65862c17 100644 --- a/overcooked_simulator/game_server.py +++ b/overcooked_simulator/game_server.py @@ -35,7 +35,11 @@ from overcooked_simulator.server_results import ( PlayerInfo, PlayerRequestResult, ) -from overcooked_simulator.utils import setup_logging, url_and_port_arguments +from overcooked_simulator.utils import ( + setup_logging, + url_and_port_arguments, + disable_websocket_logging_arguments, +) log = logging.getLogger(__name__) @@ -728,6 +732,7 @@ if __name__ == "__main__": ) url_and_port_arguments(parser) + disable_websocket_logging_arguments(parser) args = parser.parse_args() setup_logging(args.enable_websocket_logging) main(args.url, args.port) diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index 90694f30..c023fa6e 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -14,11 +14,10 @@ import numpy.typing as npt import yaml from scipy.spatial import distance_matrix +from overcooked_simulator.counter_factory import CounterFactory from overcooked_simulator.counters import ( Counter, CuttingBoard, - Trashcan, - Dispenser, ServingWindow, CookingCounter, Sink, @@ -29,7 +28,6 @@ from overcooked_simulator.counters import ( from overcooked_simulator.game_items import ( ItemInfo, ItemType, - CookingEquipment, ) from overcooked_simulator.order import OrderAndScoreManager from overcooked_simulator.player import Player, PlayerConfig @@ -114,7 +112,7 @@ class Environment: self.layout_config = layout_config # self.counter_side_length = 1 # -> this changed! is 1 now - self.item_info = self.load_item_info(item_info) + self.item_info: dict[str, ItemInfo] = self.load_item_info(item_info) """The loaded item info dict. Keys are the item names.""" # self.validate_item_info() if self.environment_config["meals"]["all"]: @@ -139,128 +137,35 @@ class Environment: ) """The manager for the orders and score update.""" - cooking_counter_equipments = { - cooking_counter: [ - equipment - for equipment, e_info in self.item_info.items() - if e_info.equipment and e_info.equipment.name == cooking_counter - ] - for cooking_counter, info in self.item_info.items() - if info.type == ItemType.Equipment and info.equipment is None - } - - self.SYMBOL_TO_CHARACTER_MAP = { - "#": Counter, - "C": lambda pos: CuttingBoard( - pos=pos, - transitions=self.filter_item_info( - self.item_info, by_equipment_name="CuttingBoard" - ), - ), - "X": Trashcan, - "W": lambda pos: ServingWindow( - pos, - self.order_and_score, - meals=self.allowed_meal_names, - env_time_func=self.get_env_time, - ), - "T": lambda pos: Dispenser(pos, self.item_info["Tomato"]), - "L": lambda pos: Dispenser(pos, self.item_info["Lettuce"]), - "K": lambda pos: Dispenser(pos, self.item_info["Potato"]), # Kartoffel - "I": lambda pos: Dispenser(pos, self.item_info["Fish"]), # fIIIsh - "D": lambda pos: Dispenser(pos, self.item_info["Dough"]), - "E": lambda pos: Dispenser(pos, self.item_info["Cheese"]), # chEEEEse - "G": lambda pos: Dispenser(pos, self.item_info["Sausage"]), # sausaGe - "P": lambda pos: PlateDispenser( - plate_transitions=self.filter_item_info( - item_info=self.item_info, by_item_type=ItemType.Meal - ), - pos=pos, - dispensing=self.item_info["Plate"], - plate_config=PlateConfig( - **( - self.environment_config["plates"] - if "plates" in self.environment_config - else {} - ) - ), - ), - "N": lambda pos: Dispenser(pos, self.item_info["Onion"]), # N for oNioN - "_": "Free", - "A": "Agent", - "U": lambda pos: CookingCounter( - name="Stove", - cooking_counter_equipments=cooking_counter_equipments, - pos=pos, - occupied_by=CookingEquipment( - name="Pot", - item_info=self.item_info["Pot"], - transitions=self.filter_item_info( - self.item_info, by_equipment_name="Pot" - ), - ), - ), # Stove with pot: U because it looks like a pot - "Q": lambda pos: CookingCounter( - name="Stove", - cooking_counter_equipments=cooking_counter_equipments, - pos=pos, - occupied_by=CookingEquipment( - name="Pan", - item_info=self.item_info["Pan"], - transitions=self.filter_item_info( - self.item_info, by_equipment_name="Pan" - ), - ), - ), # Stove with pan: Q because it looks like a pan - "O": lambda pos: CookingCounter( - name="Oven", - cooking_counter_equipments=cooking_counter_equipments, - pos=pos, - occupied_by=CookingEquipment( - name="Peel", - item_info=self.item_info["Peel"], - transitions=self.filter_item_info( - self.item_info, by_equipment_name="Peel" - ), - ), - ), - "F": lambda pos: CookingCounter( - name="DeepFryer", - cooking_counter_equipments=cooking_counter_equipments, - pos=pos, - occupied_by=CookingEquipment( - name="Basket", - item_info=self.item_info["Basket"], - transitions=self.filter_item_info( - self.item_info, by_equipment_name="Basket" - ), - ), - ), # Stove with pan: Q because it looks like a pan - "B": lambda pos: Dispenser(pos, self.item_info["Bun"]), - "M": lambda pos: Dispenser(pos, self.item_info["Meat"]), - "S": lambda pos: Sink( - pos, - transitions=self.filter_item_info( - item_info=self.item_info, by_equipment_name="Sink" - ), - ), - "+": SinkAddon, - } - """Map of the characters in the layout file to callables returning the object/counter. In the future, - maybe replaced with a factory and the characters defined elsewhere in an config.""" - self.kitchen_height: int = 0 """The height of the kitchen, is set by the `Environment.parse_layout_file` method""" self.kitchen_width: int = 0 """The width of the kitchen, is set by the `Environment.parse_layout_file` method""" + self.counter_factory = CounterFactory( + layout_chars_config=self.environment_config["layout_chars"], + item_info=self.item_info, + serving_window_additional_kwargs={ + "order_and_score": self.order_and_score, + "meals": self.allowed_meal_names, + "env_time_func": self.get_env_time, + }, + plate_config=PlateConfig( + **( + self.environment_config["plates"] + if "plates" in self.environment_config + else {} + ) + ), + ) + ( self.counters, self.designated_player_positions, self.free_positions, ) = self.parse_layout_file() - self.init_counters() + self.post_counter_setup() self.env_time: datetime = create_init_env_time() """the internal time of the environment. An environment starts always with the time from @@ -379,6 +284,7 @@ class Environment: else: lines = self.layout_config.split("\n") self.kitchen_height = len(lines) + print(self.kitchen_height) for line in lines: line = line.replace("\n", "").replace(" ", "") # remove newline char @@ -386,17 +292,20 @@ class Environment: for character in line: character = character.capitalize() pos = np.array([current_x, current_y]) - counter_class = self.SYMBOL_TO_CHARACTER_MAP[character] - if not isinstance(counter_class, str): - counter = counter_class(pos) - counters.append(counter) + assert self.counter_factory.can_map( + character + ), f"{character=} in layout file can not be mapped" + if self.counter_factory.is_counter(character): + counters.append( + self.counter_factory.get_counter_object(character, pos) + ) else: - if counter_class == "Agent": - designated_player_positions.append( - np.array([current_x, current_y]) - ) - elif counter_class == "Free": - free_positions.append(np.array([current_x, current_y])) + match self.counter_factory.map_not_counter(character): + case "Agent": + designated_player_positions.append(pos) + case "Free": + free_positions.append(np.array([current_x, current_y])) + current_x += 1 if current_x > self.kitchen_width: self.kitchen_width = current_x @@ -715,14 +624,14 @@ class Environment: assert StateRepresentation.model_validate_json(json_data=json_data) return json_data - def init_counters(self): + def post_counter_setup(self): """Initialize the counters in the environment. Connect the `ServingWindow`(s) with the `PlateDispenser`. Find and connect the `SinkAddon`s with the `Sink`s """ plate_dispenser = self.get_counter_of_type(PlateDispenser) - assert len(plate_dispenser) > 0, "No Plate Return in the environment" + assert len(plate_dispenser) > 0, "No Plate Dispenser in the environment" sink_addons = self.get_counter_of_type(SinkAddon) @@ -760,25 +669,3 @@ class Environment: """Reset the env time to the initial time, defined by `create_init_env_time`.""" self.env_time = create_init_env_time() log.debug(f"Reset env time to {self.env_time}") - - @staticmethod - def filter_item_info( - item_info: dict[str, ItemInfo], - by_item_type: ItemType = None, - by_equipment_name: str = None, - ) -> dict[str, ItemInfo]: - """Filter the item info dict by item type or equipment name""" - if by_item_type is not None: - return { - name: info - for name, info in item_info.items() - if info.type == by_item_type - } - if by_equipment_name is not None: - return { - name: info - for name, info in item_info.items() - if info.equipment is not None - and info.equipment.name == by_equipment_name - } - return item_info -- GitLab