Skip to content
Snippets Groups Projects
Commit 7ba326c6 authored by Florian Schröder's avatar Florian Schröder
Browse files

Update server and client communication in game simulator

The server (game_server.py) and client communication (overcooked_gui.py) now uses the WebSocket communication protocol. In setup.py, the 'requests' module was added as a new requirement. The game simulator now waits for a player to be ready before starting, stops a game environment if no step is taken within a minute, and pauses or unpauses a game environment. Player actions are now handled based on a new 'Action' type. The server also now handles several client types that can send messages.
parent 6a57f19a
No related branches found
No related tags found
1 merge request!26Resolve "api"
Pipeline #44604 failed
......@@ -11,15 +11,14 @@ from datetime import datetime, timedelta
from enum import Enum
from typing import Set
import numpy as np
import uvicorn
from fastapi import FastAPI
from fastapi import WebSocket
from overcooked_simulator.game_server_OLD import setup_logging
from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect
from typing_extensions import TypedDict
from overcooked_simulator.main import setup_logging
from overcooked_simulator.overcooked_environment import Action, Environment
log = logging.getLogger(__name__)
......@@ -67,13 +66,14 @@ class EnvironmentData:
last_step_time: int | None = None
class GameServer:
class EnvironmentHandler:
def __init__(self, env_step_frequency: int = 200):
self.envs: dict[str, EnvironmentData] = {}
self.player_data: dict[str, PlayerData] = {}
self.manager_envs: dict[str, Set[str]] = defaultdict(set)
self.env_step_frequency = env_step_frequency
self.preferred_sleep_time_ns = 1e9 / self.env_step_frequency
self.client_ids_to_player_hashes = {}
def create_env(self, environment_config: CreateEnvironmentConfig):
env_id = uuid.uuid4().hex
......@@ -92,7 +92,7 @@ class GameServer:
self.manager_envs[environment_config.manager_id].update([env_id])
return {"env_id": env_id}
return {"env_id": env_id, "player_info": player_info}
def create_player(self, env, env_id, player_id):
player_hash = uuid.uuid4().hex
......@@ -103,6 +103,7 @@ class GameServer:
websocket_id=client_id,
)
self.player_data[player_hash] = player_data
self.client_ids_to_player_hashes[client_id] = player_hash
env.add_player(player_id)
return {
......@@ -135,8 +136,15 @@ class GameServer:
self.envs[env_id].last_step_time = time.time_ns()
self.envs[env_id].environment.reset_env_time()
def get_state(self):
...
def get_state(self, player_hash: str):
if (
player_hash in self.player_data
and self.player_data[player_hash].env_id in self.envs
):
# TODO normal json state
return self.envs[
self.player_data[player_hash].env_id
].environment.get_state_simple_json()
def pause_env(self, manager_id: str, env_id: str, reason: str):
if (
......@@ -166,35 +174,25 @@ class GameServer:
self.envs[env_id].status = EnvironmentStatus.STOPPED
self.envs[env_id].stop_reason = reason
def set_player_ready(self, env_id: str, player_hash, player_id: int):
if (
player_hash in self.player_data
and self.player_data[player_hash].player_id == player_id
and self.player_data[player_hash].env_id == env_id
):
def set_player_ready(self, player_hash):
if player_hash in self.player_data:
self.player_data[player_hash].ready = True
return True
return False
def set_player_connected(self, env_id: str, player_hash, player_id: int) -> bool:
if (
player_hash in self.player_data
and self.player_data[player_hash].player_id == player_id
and self.player_data[player_hash].env_id == env_id
):
self.player_data[player_hash].connected = True
def set_player_connected(self, client_id: str) -> bool:
if client_id in self.client_ids_to_player_hashes:
self.player_data[
self.client_ids_to_player_hashes[client_id]
].connected = True
return True
return False
def set_player_disconnected(
self, env_id: str, player_hash: str, player_id: int
) -> bool:
if (
player_hash in self.player_data
and self.player_data[player_hash].player_id == player_id
and self.player_data[player_hash].env_id == env_id
):
self.player_data[player_hash].connected = False
def set_player_disconnected(self, client_id: str) -> bool:
if client_id in self.client_ids_to_player_hashes:
self.player_data[
self.client_ids_to_player_hashes[client_id]
].connected = False
return True
return False
......@@ -228,9 +226,11 @@ class GameServer:
]
async def environment_steps(self):
# TODO environment dependent steps.
overslept_in_ns = 0
while True:
pre_step_start = time.time_ns()
to_remove = []
for env_id, env_data in self.envs.items():
if env_data.status == EnvironmentStatus.RUNNING:
step_start = time.time_ns()
......@@ -241,6 +241,15 @@ class GameServer:
)
)
env_data.last_step_time = step_start
elif (
env_data.status == EnvironmentStatus.STOPPED
and env_data.last_step_time + (60 * 1e9) < pre_step_start
):
to_remove.append(env_id)
if to_remove:
for env_id in to_remove:
del self.envs[env_id]
step_duration = time.time_ns() - pre_step_start
time_to_sleep_ns = self.preferred_sleep_time_ns - (
......@@ -252,21 +261,36 @@ class GameServer:
sleep_function_duration = time.time_ns() - sleep_start
overslept_in_ns = sleep_function_duration - time_to_sleep_ns
def is_known_client_id(self, client_id: str) -> bool:
return client_id in self.client_ids_to_player_hashes
def player_action(self, player_hash: str, action: Action):
if (
player_hash in self.player_data
and action.player == self.player_data[player_hash].player_id
and self.player_data[player_hash].env_id in self.envs
and player_hash
in self.envs[self.player_data[player_hash].env_id].player_hashes
):
self.envs[self.player_data[player_hash].env_id].environment.perform_action(
action
)
class PlayerConnectionManager:
def __init__(self):
self.player_connections: dict[str, WebSocket] = {}
async def connect_player(self, websocket: WebSocket, player_id: str) -> bool:
if player_id not in self.player_connections:
async def connect_player(self, websocket: WebSocket, client_id: str) -> bool:
if client_id not in self.player_connections:
await websocket.accept()
self.player_connections[player_id] = websocket
self.player_connections[client_id] = websocket
return True
return False
def disconnect(self, id_: str):
if id_ in self.player_connections:
del self.player_connections[id_]
def disconnect(self, client_id: str):
if client_id in self.player_connections:
del self.player_connections[client_id]
@staticmethod
async def send_personal_message(message: str, websocket: WebSocket):
......@@ -278,49 +302,43 @@ class PlayerConnectionManager:
manager = PlayerConnectionManager()
oc_api: GameServer = GameServer()
def parse_websocket_action(message: str) -> Action:
if message.replace('"', "") != "get_state":
message_dict = json.loads(message)
if message_dict["act_type"] == "movement":
if isinstance(message_dict["value"], list):
x, y = message_dict["value"]
elif isinstance(message_dict["value"], str):
x, y = (
message_dict["value"]
.replace(" ", "")
.replace("[", "")
.replace("]", "")
.split(",")
)
else:
x, y = 0, 0
value = np.array([x, y], dtype=float)
else:
value = None
action = Action(
message_dict["player_name"],
message_dict["act_type"],
value,
duration=message_dict["duration"],
)
return action
environment_handler: EnvironmentHandler = EnvironmentHandler()
def manage_websocket_message(message: str):
if "get_state" in message:
return oc_api.get_state()
if "reset_game" in message:
oc_api.reset_game()
return "Reset game."
action = parse_websocket_action(message)
oc_api.simulator.enter_action(action)
answer = oc_api.get_state()
return answer
@dataclasses.dataclass
class PlayerAction:
player_hash: str
action: Action
def manage_websocket_message(message: str, client_id: str):
message_dict = json.loads(message)
assert "type" in message_dict, "message needs a type"
match message_dict["type"]:
case "ready":
assert "player_hash" in message_dict, "needs player hash for ready"
environment_handler.set_player_ready(message_dict["player_hash"])
return {
"status": "ready accepted",
"player_hash": message_dict["player_hash"],
}
case "get_state":
assert "player_hash" in message_dict, "needs player hash for environment"
return environment_handler.get_state(message_dict["player_hash"])
case "action":
assert "action" in message_dict, "action type needs action data"
assert "player_hash" in message_dict, "action type needs player hash"
environment_handler.player_action(
message_dict["player_hash"], Action(**message_dict["action"])
)
return {
"status": "action accepted",
"player_hash": message_dict["player_hash"],
}
# TODO setup error enums or class
return {"status": "error", "info": "unknown message type"}
@app.get("/")
......@@ -345,44 +363,55 @@ class AdditionalPlayer(BaseModel):
existing_websocket: str | None = None
@app.post("/manage/create_env")
async def register_manger(creation: CreateEnvironmentConfig):
result = oc_api.create_env(creation)
@app.post("/manage/create_env/")
async def create_env(creation: CreateEnvironmentConfig):
print(creation)
result = environment_handler.create_env(creation)
return result
@app.post("/manage/additional_player")
@app.post("/manage/additional_player/")
async def additional_player(creation: AdditionalPlayer):
result = oc_api.add_player(creation)
result = environment_handler.add_player(creation)
return result
@app.post("manage/stop_env")
@app.post("/manage/stop_env/")
async def stop_env(manager_id: str, env_id: str, reason: str):
result = oc_api.stop_env(manager_id, env_id, reason)
result = environment_handler.stop_env(manager_id, env_id, reason)
return result
# pause / unpause
# control access / functions / data
@app.websocket("/ws/player/{client_id}")
async def websocket_player_endpoint(websocket: WebSocket, client_id: int):
await manager.connect(websocket)
async def websocket_player_endpoint(websocket: WebSocket, client_id: str):
if client_id not in environment_handler.is_known_client_id(client_id):
return
await manager.connect_player(websocket, client_id)
log.debug(f"Client #{client_id} connected")
environment_handler.set_player_connected(client_id)
try:
while True:
message = await websocket.receive_text()
answer = manage_websocket_message(message)
answer = manage_websocket_message(message, client_id)
await manager.send_personal_message(answer, websocket)
except WebSocketDisconnect:
manager.disconnect(websocket)
manager.disconnect(client_id)
environment_handler.set_player_disconnected(client_id)
log.debug(f"Client #{client_id} disconnected")
def main():
uvicorn.run(app, host=WEBSOCKET_URL, port=WEBSOCKET_PORT)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(environment_handler.environment_steps())
config = uvicorn.Config(app, host=WEBSOCKET_URL, port=WEBSOCKET_PORT, loop=loop)
server = uvicorn.Server(config)
loop.run_until_complete(server.serve())
if __name__ == "__main__":
......
......@@ -3,6 +3,7 @@ import json
import logging
import math
import sys
import time
from datetime import timedelta
from enum import Enum
......@@ -10,6 +11,7 @@ import numpy as np
import numpy.typing as npt
import pygame
import pygame_gui
import requests
import yaml
from scipy.spatial import KDTree
from websockets.sync.client import connect
......@@ -20,6 +22,7 @@ from overcooked_simulator.game_items import (
CookingEquipment,
Plate,
)
from overcooked_simulator.game_server import CreateEnvironmentConfig
from overcooked_simulator.gui_2d_vis.game_colors import BLUE
from overcooked_simulator.gui_2d_vis.game_colors import colors, Color
from overcooked_simulator.order import Order
......@@ -36,6 +39,9 @@ class MenuStates(Enum):
End = "End"
MANAGER_ID = "1233245425"
def create_polygon(n, length):
if n == 1:
return np.array([0, 0])
......@@ -107,7 +113,8 @@ class PyGameGUI:
]
# self.websocket_url = "ws://localhost:8765"
self.websocket_url = "ws://localhost:8000/ws/29"
self.websocket_url = "ws://localhost:8000/ws/"
self.websockets = {}
# TODO cache loaded images?
with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file:
......@@ -821,8 +828,50 @@ class PyGameGUI:
def start_button_press(self):
self.menu_state = MenuStates.Game
with connect(self.websocket_url) as websocket:
state = self.request_state()
environment_config_path = ROOT_DIR / "game_content" / "environment_config.yaml"
layout_path = ROOT_DIR / "game_content" / "layouts" / "basic.layout"
item_info_path = ROOT_DIR / "game_content" / "item_info.yaml"
with open(item_info_path, "r") as file:
item_info = file.read()
with open(layout_path, "r") as file:
layout = file.read()
with open(environment_config_path, "r") as file:
environment_config = file.read()
print(
CreateEnvironmentConfig(
manager_id=MANAGER_ID,
number_players=2,
environment_settings={"all_player_can_pause_game": False},
item_info_config=item_info,
environment_config=environment_config,
layout_config=layout,
).model_dump_json()
)
env_info = requests.post(
"http://localhost:8000/manage/create_env/",
json=CreateEnvironmentConfig(
manager_id=MANAGER_ID,
number_players=2,
environment_settings={"all_player_can_pause_game": False},
item_info_config=item_info,
environment_config=environment_config,
layout_config=layout,
).model_dump_json(),
)
print(env_info)
assert isinstance(env_info, dict), "Env info must be a dictionary"
self.current_env_id = env_info["env_id"]
self.player_info = env_info["player_info"]
for player_id, player_info in env_info["player_info"].items():
websocket = connect(self.websocket_url + player_info["client_id"])
websocket.send({"type": "ready", "player_hash": player_info["player_hash"]})
self.websockets[player_id] = websocket
time.sleep(0.1)
state = websocket.send(
{"type": "get_state", "player_hash": player_info["player_hash"]}
)
self.state_player_id = player_id
(
self.window_width,
......@@ -849,7 +898,14 @@ class PyGameGUI:
log.debug("Pressed quit button")
def reset_button_press(self):
_ = self.websocket_communicate("reset_game")
requests.post(
"http://localhost:8000/manage/stop_env",
json={
"manager_id": MANAGER_ID,
"env_id": self.current_env_id,
"reason": "reset button pressed",
},
)
# self.websocket.send(json.dumps("reset_game"))
# answer = self.websocket.recv()
......@@ -867,31 +923,28 @@ class PyGameGUI:
action: The action to be sent. Contains the player, action type and move direction if action is a movement.
"""
if isinstance(action.action, np.ndarray):
value = [float(action.action[0]), float(action.action[1])]
action.action = [float(action.action[0]), float(action.action[1])]
else:
value = action.action
message_dict = {
"player_name": action.player,
"act_type": action.act_type,
"value": value,
"duration": action.duration,
}
_ = self.websocket_communicate(message_dict)
def websocket_communicate(self, message_dict: dict | str):
self.websocket.send(json.dumps(message_dict))
answer = self.websocket.recv()
try:
answer = json.loads(answer)
except json.decoder.JSONDecodeError:
answer = None
return answer
action.action = action.action
ret = self.websockets[action.player].send(
{
"type": "action",
"action": action,
"player_hash": self.player_info[action.player]["player_hash"],
}
)
print(ret)
def request_state(self):
state_dict = self.websocket_communicate("get_state")
state_dict = self.websockets[self.state_player_id].send(
{
"type": "get_state",
"player_hash": self.player_info[self.state_player_id]["player_hash"],
}
)
# self.websocket.send(json.dumps("get_state"))
# state_dict = json.loads(self.websocket.recv())
return state_dict
return json.loads(state_dict)
def start_pygame(self):
"""Starts pygame and the gui loop. Each frame the game state is visualized and keyboard inputs are read."""
......@@ -907,97 +960,91 @@ class PyGameGUI:
self.init_ui_elements()
self.manage_button_visibility()
with connect(self.websocket_url) as websocket:
self.websocket = websocket
# Game loop
self.running = True
while self.running:
try:
time_delta = clock.tick(self.FPS) / 1000.0
for event in pygame.event.get():
if event.type == pygame.QUIT:
self.running = False
# UI Buttons:
if event.type == pygame_gui.UI_BUTTON_PRESSED:
match event.ui_element:
case self.start_button:
self.start_button_press()
case self.back_button:
self.start_button_press()
case self.finished_button:
self.finished_button_press()
case self.quit_button:
self.quit_button_press()
case self.reset_button:
self.reset_button_press()
self.start_button_press()
# Game loop
self.running = True
while self.running:
try:
time_delta = clock.tick(self.FPS) / 1000.0
for event in pygame.event.get():
if event.type == pygame.QUIT:
self.running = False
# UI Buttons:
if event.type == pygame_gui.UI_BUTTON_PRESSED:
match event.ui_element:
case self.start_button:
self.start_button_press()
case self.back_button:
self.start_button_press()
case self.finished_button:
self.finished_button_press()
case self.quit_button:
self.quit_button_press()
case self.reset_button:
self.reset_button_press()
self.start_button_press()
self.manage_button_visibility()
self.manage_button_visibility()
if (
event.type in [pygame.KEYDOWN, pygame.KEYUP]
and self.menu_state == MenuStates.Game
):
pass
self.handle_key_event(event)
if (
event.type in [pygame.KEYDOWN, pygame.KEYUP]
and self.menu_state == MenuStates.Game
):
pass
self.handle_key_event(event)
self.manager.process_events(event)
self.manager.process_events(event)
# drawing:
# drawing:
# state = self.simulator.get_state()
# state = self.simulator.get_state()
self.main_window.fill(
colors[
self.visualization_config["GameWindow"]["background_color"]
]
)
self.manager.draw_ui(self.main_window)
self.main_window.fill(
colors[self.visualization_config["GameWindow"]["background_color"]]
)
self.manager.draw_ui(self.main_window)
match self.menu_state:
case MenuStates.Start:
pass
match self.menu_state:
case MenuStates.Start:
pass
case MenuStates.Game:
state = self.request_state()
case MenuStates.Game:
state = self.request_state()
self.draw_background()
self.draw_background()
self.handle_keys()
self.handle_keys()
# state = self.simulator.get_state()
self.draw(state)
# state = self.simulator.get_state()
self.draw(state)
if state["ended"]:
self.finished_button_press()
self.manage_button_visibility()
else:
self.draw(state)
if state["ended"]:
self.finished_button_press()
self.manage_button_visibility()
else:
self.draw(state)
game_screen_rect = self.game_screen.get_rect()
game_screen_rect.center = [
self.window_width // 2,
self.window_height // 2,
]
game_screen_rect = self.game_screen.get_rect()
game_screen_rect.center = [
self.window_width // 2,
self.window_height // 2,
]
self.main_window.blit(
self.game_screen, game_screen_rect
)
self.main_window.blit(self.game_screen, game_screen_rect)
case MenuStates.End:
self.update_conclusion_label(state)
case MenuStates.End:
self.update_conclusion_label(state)
self.manager.update(time_delta)
pygame.display.flip()
self.manager.update(time_delta)
pygame.display.flip()
except KeyboardInterrupt:
pygame.quit()
sys.exit()
except KeyboardInterrupt:
pygame.quit()
sys.exit()
pygame.quit()
sys.exit()
pygame.quit()
sys.exit()
def main():
......
......@@ -20,6 +20,7 @@ requirements = [
"fastapi",
"uvicorn",
"websockets",
"requests",
]
test_requirements = [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment