diff --git a/overcooked_simulator/__main__.py b/overcooked_simulator/__main__.py index 436201f2ec085f299699bdbf641f6f01029270b8..f48180ac4aa23e525646e844d9e53b9f49a71595 100644 --- a/overcooked_simulator/__main__.py +++ b/overcooked_simulator/__main__.py @@ -50,7 +50,7 @@ def main(cli_args=None): print("Received Keyboard interrupt") finally: if game_server is not None and game_server.is_alive(): - print("Terminate gparserame server") + print("Terminate game server") game_server.terminate() if pygame_gui is not None and pygame_gui.is_alive(): print("Terminate pygame gui") diff --git a/overcooked_simulator/counters.py b/overcooked_simulator/counters.py index d4b72eec6029dd96ca6c8d5604deb64b12b0c970..4b417a2f652286e3a9b566665b70da473ebebccb 100644 --- a/overcooked_simulator/counters.py +++ b/overcooked_simulator/counters.py @@ -63,6 +63,7 @@ from overcooked_simulator.hooks import ( ADDED_PLATE_TO_SINK, DROP_ON_SINK_ADDON, PICK_UP_FROM_SINK_ADDON, + PLATE_OUT_OF_KITCHEN_TIME, ) if TYPE_CHECKING: @@ -506,6 +507,7 @@ class PlateDispenser(Counter): self.out_of_kitchen_timer.append(time_plate_to_add) if time_plate_to_add < self.next_plate_time: self.next_plate_time = time_plate_to_add + self.hook(PLATE_OUT_OF_KITCHEN_TIME, time_plate_to_add=time_plate_to_add) def setup_plates(self): """Create plates based on the config. Clean and dirty ones.""" diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml index c0de86d2afc35fad41a5c206b448a0f1eecdad6e..671900ea0075251003d08bb29b56d935eef6119c 100644 --- a/overcooked_simulator/game_content/environment_config.yaml +++ b/overcooked_simulator/game_content/environment_config.yaml @@ -85,3 +85,36 @@ player_config: radius: 0.4 player_speed_units_per_seconds: 8 interaction_range: 1.6 + +extra_setup_functions: + # json_states: + # func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' + # kwargs: + # hooks: [ json_state ] + # log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' + # log_class_kwargs: + # log_path: USER_LOG_DIR/ENV_NAME/json_states.jsonl + actions: + func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' + kwargs: + hooks: [ pre_perform_action ] + log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' + log_class_kwargs: + log_path: USER_LOG_DIR/ENV_NAME/actions.jsonl + random_env_events: + func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' + kwargs: + hooks: [ order_duration_sample, plate_out_of_kitchen_time ] + log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' + log_class_kwargs: + log_path: USER_LOG_DIR/ENV_NAME/random_env_events.jsonl + add_hook_ref: true + env_configs: + func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' + kwargs: + hooks: [ env_initialized, item_info_config ] + log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' + log_class_kwargs: + log_path: USER_LOG_DIR/ENV_NAME/env_configs.jsonl + add_hook_ref: true + diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py index 1848a5329552dc8e0f733fcc3b330fd78ada83dd..8e663b1f28bc5f2bf3d74c3adb5896254a30f6b1 100644 --- a/overcooked_simulator/game_server.py +++ b/overcooked_simulator/game_server.py @@ -125,6 +125,7 @@ class EnvironmentHandler: layout_config=environment_config.layout_config, item_info=environment_config.item_info_config, as_files=False, + env_name=env_id, ) player_info = {} for player_id in range(environment_config.number_players): @@ -724,7 +725,10 @@ async def websocket_player_endpoint(websocket: WebSocket, client_id: str): log.debug(f"Client #{client_id} disconnected") -def main(host: str, port: int, manager_ids: list[str]): +def main( + host: str, port: int, manager_ids: list[str], enable_websocket_logging: bool = False +): + setup_logging(enable_websocket_logging) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) environment_handler.extend_allowed_manager(manager_ids) @@ -745,8 +749,7 @@ if __name__ == "__main__": disable_websocket_logging_arguments(parser) add_list_of_manager_ids_arguments(parser) args = parser.parse_args() - setup_logging(args.enable_websocket_logging) - main(args.url, args.port, args.manager_ids) + main(args.url, args.port, args.manager_ids, args.enable_websocket_logging) """ Or in console: uvicorn overcooked_simulator.fastapi_game_server:app --reload diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py index 7922bc2fedf5de237a620e01ef0ad2496e942da9..7248ef4eb05f9eab011198943faaeffcd5138796 100644 --- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py +++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py @@ -24,10 +24,10 @@ from overcooked_simulator.overcooked_environment import ( ) from overcooked_simulator.utils import ( custom_asdict_factory, - setup_logging, url_and_port_arguments, disable_websocket_logging_arguments, add_list_of_manager_ids_arguments, + setup_logging, ) @@ -692,8 +692,12 @@ class PyGameGUI: sys.exit() -def main(url: str, port: int, manager_ids: list[str]): +def main( + url: str, port: int, manager_ids: list[str], enable_websocket_logging: bool = False +): # TODO maybe read the player names and keyboard keys from config file? + setup_logging(enable_websocket_logging) + keys1 = [ pygame.K_LEFT, pygame.K_RIGHT, @@ -726,5 +730,4 @@ if __name__ == "__main__": disable_websocket_logging_arguments(parser) add_list_of_manager_ids_arguments(parser) args = parser.parse_args() - setup_logging(enable_websocket_logging=args.enable_websocket_logging) - main(args.url, args.port, args.manager_ids) + main(args.url, args.port, args.manager_ids, args.enable_websocket_logging) diff --git a/overcooked_simulator/hooks.py b/overcooked_simulator/hooks.py index 93391eda3df8c8e834330698d214657516fa8300..285ed735fb58190183dce24f051d2c4f1b8bcd24 100644 --- a/overcooked_simulator/hooks.py +++ b/overcooked_simulator/hooks.py @@ -11,6 +11,8 @@ ITEM_INFO_LOADED = "item_info_load" LAYOUT_FILE_PARSED = "layout_file_parsed" """After the layout file was parsed. No additional kwargs. Everything is stored in the env.""" +ITEM_INFO_CONFIG = "item_info_config" + ENV_INITIALIZED = "env_initialized" """At the end of the __init__ method. No additional kwargs. Everything is stored in the env.""" @@ -58,6 +60,8 @@ NO_SERVING = "no_serving" # TODO drop off +PLATE_OUT_OF_KITCHEN_TIME = "plate_out_of_kitchen_time" + DIRTY_PLATE_ARRIVES = "dirty_plate_arrives" TRASHCAN_USAGE = "trashcan_usage" @@ -78,6 +82,8 @@ SERVE_NOT_ORDERED_MEAL = "serve_not_ordered_meal" SERVE_WITHOUT_PLATE = "serve_without_plate" +ORDER_DURATION_SAMPLE = "order_duration_sample" + COMPLETED_ORDER = "completed_order" INIT_ORDERS = "init_orders" diff --git a/overcooked_simulator/order.py b/overcooked_simulator/order.py index 299bcdd6e3c7afbfcd9b2a57c816a911606106f2..e4837ea1681671a5607e9da9fc4c5b2a8d28097d 100644 --- a/overcooked_simulator/order.py +++ b/overcooked_simulator/order.py @@ -61,6 +61,7 @@ from overcooked_simulator.hooks import ( COMPLETED_ORDER, INIT_ORDERS, NEW_ORDERS, + ORDER_DURATION_SAMPLE, ) log = logging.getLogger(__name__) @@ -137,9 +138,11 @@ class OrderGeneration: ``` """ - def __init__(self, available_meals: dict[str, ItemInfo], **kwargs): + def __init__(self, available_meals: dict[str, ItemInfo], hook: Hooks, **kwargs): self.available_meals: list[ItemInfo] = list(available_meals.values()) """Available meals restricted through the `environment_config.yml`.""" + self.hook = hook + """Reference to the hook manager.""" @abstractmethod def init_orders(self, now) -> list[Order]: @@ -165,7 +168,9 @@ class OrderAndScoreManager: self.score: float = 0.0 """The current score of the environment.""" self.order_gen: OrderGeneration = order_config["order_gen_class"]( - available_meals=available_meals, kwargs=order_config["order_gen_kwargs"] + available_meals=available_meals, + hook=hook, + kwargs=order_config["order_gen_kwargs"], ) """The order generation.""" self.serving_not_ordered_meals: Callable[ @@ -511,8 +516,8 @@ class RandomOrderGeneration(OrderGeneration): ``` """ - def __init__(self, available_meals: dict[str, ItemInfo], **kwargs): - super().__init__(available_meals, **kwargs) + def __init__(self, available_meals: dict[str, ItemInfo], hook: Hooks, **kwargs): + super().__init__(available_meals, hook, **kwargs) self.kwargs: RandomOrderKwarg = RandomOrderKwarg(**kwargs["kwargs"]) self.next_order_time: datetime | None = datetime.max self.number_cur_orders: int = 0 @@ -578,6 +583,10 @@ class RandomOrderGeneration(OrderGeneration): random, self.kwargs.order_duration_random_func["func"] )(**self.kwargs.order_duration_random_func["kwargs"]) ) + self.hook( + ORDER_DURATION_SAMPLE, + duration=duration, + ) log.info(f"Create order for meal {meal} with duration {duration}") orders.append( Order( diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index 53d8559f0be0642f3fa376789bea66e495eafb14..626e3cdac17eeb6dbca10ab09c4c36814919199f 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -41,6 +41,7 @@ from overcooked_simulator.hooks import ( ACTION_ON_NOT_REACHABLE_COUNTER, ACTION_PUT, ACTION_INTERACT_START, + ITEM_INFO_CONFIG, ) from overcooked_simulator.order import ( OrderAndScoreManager, @@ -107,6 +108,7 @@ class EnvironmentConfig(TypedDict): orders: OrderConfig player_config: PlayerConfig layout_chars: dict[str, str] + extra_setup_functions: dict[str, dict] class Environment: @@ -123,13 +125,17 @@ class Environment: layout_config: Path | str, item_info: Path | str, as_files: bool = True, + env_name: str = "overcooked_sim", ): + self.env_name = env_name + """Reference to the run. E.g, the env id.""" + self.env_time: datetime = create_init_env_time() + """the internal time of the environment. An environment starts always with the time from + `create_init_env_time`.""" + self.hook: Hooks = Hooks(self) """Hook manager. Register callbacks and create hook points with additional kwargs.""" - # init callbacks here from config - # add_dummy_callbacks(self) - self.players: dict[str, Player] = {} """the player, keyed by their id/name.""" @@ -137,13 +143,14 @@ class Environment: """Are the configs just the path to the files.""" if self.as_files: with open(env_config, "r") as file: - self.environment_config: EnvironmentConfig = yaml.load( - file, Loader=yaml.Loader - ) - else: - self.environment_config: EnvironmentConfig = yaml.load( - env_config, Loader=yaml.Loader - ) + env_config = file.read() + self.environment_config: EnvironmentConfig = yaml.load( + env_config, Loader=yaml.Loader + ) + """The config of the environment. All environment specific attributes is configured here.""" + + self.extra_setup_functions() + self.layout_config = layout_config """The layout config for the environment""" # self.counter_side_length = 1 # -> this changed! is 1 now @@ -228,9 +235,6 @@ class Environment: ) """Counters that needs to be called in the step function via the `progress` method.""" - self.env_time: datetime = create_init_env_time() - """the internal time of the environment. An environment starts always with the time from - `create_init_env_time`.""" self.order_and_score.create_init_orders(self.env_time) self.start_time = self.env_time """The relative env time when it started.""" @@ -240,7 +244,11 @@ class Environment: """The relative env time when it will stop/end""" log.debug(f"End time: {self.env_time_end}") - self.hook(ENV_INITIALIZED) + self.hook( + ENV_INITIALIZED, + environment_config=env_config, + layout_config=self.layout_config, + ) @property def game_ended(self) -> bool: @@ -257,9 +265,9 @@ class Environment: """Load `item_info.yml`, create ItemInfo classes and replace equipment strings with item infos.""" if self.as_files: with open(data, "r") as file: - item_lookup = yaml.safe_load(file) - else: - item_lookup = yaml.safe_load(data) + data = file.read() + self.hook(ITEM_INFO_CONFIG, item_info_config=data) + item_lookup = yaml.safe_load(data) for item_name in item_lookup: item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name]) @@ -340,9 +348,8 @@ class Environment: if self.as_files: with open(self.layout_config, "r") as layout_file: - lines = layout_file.readlines() - else: - lines = self.layout_config.split("\n") + self.layout_config = layout_file.read() + lines = self.layout_config.split("\n") for line in lines: line = line.replace("\n", "").replace(" ", "") # remove newline char @@ -696,3 +703,13 @@ class Environment: def register_callback_for_hook(self, hook_ref: str | list[str], callback: Callable): self.hook.register_callback(hook_ref, callback) + + def extra_setup_functions(self): + if self.environment_config["extra_setup_functions"]: + for function_name, function_def in self.environment_config[ + "extra_setup_functions" + ].items(): + log.info(f"Setup function {function_name}") + function_def["func"]( + name=function_name, env=self, **function_def["kwargs"] + ) diff --git a/overcooked_simulator/recording.py b/overcooked_simulator/recording.py new file mode 100644 index 0000000000000000000000000000000000000000..1211a9970aca298ba2124fcd0a3eff74a246dcf1 --- /dev/null +++ b/overcooked_simulator/recording.py @@ -0,0 +1,66 @@ +import json +import logging +import os +import traceback +from pathlib import Path +from typing import Any + +import platformdirs + +from overcooked_simulator.overcooked_environment import Environment +from overcooked_simulator.utils import NumpyAndDataclassEncoder + +log = logging.getLogger(__name__) + + +def class_recording_with_hooks( + name: str, + env: Environment, + hooks: list[str], + log_class, + log_class_kwargs: dict[str, Any], +): + recorder = log_class(name=name, env=env, **log_class_kwargs) + for hook in hooks: + env.register_callback_for_hook(hook, recorder) + + +class LogRecorder: + def __init__( + self, + name: str, + env: Environment, + log_path: str, + add_hook_ref: bool = False, + ): + self.add_hook_ref = add_hook_ref + log_path = log_path.replace("ENV_NAME", env.env_name) + if log_path.startswith("USER_LOG_DIR/"): + log_path = ( + Path(platformdirs.user_log_dir("overcooked_simulator")) + / log_path[len("USER_LOG_DIR/") :] + ) + else: + log_path = Path(log_path) + self.log_path = log_path + + os.makedirs(log_path.parent, exist_ok=True) + + def __call__(self, hook_ref: str, env: Environment, **kwargs): + try: + record = ( + json.dumps( + { + "env_time": env.env_time.isoformat(), + **kwargs, + **({"hook_ref": hook_ref} if self.add_hook_ref else {}), + }, + cls=NumpyAndDataclassEncoder, + ) + + "\n" + ) + with open(self.log_path, "a") as log_file: + log_file.write(record) + except TypeError as e: + traceback.print_exception(e) + log.info(f"Not JSON serializable Record {kwargs}") diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py index 2754f08912d31ff513ed891646d29571800e8960..6805a149bf69f48275703c90a1895609e6cb21f8 100644 --- a/overcooked_simulator/utils.py +++ b/overcooked_simulator/utils.py @@ -1,12 +1,13 @@ """ Some utility functions. """ - +import dataclasses +import json import logging import os import sys import uuid -from datetime import datetime +from datetime import datetime, timedelta from enum import Enum import numpy as np @@ -106,3 +107,25 @@ def add_list_of_manager_ids_arguments(parser): default=[uuid.uuid4().hex], help="List of manager IDs that can create environments.", ) + + +class NumpyAndDataclassEncoder(json.JSONEncoder): + """Special json encoder for numpy types""" + + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, timedelta): + return obj.total_seconds() + elif isinstance(obj, datetime): + return obj.isoformat() + elif dataclasses.is_dataclass(obj): + return dataclasses.asdict(obj, dict_factory=custom_asdict_factory) + # elif callable(obj): + # return getattr(obj, "__name__", "Unknown") + + return json.JSONEncoder.default(self, obj) diff --git a/setup.py b/setup.py index d57d47c30a8eefd7e427ff97be2c4884bd1a4cfb..8daf673c953090b4bbf6ce910949487412e300a5 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ requirements = [ "uvicorn", "websockets", "requests", + "platformdirs", ] test_requirements = [