From 3606ecfd884ce7d7aee581fe2d6d74073f4301dc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Fri, 26 Jan 2024 21:19:09 +0100
Subject: [PATCH] Refactor simulator environment and add counter factory

This commit simplifies the overcooked environment by refactoring the way counters are created and managed. A new CounterFactory class is introduced, offloading logic from the environment class. In addition, the symbol to character mapping and other environment-related configurations are moved to a separate YAML file. The .gitignore file is also updated to ignore the 'playground' directory. Making these changes enhances code maintainability and readability.
---
 .gitignore                                    |   2 +
 overcooked_simulator/counter_factory.py       | 173 ++++++++++++++++
 overcooked_simulator/counters.py              |  14 +-
 .../game_content/environment_config.yaml      |  25 +++
 overcooked_simulator/game_server.py           |   7 +-
 .../overcooked_environment.py                 | 185 ++++--------------
 6 files changed, 251 insertions(+), 155 deletions(-)
 create mode 100644 overcooked_simulator/counter_factory.py

diff --git a/.gitignore b/.gitignore
index 3c384f42..eb6bba84 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,8 @@
 # Created by https://www.toptal.com/developers/gitignore/api/python,intellij,visualstudiocode,pycharm,git,flask,django,docusaurus,ros,ros2,linux,macos,windows
 # Edit at https://www.toptal.com/developers/gitignore?templates=python,intellij,visualstudiocode,pycharm,git,flask,django,docusaurus,ros,ros2,linux,macos,windows
 
+playground
+
 ### Django ###
 *.log
 *.pot
diff --git a/overcooked_simulator/counter_factory.py b/overcooked_simulator/counter_factory.py
new file mode 100644
index 00000000..caa9fb23
--- /dev/null
+++ b/overcooked_simulator/counter_factory.py
@@ -0,0 +1,173 @@
+import inspect
+import sys
+from typing import Any
+
+import numpy.typing as npt
+
+from overcooked_simulator.counters import (
+    Counter,
+    CookingCounter,
+    Dispenser,
+    ServingWindow,
+    CuttingBoard,
+    PlateDispenser,
+    Sink,
+    PlateConfig,
+)
+from overcooked_simulator.game_items import ItemInfo, ItemType, CookingEquipment
+
+
+def convert_words_to_chars(layout_chars_config: dict[str, str]) -> dict[str, str]:
+    word_refs = {
+        "underscore": "_",
+        "hash": "#",
+        "space": " ",
+        "dot": ".",
+        "comma": ",",
+        "semicolon": ";",
+        "colon": ":",
+        "plus": "+",
+        "minus": "-",
+        "exclamation": "!",
+        "question": "?",
+        "dquote": '"',
+        "squote": "'",
+        "star": "*",
+        "dollar": "$",
+        "euro": "€",
+        "ampersand": "&",
+        "slash": "/",
+        "oparentheses": "(",
+        "cparentheses": ")",
+        "equal": "=",
+        "right": ">",
+        "left": "<",
+        "pipe": "|",
+        "at": "@",
+        "top": "^",
+        "tilde": "~",
+    }
+    return {word_refs.get(c, c): name for c, name in layout_chars_config.items()}
+
+
+class CounterFactory:
+    additional_counter_names = {"Counter"}
+
+    def __init__(
+        self,
+        layout_chars_config: dict[str, str],
+        item_info: dict[str, ItemInfo],
+        serving_window_additional_kwargs: dict[str, Any],
+        plate_config: PlateConfig,
+    ) -> None:
+        self.layout_chars_config = convert_words_to_chars(layout_chars_config)
+        self.item_info = item_info
+        self.serving_window_additional_kwargs = serving_window_additional_kwargs
+        self.plate_config = plate_config
+
+        self.no_counter_chars = set(
+            c
+            for c, name in self.layout_chars_config.items()
+            if name in ["Agent", "Free"]
+        )
+
+        self.counter_classes = dict(
+            inspect.getmembers(
+                sys.modules["overcooked_simulator.counters"], inspect.isclass
+            )
+        )
+
+        self.cooking_counter_equipments = {
+            cooking_counter: [
+                equipment
+                for equipment, e_info in self.item_info.items()
+                if e_info.equipment and e_info.equipment.name == cooking_counter
+            ]
+            for cooking_counter, info in self.item_info.items()
+            if info.type == ItemType.Equipment and info.equipment is None
+        }
+
+    def get_counter_object(self, c: str, pos: npt.NDArray[float]) -> Counter:
+        assert self.can_map(c), f"Can't map counter char {c}"
+        counter_class = None
+        if self.layout_chars_config[c] in self.item_info:
+            item_info = self.item_info[self.layout_chars_config[c]]
+            if item_info.type == ItemType.Equipment and item_info.equipment:
+                if item_info.equipment.name in self.counter_classes:
+                    counter_class = self.counter_classes[item_info.equipment.name]
+                else:
+                    return CookingCounter(
+                        name=item_info.equipment.name,
+                        cooking_counter_equipments=self.cooking_counter_equipments,
+                        pos=pos,
+                        occupied_by=CookingEquipment(
+                            name=item_info.name,
+                            item_info=item_info,
+                            transitions=self.filter_item_info(
+                                by_equipment_name=item_info.name
+                            ),
+                        ),
+                    )
+            elif item_info.type == ItemType.Ingredient:
+                return Dispenser(pos=pos, dispensing=item_info)
+
+        if counter_class is None:
+            counter_class = self.counter_classes[self.layout_chars_config[c]]
+        kwargs = {
+            "pos": pos,
+        }
+        if counter_class.__name__ in [CuttingBoard.__name__, Sink.__name__]:
+            kwargs["transitions"] = self.filter_item_info(
+                by_equipment_name=counter_class.__name__
+            )
+        elif counter_class.__name__ == PlateDispenser.__name__:
+            kwargs.update(
+                {
+                    "plate_transitions": self.filter_item_info(
+                        by_item_type=ItemType.Meal
+                    ),
+                    "plate_config": self.plate_config,
+                    "dispensing": self.item_info["Plate"],
+                }
+            )
+        elif counter_class.__name__ == ServingWindow.__name__:
+            kwargs.update(self.serving_window_additional_kwargs)
+
+        return counter_class(**kwargs)
+
+    def can_map(self, char) -> bool:
+        return char in self.layout_chars_config and (
+            not self.is_counter(char)
+            or self.layout_chars_config[char] in self.item_info
+            or self.layout_chars_config[char] in self.counter_classes
+        )
+
+    def is_counter(self, c: str) -> bool:
+        return c in self.layout_chars_config and c not in self.no_counter_chars
+
+    def map_not_counter(self, c: str) -> str:
+        assert self.can_map(c) and not self.is_counter(
+            c
+        ), "Cannot map char {c} as a 'not counter'"
+        return self.layout_chars_config[c]
+
+    def filter_item_info(
+        self,
+        by_item_type: ItemType = None,
+        by_equipment_name: str = None,
+    ) -> dict[str, ItemInfo]:
+        """Filter the item info dict by item type or equipment name"""
+        if by_item_type is not None:
+            return {
+                name: info
+                for name, info in self.item_info.items()
+                if info.type == by_item_type
+            }
+        if by_equipment_name is not None:
+            return {
+                name: info
+                for name, info in self.item_info.items()
+                if info.equipment is not None
+                and info.equipment.name == by_equipment_name
+            }
+        return self.item_info
diff --git a/overcooked_simulator/counters.py b/overcooked_simulator/counters.py
index 75152544..fa278bb2 100644
--- a/overcooked_simulator/counters.py
+++ b/overcooked_simulator/counters.py
@@ -73,6 +73,7 @@ class Counter:
         pos: npt.NDArray[float],
         occupied_by: Optional[Item] = None,
         uid: hex = None,
+        **kwargs,
     ):
         """Constructor setting the arguments as attributes.
 
@@ -178,13 +179,13 @@ class CuttingBoard(Counter):
     The character `C` in the `layout` file represents the CuttingBoard.
     """
 
-    def __init__(self, pos: np.ndarray, transitions: dict[str, ItemInfo]):
+    def __init__(self, pos: np.ndarray, transitions: dict[str, ItemInfo], **kwargs):
         self.progressing = False
         self.transitions = transitions
         self.inverted_transition_dict = {
             info.needs[0]: info for name, info in self.transitions.items()
         }
-        super().__init__(pos=pos)
+        super().__init__(pos=pos, **kwargs)
 
     def progress(self, passed_time: timedelta, now: datetime):
         """Called by environment step function for time progression.
@@ -256,12 +257,13 @@ class ServingWindow(Counter):
         meals: set[str],
         env_time_func: Callable[[], datetime],
         plate_dispenser: PlateDispenser = None,
+        **kwargs,
     ):
         self.order_and_score = order_and_score
         self.plate_dispenser = plate_dispenser
         self.meals = meals
         self.env_time_func = env_time_func
-        super().__init__(pos=pos)
+        super().__init__(pos=pos, **kwargs)
 
     def drop_off(self, item) -> Item | None:
         env_time = self.env_time_func()
@@ -303,11 +305,12 @@ class Dispenser(Counter):
     Which also is easier for the visualization of the dispenser.
     """
 
-    def __init__(self, pos: npt.NDArray[float], dispensing: ItemInfo):
+    def __init__(self, pos: npt.NDArray[float], dispensing: ItemInfo, **kwargs):
         self.dispensing = dispensing
         super().__init__(
             pos=pos,
             occupied_by=self.create_item(),
+            **kwargs,
         )
 
     def pick_up(self, on_hands: bool = True) -> Item | None:
@@ -548,8 +551,9 @@ class Sink(Counter):
         pos: npt.NDArray[float],
         transitions: dict[str, ItemInfo],
         sink_addon: SinkAddon = None,
+        **kwargs,
     ):
-        super().__init__(pos=pos)
+        super().__init__(pos=pos, **kwargs)
         self.progressing = False
         self.sink_addon: SinkAddon = sink_addon
         """The connected sink addon which will receive the clean plates"""
diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml
index cad2888d..3cfabc29 100644
--- a/overcooked_simulator/game_content/environment_config.yaml
+++ b/overcooked_simulator/game_content/environment_config.yaml
@@ -16,6 +16,31 @@ meals:
     - OnionSoup
     - Salad
 
+layout_chars:
+  underscore: Free
+  hash: Counter
+  A: Agent
+  P: PlateDispenser
+  C: CuttingBoard
+  X: Trashcan
+  W: ServingWindow
+  S: Sink
+  plus: SinkAddon
+  U: Pot  # with Stove
+  Q: Pan  # with Stove
+  O: Peel  # with Oven
+  F: Basket  # with DeepFryer
+  T: Tomato
+  N: Onion  # oNioN
+  L: Lettuce
+  K: Potato  # Kartoffel
+  I: Fish  # fIIIsh
+  D: Dough
+  E: Cheese  # chEEEse
+  G: Sausage  # sausaGe
+  B: Bun
+  M: Meat
+
 orders:
   order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration ''
   # the class to that receives the kwargs. Should be a child class of OrderGeneration in order.py
diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py
index 35db3885..65862c17 100644
--- a/overcooked_simulator/game_server.py
+++ b/overcooked_simulator/game_server.py
@@ -35,7 +35,11 @@ from overcooked_simulator.server_results import (
     PlayerInfo,
     PlayerRequestResult,
 )
-from overcooked_simulator.utils import setup_logging, url_and_port_arguments
+from overcooked_simulator.utils import (
+    setup_logging,
+    url_and_port_arguments,
+    disable_websocket_logging_arguments,
+)
 
 log = logging.getLogger(__name__)
 
@@ -728,6 +732,7 @@ if __name__ == "__main__":
     )
 
     url_and_port_arguments(parser)
+    disable_websocket_logging_arguments(parser)
     args = parser.parse_args()
     setup_logging(args.enable_websocket_logging)
     main(args.url, args.port)
diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py
index 90694f30..c023fa6e 100644
--- a/overcooked_simulator/overcooked_environment.py
+++ b/overcooked_simulator/overcooked_environment.py
@@ -14,11 +14,10 @@ import numpy.typing as npt
 import yaml
 from scipy.spatial import distance_matrix
 
+from overcooked_simulator.counter_factory import CounterFactory
 from overcooked_simulator.counters import (
     Counter,
     CuttingBoard,
-    Trashcan,
-    Dispenser,
     ServingWindow,
     CookingCounter,
     Sink,
@@ -29,7 +28,6 @@ from overcooked_simulator.counters import (
 from overcooked_simulator.game_items import (
     ItemInfo,
     ItemType,
-    CookingEquipment,
 )
 from overcooked_simulator.order import OrderAndScoreManager
 from overcooked_simulator.player import Player, PlayerConfig
@@ -114,7 +112,7 @@ class Environment:
         self.layout_config = layout_config
         # self.counter_side_length = 1  # -> this changed! is 1 now
 
-        self.item_info = self.load_item_info(item_info)
+        self.item_info: dict[str, ItemInfo] = self.load_item_info(item_info)
         """The loaded item info dict. Keys are the item names."""
         # self.validate_item_info()
         if self.environment_config["meals"]["all"]:
@@ -139,128 +137,35 @@ class Environment:
         )
         """The manager for the orders and score update."""
 
-        cooking_counter_equipments = {
-            cooking_counter: [
-                equipment
-                for equipment, e_info in self.item_info.items()
-                if e_info.equipment and e_info.equipment.name == cooking_counter
-            ]
-            for cooking_counter, info in self.item_info.items()
-            if info.type == ItemType.Equipment and info.equipment is None
-        }
-
-        self.SYMBOL_TO_CHARACTER_MAP = {
-            "#": Counter,
-            "C": lambda pos: CuttingBoard(
-                pos=pos,
-                transitions=self.filter_item_info(
-                    self.item_info, by_equipment_name="CuttingBoard"
-                ),
-            ),
-            "X": Trashcan,
-            "W": lambda pos: ServingWindow(
-                pos,
-                self.order_and_score,
-                meals=self.allowed_meal_names,
-                env_time_func=self.get_env_time,
-            ),
-            "T": lambda pos: Dispenser(pos, self.item_info["Tomato"]),
-            "L": lambda pos: Dispenser(pos, self.item_info["Lettuce"]),
-            "K": lambda pos: Dispenser(pos, self.item_info["Potato"]),  # Kartoffel
-            "I": lambda pos: Dispenser(pos, self.item_info["Fish"]),  # fIIIsh
-            "D": lambda pos: Dispenser(pos, self.item_info["Dough"]),
-            "E": lambda pos: Dispenser(pos, self.item_info["Cheese"]),  # chEEEEse
-            "G": lambda pos: Dispenser(pos, self.item_info["Sausage"]),  # sausaGe
-            "P": lambda pos: PlateDispenser(
-                plate_transitions=self.filter_item_info(
-                    item_info=self.item_info, by_item_type=ItemType.Meal
-                ),
-                pos=pos,
-                dispensing=self.item_info["Plate"],
-                plate_config=PlateConfig(
-                    **(
-                        self.environment_config["plates"]
-                        if "plates" in self.environment_config
-                        else {}
-                    )
-                ),
-            ),
-            "N": lambda pos: Dispenser(pos, self.item_info["Onion"]),  # N for oNioN
-            "_": "Free",
-            "A": "Agent",
-            "U": lambda pos: CookingCounter(
-                name="Stove",
-                cooking_counter_equipments=cooking_counter_equipments,
-                pos=pos,
-                occupied_by=CookingEquipment(
-                    name="Pot",
-                    item_info=self.item_info["Pot"],
-                    transitions=self.filter_item_info(
-                        self.item_info, by_equipment_name="Pot"
-                    ),
-                ),
-            ),  # Stove with pot: U because it looks like a pot
-            "Q": lambda pos: CookingCounter(
-                name="Stove",
-                cooking_counter_equipments=cooking_counter_equipments,
-                pos=pos,
-                occupied_by=CookingEquipment(
-                    name="Pan",
-                    item_info=self.item_info["Pan"],
-                    transitions=self.filter_item_info(
-                        self.item_info, by_equipment_name="Pan"
-                    ),
-                ),
-            ),  # Stove with pan: Q because it looks like a pan
-            "O": lambda pos: CookingCounter(
-                name="Oven",
-                cooking_counter_equipments=cooking_counter_equipments,
-                pos=pos,
-                occupied_by=CookingEquipment(
-                    name="Peel",
-                    item_info=self.item_info["Peel"],
-                    transitions=self.filter_item_info(
-                        self.item_info, by_equipment_name="Peel"
-                    ),
-                ),
-            ),
-            "F": lambda pos: CookingCounter(
-                name="DeepFryer",
-                cooking_counter_equipments=cooking_counter_equipments,
-                pos=pos,
-                occupied_by=CookingEquipment(
-                    name="Basket",
-                    item_info=self.item_info["Basket"],
-                    transitions=self.filter_item_info(
-                        self.item_info, by_equipment_name="Basket"
-                    ),
-                ),
-            ),  # Stove with pan: Q because it looks like a pan
-            "B": lambda pos: Dispenser(pos, self.item_info["Bun"]),
-            "M": lambda pos: Dispenser(pos, self.item_info["Meat"]),
-            "S": lambda pos: Sink(
-                pos,
-                transitions=self.filter_item_info(
-                    item_info=self.item_info, by_equipment_name="Sink"
-                ),
-            ),
-            "+": SinkAddon,
-        }
-        """Map of the characters in the layout file to callables returning the object/counter. In the future, 
-        maybe replaced with a factory and the characters defined elsewhere in an config."""
-
         self.kitchen_height: int = 0
         """The height of the kitchen, is set by the `Environment.parse_layout_file` method"""
         self.kitchen_width: int = 0
         """The width of the kitchen, is set by the `Environment.parse_layout_file` method"""
 
+        self.counter_factory = CounterFactory(
+            layout_chars_config=self.environment_config["layout_chars"],
+            item_info=self.item_info,
+            serving_window_additional_kwargs={
+                "order_and_score": self.order_and_score,
+                "meals": self.allowed_meal_names,
+                "env_time_func": self.get_env_time,
+            },
+            plate_config=PlateConfig(
+                **(
+                    self.environment_config["plates"]
+                    if "plates" in self.environment_config
+                    else {}
+                )
+            ),
+        )
+
         (
             self.counters,
             self.designated_player_positions,
             self.free_positions,
         ) = self.parse_layout_file()
 
-        self.init_counters()
+        self.post_counter_setup()
 
         self.env_time: datetime = create_init_env_time()
         """the internal time of the environment. An environment starts always with the time from 
@@ -379,6 +284,7 @@ class Environment:
         else:
             lines = self.layout_config.split("\n")
         self.kitchen_height = len(lines)
+        print(self.kitchen_height)
 
         for line in lines:
             line = line.replace("\n", "").replace(" ", "")  # remove newline char
@@ -386,17 +292,20 @@ class Environment:
             for character in line:
                 character = character.capitalize()
                 pos = np.array([current_x, current_y])
-                counter_class = self.SYMBOL_TO_CHARACTER_MAP[character]
-                if not isinstance(counter_class, str):
-                    counter = counter_class(pos)
-                    counters.append(counter)
+                assert self.counter_factory.can_map(
+                    character
+                ), f"{character=} in layout file can not be mapped"
+                if self.counter_factory.is_counter(character):
+                    counters.append(
+                        self.counter_factory.get_counter_object(character, pos)
+                    )
                 else:
-                    if counter_class == "Agent":
-                        designated_player_positions.append(
-                            np.array([current_x, current_y])
-                        )
-                    elif counter_class == "Free":
-                        free_positions.append(np.array([current_x, current_y]))
+                    match self.counter_factory.map_not_counter(character):
+                        case "Agent":
+                            designated_player_positions.append(pos)
+                        case "Free":
+                            free_positions.append(np.array([current_x, current_y]))
+
                 current_x += 1
                 if current_x > self.kitchen_width:
                     self.kitchen_width = current_x
@@ -715,14 +624,14 @@ class Environment:
         assert StateRepresentation.model_validate_json(json_data=json_data)
         return json_data
 
-    def init_counters(self):
+    def post_counter_setup(self):
         """Initialize the counters in the environment.
 
         Connect the `ServingWindow`(s) with the `PlateDispenser`.
         Find and connect the `SinkAddon`s with the `Sink`s
         """
         plate_dispenser = self.get_counter_of_type(PlateDispenser)
-        assert len(plate_dispenser) > 0, "No Plate Return in the environment"
+        assert len(plate_dispenser) > 0, "No Plate Dispenser in the environment"
 
         sink_addons = self.get_counter_of_type(SinkAddon)
 
@@ -760,25 +669,3 @@ class Environment:
         """Reset the env time to the initial time, defined by `create_init_env_time`."""
         self.env_time = create_init_env_time()
         log.debug(f"Reset env time to {self.env_time}")
-
-    @staticmethod
-    def filter_item_info(
-        item_info: dict[str, ItemInfo],
-        by_item_type: ItemType = None,
-        by_equipment_name: str = None,
-    ) -> dict[str, ItemInfo]:
-        """Filter the item info dict by item type or equipment name"""
-        if by_item_type is not None:
-            return {
-                name: info
-                for name, info in item_info.items()
-                if info.type == by_item_type
-            }
-        if by_equipment_name is not None:
-            return {
-                name: info
-                for name, info in item_info.items()
-                if info.equipment is not None
-                and info.equipment.name == by_equipment_name
-            }
-        return item_info
-- 
GitLab