Skip to content
Snippets Groups Projects
Commit a9290be8 authored by Fabian Heinrich's avatar Fabian Heinrich
Browse files

Merge branch '72-record-relevant-env-events-in-a-file' into 'main'

Resolve "Record relevant env events in a file"

Closes #72

See merge request scs/cocosy/overcooked-simulator!43
parents ab5b30b9 8ada3ae3
No related branches found
No related tags found
1 merge request!43Resolve "Record relevant env events in a file"
Pipeline #45573 passed
...@@ -50,7 +50,7 @@ def main(cli_args=None): ...@@ -50,7 +50,7 @@ def main(cli_args=None):
print("Received Keyboard interrupt") print("Received Keyboard interrupt")
finally: finally:
if game_server is not None and game_server.is_alive(): if game_server is not None and game_server.is_alive():
print("Terminate gparserame server") print("Terminate game server")
game_server.terminate() game_server.terminate()
if pygame_gui is not None and pygame_gui.is_alive(): if pygame_gui is not None and pygame_gui.is_alive():
print("Terminate pygame gui") print("Terminate pygame gui")
......
...@@ -60,6 +60,7 @@ from overcooked_simulator.hooks import ( ...@@ -60,6 +60,7 @@ from overcooked_simulator.hooks import (
ADDED_PLATE_TO_SINK, ADDED_PLATE_TO_SINK,
DROP_ON_SINK_ADDON, DROP_ON_SINK_ADDON,
PICK_UP_FROM_SINK_ADDON, PICK_UP_FROM_SINK_ADDON,
PLATE_OUT_OF_KITCHEN_TIME,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -536,6 +537,7 @@ class PlateDispenser(Counter): ...@@ -536,6 +537,7 @@ class PlateDispenser(Counter):
self.out_of_kitchen_timer.append(time_plate_to_add) self.out_of_kitchen_timer.append(time_plate_to_add)
if time_plate_to_add < self.next_plate_time: if time_plate_to_add < self.next_plate_time:
self.next_plate_time = time_plate_to_add self.next_plate_time = time_plate_to_add
self.hook(PLATE_OUT_OF_KITCHEN_TIME, time_plate_to_add=time_plate_to_add)
def setup_plates(self): def setup_plates(self):
"""Create plates based on the config. Clean and dirty ones.""" """Create plates based on the config. Clean and dirty ones."""
......
...@@ -94,4 +94,38 @@ effect_manager: ...@@ -94,4 +94,38 @@ effect_manager:
class: !!python/name:overcooked_simulator.effect_manager.FireEffectManager '' class: !!python/name:overcooked_simulator.effect_manager.FireEffectManager ''
kwargs: kwargs:
spreading_duration: [ 5, 10 ] spreading_duration: [ 5, 10 ]
fire_burns_ingredients_and_meals: true fire_burns_ingredients_and_meals: true
\ No newline at end of file
extra_setup_functions:
# json_states:
# func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
# kwargs:
# hooks: [ json_state ]
# log_class: !!python/name:overcooked_simulator.recording.LogRecorder ''
# log_class_kwargs:
# log_path: USER_LOG_DIR/ENV_NAME/json_states.jsonl
actions:
func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
kwargs:
hooks: [ pre_perform_action ]
log_class: !!python/name:overcooked_simulator.recording.LogRecorder ''
log_class_kwargs:
log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl
random_env_events:
func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
kwargs:
hooks: [ order_duration_sample, plate_out_of_kitchen_time ]
log_class: !!python/name:overcooked_simulator.recording.LogRecorder ''
log_class_kwargs:
log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl
add_hook_ref: true
env_configs:
func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
kwargs:
hooks: [ env_initialized, item_info_config ]
log_class: !!python/name:overcooked_simulator.recording.LogRecorder ''
log_class_kwargs:
log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl
add_hook_ref: true
...@@ -125,6 +125,7 @@ class EnvironmentHandler: ...@@ -125,6 +125,7 @@ class EnvironmentHandler:
layout_config=environment_config.layout_config, layout_config=environment_config.layout_config,
item_info=environment_config.item_info_config, item_info=environment_config.item_info_config,
as_files=False, as_files=False,
env_name=env_id,
) )
player_info = {} player_info = {}
for player_id in range(environment_config.number_players): for player_id in range(environment_config.number_players):
...@@ -725,8 +726,10 @@ async def websocket_player_endpoint(websocket: WebSocket, client_id: str): ...@@ -725,8 +726,10 @@ async def websocket_player_endpoint(websocket: WebSocket, client_id: str):
log.debug(f"Client #{client_id} disconnected") log.debug(f"Client #{client_id} disconnected")
def main(host: str, port: int, manager_ids: list[str]): def main(
setup_logging() host: str, port: int, manager_ids: list[str], enable_websocket_logging: bool = False
):
setup_logging(enable_websocket_logging)
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
environment_handler.extend_allowed_manager(manager_ids) environment_handler.extend_allowed_manager(manager_ids)
...@@ -747,8 +750,7 @@ if __name__ == "__main__": ...@@ -747,8 +750,7 @@ if __name__ == "__main__":
disable_websocket_logging_arguments(parser) disable_websocket_logging_arguments(parser)
add_list_of_manager_ids_arguments(parser) add_list_of_manager_ids_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
setup_logging(args.enable_websocket_logging) main(args.url, args.port, args.manager_ids, args.enable_websocket_logging)
main(args.url, args.port, args.manager_ids)
""" """
Or in console: Or in console:
uvicorn overcooked_simulator.fastapi_game_server:app --reload uvicorn overcooked_simulator.fastapi_game_server:app --reload
......
...@@ -24,10 +24,10 @@ from overcooked_simulator.overcooked_environment import ( ...@@ -24,10 +24,10 @@ from overcooked_simulator.overcooked_environment import (
) )
from overcooked_simulator.utils import ( from overcooked_simulator.utils import (
custom_asdict_factory, custom_asdict_factory,
setup_logging,
url_and_port_arguments, url_and_port_arguments,
disable_websocket_logging_arguments, disable_websocket_logging_arguments,
add_list_of_manager_ids_arguments, add_list_of_manager_ids_arguments,
setup_logging,
) )
...@@ -696,8 +696,12 @@ class PyGameGUI: ...@@ -696,8 +696,12 @@ class PyGameGUI:
sys.exit() sys.exit()
def main(url: str, port: int, manager_ids: list[str]): def main(
url: str, port: int, manager_ids: list[str], enable_websocket_logging: bool = False
):
# TODO maybe read the player names and keyboard keys from config file? # TODO maybe read the player names and keyboard keys from config file?
setup_logging(enable_websocket_logging)
keys1 = [ keys1 = [
pygame.K_LEFT, pygame.K_LEFT,
pygame.K_RIGHT, pygame.K_RIGHT,
...@@ -730,5 +734,4 @@ if __name__ == "__main__": ...@@ -730,5 +734,4 @@ if __name__ == "__main__":
disable_websocket_logging_arguments(parser) disable_websocket_logging_arguments(parser)
add_list_of_manager_ids_arguments(parser) add_list_of_manager_ids_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
setup_logging(enable_websocket_logging=args.enable_websocket_logging) main(args.url, args.port, args.manager_ids, args.enable_websocket_logging)
main(args.url, args.port, args.manager_ids)
...@@ -11,6 +11,8 @@ ITEM_INFO_LOADED = "item_info_load" ...@@ -11,6 +11,8 @@ ITEM_INFO_LOADED = "item_info_load"
LAYOUT_FILE_PARSED = "layout_file_parsed" LAYOUT_FILE_PARSED = "layout_file_parsed"
"""After the layout file was parsed. No additional kwargs. Everything is stored in the env.""" """After the layout file was parsed. No additional kwargs. Everything is stored in the env."""
ITEM_INFO_CONFIG = "item_info_config"
ENV_INITIALIZED = "env_initialized" ENV_INITIALIZED = "env_initialized"
"""At the end of the __init__ method. No additional kwargs. Everything is stored in the env.""" """At the end of the __init__ method. No additional kwargs. Everything is stored in the env."""
...@@ -58,6 +60,8 @@ NO_SERVING = "no_serving" ...@@ -58,6 +60,8 @@ NO_SERVING = "no_serving"
# TODO drop off # TODO drop off
PLATE_OUT_OF_KITCHEN_TIME = "plate_out_of_kitchen_time"
DIRTY_PLATE_ARRIVES = "dirty_plate_arrives" DIRTY_PLATE_ARRIVES = "dirty_plate_arrives"
TRASHCAN_USAGE = "trashcan_usage" TRASHCAN_USAGE = "trashcan_usage"
...@@ -78,6 +82,8 @@ SERVE_NOT_ORDERED_MEAL = "serve_not_ordered_meal" ...@@ -78,6 +82,8 @@ SERVE_NOT_ORDERED_MEAL = "serve_not_ordered_meal"
SERVE_WITHOUT_PLATE = "serve_without_plate" SERVE_WITHOUT_PLATE = "serve_without_plate"
ORDER_DURATION_SAMPLE = "order_duration_sample"
COMPLETED_ORDER = "completed_order" COMPLETED_ORDER = "completed_order"
INIT_ORDERS = "init_orders" INIT_ORDERS = "init_orders"
......
...@@ -61,6 +61,7 @@ from overcooked_simulator.hooks import ( ...@@ -61,6 +61,7 @@ from overcooked_simulator.hooks import (
COMPLETED_ORDER, COMPLETED_ORDER,
INIT_ORDERS, INIT_ORDERS,
NEW_ORDERS, NEW_ORDERS,
ORDER_DURATION_SAMPLE,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -137,9 +138,11 @@ class OrderGeneration: ...@@ -137,9 +138,11 @@ class OrderGeneration:
``` ```
""" """
def __init__(self, available_meals: dict[str, ItemInfo], random: Random, **kwargs): def __init__(self, available_meals: dict[str, ItemInfo], hook: Hooks, random: Random, **kwargs):
self.available_meals: list[ItemInfo] = list(available_meals.values()) self.available_meals: list[ItemInfo] = list(available_meals.values())
"""Available meals restricted through the `environment_config.yml`.""" """Available meals restricted through the `environment_config.yml`."""
self.hook = hook
"""Reference to the hook manager."""
self.random = random self.random = random
"""Random instance.""" """Random instance."""
...@@ -176,6 +179,7 @@ class OrderAndScoreManager: ...@@ -176,6 +179,7 @@ class OrderAndScoreManager:
"""The current score of the environment.""" """The current score of the environment."""
self.order_gen: OrderGeneration = order_config["order_gen_class"]( self.order_gen: OrderGeneration = order_config["order_gen_class"](
available_meals=available_meals, available_meals=available_meals,
hook=hook,
random=random, random=random,
kwargs=order_config["order_gen_kwargs"], kwargs=order_config["order_gen_kwargs"],
) )
...@@ -523,8 +527,8 @@ class RandomOrderGeneration(OrderGeneration): ...@@ -523,8 +527,8 @@ class RandomOrderGeneration(OrderGeneration):
``` ```
""" """
def __init__(self, available_meals: dict[str, ItemInfo], random: Random, **kwargs): def __init__(self, available_meals: dict[str, ItemInfo], hook: Hooks, random: Random, **kwargs):
super().__init__(available_meals, random, **kwargs) super().__init__(available_meals, hook, random, **kwargs)
self.kwargs: RandomOrderKwarg = RandomOrderKwarg(**kwargs["kwargs"]) self.kwargs: RandomOrderKwarg = RandomOrderKwarg(**kwargs["kwargs"])
self.next_order_time: datetime | None = datetime.max self.next_order_time: datetime | None = datetime.max
self.number_cur_orders: int = 0 self.number_cur_orders: int = 0
...@@ -590,6 +594,10 @@ class RandomOrderGeneration(OrderGeneration): ...@@ -590,6 +594,10 @@ class RandomOrderGeneration(OrderGeneration):
self.random, self.kwargs.order_duration_random_func["func"] self.random, self.kwargs.order_duration_random_func["func"]
)(**self.kwargs.order_duration_random_func["kwargs"]) )(**self.kwargs.order_duration_random_func["kwargs"])
) )
self.hook(
ORDER_DURATION_SAMPLE,
duration=duration,
)
log.info(f"Create order for meal {meal} with duration {duration}") log.info(f"Create order for meal {meal} with duration {duration}")
orders.append( orders.append(
Order( Order(
......
...@@ -42,6 +42,7 @@ from overcooked_simulator.hooks import ( ...@@ -42,6 +42,7 @@ from overcooked_simulator.hooks import (
ACTION_ON_NOT_REACHABLE_COUNTER, ACTION_ON_NOT_REACHABLE_COUNTER,
ACTION_PUT, ACTION_PUT,
ACTION_INTERACT_START, ACTION_INTERACT_START,
ITEM_INFO_CONFIG,
) )
from overcooked_simulator.order import ( from overcooked_simulator.order import (
OrderAndScoreManager, OrderAndScoreManager,
...@@ -108,6 +109,7 @@ class EnvironmentConfig(TypedDict): ...@@ -108,6 +109,7 @@ class EnvironmentConfig(TypedDict):
orders: OrderConfig orders: OrderConfig
player_config: PlayerConfig player_config: PlayerConfig
layout_chars: dict[str, str] layout_chars: dict[str, str]
extra_setup_functions: dict[str, dict]
effect_manager: dict effect_manager: dict
...@@ -125,16 +127,20 @@ class Environment: ...@@ -125,16 +127,20 @@ class Environment:
layout_config: Path | str, layout_config: Path | str,
item_info: Path | str, item_info: Path | str,
as_files: bool = True, as_files: bool = True,
env_name: str = "overcooked_sim",
seed: int = 56789223842348, seed: int = 56789223842348,
): ):
self.env_name = env_name
"""Reference to the run. E.g, the env id."""
self.env_time: datetime = create_init_env_time()
"""the internal time of the environment. An environment starts always with the time from
`create_init_env_time`."""
self.random: Random = Random(seed) self.random: Random = Random(seed)
"""Random instance.""" """Random instance."""
self.hook: Hooks = Hooks(self) self.hook: Hooks = Hooks(self)
"""Hook manager. Register callbacks and create hook points with additional kwargs.""" """Hook manager. Register callbacks and create hook points with additional kwargs."""
# init callbacks here from config
# add_dummy_callbacks(self)
self.players: dict[str, Player] = {} self.players: dict[str, Player] = {}
"""the player, keyed by their id/name.""" """the player, keyed by their id/name."""
...@@ -142,13 +148,14 @@ class Environment: ...@@ -142,13 +148,14 @@ class Environment:
"""Are the configs just the path to the files.""" """Are the configs just the path to the files."""
if self.as_files: if self.as_files:
with open(env_config, "r") as file: with open(env_config, "r") as file:
self.environment_config: EnvironmentConfig = yaml.load( env_config = file.read()
file, Loader=yaml.Loader self.environment_config: EnvironmentConfig = yaml.load(
) env_config, Loader=yaml.Loader
else: )
self.environment_config: EnvironmentConfig = yaml.load( """The config of the environment. All environment specific attributes is configured here."""
env_config, Loader=yaml.Loader
) self.extra_setup_functions()
self.layout_config = layout_config self.layout_config = layout_config
"""The layout config for the environment""" """The layout config for the environment"""
# self.counter_side_length = 1 # -> this changed! is 1 now # self.counter_side_length = 1 # -> this changed! is 1 now
...@@ -236,9 +243,6 @@ class Environment: ...@@ -236,9 +243,6 @@ class Environment:
) )
"""Counters that needs to be called in the step function via the `progress` method.""" """Counters that needs to be called in the step function via the `progress` method."""
self.env_time: datetime = create_init_env_time()
"""the internal time of the environment. An environment starts always with the time from
`create_init_env_time`."""
self.order_and_score.create_init_orders(self.env_time) self.order_and_score.create_init_orders(self.env_time)
self.start_time = self.env_time self.start_time = self.env_time
"""The relative env time when it started.""" """The relative env time when it started."""
...@@ -252,7 +256,13 @@ class Environment: ...@@ -252,7 +256,13 @@ class Environment:
str, EffectManager str, EffectManager
] = self.counter_factory.setup_effect_manger(self.counters) ] = self.counter_factory.setup_effect_manger(self.counters)
self.hook(ENV_INITIALIZED) self.hook(
ENV_INITIALIZED,
environment_config=env_config,
layout_config=self.layout_config,
seed=seed,
env_start_time_worldtime=datetime.now()
)
@property @property
def game_ended(self) -> bool: def game_ended(self) -> bool:
...@@ -269,9 +279,9 @@ class Environment: ...@@ -269,9 +279,9 @@ class Environment:
"""Load `item_info.yml`, create ItemInfo classes and replace equipment strings with item infos.""" """Load `item_info.yml`, create ItemInfo classes and replace equipment strings with item infos."""
if self.as_files: if self.as_files:
with open(data, "r") as file: with open(data, "r") as file:
item_lookup = yaml.safe_load(file) data = file.read()
else: self.hook(ITEM_INFO_CONFIG, item_info_config=data)
item_lookup = yaml.safe_load(data) item_lookup = yaml.safe_load(data)
for item_name in item_lookup: for item_name in item_lookup:
item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name]) item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name])
...@@ -354,9 +364,8 @@ class Environment: ...@@ -354,9 +364,8 @@ class Environment:
if self.as_files: if self.as_files:
with open(self.layout_config, "r") as layout_file: with open(self.layout_config, "r") as layout_file:
lines = layout_file.readlines() self.layout_config = layout_file.read()
else: lines = self.layout_config.split("\n")
lines = self.layout_config.split("\n")
grid = [] grid = []
...@@ -780,3 +789,13 @@ class Environment: ...@@ -780,3 +789,13 @@ class Environment:
def register_callback_for_hook(self, hook_ref: str | list[str], callback: Callable): def register_callback_for_hook(self, hook_ref: str | list[str], callback: Callable):
self.hook.register_callback(hook_ref, callback) self.hook.register_callback(hook_ref, callback)
def extra_setup_functions(self):
if self.environment_config["extra_setup_functions"]:
for function_name, function_def in self.environment_config[
"extra_setup_functions"
].items():
log.info(f"Setup function {function_name}")
function_def["func"](
name=function_name, env=self, **function_def["kwargs"]
)
import json
import logging
import os
import traceback
from pathlib import Path
from typing import Any
import platformdirs
from overcooked_simulator import ROOT_DIR
from overcooked_simulator.overcooked_environment import Environment
from overcooked_simulator.utils import NumpyAndDataclassEncoder
log = logging.getLogger(__name__)
def class_recording_with_hooks(
name: str,
env: Environment,
hooks: list[str],
log_class,
log_class_kwargs: dict[str, Any],
):
recorder = log_class(name=name, env=env, **log_class_kwargs)
for hook in hooks:
env.register_callback_for_hook(hook, recorder)
class LogRecorder:
def __init__(
self,
name: str,
env: Environment,
log_path: str = "USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl",
add_hook_ref: bool = False,
):
self.add_hook_ref = add_hook_ref
log_path = log_path.replace("ENV_NAME", env.env_name).replace(
"LOG_RECORD_NAME", name
)
if log_path.startswith("USER_LOG_DIR/"):
log_path = (
Path(platformdirs.user_log_dir("overcooked_simulator"))
/ log_path[len("USER_LOG_DIR/") :]
)
elif log_path.startswith("ROOT_DIR/"):
log_path = ROOT_DIR / log_path[len("ROOT_DIR/") :]
else:
log_path = Path(log_path)
self.log_path = log_path
log.info(f"Recorder record for {name} in file://{log_path}")
os.makedirs(log_path.parent, exist_ok=True)
def __call__(self, hook_ref: str, env: Environment, **kwargs):
try:
record = (
json.dumps(
{
"env_time": env.env_time.isoformat(),
**kwargs,
**({"hook_ref": hook_ref} if self.add_hook_ref else {}),
},
cls=NumpyAndDataclassEncoder,
)
+ "\n"
)
with open(self.log_path, "a") as log_file:
log_file.write(record)
except TypeError as e:
traceback.print_exception(e)
log.info(f"Not JSON serializable Record {kwargs}")
...@@ -3,12 +3,14 @@ Some utility functions. ...@@ -3,12 +3,14 @@ Some utility functions.
""" """
from __future__ import annotations from __future__ import annotations
import dataclasses
import json
import logging import logging
import os import os
import sys import sys
import uuid import uuid
from collections import deque from collections import deque
from datetime import datetime from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
...@@ -131,3 +133,25 @@ def add_list_of_manager_ids_arguments(parser): ...@@ -131,3 +133,25 @@ def add_list_of_manager_ids_arguments(parser):
default=[uuid.uuid4().hex], default=[uuid.uuid4().hex],
help="List of manager IDs that can create environments.", help="List of manager IDs that can create environments.",
) )
class NumpyAndDataclassEncoder(json.JSONEncoder):
"""Special json encoder for numpy types"""
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, timedelta):
return obj.total_seconds()
elif isinstance(obj, datetime):
return obj.isoformat()
elif dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj, dict_factory=custom_asdict_factory)
# elif callable(obj):
# return getattr(obj, "__name__", "Unknown")
return json.JSONEncoder.default(self, obj)
...@@ -21,6 +21,7 @@ requirements = [ ...@@ -21,6 +21,7 @@ requirements = [
"uvicorn", "uvicorn",
"websockets", "websockets",
"requests", "requests",
"platformdirs",
] ]
test_requirements = [ test_requirements = [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment