Skip to content
Snippets Groups Projects
Commit 8d82c086 authored by Florian Schröder's avatar Florian Schröder
Browse files

Add manager ID argument to Overcooked simulator

An optional argument for the list of manager IDs was added across the Overcooked simulator. These IDs are now used for authorization purposes when creating environments. Unauthorized attempts now return HTTP error 403. This helps enhance security and control over environment management in the simulation.
parent eff27363
No related branches found
No related tags found
No related merge requests found
Pipeline #44867 passed
......@@ -5,19 +5,20 @@ from multiprocessing import Process
from overcooked_simulator.utils import (
url_and_port_arguments,
disable_websocket_logging_arguments,
add_list_of_manager_ids_arguments,
)
def start_game_server(cli_args):
from overcooked_simulator.game_server import main
main(cli_args.url, cli_args.port)
main(cli_args.url, cli_args.port, cli_args.manager_ids)
def start_pygame_gui(cli_args):
from overcooked_simulator.gui_2d_vis.overcooked_gui import main
main(cli_args.url, cli_args.port)
main(cli_args.url, cli_args.port, cli_args.manager_ids)
def main(cli_args=None):
......@@ -56,6 +57,7 @@ if __name__ == "__main__":
url_and_port_arguments(parser)
disable_websocket_logging_arguments(parser)
add_list_of_manager_ids_arguments(parser)
args = parser.parse_args()
print(args)
......
......@@ -35,7 +35,12 @@ from overcooked_simulator.server_results import (
PlayerInfo,
PlayerRequestResult,
)
from overcooked_simulator.utils import setup_logging, url_and_port_arguments
from overcooked_simulator.utils import (
setup_logging,
url_and_port_arguments,
add_list_of_manager_ids_arguments,
disable_websocket_logging_arguments,
)
log = logging.getLogger(__name__)
......@@ -97,10 +102,11 @@ class EnvironmentHandler:
"""The preferred sleep time between environment steps in nanoseconds based on the `env_step_frequency`."""
self.client_ids_to_player_hashes = {}
"""A dictionary mapping client IDs to player hashes."""
self.allowed_manager: list[str] = []
def create_env(
self, environment_config: CreateEnvironmentConfig
) -> CreateEnvResult:
) -> CreateEnvResult | int:
"""Create a new environment.
Args:
......@@ -110,6 +116,8 @@ class EnvironmentHandler:
A dictionary containing the created environment ID and player information.
"""
if environment_config.manager_id not in self.allowed_manager:
return 1
env_id = uuid.uuid4().hex
env = Environment(
......@@ -482,6 +490,9 @@ class EnvironmentHandler:
return True
return False
def extend_allowed_manager(self, manager: list[str]):
self.allowed_manager.extend(manager)
class PlayerConnectionManager:
"""
......@@ -540,7 +551,7 @@ class PlayerConnectionManager:
await connection.send_text(message)
manager = PlayerConnectionManager()
connection_manager = PlayerConnectionManager()
environment_handler: EnvironmentHandler = EnvironmentHandler()
......@@ -656,6 +667,8 @@ class AdditionalPlayer(BaseModel):
@app.post("/manage/create_env/")
async def create_env(creation: CreateEnvironmentConfig) -> CreateEnvResult:
result = environment_handler.create_env(creation)
if result == 1:
raise HTTPException(status_code=403, detail="Manager ID not known/registered.")
return result
......@@ -694,7 +707,7 @@ async def websocket_player_endpoint(websocket: WebSocket, client_id: str):
if not environment_handler.is_known_client_id(client_id):
log.warning(f"wrong websocket connection with {client_id=}")
return
await manager.connect_player(websocket, client_id)
await connection_manager.connect_player(websocket, client_id)
log.debug(f"Client #{client_id} connected")
environment_handler.set_player_connected(client_id)
try:
......@@ -703,17 +716,18 @@ async def websocket_player_endpoint(websocket: WebSocket, client_id: str):
answer = manage_websocket_message(message, client_id)
if isinstance(answer, dict):
answer = json.dumps(answer)
await manager.send_personal_message(answer, websocket)
await connection_manager.send_personal_message(answer, websocket)
except WebSocketDisconnect:
manager.disconnect(client_id)
connection_manager.disconnect(client_id)
environment_handler.set_player_disconnected(client_id)
log.debug(f"Client #{client_id} disconnected")
def main(host: str, port: int):
def main(host: str, port: int, manager_ids: list[str]):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
environment_handler.extend_allowed_manager(manager_ids)
loop.create_task(environment_handler.environment_steps())
config = uvicorn.Config(app, host=host, port=port, loop=loop)
server = uvicorn.Server(config)
......@@ -728,9 +742,11 @@ if __name__ == "__main__":
)
url_and_port_arguments(parser)
disable_websocket_logging_arguments(parser)
add_list_of_manager_ids_arguments(parser)
args = parser.parse_args()
setup_logging(args.enable_websocket_logging)
main(args.url, args.port)
main(args.url, args.port, args.manager_ids)
"""
Or in console:
uvicorn overcooked_simulator.fastapi_game_server:app --reload
......
......@@ -2,6 +2,7 @@ import argparse
import dataclasses
import json
import logging
import random
import sys
from enum import Enum
......@@ -26,6 +27,7 @@ from overcooked_simulator.utils import (
setup_logging,
url_and_port_arguments,
disable_websocket_logging_arguments,
add_list_of_manager_ids_arguments,
)
......@@ -35,9 +37,6 @@ class MenuStates(Enum):
End = "End"
MANAGER_ID = "1233245425"
log = logging.getLogger(__name__)
......@@ -73,7 +72,8 @@ class PyGameGUI:
player_names: list[str | int],
player_keys: list[pygame.key],
url: str,
port: str,
port: int,
manager_ids: list[str],
):
self.game_screen = None
self.FPS = 60
......@@ -93,6 +93,7 @@ class PyGameGUI:
self.websockets = {}
self.request_url = f"http://{url}:{port}"
self.manager_id = random.choice(manager_ids)
# TODO cache loaded images?
with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file:
......@@ -444,7 +445,7 @@ class PyGameGUI:
with open(environment_config_path, "r") as file:
environment_config = file.read()
creation_json = CreateEnvironmentConfig(
manager_id=MANAGER_ID,
manager_id=self.manager_id,
number_players=2,
environment_settings={"all_player_can_pause_game": False},
item_info_config=item_info,
......@@ -507,7 +508,7 @@ class PyGameGUI:
requests.post(
f"{self.request_url}/manage/stop_env",
json={
"manager_id": MANAGER_ID,
"manager_id": self.manager_id,
"env_id": self.current_env_id,
"reason": "reset button pressed",
},
......@@ -520,7 +521,7 @@ class PyGameGUI:
requests.post(
f"{self.request_url}/manage/stop_env",
json={
"manager_id": MANAGER_ID,
"manager_id": self.manager_id,
"env_id": self.current_env_id,
"reason": "finish button pressed",
},
......@@ -670,7 +671,7 @@ class PyGameGUI:
sys.exit()
def main(url, port):
def main(url: str, port: int, manager_ids: list[str]):
# TODO maybe read the player names and keyboard keys from config file?
keys1 = [
pygame.K_LEFT,
......@@ -684,7 +685,11 @@ def main(url, port):
number_players = 2
gui = PyGameGUI(
list(map(str, range(number_players))), [keys1, keys2], url=url, port=port
list(map(str, range(number_players))),
[keys1, keys2],
url=url,
port=port,
manager_ids=manager_ids,
)
gui.start_pygame()
......@@ -698,6 +703,7 @@ if __name__ == "__main__":
url_and_port_arguments(parser)
disable_websocket_logging_arguments(parser)
add_list_of_manager_ids_arguments(parser)
args = parser.parse_args()
setup_logging(enable_websocket_logging=args.enable_websocket_logging)
main(args.url, args.port)
main(args.url, args.port, args.manager_ids)
import logging
import os
import sys
import uuid
from datetime import datetime
from enum import Enum
......@@ -66,3 +67,14 @@ def url_and_port_arguments(parser):
def disable_websocket_logging_arguments(parser):
parser.add_argument("--enable-websocket-logging", action="store_true", default=True)
def add_list_of_manager_ids_arguments(parser):
parser.add_argument(
"-m",
"--manager_ids",
nargs="+",
type=list[str],
default=[uuid.uuid4().hex],
help="List of manager IDs that can create environments.",
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment