From 6a57f19a0b94ce734031c6c8d8a0589bfd8a88e0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Wed, 24 Jan 2024 17:11:34 +0100
Subject: [PATCH] Refactor and enhance Overcooked simulator codebase

Codebase for Overcooked simulator has been refactored. File rearrangement includes deletions and renaming to better reflect their functionalities. Within the code, several updates are made like support for using both file paths and direct strings as configurations. These amendments improve the organization, modularity, and overall efficiency of the code.
---
 overcooked_simulator/api_call.py              |  14 -
 overcooked_simulator/fastapi_game_server.py   | 166 --------
 overcooked_simulator/game_server.py           | 393 ++++++++++++++++++
 overcooked_simulator/game_server_OLD.py       | 123 ------
 overcooked_simulator/main.py                  |  25 +-
 .../overcooked_environment.py                 |  48 ++-
 overcooked_simulator/player.py                |   4 +-
 7 files changed, 452 insertions(+), 321 deletions(-)
 delete mode 100644 overcooked_simulator/api_call.py
 delete mode 100644 overcooked_simulator/fastapi_game_server.py
 create mode 100644 overcooked_simulator/game_server.py
 delete mode 100644 overcooked_simulator/game_server_OLD.py

diff --git a/overcooked_simulator/api_call.py b/overcooked_simulator/api_call.py
deleted file mode 100644
index ea733f53..00000000
--- a/overcooked_simulator/api_call.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# websocket_client.py
-import asyncio
-
-import websockets
-
-
-async def send_message():
-    uri = "ws://127.0.0.1:8000/ws"
-    async with websockets.connect(uri) as websocket:
-        await websocket.send("Hello, server!")
-        response = await websocket.recv()
-        print(response)
-
-asyncio.run(send_message())
diff --git a/overcooked_simulator/fastapi_game_server.py b/overcooked_simulator/fastapi_game_server.py
deleted file mode 100644
index eb2695e0..00000000
--- a/overcooked_simulator/fastapi_game_server.py
+++ /dev/null
@@ -1,166 +0,0 @@
-import json
-import logging
-import threading
-from contextlib import asynccontextmanager
-
-import numpy as np
-import uvicorn
-from fastapi import FastAPI
-from fastapi import WebSocket
-from starlette.websockets import WebSocketDisconnect
-
-from overcooked_simulator import ROOT_DIR
-from overcooked_simulator.game_server_OLD import setup_logging
-from overcooked_simulator.overcooked_environment import Action
-from overcooked_simulator.simulation_runner import Simulator
-
-log = logging.getLogger(__name__)
-setup_logging()
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
-    setup_logging()
-    yield
-    for thread in threading.enumerate():
-        if isinstance(thread, Simulator):
-            thread.stop()
-            thread.join()
-
-
-app = FastAPI(lifespan=lifespan)
-
-
-WEBSOCKET_URL = "localhost"
-WEBSOCKET_PORT = 8000
-
-
-class GameServer:
-    simulator: Simulator
-
-    def __init__(self):
-        self.setup_game()
-
-        self.envs = {int: Simulator}
-
-    def create_env(self, n_players: int) -> (int, list[WebSocket]):
-        pass
-
-    def add_player(self, env_id) -> (int, WebSocket):
-        pass
-
-    def setup_game(self):
-        self.simulator = Simulator(
-            ROOT_DIR / "game_content" / "environment_config.yaml",
-            ROOT_DIR / "game_content" / "layouts" / "basic.layout",
-            600,
-        )
-        number_player = 2
-        for i in range(number_player):
-            player_name = f"p{i}"
-            self.simulator.register_player(player_name)
-        self.simulator.start()
-
-    def get_state(self):
-        return self.simulator.get_state_simple_json()
-
-    def reset_game(self):
-        self.simulator.stop()
-        self.setup_game()
-
-
-class ConnectionManager:
-    def __init__(self):
-        self.active_connections: list[WebSocket] = []
-
-    async def connect(self, websocket: WebSocket):
-        await websocket.accept()
-        self.active_connections.append(websocket)
-
-    def disconnect(self, websocket: WebSocket):
-        self.active_connections.remove(websocket)
-
-    async def send_personal_message(self, message: str, websocket: WebSocket):
-        await websocket.send_text(message)
-
-    async def broadcast(self, message: str):
-        for connection in self.active_connections:
-            await connection.send_text(message)
-
-
-manager = ConnectionManager()
-oc_api: GameServer = GameServer()
-
-
-def parse_websocket_action(message: str) -> Action:
-    if message.replace('"', "") != "get_state":
-        message_dict = json.loads(message)
-        if message_dict["act_type"] == "movement":
-            if isinstance(message_dict["value"], list):
-                x, y = message_dict["value"]
-            elif isinstance(message_dict["value"], str):
-                x, y = (
-                    message_dict["value"]
-                    .replace(" ", "")
-                    .replace("[", "")
-                    .replace("]", "")
-                    .split(",")
-                )
-            else:
-                x, y = 0, 0
-            value = np.array([x, y], dtype=float)
-        else:
-            value = None
-        action = Action(
-            message_dict["player_name"],
-            message_dict["act_type"],
-            value,
-            duration=message_dict["duration"],
-        )
-        return action
-
-
-def manage_websocket_message(message: str):
-    if "get_state" in message:
-        return oc_api.get_state()
-
-    if "reset_game" in message:
-        oc_api.reset_game()
-        return "Reset game."
-
-    action = parse_websocket_action(message)
-    oc_api.simulator.enter_action(action)
-    answer = oc_api.get_state()
-    return answer
-
-
-@app.get("/")
-def read_root():
-    return {"OVER": "COOKED"}
-
-
-@app.websocket("/ws/{client_id}")
-async def websocket_endpoint(websocket: WebSocket, client_id: int):
-    await manager.connect(websocket)
-    log.debug(f"Client #{client_id} connected")
-    try:
-        while True:
-            message = await websocket.receive_text()
-            answer = manage_websocket_message(message)
-            await manager.send_personal_message(answer, websocket)
-
-    except WebSocketDisconnect:
-        manager.disconnect(websocket)
-        log.debug(f"Client #{client_id} disconnected")
-
-
-def main():
-    uvicorn.run(app, host=WEBSOCKET_URL, port=WEBSOCKET_PORT)
-
-
-if __name__ == "__main__":
-    main()
-    """
-    Or in console: 
-    uvicorn overcooked_simulator.fastapi_game_server:app --reload
-    """
diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py
new file mode 100644
index 00000000..9fe683d4
--- /dev/null
+++ b/overcooked_simulator/game_server.py
@@ -0,0 +1,393 @@
+from __future__ import annotations
+
+import asyncio
+import dataclasses
+import json
+import logging
+import time
+import uuid
+from collections import defaultdict
+from datetime import datetime, timedelta
+from enum import Enum
+from typing import Set
+
+import numpy as np
+import uvicorn
+from fastapi import FastAPI
+from fastapi import WebSocket
+from overcooked_simulator.game_server_OLD import setup_logging
+from pydantic import BaseModel
+from starlette.websockets import WebSocketDisconnect
+from typing_extensions import TypedDict
+
+from overcooked_simulator.overcooked_environment import Action, Environment
+
+log = logging.getLogger(__name__)
+setup_logging()
+
+
+app = FastAPI()
+
+
+WEBSOCKET_URL = "localhost"
+WEBSOCKET_PORT = 8000
+
+
+@dataclasses.dataclass
+class PlayerData:
+    player_id: int
+    env_id: str
+    websocket_id: str | None = None
+    connected: bool = False
+    ready: bool = False
+    last_action: datetime | None = None
+    name: str = ""
+
+
+class EnvironmentSettings(TypedDict):
+    all_player_can_pause_game: bool
+    # env_steps_per_second: int
+
+
+class EnvironmentStatus(Enum):
+    WAITING_FOR_PLAYERS = "waitingForPlayers"
+    PAUSED = "paused"
+    RUNNING = "running"
+    STOPPED = "stopped"
+
+
+@dataclasses.dataclass
+class EnvironmentData:
+    environment: Environment
+    player_hashes: Set[str] = dataclasses.field(default_factory=set)
+    environment_settings: EnvironmentSettings = dataclasses.field(default_factory=dict)
+    status: EnvironmentStatus = EnvironmentStatus.WAITING_FOR_PLAYERS
+    stop_reason: str = ""
+    start_time: datetime | None = None
+    last_step_time: int | None = None
+
+
+class GameServer:
+    def __init__(self, env_step_frequency: int = 200):
+        self.envs: dict[str, EnvironmentData] = {}
+        self.player_data: dict[str, PlayerData] = {}
+        self.manager_envs: dict[str, Set[str]] = defaultdict(set)
+        self.env_step_frequency = env_step_frequency
+        self.preferred_sleep_time_ns = 1e9 / self.env_step_frequency
+
+    def create_env(self, environment_config: CreateEnvironmentConfig):
+        env_id = uuid.uuid4().hex
+
+        env = Environment(
+            env_config=environment_config.environment_config,
+            layout_config=environment_config.layout_config,
+            item_info=environment_config.item_info_config,
+            as_files=False,
+        )
+        player_info = {}
+        for player_id in range(environment_config.number_players):
+            player_info[player_id] = self.create_player(env, env_id, player_id)
+
+        self.envs[env_id] = EnvironmentData(environment=env)
+
+        self.manager_envs[environment_config.manager_id].update([env_id])
+
+        return {"env_id": env_id}
+
+    def create_player(self, env, env_id, player_id):
+        player_hash = uuid.uuid4().hex
+        client_id = uuid.uuid4().hex
+        player_data = PlayerData(
+            player_id=player_id,
+            env_id=env_id,
+            websocket_id=client_id,
+        )
+        self.player_data[player_hash] = player_data
+        env.add_player(player_id)
+
+        return {
+            "client_id": client_id,
+            "player_hash": player_hash,
+            "player_id": player_id,
+        }
+
+    def add_player(self, config: AdditionalPlayer):
+        new_player_info = {}
+        if (
+            config.manager_id in self.manager_envs
+            and config.env_id in self.manager_envs[config.manager_id]
+            and self.envs[config.env_id].status != EnvironmentStatus.STOPPED
+        ):
+            n_players = len(self.envs[config.env_id].player_hashes)
+            for player_id in range(n_players, n_players + config.number_players):
+                new_player_info[player_id] = self.create_player(
+                    env=self.envs[config.env_id].environment,
+                    env_id=config.env_id,
+                    player_id=player_id,
+                )
+        return new_player_info
+
+    def start_env(self, env_id: str):
+        if env_id in self.envs:
+            start_time = datetime.now()
+            self.envs[env_id].status = EnvironmentStatus.RUNNING
+            self.envs[env_id].start_time = start_time
+            self.envs[env_id].last_step_time = time.time_ns()
+            self.envs[env_id].environment.reset_env_time()
+
+    def get_state(self):
+        ...
+
+    def pause_env(self, manager_id: str, env_id: str, reason: str):
+        if (
+            manager_id in self.manager_envs
+            and env_id in self.manager_envs[manager_id]
+            and self.envs[env_id].status
+            not in [EnvironmentStatus.STOPPED, Environment.PAUSED]
+        ):
+            self.envs[env_id].status = EnvironmentStatus.PAUSED
+
+    def unpause_env(self, manager_id: str, env_id: str, reason: str):
+        if (
+            manager_id in self.manager_envs
+            and env_id in self.manager_envs[manager_id]
+            and self.envs[env_id].status
+            not in [EnvironmentStatus.STOPPED, Environment.PAUSED]
+        ):
+            self.envs[env_id].status = EnvironmentStatus.PAUSED
+            self.envs[env_id].last_step_time = time.time_ns()
+
+    def stop_env(self, manager_id: str, env_id: str, reason: str):
+        if (
+            manager_id in self.manager_envs
+            and env_id in self.manager_envs[manager_id]
+            and self.envs[env_id].status != EnvironmentStatus.STOPPED
+        ):
+            self.envs[env_id].status = EnvironmentStatus.STOPPED
+            self.envs[env_id].stop_reason = reason
+
+    def set_player_ready(self, env_id: str, player_hash, player_id: int):
+        if (
+            player_hash in self.player_data
+            and self.player_data[player_hash].player_id == player_id
+            and self.player_data[player_hash].env_id == env_id
+        ):
+            self.player_data[player_hash].ready = True
+            return True
+        return False
+
+    def set_player_connected(self, env_id: str, player_hash, player_id: int) -> bool:
+        if (
+            player_hash in self.player_data
+            and self.player_data[player_hash].player_id == player_id
+            and self.player_data[player_hash].env_id == env_id
+        ):
+            self.player_data[player_hash].connected = True
+            return True
+        return False
+
+    def set_player_disconnected(
+        self, env_id: str, player_hash: str, player_id: int
+    ) -> bool:
+        if (
+            player_hash in self.player_data
+            and self.player_data[player_hash].player_id == player_id
+            and self.player_data[player_hash].env_id == env_id
+        ):
+            self.player_data[player_hash].connected = False
+            return True
+        return False
+
+    def check_all_player_ready(self, env_id: str) -> bool:
+        return env_id in self.envs and all(
+            self.player_data[player_hash].connected
+            and self.player_data[player_hash].ready
+            for player_hash in self.envs[env_id].player_hashes
+        )
+
+    def check_all_players_connected(self, env_id: str) -> bool:
+        return env_id in self.envs and all(
+            self.player_data[player_hash].connected
+            for player_hash in self.envs[env_id].player_hashes
+        )
+
+    def get_not_connected_players(self, env_id: str) -> list[int]:
+        if env_id in self.envs:
+            return [
+                self.player_data[player_hash].player_id
+                for player_hash in self.envs[env_id].player_hashes
+                if not self.player_data[player_hash].connected
+            ]
+
+    def get_not_ready_players(self, env_id: str) -> list[int]:
+        if env_id in self.envs:
+            return [
+                self.player_data[player_hash].player_id
+                for player_hash in self.envs[env_id].player_hashes
+                if not self.player_data[player_hash].ready
+            ]
+
+    async def environment_steps(self):
+        overslept_in_ns = 0
+        while True:
+            pre_step_start = time.time_ns()
+            for env_id, env_data in self.envs.items():
+                if env_data.status == EnvironmentStatus.RUNNING:
+                    step_start = time.time_ns()
+                    env_data.environment.step(
+                        timedelta(
+                            seconds=(step_start - env_data.last_step_time)
+                            / 1_000_000_000
+                        )
+                    )
+                    env_data.last_step_time = step_start
+            step_duration = time.time_ns() - pre_step_start
+
+            time_to_sleep_ns = self.preferred_sleep_time_ns - (
+                step_duration + overslept_in_ns
+            )
+
+            sleep_start = time.time_ns()
+            await asyncio.sleep(max(time_to_sleep_ns / 1e9, 0))
+            sleep_function_duration = time.time_ns() - sleep_start
+            overslept_in_ns = sleep_function_duration - time_to_sleep_ns
+
+
+class PlayerConnectionManager:
+    def __init__(self):
+        self.player_connections: dict[str, WebSocket] = {}
+
+    async def connect_player(self, websocket: WebSocket, player_id: str) -> bool:
+        if player_id not in self.player_connections:
+            await websocket.accept()
+            self.player_connections[player_id] = websocket
+            return True
+        return False
+
+    def disconnect(self, id_: str):
+        if id_ in self.player_connections:
+            del self.player_connections[id_]
+
+    @staticmethod
+    async def send_personal_message(message: str, websocket: WebSocket):
+        await websocket.send_text(message)
+
+    async def broadcast(self, message: str):
+        for connection in self.player_connections.values():
+            await connection.send_text(message)
+
+
+manager = PlayerConnectionManager()
+oc_api: GameServer = GameServer()
+
+
+def parse_websocket_action(message: str) -> Action:
+    if message.replace('"', "") != "get_state":
+        message_dict = json.loads(message)
+        if message_dict["act_type"] == "movement":
+            if isinstance(message_dict["value"], list):
+                x, y = message_dict["value"]
+            elif isinstance(message_dict["value"], str):
+                x, y = (
+                    message_dict["value"]
+                    .replace(" ", "")
+                    .replace("[", "")
+                    .replace("]", "")
+                    .split(",")
+                )
+            else:
+                x, y = 0, 0
+            value = np.array([x, y], dtype=float)
+        else:
+            value = None
+        action = Action(
+            message_dict["player_name"],
+            message_dict["act_type"],
+            value,
+            duration=message_dict["duration"],
+        )
+        return action
+
+
+def manage_websocket_message(message: str):
+    if "get_state" in message:
+        return oc_api.get_state()
+
+    if "reset_game" in message:
+        oc_api.reset_game()
+        return "Reset game."
+
+    action = parse_websocket_action(message)
+    oc_api.simulator.enter_action(action)
+    answer = oc_api.get_state()
+    return answer
+
+
+@app.get("/")
+def read_root():
+    return {"OVER": "COOKED"}
+
+
+class CreateEnvironmentConfig(BaseModel):
+    manager_id: str
+    number_players: int
+    same_websocket_player: list[list[int]] | None = None
+    environment_settings: EnvironmentSettings
+    item_info_config: str
+    environment_config: str
+    layout_config: str
+
+
+class AdditionalPlayer(BaseModel):
+    manager_id: str
+    env_id: str
+    number_players: int
+    existing_websocket: str | None = None
+
+
+@app.post("/manage/create_env")
+async def register_manger(creation: CreateEnvironmentConfig):
+    result = oc_api.create_env(creation)
+    return result
+
+
+@app.post("/manage/additional_player")
+async def additional_player(creation: AdditionalPlayer):
+    result = oc_api.add_player(creation)
+    return result
+
+
+@app.post("manage/stop_env")
+async def stop_env(manager_id: str, env_id: str, reason: str):
+    result = oc_api.stop_env(manager_id, env_id, reason)
+    return result
+
+
+# control access / functions / data
+
+
+@app.websocket("/ws/player/{client_id}")
+async def websocket_player_endpoint(websocket: WebSocket, client_id: int):
+    await manager.connect(websocket)
+    log.debug(f"Client #{client_id} connected")
+    try:
+        while True:
+            message = await websocket.receive_text()
+            answer = manage_websocket_message(message)
+            await manager.send_personal_message(answer, websocket)
+
+    except WebSocketDisconnect:
+        manager.disconnect(websocket)
+        log.debug(f"Client #{client_id} disconnected")
+
+
+def main():
+    uvicorn.run(app, host=WEBSOCKET_URL, port=WEBSOCKET_PORT)
+
+
+if __name__ == "__main__":
+    main()
+    """
+    Or in console: 
+    uvicorn overcooked_simulator.fastapi_game_server:app --reload
+    """
diff --git a/overcooked_simulator/game_server_OLD.py b/overcooked_simulator/game_server_OLD.py
deleted file mode 100644
index 4497a07c..00000000
--- a/overcooked_simulator/game_server_OLD.py
+++ /dev/null
@@ -1,123 +0,0 @@
-import asyncio
-import json
-import logging
-import os
-import sys
-import threading
-from datetime import datetime
-
-import numpy as np
-from websockets.server import serve
-
-from overcooked_simulator import ROOT_DIR
-from overcooked_simulator.overcooked_environment import Action
-from overcooked_simulator.simulation_runner import Simulator
-
-log = logging.getLogger(__name__)
-
-
-WEBSOCKET_URL = "localhost"
-WEBSOCKET_PORT = 8765
-
-
-class Connector:
-    def __init__(self, simulator: Simulator):
-        self.simulator: Simulator = simulator
-
-        self.last_message_time = datetime.now()
-
-        super().__init__()
-
-    async def process_message(self, websocket):
-        """
-
-        Args:
-            websocket:
-
-        Returns:
-
-        """
-        async for message in websocket:
-            if message.replace('"', "") != "get_state":
-                message_dict = json.loads(message)
-                if message_dict["act_type"] == "movement":
-                    if isinstance(message_dict["value"], list):
-                        x, y = message_dict["value"]
-                    elif isinstance(message_dict["value"], str):
-                        x, y = (
-                            message_dict["value"]
-                            .replace(" ", "")
-                            .replace("[", "")
-                            .replace("]", "")
-                            .split(",")
-                        )
-                    else:
-                        x, y = 0, 0
-                    value = np.array([x, y], dtype=float)
-                else:
-                    value = None
-                action = Action(
-                    message_dict["player_name"], message_dict["act_type"], value
-                )
-                self.simulator.enter_action(action)
-
-            json_answer = self.simulator.get_state_simple_json()
-
-            # print("json:", json_answer, type(json_answer))
-            await websocket.send(json_answer)
-
-    async def connection_server(self):
-        async with serve(self.process_message, WEBSOCKET_URL, WEBSOCKET_PORT):
-            await asyncio.Future()  # run forever
-
-    def set_sim(self, simulation_runner: Simulator):
-        self.simulator = simulation_runner
-
-    def start_connector(self):
-        asyncio.run(self.connection_server())
-
-
-def setup_logging():
-    path_logs = ROOT_DIR.parent / "logs"
-    os.makedirs(path_logs, exist_ok=True)
-    logging.basicConfig(
-        level=logging.DEBUG,
-        format="%(asctime)s %(levelname)-8s %(name)-50s %(message)s",
-        handlers=[
-            logging.FileHandler(
-                path_logs / f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_debug.log",
-                encoding="utf-8",
-            ),
-            logging.StreamHandler(sys.stdout),
-        ],
-    )
-
-
-def main():
-    simulator = Simulator(
-        ROOT_DIR / "game_content" / "environment_config.yaml",
-        ROOT_DIR / "game_content" / "layouts" / "basic.layout",
-        600,
-    )
-    number_player = 2
-    for i in range(number_player):
-        player_name = f"p{i}"
-        simulator.register_player(player_name)
-    simulator.start()
-
-    print(simulator.get_state_simple_json())
-    connector = Connector(simulator)
-    connector.start_connector()
-
-
-if __name__ == "__main__":
-    setup_logging()
-    try:
-        main()
-    except Exception as e:
-        log.exception(e)
-        for thread in threading.enumerate():
-            if isinstance(thread, Simulator):
-                thread.stop()
-                thread.join()
-        sys.exit(1)
diff --git a/overcooked_simulator/main.py b/overcooked_simulator/main.py
index ae9bcfbd..ff5fdcf0 100644
--- a/overcooked_simulator/main.py
+++ b/overcooked_simulator/main.py
@@ -1,10 +1,33 @@
 import logging
+import os
+import sys
+from datetime import datetime
+
+from overcooked_simulator import ROOT_DIR
 
 log = logging.getLogger(__name__)
 
+
+def setup_logging():
+    path_logs = ROOT_DIR.parent / "logs"
+    os.makedirs(path_logs, exist_ok=True)
+    logging.basicConfig(
+        level=logging.DEBUG,
+        format="%(asctime)s %(levelname)-8s %(name)-50s %(message)s",
+        handlers=[
+            logging.FileHandler(
+                path_logs / f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_debug.log",
+                encoding="utf-8",
+            ),
+            logging.StreamHandler(sys.stdout),
+        ],
+    )
+    logging.getLogger("matplotlib").setLevel(logging.WARNING)
+
+
 if __name__ == "__main__":
     # os.popen(
     #     "mamba activate overooked-simulator & uvicorn overcooked_simulator.fastapi_game_server:app"
     # )
     # gui_main()
-    pass
\ No newline at end of file
+    pass
diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py
index ae5928dd..a0d856cd 100644
--- a/overcooked_simulator/overcooked_environment.py
+++ b/overcooked_simulator/overcooked_environment.py
@@ -62,17 +62,29 @@ class Environment:
     # TODO Abstract base class for different environments
     """
 
-    def __init__(self, env_config_path: Path, layout_path, item_info_path: Path):
+    PAUSED = None
+
+    def __init__(
+        self,
+        env_config: Path | str,
+        layout_config: Path | str,
+        item_info: Path | str,
+        as_files: bool = True,
+    ):
         self.lock = Lock()
         self.players: dict[str, Player] = {}
 
-        with open(env_config_path, "r") as file:
-            self.environment_config = yaml.load(file, Loader=yaml.Loader)
-        self.layout_path: Path = layout_path
+        self.as_files = as_files
+
+        if self.as_files:
+            with open(env_config, "r") as file:
+                self.environment_config = yaml.load(file, Loader=yaml.Loader)
+        else:
+            self.environment_config = yaml.load(env_config, Loader=yaml.Loader)
+        self.layout_config = layout_config
         # self.counter_side_length = 1  # -> this changed! is 1 now
 
-        self.item_info_path: Path = item_info_path
-        self.item_info = self.load_item_info()
+        self.item_info = self.load_item_info(item_info)
         self.validate_item_info()
         if self.environment_config["meals"]["all"]:
             self.allowed_meal_names = set(
@@ -186,7 +198,7 @@ class Environment:
             self.counters,
             self.designated_player_positions,
             self.free_positions,
-        ) = self.parse_layout_file(self.layout_path)
+        ) = self.parse_layout_file()
 
         self.init_counters()
 
@@ -205,9 +217,12 @@ class Environment:
     def game_ended(self) -> bool:
         return self.env_time >= self.env_time_end
 
-    def load_item_info(self) -> dict[str, ItemInfo]:
-        with open(self.item_info_path, "r") as file:
-            item_lookup = yaml.safe_load(file)
+    def load_item_info(self, data) -> dict[str, ItemInfo]:
+        if self.as_files:
+            with open(data, "r") as file:
+                item_lookup = yaml.safe_load(file)
+        else:
+            item_lookup = yaml.safe_load(data)
         for item_name in item_lookup:
             item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name])
 
@@ -277,7 +292,7 @@ class Environment:
         # TODO add colors for ingredients, equipment and meals
         # plt.show()
 
-    def parse_layout_file(self, layout_file: Path):
+    def parse_layout_file(self):
         """Creates layout of kitchen counters in the environment based on layout file.
         Counters are arranged in a fixed size grid starting at [0,0]. The center of the first counter is at
         [counter_size/2, counter_size/2], counters are directly next to each other (of no empty space is specified
@@ -293,9 +308,12 @@ class Environment:
 
         self.kitchen_width = 0
 
-        with open(layout_file, "r") as layout_file:
-            lines = layout_file.readlines()
-            self.kitchen_height = len(lines)
+        if self.as_files:
+            with open(self.layout_config, "r") as layout_file:
+                lines = layout_file.readlines()
+        else:
+            lines = self.layout_config.split("\n")
+        self.kitchen_height = len(lines)
 
         for line in lines:
             line = line.replace("\n", "").replace(" ", "")  # remove newline char
@@ -535,7 +553,7 @@ class Environment:
         distance = np.linalg.norm([dx, dy])
         return distance < (player.radius)
 
-    def add_player(self, player_name: str, pos: npt.NDArray = None):
+    def add_player(self, player_name: int | str, pos: npt.NDArray = None):
         log.debug(f"Add player {player_name} to the game")
         player = Player(
             player_name, player_config=self.environment_config["player_config"], pos=pos
diff --git a/overcooked_simulator/player.py b/overcooked_simulator/player.py
index de97d0e2..ed1c29a9 100644
--- a/overcooked_simulator/player.py
+++ b/overcooked_simulator/player.py
@@ -21,11 +21,11 @@ class Player:
 
     def __init__(
         self,
-        name: str,
+        name: int | str,
         player_config: dict[str, Any],
         pos: Optional[npt.NDArray[float]] = None,
     ):
-        self.name: str = name
+        self.name: int | str = name
         self.player_config = player_config
         if pos is not None:
             self.pos: npt.NDArray[float] = np.array(pos, dtype=float)
-- 
GitLab