Skip to content
Snippets Groups Projects
Commit 7ba326c6 authored by Florian Schröder's avatar Florian Schröder
Browse files

Update server and client communication in game simulator

The server (game_server.py) and client communication (overcooked_gui.py) now uses the WebSocket communication protocol. In setup.py, the 'requests' module was added as a new requirement. The game simulator now waits for a player to be ready before starting, stops a game environment if no step is taken within a minute, and pauses or unpauses a game environment. Player actions are now handled based on a new 'Action' type. The server also now handles several client types that can send messages.
parent 6a57f19a
No related branches found
No related tags found
1 merge request!26Resolve "api"
Pipeline #44604 failed
...@@ -11,15 +11,14 @@ from datetime import datetime, timedelta ...@@ -11,15 +11,14 @@ from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from typing import Set from typing import Set
import numpy as np
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi import WebSocket from fastapi import WebSocket
from overcooked_simulator.game_server_OLD import setup_logging
from pydantic import BaseModel from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect from starlette.websockets import WebSocketDisconnect
from typing_extensions import TypedDict from typing_extensions import TypedDict
from overcooked_simulator.main import setup_logging
from overcooked_simulator.overcooked_environment import Action, Environment from overcooked_simulator.overcooked_environment import Action, Environment
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -67,13 +66,14 @@ class EnvironmentData: ...@@ -67,13 +66,14 @@ class EnvironmentData:
last_step_time: int | None = None last_step_time: int | None = None
class GameServer: class EnvironmentHandler:
def __init__(self, env_step_frequency: int = 200): def __init__(self, env_step_frequency: int = 200):
self.envs: dict[str, EnvironmentData] = {} self.envs: dict[str, EnvironmentData] = {}
self.player_data: dict[str, PlayerData] = {} self.player_data: dict[str, PlayerData] = {}
self.manager_envs: dict[str, Set[str]] = defaultdict(set) self.manager_envs: dict[str, Set[str]] = defaultdict(set)
self.env_step_frequency = env_step_frequency self.env_step_frequency = env_step_frequency
self.preferred_sleep_time_ns = 1e9 / self.env_step_frequency self.preferred_sleep_time_ns = 1e9 / self.env_step_frequency
self.client_ids_to_player_hashes = {}
def create_env(self, environment_config: CreateEnvironmentConfig): def create_env(self, environment_config: CreateEnvironmentConfig):
env_id = uuid.uuid4().hex env_id = uuid.uuid4().hex
...@@ -92,7 +92,7 @@ class GameServer: ...@@ -92,7 +92,7 @@ class GameServer:
self.manager_envs[environment_config.manager_id].update([env_id]) self.manager_envs[environment_config.manager_id].update([env_id])
return {"env_id": env_id} return {"env_id": env_id, "player_info": player_info}
def create_player(self, env, env_id, player_id): def create_player(self, env, env_id, player_id):
player_hash = uuid.uuid4().hex player_hash = uuid.uuid4().hex
...@@ -103,6 +103,7 @@ class GameServer: ...@@ -103,6 +103,7 @@ class GameServer:
websocket_id=client_id, websocket_id=client_id,
) )
self.player_data[player_hash] = player_data self.player_data[player_hash] = player_data
self.client_ids_to_player_hashes[client_id] = player_hash
env.add_player(player_id) env.add_player(player_id)
return { return {
...@@ -135,8 +136,15 @@ class GameServer: ...@@ -135,8 +136,15 @@ class GameServer:
self.envs[env_id].last_step_time = time.time_ns() self.envs[env_id].last_step_time = time.time_ns()
self.envs[env_id].environment.reset_env_time() self.envs[env_id].environment.reset_env_time()
def get_state(self): def get_state(self, player_hash: str):
... if (
player_hash in self.player_data
and self.player_data[player_hash].env_id in self.envs
):
# TODO normal json state
return self.envs[
self.player_data[player_hash].env_id
].environment.get_state_simple_json()
def pause_env(self, manager_id: str, env_id: str, reason: str): def pause_env(self, manager_id: str, env_id: str, reason: str):
if ( if (
...@@ -166,35 +174,25 @@ class GameServer: ...@@ -166,35 +174,25 @@ class GameServer:
self.envs[env_id].status = EnvironmentStatus.STOPPED self.envs[env_id].status = EnvironmentStatus.STOPPED
self.envs[env_id].stop_reason = reason self.envs[env_id].stop_reason = reason
def set_player_ready(self, env_id: str, player_hash, player_id: int): def set_player_ready(self, player_hash):
if ( if player_hash in self.player_data:
player_hash in self.player_data
and self.player_data[player_hash].player_id == player_id
and self.player_data[player_hash].env_id == env_id
):
self.player_data[player_hash].ready = True self.player_data[player_hash].ready = True
return True return True
return False return False
def set_player_connected(self, env_id: str, player_hash, player_id: int) -> bool: def set_player_connected(self, client_id: str) -> bool:
if ( if client_id in self.client_ids_to_player_hashes:
player_hash in self.player_data self.player_data[
and self.player_data[player_hash].player_id == player_id self.client_ids_to_player_hashes[client_id]
and self.player_data[player_hash].env_id == env_id ].connected = True
):
self.player_data[player_hash].connected = True
return True return True
return False return False
def set_player_disconnected( def set_player_disconnected(self, client_id: str) -> bool:
self, env_id: str, player_hash: str, player_id: int if client_id in self.client_ids_to_player_hashes:
) -> bool: self.player_data[
if ( self.client_ids_to_player_hashes[client_id]
player_hash in self.player_data ].connected = False
and self.player_data[player_hash].player_id == player_id
and self.player_data[player_hash].env_id == env_id
):
self.player_data[player_hash].connected = False
return True return True
return False return False
...@@ -228,9 +226,11 @@ class GameServer: ...@@ -228,9 +226,11 @@ class GameServer:
] ]
async def environment_steps(self): async def environment_steps(self):
# TODO environment dependent steps.
overslept_in_ns = 0 overslept_in_ns = 0
while True: while True:
pre_step_start = time.time_ns() pre_step_start = time.time_ns()
to_remove = []
for env_id, env_data in self.envs.items(): for env_id, env_data in self.envs.items():
if env_data.status == EnvironmentStatus.RUNNING: if env_data.status == EnvironmentStatus.RUNNING:
step_start = time.time_ns() step_start = time.time_ns()
...@@ -241,6 +241,15 @@ class GameServer: ...@@ -241,6 +241,15 @@ class GameServer:
) )
) )
env_data.last_step_time = step_start env_data.last_step_time = step_start
elif (
env_data.status == EnvironmentStatus.STOPPED
and env_data.last_step_time + (60 * 1e9) < pre_step_start
):
to_remove.append(env_id)
if to_remove:
for env_id in to_remove:
del self.envs[env_id]
step_duration = time.time_ns() - pre_step_start step_duration = time.time_ns() - pre_step_start
time_to_sleep_ns = self.preferred_sleep_time_ns - ( time_to_sleep_ns = self.preferred_sleep_time_ns - (
...@@ -252,21 +261,36 @@ class GameServer: ...@@ -252,21 +261,36 @@ class GameServer:
sleep_function_duration = time.time_ns() - sleep_start sleep_function_duration = time.time_ns() - sleep_start
overslept_in_ns = sleep_function_duration - time_to_sleep_ns overslept_in_ns = sleep_function_duration - time_to_sleep_ns
def is_known_client_id(self, client_id: str) -> bool:
return client_id in self.client_ids_to_player_hashes
def player_action(self, player_hash: str, action: Action):
if (
player_hash in self.player_data
and action.player == self.player_data[player_hash].player_id
and self.player_data[player_hash].env_id in self.envs
and player_hash
in self.envs[self.player_data[player_hash].env_id].player_hashes
):
self.envs[self.player_data[player_hash].env_id].environment.perform_action(
action
)
class PlayerConnectionManager: class PlayerConnectionManager:
def __init__(self): def __init__(self):
self.player_connections: dict[str, WebSocket] = {} self.player_connections: dict[str, WebSocket] = {}
async def connect_player(self, websocket: WebSocket, player_id: str) -> bool: async def connect_player(self, websocket: WebSocket, client_id: str) -> bool:
if player_id not in self.player_connections: if client_id not in self.player_connections:
await websocket.accept() await websocket.accept()
self.player_connections[player_id] = websocket self.player_connections[client_id] = websocket
return True return True
return False return False
def disconnect(self, id_: str): def disconnect(self, client_id: str):
if id_ in self.player_connections: if client_id in self.player_connections:
del self.player_connections[id_] del self.player_connections[client_id]
@staticmethod @staticmethod
async def send_personal_message(message: str, websocket: WebSocket): async def send_personal_message(message: str, websocket: WebSocket):
...@@ -278,49 +302,43 @@ class PlayerConnectionManager: ...@@ -278,49 +302,43 @@ class PlayerConnectionManager:
manager = PlayerConnectionManager() manager = PlayerConnectionManager()
oc_api: GameServer = GameServer() environment_handler: EnvironmentHandler = EnvironmentHandler()
def parse_websocket_action(message: str) -> Action:
if message.replace('"', "") != "get_state":
message_dict = json.loads(message)
if message_dict["act_type"] == "movement":
if isinstance(message_dict["value"], list):
x, y = message_dict["value"]
elif isinstance(message_dict["value"], str):
x, y = (
message_dict["value"]
.replace(" ", "")
.replace("[", "")
.replace("]", "")
.split(",")
)
else:
x, y = 0, 0
value = np.array([x, y], dtype=float)
else:
value = None
action = Action(
message_dict["player_name"],
message_dict["act_type"],
value,
duration=message_dict["duration"],
)
return action
def manage_websocket_message(message: str):
if "get_state" in message:
return oc_api.get_state()
if "reset_game" in message: @dataclasses.dataclass
oc_api.reset_game() class PlayerAction:
return "Reset game." player_hash: str
action: Action
action = parse_websocket_action(message)
oc_api.simulator.enter_action(action)
answer = oc_api.get_state() def manage_websocket_message(message: str, client_id: str):
return answer message_dict = json.loads(message)
assert "type" in message_dict, "message needs a type"
match message_dict["type"]:
case "ready":
assert "player_hash" in message_dict, "needs player hash for ready"
environment_handler.set_player_ready(message_dict["player_hash"])
return {
"status": "ready accepted",
"player_hash": message_dict["player_hash"],
}
case "get_state":
assert "player_hash" in message_dict, "needs player hash for environment"
return environment_handler.get_state(message_dict["player_hash"])
case "action":
assert "action" in message_dict, "action type needs action data"
assert "player_hash" in message_dict, "action type needs player hash"
environment_handler.player_action(
message_dict["player_hash"], Action(**message_dict["action"])
)
return {
"status": "action accepted",
"player_hash": message_dict["player_hash"],
}
# TODO setup error enums or class
return {"status": "error", "info": "unknown message type"}
@app.get("/") @app.get("/")
...@@ -345,44 +363,55 @@ class AdditionalPlayer(BaseModel): ...@@ -345,44 +363,55 @@ class AdditionalPlayer(BaseModel):
existing_websocket: str | None = None existing_websocket: str | None = None
@app.post("/manage/create_env") @app.post("/manage/create_env/")
async def register_manger(creation: CreateEnvironmentConfig): async def create_env(creation: CreateEnvironmentConfig):
result = oc_api.create_env(creation) print(creation)
result = environment_handler.create_env(creation)
return result return result
@app.post("/manage/additional_player") @app.post("/manage/additional_player/")
async def additional_player(creation: AdditionalPlayer): async def additional_player(creation: AdditionalPlayer):
result = oc_api.add_player(creation) result = environment_handler.add_player(creation)
return result return result
@app.post("manage/stop_env") @app.post("/manage/stop_env/")
async def stop_env(manager_id: str, env_id: str, reason: str): async def stop_env(manager_id: str, env_id: str, reason: str):
result = oc_api.stop_env(manager_id, env_id, reason) result = environment_handler.stop_env(manager_id, env_id, reason)
return result return result
# pause / unpause
# control access / functions / data # control access / functions / data
@app.websocket("/ws/player/{client_id}") @app.websocket("/ws/player/{client_id}")
async def websocket_player_endpoint(websocket: WebSocket, client_id: int): async def websocket_player_endpoint(websocket: WebSocket, client_id: str):
await manager.connect(websocket) if client_id not in environment_handler.is_known_client_id(client_id):
return
await manager.connect_player(websocket, client_id)
log.debug(f"Client #{client_id} connected") log.debug(f"Client #{client_id} connected")
environment_handler.set_player_connected(client_id)
try: try:
while True: while True:
message = await websocket.receive_text() message = await websocket.receive_text()
answer = manage_websocket_message(message) answer = manage_websocket_message(message, client_id)
await manager.send_personal_message(answer, websocket) await manager.send_personal_message(answer, websocket)
except WebSocketDisconnect: except WebSocketDisconnect:
manager.disconnect(websocket) manager.disconnect(client_id)
environment_handler.set_player_disconnected(client_id)
log.debug(f"Client #{client_id} disconnected") log.debug(f"Client #{client_id} disconnected")
def main(): def main():
uvicorn.run(app, host=WEBSOCKET_URL, port=WEBSOCKET_PORT) loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(environment_handler.environment_steps())
config = uvicorn.Config(app, host=WEBSOCKET_URL, port=WEBSOCKET_PORT, loop=loop)
server = uvicorn.Server(config)
loop.run_until_complete(server.serve())
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -3,6 +3,7 @@ import json ...@@ -3,6 +3,7 @@ import json
import logging import logging
import math import math
import sys import sys
import time
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
...@@ -10,6 +11,7 @@ import numpy as np ...@@ -10,6 +11,7 @@ import numpy as np
import numpy.typing as npt import numpy.typing as npt
import pygame import pygame
import pygame_gui import pygame_gui
import requests
import yaml import yaml
from scipy.spatial import KDTree from scipy.spatial import KDTree
from websockets.sync.client import connect from websockets.sync.client import connect
...@@ -20,6 +22,7 @@ from overcooked_simulator.game_items import ( ...@@ -20,6 +22,7 @@ from overcooked_simulator.game_items import (
CookingEquipment, CookingEquipment,
Plate, Plate,
) )
from overcooked_simulator.game_server import CreateEnvironmentConfig
from overcooked_simulator.gui_2d_vis.game_colors import BLUE from overcooked_simulator.gui_2d_vis.game_colors import BLUE
from overcooked_simulator.gui_2d_vis.game_colors import colors, Color from overcooked_simulator.gui_2d_vis.game_colors import colors, Color
from overcooked_simulator.order import Order from overcooked_simulator.order import Order
...@@ -36,6 +39,9 @@ class MenuStates(Enum): ...@@ -36,6 +39,9 @@ class MenuStates(Enum):
End = "End" End = "End"
MANAGER_ID = "1233245425"
def create_polygon(n, length): def create_polygon(n, length):
if n == 1: if n == 1:
return np.array([0, 0]) return np.array([0, 0])
...@@ -107,7 +113,8 @@ class PyGameGUI: ...@@ -107,7 +113,8 @@ class PyGameGUI:
] ]
# self.websocket_url = "ws://localhost:8765" # self.websocket_url = "ws://localhost:8765"
self.websocket_url = "ws://localhost:8000/ws/29" self.websocket_url = "ws://localhost:8000/ws/"
self.websockets = {}
# TODO cache loaded images? # TODO cache loaded images?
with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file: with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file:
...@@ -821,8 +828,50 @@ class PyGameGUI: ...@@ -821,8 +828,50 @@ class PyGameGUI:
def start_button_press(self): def start_button_press(self):
self.menu_state = MenuStates.Game self.menu_state = MenuStates.Game
with connect(self.websocket_url) as websocket: environment_config_path = ROOT_DIR / "game_content" / "environment_config.yaml"
state = self.request_state() layout_path = ROOT_DIR / "game_content" / "layouts" / "basic.layout"
item_info_path = ROOT_DIR / "game_content" / "item_info.yaml"
with open(item_info_path, "r") as file:
item_info = file.read()
with open(layout_path, "r") as file:
layout = file.read()
with open(environment_config_path, "r") as file:
environment_config = file.read()
print(
CreateEnvironmentConfig(
manager_id=MANAGER_ID,
number_players=2,
environment_settings={"all_player_can_pause_game": False},
item_info_config=item_info,
environment_config=environment_config,
layout_config=layout,
).model_dump_json()
)
env_info = requests.post(
"http://localhost:8000/manage/create_env/",
json=CreateEnvironmentConfig(
manager_id=MANAGER_ID,
number_players=2,
environment_settings={"all_player_can_pause_game": False},
item_info_config=item_info,
environment_config=environment_config,
layout_config=layout,
).model_dump_json(),
)
print(env_info)
assert isinstance(env_info, dict), "Env info must be a dictionary"
self.current_env_id = env_info["env_id"]
self.player_info = env_info["player_info"]
for player_id, player_info in env_info["player_info"].items():
websocket = connect(self.websocket_url + player_info["client_id"])
websocket.send({"type": "ready", "player_hash": player_info["player_hash"]})
self.websockets[player_id] = websocket
time.sleep(0.1)
state = websocket.send(
{"type": "get_state", "player_hash": player_info["player_hash"]}
)
self.state_player_id = player_id
( (
self.window_width, self.window_width,
...@@ -849,7 +898,14 @@ class PyGameGUI: ...@@ -849,7 +898,14 @@ class PyGameGUI:
log.debug("Pressed quit button") log.debug("Pressed quit button")
def reset_button_press(self): def reset_button_press(self):
_ = self.websocket_communicate("reset_game") requests.post(
"http://localhost:8000/manage/stop_env",
json={
"manager_id": MANAGER_ID,
"env_id": self.current_env_id,
"reason": "reset button pressed",
},
)
# self.websocket.send(json.dumps("reset_game")) # self.websocket.send(json.dumps("reset_game"))
# answer = self.websocket.recv() # answer = self.websocket.recv()
...@@ -867,31 +923,28 @@ class PyGameGUI: ...@@ -867,31 +923,28 @@ class PyGameGUI:
action: The action to be sent. Contains the player, action type and move direction if action is a movement. action: The action to be sent. Contains the player, action type and move direction if action is a movement.
""" """
if isinstance(action.action, np.ndarray): if isinstance(action.action, np.ndarray):
value = [float(action.action[0]), float(action.action[1])] action.action = [float(action.action[0]), float(action.action[1])]
else: else:
value = action.action action.action = action.action
message_dict = { ret = self.websockets[action.player].send(
"player_name": action.player, {
"act_type": action.act_type, "type": "action",
"value": value, "action": action,
"duration": action.duration, "player_hash": self.player_info[action.player]["player_hash"],
} }
_ = self.websocket_communicate(message_dict) )
print(ret)
def websocket_communicate(self, message_dict: dict | str):
self.websocket.send(json.dumps(message_dict))
answer = self.websocket.recv()
try:
answer = json.loads(answer)
except json.decoder.JSONDecodeError:
answer = None
return answer
def request_state(self): def request_state(self):
state_dict = self.websocket_communicate("get_state") state_dict = self.websockets[self.state_player_id].send(
{
"type": "get_state",
"player_hash": self.player_info[self.state_player_id]["player_hash"],
}
)
# self.websocket.send(json.dumps("get_state")) # self.websocket.send(json.dumps("get_state"))
# state_dict = json.loads(self.websocket.recv()) # state_dict = json.loads(self.websocket.recv())
return state_dict return json.loads(state_dict)
def start_pygame(self): def start_pygame(self):
"""Starts pygame and the gui loop. Each frame the game state is visualized and keyboard inputs are read.""" """Starts pygame and the gui loop. Each frame the game state is visualized and keyboard inputs are read."""
...@@ -907,97 +960,91 @@ class PyGameGUI: ...@@ -907,97 +960,91 @@ class PyGameGUI:
self.init_ui_elements() self.init_ui_elements()
self.manage_button_visibility() self.manage_button_visibility()
with connect(self.websocket_url) as websocket: # Game loop
self.websocket = websocket self.running = True
# Game loop while self.running:
self.running = True try:
while self.running: time_delta = clock.tick(self.FPS) / 1000.0
try:
time_delta = clock.tick(self.FPS) / 1000.0 for event in pygame.event.get():
if event.type == pygame.QUIT:
for event in pygame.event.get(): self.running = False
if event.type == pygame.QUIT:
self.running = False # UI Buttons:
if event.type == pygame_gui.UI_BUTTON_PRESSED:
# UI Buttons: match event.ui_element:
if event.type == pygame_gui.UI_BUTTON_PRESSED: case self.start_button:
match event.ui_element: self.start_button_press()
case self.start_button: case self.back_button:
self.start_button_press() self.start_button_press()
case self.back_button: case self.finished_button:
self.start_button_press() self.finished_button_press()
case self.finished_button: case self.quit_button:
self.finished_button_press() self.quit_button_press()
case self.quit_button: case self.reset_button:
self.quit_button_press() self.reset_button_press()
case self.reset_button: self.start_button_press()
self.reset_button_press()
self.start_button_press()
self.manage_button_visibility() self.manage_button_visibility()
if ( if (
event.type in [pygame.KEYDOWN, pygame.KEYUP] event.type in [pygame.KEYDOWN, pygame.KEYUP]
and self.menu_state == MenuStates.Game and self.menu_state == MenuStates.Game
): ):
pass pass
self.handle_key_event(event) self.handle_key_event(event)
self.manager.process_events(event) self.manager.process_events(event)
# drawing: # drawing:
# state = self.simulator.get_state() # state = self.simulator.get_state()
self.main_window.fill( self.main_window.fill(
colors[ colors[self.visualization_config["GameWindow"]["background_color"]]
self.visualization_config["GameWindow"]["background_color"] )
] self.manager.draw_ui(self.main_window)
)
self.manager.draw_ui(self.main_window)
match self.menu_state: match self.menu_state:
case MenuStates.Start: case MenuStates.Start:
pass pass
case MenuStates.Game: case MenuStates.Game:
state = self.request_state() state = self.request_state()
self.draw_background() self.draw_background()
self.handle_keys() self.handle_keys()
# state = self.simulator.get_state() # state = self.simulator.get_state()
self.draw(state) self.draw(state)
if state["ended"]: if state["ended"]:
self.finished_button_press() self.finished_button_press()
self.manage_button_visibility() self.manage_button_visibility()
else: else:
self.draw(state) self.draw(state)
game_screen_rect = self.game_screen.get_rect() game_screen_rect = self.game_screen.get_rect()
game_screen_rect.center = [ game_screen_rect.center = [
self.window_width // 2, self.window_width // 2,
self.window_height // 2, self.window_height // 2,
] ]
self.main_window.blit( self.main_window.blit(self.game_screen, game_screen_rect)
self.game_screen, game_screen_rect
)
case MenuStates.End: case MenuStates.End:
self.update_conclusion_label(state) self.update_conclusion_label(state)
self.manager.update(time_delta) self.manager.update(time_delta)
pygame.display.flip() pygame.display.flip()
except KeyboardInterrupt: except KeyboardInterrupt:
pygame.quit() pygame.quit()
sys.exit() sys.exit()
pygame.quit() pygame.quit()
sys.exit() sys.exit()
def main(): def main():
......
...@@ -20,6 +20,7 @@ requirements = [ ...@@ -20,6 +20,7 @@ requirements = [
"fastapi", "fastapi",
"uvicorn", "uvicorn",
"websockets", "websockets",
"requests",
] ]
test_requirements = [ test_requirements = [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment