Skip to content
Snippets Groups Projects
  • Florian Schröder's avatar
    794a4414
    Update argument names in function calls · 794a4414
    Florian Schröder authored
    The argument names in the main function calls within study_server.py and game_server.py were updated for clarity. These changes better reflect their function within these scripts, making the code easier to understand and maintain. The modifications primarily involve changing 'url' and 'port' to 'game_url' and 'game_port', and 'port' to 'study_port'.
    794a4414
    History
    Update argument names in function calls
    Florian Schröder authored
    The argument names in the main function calls within study_server.py and game_server.py were updated for clarity. These changes better reflect their function within these scripts, making the code easier to understand and maintain. The modifications primarily involve changing 'url' and 'port' to 'game_url' and 'game_port', and 'port' to 'study_port'.
study_server.py 15.32 KiB
"""
# Usage
- Set `CONNECT_WITH_STUDY_SERVER` in gui.py to True.
- Run this script. Copy the manager id that is printed
- Run the game_server.py script with the manager id copied from the terminal
```
python game_server.py --manager_ids COPIED_UUID
```
- Run 2 gui.py scripts in different terminals. For more players change `NUMBER_PLAYER_PER_ENV` and start more guis.

The environment starts when all players connected.
"""

import argparse
import asyncio
import logging
import os
import random
import signal
import subprocess
from pathlib import Path
from subprocess import Popen
from typing import Tuple, TypedDict

import requests
import uvicorn
import yaml
from fastapi import FastAPI

from cooperative_cuisine import ROOT_DIR
from cooperative_cuisine.environment import EnvironmentConfig
from cooperative_cuisine.game_server import CreateEnvironmentConfig
from cooperative_cuisine.server_results import PlayerInfo
from cooperative_cuisine.utils import (
    url_and_port_arguments,
    add_list_of_manager_ids_arguments,
    expand_path,
    add_study_arguments,
)

NUMBER_PLAYER_PER_ENV = 2

log = logging.getLogger(__name__)

app = FastAPI()


# HARDCODED_MANAGER_ID = "1234"

USE_AAAMBOS_AGENT = False


class LevelConfig(TypedDict):
    name: str
    config_path: str
    layout_path: str
    item_info_path: str


class LevelInfo(TypedDict):
    name: str
    last_level: bool
    recipes: list[str]
    recipe_graphs: list[dict]


class StudyConfig(TypedDict):
    levels: list[LevelConfig]
    num_players: int
    num_bots: int


class StudyState:
    def __init__(self, study_config_path: str | Path, game_url, game_port):
        with open(study_config_path, "r") as file:
            env_config_f = file.read()

        self.study_config: StudyConfig = yaml.load(
            str(env_config_f), Loader=yaml.SafeLoader
        )
        self.levels: list[LevelConfig] = self.study_config["levels"]
        self.current_level_idx: int = 0

        self.participant_id_to_player_info = {}
        self.player_ids = {}
        self.num_connected_players: int = 0

        self.current_running_env = None
        self.next_level_env = None
        self.players_done = {}

        self.use_aaambos_agent = False

        self.websocket_url = f"ws://{game_url}:{game_port}/ws/player/"
        print("WS:", self.websocket_url)
        self.sub_processes = []

        self.current_item_info = None
        self.current_config = None

    @property
    def study_done(self):
        return self.current_level_idx >= len(self.levels)

    @property
    def last_level(self):
        return self.current_level_idx >= len(self.levels) - 1

    @property
    def is_full(self):
        return (
            len(self.participant_id_to_player_info) == self.study_config["num_players"]
        )

    def can_add_participant(self, num_participants: int) -> bool:
        filled = (
            self.num_connected_players + num_participants
            <= self.study_config["num_players"]
        )
        return filled and not self.is_full

    def create_env(self, level):
        item_info_path = expand_path(level["item_info_path"])
        layout_path = expand_path(level["layout_path"])
        config_path = expand_path(level["config_path"])

        with open(item_info_path, "r") as file:
            item_info = file.read()
            self.current_item_info: EnvironmentConfig = yaml.load(
                item_info, Loader=yaml.Loader
            )
        with open(layout_path, "r") as file:
            layout = file.read()
        with open(config_path, "r") as file:
            environment_config = file.read()
            self.current_config: EnvironmentConfig = yaml.load(
                environment_config, Loader=yaml.Loader
            )
        seed = int(random.random() * 1000000)
        print(seed)
        creation_json = CreateEnvironmentConfig(
            manager_id=study_manager.server_manager_id,
            number_players=self.study_config["num_players"]
            + self.study_config["num_bots"],
            environment_settings={"all_player_can_pause_game": False},
            item_info_config=item_info,
            environment_config=environment_config,
            layout_config=layout,
            seed=seed,
        ).model_dump(mode="json")

        env_info = requests.post(
            study_manager.game_server_url + "/manage/create_env/", json=creation_json
        )

        if env_info.status_code == 403:
            raise ValueError(f"Forbidden Request: {env_info.json()['detail']}")
        env_info = env_info.json()

        player_info = env_info["player_info"]
        for idx, (player_id, player_info) in enumerate(player_info.items()):
            if idx >= self.study_config["num_players"]:
                self.create_and_connect_bot(player_id, player_info)
        return env_info

    def start_level(self):
        level = self.levels[self.current_level_idx]
        self.current_running_env = self.create_env(level)

    def next_level(self):
        requests.post(
            f"{study_manager.game_server_url}/manage/stop_env/",
            json={
                "manager_id": study_manager.server_manager_id,
                "env_id": self.current_running_env["env_id"],
                "reason": "Next level",
            },
        )

        self.current_level_idx += 1
        if not self.study_done:
            self.start_level()
            for (
                participant_id,
                player_info,
            ) in self.participant_id_to_player_info.items():
                new_player_info = {
                    player_name: self.current_running_env["player_info"][player_name]
                    for player_name in player_info.keys()
                }
                self.participant_id_to_player_info[participant_id] = new_player_info

            for key in self.players_done:
                self.players_done[key] = False

    def add_participant(self, participant_id: str, number_players: int):
        player_names = [
            str(self.num_connected_players + i) for i in range(number_players)
        ]
        player_info = {
            player_name: self.current_running_env["player_info"][player_name]
            for player_name in player_names
        }
        self.participant_id_to_player_info[participant_id] = player_info
        self.num_connected_players += number_players
        return player_info

    def player_finished_level(self, participant_id):
        self.players_done[participant_id] = True
        if all(self.players_done.values()):
            self.next_level()

    def get_connection(self, participant_id: str):
        player_info = self.participant_id_to_player_info[participant_id]
        current_level = self.levels[self.current_level_idx]
        if self.current_config["meals"]["all"]:
            recipes = ["all"]
        else:
            recipes = self.current_config["meals"]["list"]
        level_info = LevelInfo(
            name=current_level["name"],
            last_level=self.last_level,
            recipes=recipes,
            recipe_graphs=self.current_running_env["recipe_graphs"],
        )
        return player_info, level_info

    def create_and_connect_bot(self, player_id, player_info):
        player_hash = player_info["player_hash"]
        print(
            f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"'
        )
        if self.use_aaambos_agent:
            sub = Popen(
                " ".join(
                    [
                        "exec",
                        "aaambos",
                        "run",
                        "--arch_config",
                        str(ROOT_DIR / "configs" / "agents" / "arch_config.yml"),
                        "--run_config",
                        str(ROOT_DIR / "configs" / "agents" / "run_config.yml"),
                        f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"',
                        f"--instance={player_hash}",
                    ]
                ),
                shell=True,
            )
        else:
            sub = Popen(
                " ".join(
                    [
                        "python",
                        str(ROOT_DIR / "configs" / "agents" / "random_agent.py"),
                        f'--uri {self.websocket_url + player_info["client_id"]}',
                        f"--player_hash {player_hash}",
                        f"--player_id {player_id}",
                    ]
                ),
                shell=True,
            )
        self.sub_processes.append(sub)

    def kill_bots(self):
        for sub in self.sub_processes:
            try:
                if self.use_aaambos_agent:
                    pgrp = os.getpgid(sub.pid)
                    os.killpg(pgrp, signal.SIGINT)
                    subprocess.run(
                        "kill $(ps aux | grep 'aaambos' | awk '{print $2}')", shell=True
                    )
                else:
                    sub.kill()

            except ProcessLookupError:
                pass

        self.sub_processes = []

    def __repr__(self):
        return f"Study({self.current_running_env['env_id']})"


class StudyManager:
    def __init__(self):
        self.game_host: str | None = None
        self.game_port: str | None = None
        self.game_server_url: str | None = None
        self.server_manager_id: str | None = None

        self.running_studies: list[StudyState] = []

        self.participant_id_to_study_map: dict[str, StudyState] = {}
        self.running_envs: dict[str, Tuple[int, dict[str, PlayerInfo], list[str]]] = {}
        self.current_free_envs = []

        self.running_tutorials: dict[
            str, Tuple[int, dict[str, PlayerInfo], list[str]]
        ] = {}

        self.study_config_path = ROOT_DIR / "configs" / "study" / "study_config.yml"

    def create_study(self):
        study = StudyState(
            self.study_config_path,
            self.game_host,
            self.game_port,
        )
        study.start_level()
        self.running_studies.append(study)

    def add_participant(self, participant_id: str, number_players: int):
        player_info = None
        if not self.running_studies or all(
            [not s.can_add_participant(number_players) for s in self.running_studies]
        ):
            self.create_study()

        for study in self.running_studies:
            if study.can_add_participant(number_players):
                player_info = study.add_participant(participant_id, number_players)
                self.participant_id_to_study_map[participant_id] = study
        return player_info

    def player_finished_level(self, participant_id: str):
        assigned_study = self.participant_id_to_study_map[participant_id]
        assigned_study.player_finished_level(participant_id)

    def get_participant_game_connection(self, participant_id: str):
        assigned_study = self.participant_id_to_study_map[participant_id]
        player_info, level_info = assigned_study.get_connection(participant_id)
        return player_info, level_info

    def set_game_server_url(self, game_host, game_port):
        self.game_host = game_host
        self.game_port = game_port
        self.game_server_url = f"http://{self.game_host}:{self.game_port}"

    def set_manager_id(self, manager_id: str):
        self.server_manager_id = manager_id

    def set_study_config(self, study_config_path: str):
        # TODO validate study_config?
        self.study_config_path = study_config_path


study_manager = StudyManager()


@app.post("/start_study/{participant_id}/{number_players}")
async def start_study(participant_id: str, number_players: int):
    player_info = study_manager.add_participant(participant_id, number_players)
    return player_info


@app.post("/level_done/{participant_id}")
async def level_done(participant_id: str):
    last_level = study_manager.player_finished_level(participant_id)


@app.post("/get_game_connection/{participant_id}")
async def get_game_connection(participant_id: str):
    player_info, level_info = study_manager.get_participant_game_connection(
        participant_id
    )
    return {"player_info": player_info, "level_info": level_info}


@app.post("/connect_to_tutorial/{participant_id}")
async def want_to_play_tutorial(participant_id: str):
    environment_config_path = ROOT_DIR / "configs" / "tutorial_env_config.yaml"
    layout_path = ROOT_DIR / "configs" / "layouts" / "tutorial.layout"
    item_info_path = ROOT_DIR / "configs" / "item_info.yaml"

    with open(item_info_path, "r") as file:
        item_info = file.read()
    with open(layout_path, "r") as file:
        layout = file.read()
    with open(environment_config_path, "r") as file:
        environment_config = file.read()

    print("STUDY MANAGER ID", study_manager.server_manager_id)
    creation_json = CreateEnvironmentConfig(
        manager_id=study_manager.server_manager_id,
        number_players=1,
        environment_settings={"all_player_can_pause_game": False},
        item_info_config=item_info,
        environment_config=environment_config,
        layout_config=layout,
        seed=1234567890,
    ).model_dump(mode="json")
    # todo async
    env_info = requests.post(
        study_manager.game_server_url + "/manage/create_env/", json=creation_json
    )

    if env_info.status_code == 403:
        raise ValueError(f"Forbidden Request: {env_info.json()['detail']}")
    env_info = env_info.json()
    study_manager.running_tutorials[participant_id] = env_info
    return env_info["player_info"]["0"]


@app.post("/disconnect_from_tutorial/{participant_id}")
async def want_to_play_tutorial(participant_id: str):
    requests.post(
        f"{study_manager.game_server_url}/manage/stop_env/",
        json={
            "manager_id": study_manager.server_manager_id,
            "env_id": study_manager.running_tutorials[participant_id]["env_id"],
            "reason": "Finished tutorial",
        },
    )


def main(study_host, study_port, game_host, game_port, manager_ids, study_config_path):
    study_manager.set_game_server_url(game_host=game_host, game_port=game_port)
    study_manager.set_manager_id(manager_id=manager_ids[0])
    study_manager.set_study_config(study_config_path=study_config_path)

    print(
        f"Use {study_manager.server_manager_id=} for game_server_url=http://{game_host}:{game_port}"
    )
    loop = asyncio.new_event_loop()
    config = uvicorn.Config(app, host=study_host, port=study_port, loop=loop)
    server = uvicorn.Server(config)
    loop.run_until_complete(server.serve())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="Cooperative Cuisine Study Server",
        description="Study Server: Match Making, client pre and post managing.",
        epilog="For further information, see https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/overcooked_simulator.html",
    )
    url_and_port_arguments(
        parser=parser,
        server_name="Study Server",
        default_study_port=8080,
        default_game_port=8000,
    )
    add_list_of_manager_ids_arguments(parser=parser)
    add_study_arguments(parser=parser)
    args = parser.parse_args()

    game_server_url = f"https://{args.game_url}:{args.game_port}"
    main(
        args.study_url,
        args.study_port,
        game_host=args.game_url,
        game_port=args.game_port
        manager_ids=args.manager_ids,
        study_config_path=args.study_config,
    )