From 7ba326c6fb0f7ab53ed8cafe86688020f82e57ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20Schr=C3=B6der?= <fschroeder@techfak.uni-bielefeld.de> Date: Wed, 24 Jan 2024 19:34:48 +0100 Subject: [PATCH] Update server and client communication in game simulator The server (game_server.py) and client communication (overcooked_gui.py) now uses the WebSocket communication protocol. In setup.py, the 'requests' module was added as a new requirement. The game simulator now waits for a player to be ready before starting, stops a game environment if no step is taken within a minute, and pauses or unpauses a game environment. Player actions are now handled based on a new 'Action' type. The server also now handles several client types that can send messages. --- overcooked_simulator/game_server.py | 203 ++++++++------- .../gui_2d_vis/overcooked_gui.py | 241 +++++++++++------- setup.py | 1 + 3 files changed, 261 insertions(+), 184 deletions(-) diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py index 9fe683d4..e5d50146 100644 --- a/overcooked_simulator/game_server.py +++ b/overcooked_simulator/game_server.py @@ -11,15 +11,14 @@ from datetime import datetime, timedelta from enum import Enum from typing import Set -import numpy as np import uvicorn from fastapi import FastAPI from fastapi import WebSocket -from overcooked_simulator.game_server_OLD import setup_logging from pydantic import BaseModel from starlette.websockets import WebSocketDisconnect from typing_extensions import TypedDict +from overcooked_simulator.main import setup_logging from overcooked_simulator.overcooked_environment import Action, Environment log = logging.getLogger(__name__) @@ -67,13 +66,14 @@ class EnvironmentData: last_step_time: int | None = None -class GameServer: +class EnvironmentHandler: def __init__(self, env_step_frequency: int = 200): self.envs: dict[str, EnvironmentData] = {} self.player_data: dict[str, PlayerData] = {} self.manager_envs: dict[str, Set[str]] = defaultdict(set) self.env_step_frequency = env_step_frequency self.preferred_sleep_time_ns = 1e9 / self.env_step_frequency + self.client_ids_to_player_hashes = {} def create_env(self, environment_config: CreateEnvironmentConfig): env_id = uuid.uuid4().hex @@ -92,7 +92,7 @@ class GameServer: self.manager_envs[environment_config.manager_id].update([env_id]) - return {"env_id": env_id} + return {"env_id": env_id, "player_info": player_info} def create_player(self, env, env_id, player_id): player_hash = uuid.uuid4().hex @@ -103,6 +103,7 @@ class GameServer: websocket_id=client_id, ) self.player_data[player_hash] = player_data + self.client_ids_to_player_hashes[client_id] = player_hash env.add_player(player_id) return { @@ -135,8 +136,15 @@ class GameServer: self.envs[env_id].last_step_time = time.time_ns() self.envs[env_id].environment.reset_env_time() - def get_state(self): - ... + def get_state(self, player_hash: str): + if ( + player_hash in self.player_data + and self.player_data[player_hash].env_id in self.envs + ): + # TODO normal json state + return self.envs[ + self.player_data[player_hash].env_id + ].environment.get_state_simple_json() def pause_env(self, manager_id: str, env_id: str, reason: str): if ( @@ -166,35 +174,25 @@ class GameServer: self.envs[env_id].status = EnvironmentStatus.STOPPED self.envs[env_id].stop_reason = reason - def set_player_ready(self, env_id: str, player_hash, player_id: int): - if ( - player_hash in self.player_data - and self.player_data[player_hash].player_id == player_id - and self.player_data[player_hash].env_id == env_id - ): + def set_player_ready(self, player_hash): + if player_hash in self.player_data: self.player_data[player_hash].ready = True return True return False - def set_player_connected(self, env_id: str, player_hash, player_id: int) -> bool: - if ( - player_hash in self.player_data - and self.player_data[player_hash].player_id == player_id - and self.player_data[player_hash].env_id == env_id - ): - self.player_data[player_hash].connected = True + def set_player_connected(self, client_id: str) -> bool: + if client_id in self.client_ids_to_player_hashes: + self.player_data[ + self.client_ids_to_player_hashes[client_id] + ].connected = True return True return False - def set_player_disconnected( - self, env_id: str, player_hash: str, player_id: int - ) -> bool: - if ( - player_hash in self.player_data - and self.player_data[player_hash].player_id == player_id - and self.player_data[player_hash].env_id == env_id - ): - self.player_data[player_hash].connected = False + def set_player_disconnected(self, client_id: str) -> bool: + if client_id in self.client_ids_to_player_hashes: + self.player_data[ + self.client_ids_to_player_hashes[client_id] + ].connected = False return True return False @@ -228,9 +226,11 @@ class GameServer: ] async def environment_steps(self): + # TODO environment dependent steps. overslept_in_ns = 0 while True: pre_step_start = time.time_ns() + to_remove = [] for env_id, env_data in self.envs.items(): if env_data.status == EnvironmentStatus.RUNNING: step_start = time.time_ns() @@ -241,6 +241,15 @@ class GameServer: ) ) env_data.last_step_time = step_start + elif ( + env_data.status == EnvironmentStatus.STOPPED + and env_data.last_step_time + (60 * 1e9) < pre_step_start + ): + to_remove.append(env_id) + + if to_remove: + for env_id in to_remove: + del self.envs[env_id] step_duration = time.time_ns() - pre_step_start time_to_sleep_ns = self.preferred_sleep_time_ns - ( @@ -252,21 +261,36 @@ class GameServer: sleep_function_duration = time.time_ns() - sleep_start overslept_in_ns = sleep_function_duration - time_to_sleep_ns + def is_known_client_id(self, client_id: str) -> bool: + return client_id in self.client_ids_to_player_hashes + + def player_action(self, player_hash: str, action: Action): + if ( + player_hash in self.player_data + and action.player == self.player_data[player_hash].player_id + and self.player_data[player_hash].env_id in self.envs + and player_hash + in self.envs[self.player_data[player_hash].env_id].player_hashes + ): + self.envs[self.player_data[player_hash].env_id].environment.perform_action( + action + ) + class PlayerConnectionManager: def __init__(self): self.player_connections: dict[str, WebSocket] = {} - async def connect_player(self, websocket: WebSocket, player_id: str) -> bool: - if player_id not in self.player_connections: + async def connect_player(self, websocket: WebSocket, client_id: str) -> bool: + if client_id not in self.player_connections: await websocket.accept() - self.player_connections[player_id] = websocket + self.player_connections[client_id] = websocket return True return False - def disconnect(self, id_: str): - if id_ in self.player_connections: - del self.player_connections[id_] + def disconnect(self, client_id: str): + if client_id in self.player_connections: + del self.player_connections[client_id] @staticmethod async def send_personal_message(message: str, websocket: WebSocket): @@ -278,49 +302,43 @@ class PlayerConnectionManager: manager = PlayerConnectionManager() -oc_api: GameServer = GameServer() - - -def parse_websocket_action(message: str) -> Action: - if message.replace('"', "") != "get_state": - message_dict = json.loads(message) - if message_dict["act_type"] == "movement": - if isinstance(message_dict["value"], list): - x, y = message_dict["value"] - elif isinstance(message_dict["value"], str): - x, y = ( - message_dict["value"] - .replace(" ", "") - .replace("[", "") - .replace("]", "") - .split(",") - ) - else: - x, y = 0, 0 - value = np.array([x, y], dtype=float) - else: - value = None - action = Action( - message_dict["player_name"], - message_dict["act_type"], - value, - duration=message_dict["duration"], - ) - return action - +environment_handler: EnvironmentHandler = EnvironmentHandler() -def manage_websocket_message(message: str): - if "get_state" in message: - return oc_api.get_state() - if "reset_game" in message: - oc_api.reset_game() - return "Reset game." - - action = parse_websocket_action(message) - oc_api.simulator.enter_action(action) - answer = oc_api.get_state() - return answer +@dataclasses.dataclass +class PlayerAction: + player_hash: str + action: Action + + +def manage_websocket_message(message: str, client_id: str): + message_dict = json.loads(message) + + assert "type" in message_dict, "message needs a type" + + match message_dict["type"]: + case "ready": + assert "player_hash" in message_dict, "needs player hash for ready" + environment_handler.set_player_ready(message_dict["player_hash"]) + return { + "status": "ready accepted", + "player_hash": message_dict["player_hash"], + } + case "get_state": + assert "player_hash" in message_dict, "needs player hash for environment" + return environment_handler.get_state(message_dict["player_hash"]) + case "action": + assert "action" in message_dict, "action type needs action data" + assert "player_hash" in message_dict, "action type needs player hash" + environment_handler.player_action( + message_dict["player_hash"], Action(**message_dict["action"]) + ) + return { + "status": "action accepted", + "player_hash": message_dict["player_hash"], + } + # TODO setup error enums or class + return {"status": "error", "info": "unknown message type"} @app.get("/") @@ -345,44 +363,55 @@ class AdditionalPlayer(BaseModel): existing_websocket: str | None = None -@app.post("/manage/create_env") -async def register_manger(creation: CreateEnvironmentConfig): - result = oc_api.create_env(creation) +@app.post("/manage/create_env/") +async def create_env(creation: CreateEnvironmentConfig): + print(creation) + result = environment_handler.create_env(creation) return result -@app.post("/manage/additional_player") +@app.post("/manage/additional_player/") async def additional_player(creation: AdditionalPlayer): - result = oc_api.add_player(creation) + result = environment_handler.add_player(creation) return result -@app.post("manage/stop_env") +@app.post("/manage/stop_env/") async def stop_env(manager_id: str, env_id: str, reason: str): - result = oc_api.stop_env(manager_id, env_id, reason) + result = environment_handler.stop_env(manager_id, env_id, reason) return result +# pause / unpause # control access / functions / data @app.websocket("/ws/player/{client_id}") -async def websocket_player_endpoint(websocket: WebSocket, client_id: int): - await manager.connect(websocket) +async def websocket_player_endpoint(websocket: WebSocket, client_id: str): + if client_id not in environment_handler.is_known_client_id(client_id): + return + await manager.connect_player(websocket, client_id) log.debug(f"Client #{client_id} connected") + environment_handler.set_player_connected(client_id) try: while True: message = await websocket.receive_text() - answer = manage_websocket_message(message) + answer = manage_websocket_message(message, client_id) await manager.send_personal_message(answer, websocket) except WebSocketDisconnect: - manager.disconnect(websocket) + manager.disconnect(client_id) + environment_handler.set_player_disconnected(client_id) log.debug(f"Client #{client_id} disconnected") def main(): - uvicorn.run(app, host=WEBSOCKET_URL, port=WEBSOCKET_PORT) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.create_task(environment_handler.environment_steps()) + config = uvicorn.Config(app, host=WEBSOCKET_URL, port=WEBSOCKET_PORT, loop=loop) + server = uvicorn.Server(config) + loop.run_until_complete(server.serve()) if __name__ == "__main__": diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py index 65e55d49..0c70cac4 100644 --- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py +++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py @@ -3,6 +3,7 @@ import json import logging import math import sys +import time from datetime import timedelta from enum import Enum @@ -10,6 +11,7 @@ import numpy as np import numpy.typing as npt import pygame import pygame_gui +import requests import yaml from scipy.spatial import KDTree from websockets.sync.client import connect @@ -20,6 +22,7 @@ from overcooked_simulator.game_items import ( CookingEquipment, Plate, ) +from overcooked_simulator.game_server import CreateEnvironmentConfig from overcooked_simulator.gui_2d_vis.game_colors import BLUE from overcooked_simulator.gui_2d_vis.game_colors import colors, Color from overcooked_simulator.order import Order @@ -36,6 +39,9 @@ class MenuStates(Enum): End = "End" +MANAGER_ID = "1233245425" + + def create_polygon(n, length): if n == 1: return np.array([0, 0]) @@ -107,7 +113,8 @@ class PyGameGUI: ] # self.websocket_url = "ws://localhost:8765" - self.websocket_url = "ws://localhost:8000/ws/29" + self.websocket_url = "ws://localhost:8000/ws/" + self.websockets = {} # TODO cache loaded images? with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file: @@ -821,8 +828,50 @@ class PyGameGUI: def start_button_press(self): self.menu_state = MenuStates.Game - with connect(self.websocket_url) as websocket: - state = self.request_state() + environment_config_path = ROOT_DIR / "game_content" / "environment_config.yaml" + layout_path = ROOT_DIR / "game_content" / "layouts" / "basic.layout" + item_info_path = ROOT_DIR / "game_content" / "item_info.yaml" + with open(item_info_path, "r") as file: + item_info = file.read() + with open(layout_path, "r") as file: + layout = file.read() + with open(environment_config_path, "r") as file: + environment_config = file.read() + print( + CreateEnvironmentConfig( + manager_id=MANAGER_ID, + number_players=2, + environment_settings={"all_player_can_pause_game": False}, + item_info_config=item_info, + environment_config=environment_config, + layout_config=layout, + ).model_dump_json() + ) + env_info = requests.post( + "http://localhost:8000/manage/create_env/", + json=CreateEnvironmentConfig( + manager_id=MANAGER_ID, + number_players=2, + environment_settings={"all_player_can_pause_game": False}, + item_info_config=item_info, + environment_config=environment_config, + layout_config=layout, + ).model_dump_json(), + ) + print(env_info) + assert isinstance(env_info, dict), "Env info must be a dictionary" + self.current_env_id = env_info["env_id"] + self.player_info = env_info["player_info"] + for player_id, player_info in env_info["player_info"].items(): + websocket = connect(self.websocket_url + player_info["client_id"]) + websocket.send({"type": "ready", "player_hash": player_info["player_hash"]}) + self.websockets[player_id] = websocket + time.sleep(0.1) + + state = websocket.send( + {"type": "get_state", "player_hash": player_info["player_hash"]} + ) + self.state_player_id = player_id ( self.window_width, @@ -849,7 +898,14 @@ class PyGameGUI: log.debug("Pressed quit button") def reset_button_press(self): - _ = self.websocket_communicate("reset_game") + requests.post( + "http://localhost:8000/manage/stop_env", + json={ + "manager_id": MANAGER_ID, + "env_id": self.current_env_id, + "reason": "reset button pressed", + }, + ) # self.websocket.send(json.dumps("reset_game")) # answer = self.websocket.recv() @@ -867,31 +923,28 @@ class PyGameGUI: action: The action to be sent. Contains the player, action type and move direction if action is a movement. """ if isinstance(action.action, np.ndarray): - value = [float(action.action[0]), float(action.action[1])] + action.action = [float(action.action[0]), float(action.action[1])] else: - value = action.action - message_dict = { - "player_name": action.player, - "act_type": action.act_type, - "value": value, - "duration": action.duration, - } - _ = self.websocket_communicate(message_dict) - - def websocket_communicate(self, message_dict: dict | str): - self.websocket.send(json.dumps(message_dict)) - answer = self.websocket.recv() - try: - answer = json.loads(answer) - except json.decoder.JSONDecodeError: - answer = None - return answer + action.action = action.action + ret = self.websockets[action.player].send( + { + "type": "action", + "action": action, + "player_hash": self.player_info[action.player]["player_hash"], + } + ) + print(ret) def request_state(self): - state_dict = self.websocket_communicate("get_state") + state_dict = self.websockets[self.state_player_id].send( + { + "type": "get_state", + "player_hash": self.player_info[self.state_player_id]["player_hash"], + } + ) # self.websocket.send(json.dumps("get_state")) # state_dict = json.loads(self.websocket.recv()) - return state_dict + return json.loads(state_dict) def start_pygame(self): """Starts pygame and the gui loop. Each frame the game state is visualized and keyboard inputs are read.""" @@ -907,97 +960,91 @@ class PyGameGUI: self.init_ui_elements() self.manage_button_visibility() - with connect(self.websocket_url) as websocket: - self.websocket = websocket - # Game loop - self.running = True - while self.running: - try: - time_delta = clock.tick(self.FPS) / 1000.0 - - for event in pygame.event.get(): - if event.type == pygame.QUIT: - self.running = False - - # UI Buttons: - if event.type == pygame_gui.UI_BUTTON_PRESSED: - match event.ui_element: - case self.start_button: - self.start_button_press() - case self.back_button: - self.start_button_press() - case self.finished_button: - self.finished_button_press() - case self.quit_button: - self.quit_button_press() - case self.reset_button: - self.reset_button_press() - self.start_button_press() + # Game loop + self.running = True + while self.running: + try: + time_delta = clock.tick(self.FPS) / 1000.0 + + for event in pygame.event.get(): + if event.type == pygame.QUIT: + self.running = False + + # UI Buttons: + if event.type == pygame_gui.UI_BUTTON_PRESSED: + match event.ui_element: + case self.start_button: + self.start_button_press() + case self.back_button: + self.start_button_press() + case self.finished_button: + self.finished_button_press() + case self.quit_button: + self.quit_button_press() + case self.reset_button: + self.reset_button_press() + self.start_button_press() - self.manage_button_visibility() + self.manage_button_visibility() - if ( - event.type in [pygame.KEYDOWN, pygame.KEYUP] - and self.menu_state == MenuStates.Game - ): - pass - self.handle_key_event(event) + if ( + event.type in [pygame.KEYDOWN, pygame.KEYUP] + and self.menu_state == MenuStates.Game + ): + pass + self.handle_key_event(event) - self.manager.process_events(event) + self.manager.process_events(event) - # drawing: + # drawing: - # state = self.simulator.get_state() + # state = self.simulator.get_state() - self.main_window.fill( - colors[ - self.visualization_config["GameWindow"]["background_color"] - ] - ) - self.manager.draw_ui(self.main_window) + self.main_window.fill( + colors[self.visualization_config["GameWindow"]["background_color"]] + ) + self.manager.draw_ui(self.main_window) - match self.menu_state: - case MenuStates.Start: - pass + match self.menu_state: + case MenuStates.Start: + pass - case MenuStates.Game: - state = self.request_state() + case MenuStates.Game: + state = self.request_state() - self.draw_background() + self.draw_background() - self.handle_keys() + self.handle_keys() - # state = self.simulator.get_state() - self.draw(state) + # state = self.simulator.get_state() + self.draw(state) - if state["ended"]: - self.finished_button_press() - self.manage_button_visibility() - else: - self.draw(state) + if state["ended"]: + self.finished_button_press() + self.manage_button_visibility() + else: + self.draw(state) - game_screen_rect = self.game_screen.get_rect() - game_screen_rect.center = [ - self.window_width // 2, - self.window_height // 2, - ] + game_screen_rect = self.game_screen.get_rect() + game_screen_rect.center = [ + self.window_width // 2, + self.window_height // 2, + ] - self.main_window.blit( - self.game_screen, game_screen_rect - ) + self.main_window.blit(self.game_screen, game_screen_rect) - case MenuStates.End: - self.update_conclusion_label(state) + case MenuStates.End: + self.update_conclusion_label(state) - self.manager.update(time_delta) - pygame.display.flip() + self.manager.update(time_delta) + pygame.display.flip() - except KeyboardInterrupt: - pygame.quit() - sys.exit() + except KeyboardInterrupt: + pygame.quit() + sys.exit() - pygame.quit() - sys.exit() + pygame.quit() + sys.exit() def main(): diff --git a/setup.py b/setup.py index 55c09597..42a86447 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ requirements = [ "fastapi", "uvicorn", "websockets", + "requests", ] test_requirements = [ -- GitLab