diff --git a/overcooked_simulator/api_call.py b/overcooked_simulator/api_call.py deleted file mode 100644 index ea733f530cf7dfb1759238c8a0452d3e7028cb94..0000000000000000000000000000000000000000 --- a/overcooked_simulator/api_call.py +++ /dev/null @@ -1,14 +0,0 @@ -# websocket_client.py -import asyncio - -import websockets - - -async def send_message(): - uri = "ws://127.0.0.1:8000/ws" - async with websockets.connect(uri) as websocket: - await websocket.send("Hello, server!") - response = await websocket.recv() - print(response) - -asyncio.run(send_message()) diff --git a/overcooked_simulator/fastapi_game_server.py b/overcooked_simulator/fastapi_game_server.py deleted file mode 100644 index eb2695e02111f6451ccad077ab43e0ae5a4bfc9a..0000000000000000000000000000000000000000 --- a/overcooked_simulator/fastapi_game_server.py +++ /dev/null @@ -1,166 +0,0 @@ -import json -import logging -import threading -from contextlib import asynccontextmanager - -import numpy as np -import uvicorn -from fastapi import FastAPI -from fastapi import WebSocket -from starlette.websockets import WebSocketDisconnect - -from overcooked_simulator import ROOT_DIR -from overcooked_simulator.game_server_OLD import setup_logging -from overcooked_simulator.overcooked_environment import Action -from overcooked_simulator.simulation_runner import Simulator - -log = logging.getLogger(__name__) -setup_logging() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - setup_logging() - yield - for thread in threading.enumerate(): - if isinstance(thread, Simulator): - thread.stop() - thread.join() - - -app = FastAPI(lifespan=lifespan) - - -WEBSOCKET_URL = "localhost" -WEBSOCKET_PORT = 8000 - - -class GameServer: - simulator: Simulator - - def __init__(self): - self.setup_game() - - self.envs = {int: Simulator} - - def create_env(self, n_players: int) -> (int, list[WebSocket]): - pass - - def add_player(self, env_id) -> (int, WebSocket): - pass - - def setup_game(self): - self.simulator = Simulator( - ROOT_DIR / "game_content" / "environment_config.yaml", - ROOT_DIR / "game_content" / "layouts" / "basic.layout", - 600, - ) - number_player = 2 - for i in range(number_player): - player_name = f"p{i}" - self.simulator.register_player(player_name) - self.simulator.start() - - def get_state(self): - return self.simulator.get_state_simple_json() - - def reset_game(self): - self.simulator.stop() - self.setup_game() - - -class ConnectionManager: - def __init__(self): - self.active_connections: list[WebSocket] = [] - - async def connect(self, websocket: WebSocket): - await websocket.accept() - self.active_connections.append(websocket) - - def disconnect(self, websocket: WebSocket): - self.active_connections.remove(websocket) - - async def send_personal_message(self, message: str, websocket: WebSocket): - await websocket.send_text(message) - - async def broadcast(self, message: str): - for connection in self.active_connections: - await connection.send_text(message) - - -manager = ConnectionManager() -oc_api: GameServer = GameServer() - - -def parse_websocket_action(message: str) -> Action: - if message.replace('"', "") != "get_state": - message_dict = json.loads(message) - if message_dict["act_type"] == "movement": - if isinstance(message_dict["value"], list): - x, y = message_dict["value"] - elif isinstance(message_dict["value"], str): - x, y = ( - message_dict["value"] - .replace(" ", "") - .replace("[", "") - .replace("]", "") - .split(",") - ) - else: - x, y = 0, 0 - value = np.array([x, y], dtype=float) - else: - value = None - action = Action( - message_dict["player_name"], - message_dict["act_type"], - value, - duration=message_dict["duration"], - ) - return action - - -def manage_websocket_message(message: str): - if "get_state" in message: - return oc_api.get_state() - - if "reset_game" in message: - oc_api.reset_game() - return "Reset game." - - action = parse_websocket_action(message) - oc_api.simulator.enter_action(action) - answer = oc_api.get_state() - return answer - - -@app.get("/") -def read_root(): - return {"OVER": "COOKED"} - - -@app.websocket("/ws/{client_id}") -async def websocket_endpoint(websocket: WebSocket, client_id: int): - await manager.connect(websocket) - log.debug(f"Client #{client_id} connected") - try: - while True: - message = await websocket.receive_text() - answer = manage_websocket_message(message) - await manager.send_personal_message(answer, websocket) - - except WebSocketDisconnect: - manager.disconnect(websocket) - log.debug(f"Client #{client_id} disconnected") - - -def main(): - uvicorn.run(app, host=WEBSOCKET_URL, port=WEBSOCKET_PORT) - - -if __name__ == "__main__": - main() - """ - Or in console: - uvicorn overcooked_simulator.fastapi_game_server:app --reload - """ diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe683d4339516ad354edb24e01e282e502f98f4 --- /dev/null +++ b/overcooked_simulator/game_server.py @@ -0,0 +1,393 @@ +from __future__ import annotations + +import asyncio +import dataclasses +import json +import logging +import time +import uuid +from collections import defaultdict +from datetime import datetime, timedelta +from enum import Enum +from typing import Set + +import numpy as np +import uvicorn +from fastapi import FastAPI +from fastapi import WebSocket +from overcooked_simulator.game_server_OLD import setup_logging +from pydantic import BaseModel +from starlette.websockets import WebSocketDisconnect +from typing_extensions import TypedDict + +from overcooked_simulator.overcooked_environment import Action, Environment + +log = logging.getLogger(__name__) +setup_logging() + + +app = FastAPI() + + +WEBSOCKET_URL = "localhost" +WEBSOCKET_PORT = 8000 + + +@dataclasses.dataclass +class PlayerData: + player_id: int + env_id: str + websocket_id: str | None = None + connected: bool = False + ready: bool = False + last_action: datetime | None = None + name: str = "" + + +class EnvironmentSettings(TypedDict): + all_player_can_pause_game: bool + # env_steps_per_second: int + + +class EnvironmentStatus(Enum): + WAITING_FOR_PLAYERS = "waitingForPlayers" + PAUSED = "paused" + RUNNING = "running" + STOPPED = "stopped" + + +@dataclasses.dataclass +class EnvironmentData: + environment: Environment + player_hashes: Set[str] = dataclasses.field(default_factory=set) + environment_settings: EnvironmentSettings = dataclasses.field(default_factory=dict) + status: EnvironmentStatus = EnvironmentStatus.WAITING_FOR_PLAYERS + stop_reason: str = "" + start_time: datetime | None = None + last_step_time: int | None = None + + +class GameServer: + def __init__(self, env_step_frequency: int = 200): + self.envs: dict[str, EnvironmentData] = {} + self.player_data: dict[str, PlayerData] = {} + self.manager_envs: dict[str, Set[str]] = defaultdict(set) + self.env_step_frequency = env_step_frequency + self.preferred_sleep_time_ns = 1e9 / self.env_step_frequency + + def create_env(self, environment_config: CreateEnvironmentConfig): + env_id = uuid.uuid4().hex + + env = Environment( + env_config=environment_config.environment_config, + layout_config=environment_config.layout_config, + item_info=environment_config.item_info_config, + as_files=False, + ) + player_info = {} + for player_id in range(environment_config.number_players): + player_info[player_id] = self.create_player(env, env_id, player_id) + + self.envs[env_id] = EnvironmentData(environment=env) + + self.manager_envs[environment_config.manager_id].update([env_id]) + + return {"env_id": env_id} + + def create_player(self, env, env_id, player_id): + player_hash = uuid.uuid4().hex + client_id = uuid.uuid4().hex + player_data = PlayerData( + player_id=player_id, + env_id=env_id, + websocket_id=client_id, + ) + self.player_data[player_hash] = player_data + env.add_player(player_id) + + return { + "client_id": client_id, + "player_hash": player_hash, + "player_id": player_id, + } + + def add_player(self, config: AdditionalPlayer): + new_player_info = {} + if ( + config.manager_id in self.manager_envs + and config.env_id in self.manager_envs[config.manager_id] + and self.envs[config.env_id].status != EnvironmentStatus.STOPPED + ): + n_players = len(self.envs[config.env_id].player_hashes) + for player_id in range(n_players, n_players + config.number_players): + new_player_info[player_id] = self.create_player( + env=self.envs[config.env_id].environment, + env_id=config.env_id, + player_id=player_id, + ) + return new_player_info + + def start_env(self, env_id: str): + if env_id in self.envs: + start_time = datetime.now() + self.envs[env_id].status = EnvironmentStatus.RUNNING + self.envs[env_id].start_time = start_time + self.envs[env_id].last_step_time = time.time_ns() + self.envs[env_id].environment.reset_env_time() + + def get_state(self): + ... + + def pause_env(self, manager_id: str, env_id: str, reason: str): + if ( + manager_id in self.manager_envs + and env_id in self.manager_envs[manager_id] + and self.envs[env_id].status + not in [EnvironmentStatus.STOPPED, Environment.PAUSED] + ): + self.envs[env_id].status = EnvironmentStatus.PAUSED + + def unpause_env(self, manager_id: str, env_id: str, reason: str): + if ( + manager_id in self.manager_envs + and env_id in self.manager_envs[manager_id] + and self.envs[env_id].status + not in [EnvironmentStatus.STOPPED, Environment.PAUSED] + ): + self.envs[env_id].status = EnvironmentStatus.PAUSED + self.envs[env_id].last_step_time = time.time_ns() + + def stop_env(self, manager_id: str, env_id: str, reason: str): + if ( + manager_id in self.manager_envs + and env_id in self.manager_envs[manager_id] + and self.envs[env_id].status != EnvironmentStatus.STOPPED + ): + self.envs[env_id].status = EnvironmentStatus.STOPPED + self.envs[env_id].stop_reason = reason + + def set_player_ready(self, env_id: str, player_hash, player_id: int): + if ( + player_hash in self.player_data + and self.player_data[player_hash].player_id == player_id + and self.player_data[player_hash].env_id == env_id + ): + self.player_data[player_hash].ready = True + return True + return False + + def set_player_connected(self, env_id: str, player_hash, player_id: int) -> bool: + if ( + player_hash in self.player_data + and self.player_data[player_hash].player_id == player_id + and self.player_data[player_hash].env_id == env_id + ): + self.player_data[player_hash].connected = True + return True + return False + + def set_player_disconnected( + self, env_id: str, player_hash: str, player_id: int + ) -> bool: + if ( + player_hash in self.player_data + and self.player_data[player_hash].player_id == player_id + and self.player_data[player_hash].env_id == env_id + ): + self.player_data[player_hash].connected = False + return True + return False + + def check_all_player_ready(self, env_id: str) -> bool: + return env_id in self.envs and all( + self.player_data[player_hash].connected + and self.player_data[player_hash].ready + for player_hash in self.envs[env_id].player_hashes + ) + + def check_all_players_connected(self, env_id: str) -> bool: + return env_id in self.envs and all( + self.player_data[player_hash].connected + for player_hash in self.envs[env_id].player_hashes + ) + + def get_not_connected_players(self, env_id: str) -> list[int]: + if env_id in self.envs: + return [ + self.player_data[player_hash].player_id + for player_hash in self.envs[env_id].player_hashes + if not self.player_data[player_hash].connected + ] + + def get_not_ready_players(self, env_id: str) -> list[int]: + if env_id in self.envs: + return [ + self.player_data[player_hash].player_id + for player_hash in self.envs[env_id].player_hashes + if not self.player_data[player_hash].ready + ] + + async def environment_steps(self): + overslept_in_ns = 0 + while True: + pre_step_start = time.time_ns() + for env_id, env_data in self.envs.items(): + if env_data.status == EnvironmentStatus.RUNNING: + step_start = time.time_ns() + env_data.environment.step( + timedelta( + seconds=(step_start - env_data.last_step_time) + / 1_000_000_000 + ) + ) + env_data.last_step_time = step_start + step_duration = time.time_ns() - pre_step_start + + time_to_sleep_ns = self.preferred_sleep_time_ns - ( + step_duration + overslept_in_ns + ) + + sleep_start = time.time_ns() + await asyncio.sleep(max(time_to_sleep_ns / 1e9, 0)) + sleep_function_duration = time.time_ns() - sleep_start + overslept_in_ns = sleep_function_duration - time_to_sleep_ns + + +class PlayerConnectionManager: + def __init__(self): + self.player_connections: dict[str, WebSocket] = {} + + async def connect_player(self, websocket: WebSocket, player_id: str) -> bool: + if player_id not in self.player_connections: + await websocket.accept() + self.player_connections[player_id] = websocket + return True + return False + + def disconnect(self, id_: str): + if id_ in self.player_connections: + del self.player_connections[id_] + + @staticmethod + async def send_personal_message(message: str, websocket: WebSocket): + await websocket.send_text(message) + + async def broadcast(self, message: str): + for connection in self.player_connections.values(): + await connection.send_text(message) + + +manager = PlayerConnectionManager() +oc_api: GameServer = GameServer() + + +def parse_websocket_action(message: str) -> Action: + if message.replace('"', "") != "get_state": + message_dict = json.loads(message) + if message_dict["act_type"] == "movement": + if isinstance(message_dict["value"], list): + x, y = message_dict["value"] + elif isinstance(message_dict["value"], str): + x, y = ( + message_dict["value"] + .replace(" ", "") + .replace("[", "") + .replace("]", "") + .split(",") + ) + else: + x, y = 0, 0 + value = np.array([x, y], dtype=float) + else: + value = None + action = Action( + message_dict["player_name"], + message_dict["act_type"], + value, + duration=message_dict["duration"], + ) + return action + + +def manage_websocket_message(message: str): + if "get_state" in message: + return oc_api.get_state() + + if "reset_game" in message: + oc_api.reset_game() + return "Reset game." + + action = parse_websocket_action(message) + oc_api.simulator.enter_action(action) + answer = oc_api.get_state() + return answer + + +@app.get("/") +def read_root(): + return {"OVER": "COOKED"} + + +class CreateEnvironmentConfig(BaseModel): + manager_id: str + number_players: int + same_websocket_player: list[list[int]] | None = None + environment_settings: EnvironmentSettings + item_info_config: str + environment_config: str + layout_config: str + + +class AdditionalPlayer(BaseModel): + manager_id: str + env_id: str + number_players: int + existing_websocket: str | None = None + + +@app.post("/manage/create_env") +async def register_manger(creation: CreateEnvironmentConfig): + result = oc_api.create_env(creation) + return result + + +@app.post("/manage/additional_player") +async def additional_player(creation: AdditionalPlayer): + result = oc_api.add_player(creation) + return result + + +@app.post("manage/stop_env") +async def stop_env(manager_id: str, env_id: str, reason: str): + result = oc_api.stop_env(manager_id, env_id, reason) + return result + + +# control access / functions / data + + +@app.websocket("/ws/player/{client_id}") +async def websocket_player_endpoint(websocket: WebSocket, client_id: int): + await manager.connect(websocket) + log.debug(f"Client #{client_id} connected") + try: + while True: + message = await websocket.receive_text() + answer = manage_websocket_message(message) + await manager.send_personal_message(answer, websocket) + + except WebSocketDisconnect: + manager.disconnect(websocket) + log.debug(f"Client #{client_id} disconnected") + + +def main(): + uvicorn.run(app, host=WEBSOCKET_URL, port=WEBSOCKET_PORT) + + +if __name__ == "__main__": + main() + """ + Or in console: + uvicorn overcooked_simulator.fastapi_game_server:app --reload + """ diff --git a/overcooked_simulator/game_server_OLD.py b/overcooked_simulator/game_server_OLD.py deleted file mode 100644 index 4497a07c1a55396cca60cc1b0207022a7dc3e5c3..0000000000000000000000000000000000000000 --- a/overcooked_simulator/game_server_OLD.py +++ /dev/null @@ -1,123 +0,0 @@ -import asyncio -import json -import logging -import os -import sys -import threading -from datetime import datetime - -import numpy as np -from websockets.server import serve - -from overcooked_simulator import ROOT_DIR -from overcooked_simulator.overcooked_environment import Action -from overcooked_simulator.simulation_runner import Simulator - -log = logging.getLogger(__name__) - - -WEBSOCKET_URL = "localhost" -WEBSOCKET_PORT = 8765 - - -class Connector: - def __init__(self, simulator: Simulator): - self.simulator: Simulator = simulator - - self.last_message_time = datetime.now() - - super().__init__() - - async def process_message(self, websocket): - """ - - Args: - websocket: - - Returns: - - """ - async for message in websocket: - if message.replace('"', "") != "get_state": - message_dict = json.loads(message) - if message_dict["act_type"] == "movement": - if isinstance(message_dict["value"], list): - x, y = message_dict["value"] - elif isinstance(message_dict["value"], str): - x, y = ( - message_dict["value"] - .replace(" ", "") - .replace("[", "") - .replace("]", "") - .split(",") - ) - else: - x, y = 0, 0 - value = np.array([x, y], dtype=float) - else: - value = None - action = Action( - message_dict["player_name"], message_dict["act_type"], value - ) - self.simulator.enter_action(action) - - json_answer = self.simulator.get_state_simple_json() - - # print("json:", json_answer, type(json_answer)) - await websocket.send(json_answer) - - async def connection_server(self): - async with serve(self.process_message, WEBSOCKET_URL, WEBSOCKET_PORT): - await asyncio.Future() # run forever - - def set_sim(self, simulation_runner: Simulator): - self.simulator = simulation_runner - - def start_connector(self): - asyncio.run(self.connection_server()) - - -def setup_logging(): - path_logs = ROOT_DIR.parent / "logs" - os.makedirs(path_logs, exist_ok=True) - logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s %(levelname)-8s %(name)-50s %(message)s", - handlers=[ - logging.FileHandler( - path_logs / f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_debug.log", - encoding="utf-8", - ), - logging.StreamHandler(sys.stdout), - ], - ) - - -def main(): - simulator = Simulator( - ROOT_DIR / "game_content" / "environment_config.yaml", - ROOT_DIR / "game_content" / "layouts" / "basic.layout", - 600, - ) - number_player = 2 - for i in range(number_player): - player_name = f"p{i}" - simulator.register_player(player_name) - simulator.start() - - print(simulator.get_state_simple_json()) - connector = Connector(simulator) - connector.start_connector() - - -if __name__ == "__main__": - setup_logging() - try: - main() - except Exception as e: - log.exception(e) - for thread in threading.enumerate(): - if isinstance(thread, Simulator): - thread.stop() - thread.join() - sys.exit(1) diff --git a/overcooked_simulator/main.py b/overcooked_simulator/main.py index ae9bcfbd575f60801b643556e7be25a3d672133e..ff5fdcf01f2ddcb78e5523695ea09ca93e8af809 100644 --- a/overcooked_simulator/main.py +++ b/overcooked_simulator/main.py @@ -1,10 +1,33 @@ import logging +import os +import sys +from datetime import datetime + +from overcooked_simulator import ROOT_DIR log = logging.getLogger(__name__) + +def setup_logging(): + path_logs = ROOT_DIR.parent / "logs" + os.makedirs(path_logs, exist_ok=True) + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(levelname)-8s %(name)-50s %(message)s", + handlers=[ + logging.FileHandler( + path_logs / f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_debug.log", + encoding="utf-8", + ), + logging.StreamHandler(sys.stdout), + ], + ) + logging.getLogger("matplotlib").setLevel(logging.WARNING) + + if __name__ == "__main__": # os.popen( # "mamba activate overooked-simulator & uvicorn overcooked_simulator.fastapi_game_server:app" # ) # gui_main() - pass \ No newline at end of file + pass diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index ae5928ddd5927ff3ba149ac53af1483bfeb7c278..a0d856cd61a990e59419d860006d19e144da37ea 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -62,17 +62,29 @@ class Environment: # TODO Abstract base class for different environments """ - def __init__(self, env_config_path: Path, layout_path, item_info_path: Path): + PAUSED = None + + def __init__( + self, + env_config: Path | str, + layout_config: Path | str, + item_info: Path | str, + as_files: bool = True, + ): self.lock = Lock() self.players: dict[str, Player] = {} - with open(env_config_path, "r") as file: - self.environment_config = yaml.load(file, Loader=yaml.Loader) - self.layout_path: Path = layout_path + self.as_files = as_files + + if self.as_files: + with open(env_config, "r") as file: + self.environment_config = yaml.load(file, Loader=yaml.Loader) + else: + self.environment_config = yaml.load(env_config, Loader=yaml.Loader) + self.layout_config = layout_config # self.counter_side_length = 1 # -> this changed! is 1 now - self.item_info_path: Path = item_info_path - self.item_info = self.load_item_info() + self.item_info = self.load_item_info(item_info) self.validate_item_info() if self.environment_config["meals"]["all"]: self.allowed_meal_names = set( @@ -186,7 +198,7 @@ class Environment: self.counters, self.designated_player_positions, self.free_positions, - ) = self.parse_layout_file(self.layout_path) + ) = self.parse_layout_file() self.init_counters() @@ -205,9 +217,12 @@ class Environment: def game_ended(self) -> bool: return self.env_time >= self.env_time_end - def load_item_info(self) -> dict[str, ItemInfo]: - with open(self.item_info_path, "r") as file: - item_lookup = yaml.safe_load(file) + def load_item_info(self, data) -> dict[str, ItemInfo]: + if self.as_files: + with open(data, "r") as file: + item_lookup = yaml.safe_load(file) + else: + item_lookup = yaml.safe_load(data) for item_name in item_lookup: item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name]) @@ -277,7 +292,7 @@ class Environment: # TODO add colors for ingredients, equipment and meals # plt.show() - def parse_layout_file(self, layout_file: Path): + def parse_layout_file(self): """Creates layout of kitchen counters in the environment based on layout file. Counters are arranged in a fixed size grid starting at [0,0]. The center of the first counter is at [counter_size/2, counter_size/2], counters are directly next to each other (of no empty space is specified @@ -293,9 +308,12 @@ class Environment: self.kitchen_width = 0 - with open(layout_file, "r") as layout_file: - lines = layout_file.readlines() - self.kitchen_height = len(lines) + if self.as_files: + with open(self.layout_config, "r") as layout_file: + lines = layout_file.readlines() + else: + lines = self.layout_config.split("\n") + self.kitchen_height = len(lines) for line in lines: line = line.replace("\n", "").replace(" ", "") # remove newline char @@ -535,7 +553,7 @@ class Environment: distance = np.linalg.norm([dx, dy]) return distance < (player.radius) - def add_player(self, player_name: str, pos: npt.NDArray = None): + def add_player(self, player_name: int | str, pos: npt.NDArray = None): log.debug(f"Add player {player_name} to the game") player = Player( player_name, player_config=self.environment_config["player_config"], pos=pos diff --git a/overcooked_simulator/player.py b/overcooked_simulator/player.py index de97d0e2014df26077d30dad61279f4a08084867..ed1c29a950b423c5d673683dc676e9313035b292 100644 --- a/overcooked_simulator/player.py +++ b/overcooked_simulator/player.py @@ -21,11 +21,11 @@ class Player: def __init__( self, - name: str, + name: int | str, player_config: dict[str, Any], pos: Optional[npt.NDArray[float]] = None, ): - self.name: str = name + self.name: int | str = name self.player_config = player_config if pos is not None: self.pos: npt.NDArray[float] = np.array(pos, dtype=float)