From 4636abf568a3e6bc7579ef567ad6256a659f5edf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Sun, 4 Feb 2024 00:35:35 +0100
Subject: [PATCH] Implement additional logging and recording features

This update introduces comprehensive hook management, improved environment setup functions, and a new class for logging and recording simulations. The enhanced hook management ensures better callback registration and tracing of specific simulation events. Extra setup functions allow more flexibility in setting parameters from the environment config. The new logging and recording class makes it easier to capture simulation data in a structured format for analysis. Various minor fixes and code refactoring are also included.
---
 overcooked_simulator/__main__.py              |  2 +-
 overcooked_simulator/counters.py              |  2 +
 .../game_content/environment_config.yaml      | 33 ++++++++++
 overcooked_simulator/game_server.py           |  9 ++-
 .../gui_2d_vis/overcooked_gui.py              | 11 ++--
 overcooked_simulator/hooks.py                 |  6 ++
 overcooked_simulator/order.py                 | 17 +++--
 .../overcooked_environment.py                 | 57 ++++++++++------
 overcooked_simulator/recording.py             | 66 +++++++++++++++++++
 overcooked_simulator/utils.py                 | 27 +++++++-
 setup.py                                      |  1 +
 11 files changed, 197 insertions(+), 34 deletions(-)
 create mode 100644 overcooked_simulator/recording.py

diff --git a/overcooked_simulator/__main__.py b/overcooked_simulator/__main__.py
index 436201f2..f48180ac 100644
--- a/overcooked_simulator/__main__.py
+++ b/overcooked_simulator/__main__.py
@@ -50,7 +50,7 @@ def main(cli_args=None):
         print("Received Keyboard interrupt")
     finally:
         if game_server is not None and game_server.is_alive():
-            print("Terminate gparserame server")
+            print("Terminate game server")
             game_server.terminate()
         if pygame_gui is not None and pygame_gui.is_alive():
             print("Terminate pygame gui")
diff --git a/overcooked_simulator/counters.py b/overcooked_simulator/counters.py
index d4b72eec..4b417a2f 100644
--- a/overcooked_simulator/counters.py
+++ b/overcooked_simulator/counters.py
@@ -63,6 +63,7 @@ from overcooked_simulator.hooks import (
     ADDED_PLATE_TO_SINK,
     DROP_ON_SINK_ADDON,
     PICK_UP_FROM_SINK_ADDON,
+    PLATE_OUT_OF_KITCHEN_TIME,
 )
 
 if TYPE_CHECKING:
@@ -506,6 +507,7 @@ class PlateDispenser(Counter):
         self.out_of_kitchen_timer.append(time_plate_to_add)
         if time_plate_to_add < self.next_plate_time:
             self.next_plate_time = time_plate_to_add
+        self.hook(PLATE_OUT_OF_KITCHEN_TIME, time_plate_to_add=time_plate_to_add)
 
     def setup_plates(self):
         """Create plates based on the config. Clean and dirty ones."""
diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml
index c0de86d2..671900ea 100644
--- a/overcooked_simulator/game_content/environment_config.yaml
+++ b/overcooked_simulator/game_content/environment_config.yaml
@@ -85,3 +85,36 @@ player_config:
   radius: 0.4
   player_speed_units_per_seconds: 8
   interaction_range: 1.6
+
+extra_setup_functions:
+  #  json_states:
+  #    func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
+  #    kwargs:
+  #      hooks: [ json_state ]
+  #      log_class: !!python/name:overcooked_simulator.recording.LogRecorder ''
+  #      log_class_kwargs:
+  #        log_path: USER_LOG_DIR/ENV_NAME/json_states.jsonl
+  actions:
+    func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
+    kwargs:
+      hooks: [ pre_perform_action ]
+      log_class: !!python/name:overcooked_simulator.recording.LogRecorder ''
+      log_class_kwargs:
+        log_path: USER_LOG_DIR/ENV_NAME/actions.jsonl
+  random_env_events:
+    func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
+    kwargs:
+      hooks: [ order_duration_sample, plate_out_of_kitchen_time ]
+      log_class: !!python/name:overcooked_simulator.recording.LogRecorder ''
+      log_class_kwargs:
+        log_path: USER_LOG_DIR/ENV_NAME/random_env_events.jsonl
+        add_hook_ref: true
+  env_configs:
+    func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
+    kwargs:
+      hooks: [ env_initialized, item_info_config ]
+      log_class: !!python/name:overcooked_simulator.recording.LogRecorder ''
+      log_class_kwargs:
+        log_path: USER_LOG_DIR/ENV_NAME/env_configs.jsonl
+        add_hook_ref: true
+
diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py
index 1848a532..8e663b1f 100644
--- a/overcooked_simulator/game_server.py
+++ b/overcooked_simulator/game_server.py
@@ -125,6 +125,7 @@ class EnvironmentHandler:
             layout_config=environment_config.layout_config,
             item_info=environment_config.item_info_config,
             as_files=False,
+            env_name=env_id,
         )
         player_info = {}
         for player_id in range(environment_config.number_players):
@@ -724,7 +725,10 @@ async def websocket_player_endpoint(websocket: WebSocket, client_id: str):
         log.debug(f"Client #{client_id} disconnected")
 
 
-def main(host: str, port: int, manager_ids: list[str]):
+def main(
+    host: str, port: int, manager_ids: list[str], enable_websocket_logging: bool = False
+):
+    setup_logging(enable_websocket_logging)
     loop = asyncio.new_event_loop()
     asyncio.set_event_loop(loop)
     environment_handler.extend_allowed_manager(manager_ids)
@@ -745,8 +749,7 @@ if __name__ == "__main__":
     disable_websocket_logging_arguments(parser)
     add_list_of_manager_ids_arguments(parser)
     args = parser.parse_args()
-    setup_logging(args.enable_websocket_logging)
-    main(args.url, args.port, args.manager_ids)
+    main(args.url, args.port, args.manager_ids, args.enable_websocket_logging)
     """
     Or in console: 
     uvicorn overcooked_simulator.fastapi_game_server:app --reload
diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
index 7922bc2f..7248ef4e 100644
--- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py
+++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
@@ -24,10 +24,10 @@ from overcooked_simulator.overcooked_environment import (
 )
 from overcooked_simulator.utils import (
     custom_asdict_factory,
-    setup_logging,
     url_and_port_arguments,
     disable_websocket_logging_arguments,
     add_list_of_manager_ids_arguments,
+    setup_logging,
 )
 
 
@@ -692,8 +692,12 @@ class PyGameGUI:
         sys.exit()
 
 
-def main(url: str, port: int, manager_ids: list[str]):
+def main(
+    url: str, port: int, manager_ids: list[str], enable_websocket_logging: bool = False
+):
     # TODO maybe read the player names and keyboard keys from config file?
+    setup_logging(enable_websocket_logging)
+
     keys1 = [
         pygame.K_LEFT,
         pygame.K_RIGHT,
@@ -726,5 +730,4 @@ if __name__ == "__main__":
     disable_websocket_logging_arguments(parser)
     add_list_of_manager_ids_arguments(parser)
     args = parser.parse_args()
-    setup_logging(enable_websocket_logging=args.enable_websocket_logging)
-    main(args.url, args.port, args.manager_ids)
+    main(args.url, args.port, args.manager_ids, args.enable_websocket_logging)
diff --git a/overcooked_simulator/hooks.py b/overcooked_simulator/hooks.py
index 93391eda..285ed735 100644
--- a/overcooked_simulator/hooks.py
+++ b/overcooked_simulator/hooks.py
@@ -11,6 +11,8 @@ ITEM_INFO_LOADED = "item_info_load"
 LAYOUT_FILE_PARSED = "layout_file_parsed"
 """After the layout file was parsed. No additional kwargs. Everything is stored in the env."""
 
+ITEM_INFO_CONFIG = "item_info_config"
+
 ENV_INITIALIZED = "env_initialized"
 """At the end of the __init__ method. No additional kwargs. Everything is stored in the env."""
 
@@ -58,6 +60,8 @@ NO_SERVING = "no_serving"
 
 # TODO drop off
 
+PLATE_OUT_OF_KITCHEN_TIME = "plate_out_of_kitchen_time"
+
 DIRTY_PLATE_ARRIVES = "dirty_plate_arrives"
 
 TRASHCAN_USAGE = "trashcan_usage"
@@ -78,6 +82,8 @@ SERVE_NOT_ORDERED_MEAL = "serve_not_ordered_meal"
 
 SERVE_WITHOUT_PLATE = "serve_without_plate"
 
+ORDER_DURATION_SAMPLE = "order_duration_sample"
+
 COMPLETED_ORDER = "completed_order"
 
 INIT_ORDERS = "init_orders"
diff --git a/overcooked_simulator/order.py b/overcooked_simulator/order.py
index 299bcdd6..e4837ea1 100644
--- a/overcooked_simulator/order.py
+++ b/overcooked_simulator/order.py
@@ -61,6 +61,7 @@ from overcooked_simulator.hooks import (
     COMPLETED_ORDER,
     INIT_ORDERS,
     NEW_ORDERS,
+    ORDER_DURATION_SAMPLE,
 )
 
 log = logging.getLogger(__name__)
@@ -137,9 +138,11 @@ class OrderGeneration:
         ```
     """
 
-    def __init__(self, available_meals: dict[str, ItemInfo], **kwargs):
+    def __init__(self, available_meals: dict[str, ItemInfo], hook: Hooks, **kwargs):
         self.available_meals: list[ItemInfo] = list(available_meals.values())
         """Available meals restricted through the `environment_config.yml`."""
+        self.hook = hook
+        """Reference to the hook manager."""
 
     @abstractmethod
     def init_orders(self, now) -> list[Order]:
@@ -165,7 +168,9 @@ class OrderAndScoreManager:
         self.score: float = 0.0
         """The current score of the environment."""
         self.order_gen: OrderGeneration = order_config["order_gen_class"](
-            available_meals=available_meals, kwargs=order_config["order_gen_kwargs"]
+            available_meals=available_meals,
+            hook=hook,
+            kwargs=order_config["order_gen_kwargs"],
         )
         """The order generation."""
         self.serving_not_ordered_meals: Callable[
@@ -511,8 +516,8 @@ class RandomOrderGeneration(OrderGeneration):
     ```
     """
 
-    def __init__(self, available_meals: dict[str, ItemInfo], **kwargs):
-        super().__init__(available_meals, **kwargs)
+    def __init__(self, available_meals: dict[str, ItemInfo], hook: Hooks, **kwargs):
+        super().__init__(available_meals, hook, **kwargs)
         self.kwargs: RandomOrderKwarg = RandomOrderKwarg(**kwargs["kwargs"])
         self.next_order_time: datetime | None = datetime.max
         self.number_cur_orders: int = 0
@@ -578,6 +583,10 @@ class RandomOrderGeneration(OrderGeneration):
                         random, self.kwargs.order_duration_random_func["func"]
                     )(**self.kwargs.order_duration_random_func["kwargs"])
                 )
+            self.hook(
+                ORDER_DURATION_SAMPLE,
+                duration=duration,
+            )
             log.info(f"Create order for meal {meal} with duration {duration}")
             orders.append(
                 Order(
diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py
index 53d8559f..626e3cda 100644
--- a/overcooked_simulator/overcooked_environment.py
+++ b/overcooked_simulator/overcooked_environment.py
@@ -41,6 +41,7 @@ from overcooked_simulator.hooks import (
     ACTION_ON_NOT_REACHABLE_COUNTER,
     ACTION_PUT,
     ACTION_INTERACT_START,
+    ITEM_INFO_CONFIG,
 )
 from overcooked_simulator.order import (
     OrderAndScoreManager,
@@ -107,6 +108,7 @@ class EnvironmentConfig(TypedDict):
     orders: OrderConfig
     player_config: PlayerConfig
     layout_chars: dict[str, str]
+    extra_setup_functions: dict[str, dict]
 
 
 class Environment:
@@ -123,13 +125,17 @@ class Environment:
         layout_config: Path | str,
         item_info: Path | str,
         as_files: bool = True,
+        env_name: str = "overcooked_sim",
     ):
+        self.env_name = env_name
+        """Reference to the run. E.g, the env id."""
+        self.env_time: datetime = create_init_env_time()
+        """the internal time of the environment. An environment starts always with the time from 
+        `create_init_env_time`."""
+
         self.hook: Hooks = Hooks(self)
         """Hook manager. Register callbacks and create hook points with additional kwargs."""
 
-        # init callbacks here from config
-        # add_dummy_callbacks(self)
-
         self.players: dict[str, Player] = {}
         """the player, keyed by their id/name."""
 
@@ -137,13 +143,14 @@ class Environment:
         """Are the configs just the path to the files."""
         if self.as_files:
             with open(env_config, "r") as file:
-                self.environment_config: EnvironmentConfig = yaml.load(
-                    file, Loader=yaml.Loader
-                )
-        else:
-            self.environment_config: EnvironmentConfig = yaml.load(
-                env_config, Loader=yaml.Loader
-            )
+                env_config = file.read()
+        self.environment_config: EnvironmentConfig = yaml.load(
+            env_config, Loader=yaml.Loader
+        )
+        """The config of the environment. All environment specific attributes is configured here."""
+
+        self.extra_setup_functions()
+
         self.layout_config = layout_config
         """The layout config for the environment"""
         # self.counter_side_length = 1  # -> this changed! is 1 now
@@ -228,9 +235,6 @@ class Environment:
         )
         """Counters that needs to be called in the step function via the `progress` method."""
 
-        self.env_time: datetime = create_init_env_time()
-        """the internal time of the environment. An environment starts always with the time from 
-        `create_init_env_time`."""
         self.order_and_score.create_init_orders(self.env_time)
         self.start_time = self.env_time
         """The relative env time when it started."""
@@ -240,7 +244,11 @@ class Environment:
         """The relative env time when it will stop/end"""
         log.debug(f"End time: {self.env_time_end}")
 
-        self.hook(ENV_INITIALIZED)
+        self.hook(
+            ENV_INITIALIZED,
+            environment_config=env_config,
+            layout_config=self.layout_config,
+        )
 
     @property
     def game_ended(self) -> bool:
@@ -257,9 +265,9 @@ class Environment:
         """Load `item_info.yml`, create ItemInfo classes and replace equipment strings with item infos."""
         if self.as_files:
             with open(data, "r") as file:
-                item_lookup = yaml.safe_load(file)
-        else:
-            item_lookup = yaml.safe_load(data)
+                data = file.read()
+        self.hook(ITEM_INFO_CONFIG, item_info_config=data)
+        item_lookup = yaml.safe_load(data)
         for item_name in item_lookup:
             item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name])
 
@@ -340,9 +348,8 @@ class Environment:
 
         if self.as_files:
             with open(self.layout_config, "r") as layout_file:
-                lines = layout_file.readlines()
-        else:
-            lines = self.layout_config.split("\n")
+                self.layout_config = layout_file.read()
+        lines = self.layout_config.split("\n")
 
         for line in lines:
             line = line.replace("\n", "").replace(" ", "")  # remove newline char
@@ -696,3 +703,13 @@ class Environment:
 
     def register_callback_for_hook(self, hook_ref: str | list[str], callback: Callable):
         self.hook.register_callback(hook_ref, callback)
+
+    def extra_setup_functions(self):
+        if self.environment_config["extra_setup_functions"]:
+            for function_name, function_def in self.environment_config[
+                "extra_setup_functions"
+            ].items():
+                log.info(f"Setup function {function_name}")
+                function_def["func"](
+                    name=function_name, env=self, **function_def["kwargs"]
+                )
diff --git a/overcooked_simulator/recording.py b/overcooked_simulator/recording.py
new file mode 100644
index 00000000..1211a997
--- /dev/null
+++ b/overcooked_simulator/recording.py
@@ -0,0 +1,66 @@
+import json
+import logging
+import os
+import traceback
+from pathlib import Path
+from typing import Any
+
+import platformdirs
+
+from overcooked_simulator.overcooked_environment import Environment
+from overcooked_simulator.utils import NumpyAndDataclassEncoder
+
+log = logging.getLogger(__name__)
+
+
+def class_recording_with_hooks(
+    name: str,
+    env: Environment,
+    hooks: list[str],
+    log_class,
+    log_class_kwargs: dict[str, Any],
+):
+    recorder = log_class(name=name, env=env, **log_class_kwargs)
+    for hook in hooks:
+        env.register_callback_for_hook(hook, recorder)
+
+
+class LogRecorder:
+    def __init__(
+        self,
+        name: str,
+        env: Environment,
+        log_path: str,
+        add_hook_ref: bool = False,
+    ):
+        self.add_hook_ref = add_hook_ref
+        log_path = log_path.replace("ENV_NAME", env.env_name)
+        if log_path.startswith("USER_LOG_DIR/"):
+            log_path = (
+                Path(platformdirs.user_log_dir("overcooked_simulator"))
+                / log_path[len("USER_LOG_DIR/") :]
+            )
+        else:
+            log_path = Path(log_path)
+        self.log_path = log_path
+
+        os.makedirs(log_path.parent, exist_ok=True)
+
+    def __call__(self, hook_ref: str, env: Environment, **kwargs):
+        try:
+            record = (
+                json.dumps(
+                    {
+                        "env_time": env.env_time.isoformat(),
+                        **kwargs,
+                        **({"hook_ref": hook_ref} if self.add_hook_ref else {}),
+                    },
+                    cls=NumpyAndDataclassEncoder,
+                )
+                + "\n"
+            )
+            with open(self.log_path, "a") as log_file:
+                log_file.write(record)
+        except TypeError as e:
+            traceback.print_exception(e)
+            log.info(f"Not JSON serializable Record {kwargs}")
diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py
index 2754f089..6805a149 100644
--- a/overcooked_simulator/utils.py
+++ b/overcooked_simulator/utils.py
@@ -1,12 +1,13 @@
 """
 Some utility functions.
 """
-
+import dataclasses
+import json
 import logging
 import os
 import sys
 import uuid
-from datetime import datetime
+from datetime import datetime, timedelta
 from enum import Enum
 
 import numpy as np
@@ -106,3 +107,25 @@ def add_list_of_manager_ids_arguments(parser):
         default=[uuid.uuid4().hex],
         help="List of manager IDs that can create environments.",
     )
+
+
+class NumpyAndDataclassEncoder(json.JSONEncoder):
+    """Special json encoder for numpy types"""
+
+    def default(self, obj):
+        if isinstance(obj, np.integer):
+            return int(obj)
+        elif isinstance(obj, np.floating):
+            return float(obj)
+        elif isinstance(obj, np.ndarray):
+            return obj.tolist()
+        elif isinstance(obj, timedelta):
+            return obj.total_seconds()
+        elif isinstance(obj, datetime):
+            return obj.isoformat()
+        elif dataclasses.is_dataclass(obj):
+            return dataclasses.asdict(obj, dict_factory=custom_asdict_factory)
+        # elif callable(obj):
+        #     return getattr(obj, "__name__", "Unknown")
+
+        return json.JSONEncoder.default(self, obj)
diff --git a/setup.py b/setup.py
index d57d47c3..8daf673c 100644
--- a/setup.py
+++ b/setup.py
@@ -21,6 +21,7 @@ requirements = [
     "uvicorn",
     "websockets",
     "requests",
+    "platformdirs",
 ]
 
 test_requirements = [
-- 
GitLab