diff --git a/README.md b/README.md index ce6b0966e29b243df4b6e08c86616f5dd1d80706..41178079ab45f382edc3ba1c8cb9b84af288801e 100644 --- a/README.md +++ b/README.md @@ -46,9 +46,9 @@ You can also start the **Game Server**m **Study Server** (Matchmaking),and the * terminals. ```bash -python3 cooperative_cuisine/game_server.py -g localhost -gp 8000 --manager_ids SECRETKEY1 SECRETKEY2 +python3 cooperative_cuisine/game_server.py -g localhost -gp 8000 --manager-ids SECRETKEY1 SECRETKEY2 -python3 cooperative_cuisine/study_server.py -s localhost -sp 8080 --manager_ids SECRETKEY1 +python3 cooperative_cuisine/study_server.py -s localhost -sp 8080 --manager-ids SECRETKEY1 python3 cooperative_cuisine/pygame_2d_vis/gui.py -s localhost -sp 8080 -g localhost -gp 8000 ``` diff --git a/cooperative_cuisine/__init__.py b/cooperative_cuisine/__init__.py index a8f49b412dcd526542bb9d6b740753b34e8e772a..a0577c01a74d76250c0152a73dcd51e1ab5093a9 100644 --- a/cooperative_cuisine/__init__.py +++ b/cooperative_cuisine/__init__.py @@ -49,9 +49,9 @@ cooperative_cuisine -s localhost -sp 8080 -g localhost -gp 8000 You can also start the **Game Server**, **Study Server** (Matchmaking),and the **PyGame GUI** individually in different terminals. ```bash -python3 cooperative_cuisine/game_server.py -g localhost -gp 8000 --manager_ids SECRETKEY1 SECRETKEY2 +python3 cooperative_cuisine/game_server.py -g localhost -gp 8000 --manager-ids SECRETKEY1 SECRETKEY2 -python3 cooperative_cuisine/study_server.py -s localhost -sp 8080 --manager_ids SECRETKEY1 +python3 cooperative_cuisine/study_server.py -s localhost -sp 8080 --manager-ids SECRETKEY1 python3 cooperative_cuisine/pygame_2d_vis/gui.py -s localhost -sp 8080 -g localhost -gp 8000 ``` diff --git a/cooperative_cuisine/configs/environment_config.yaml b/cooperative_cuisine/configs/environment_config.yaml index 608eb9728bb07f5e49e8d9812864dd5a06c00cd1..a6c7c752d93e7e593c1a913ce32291f7495ac0fe 100644 --- a/cooperative_cuisine/configs/environment_config.yaml +++ b/cooperative_cuisine/configs/environment_config.yaml @@ -202,7 +202,13 @@ extra_setup_functions: log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl add_hook_ref: true - + empty_info_msg: + func: !!python/name:cooperative_cuisine.hooks.hooks_via_callback_class '' + kwargs: + hooks: [ action_put ] + callback_class: !!python/name:cooperative_cuisine.info_msg.InfoMsgManager '' + callback_class_kwargs: + msg: "" # info_msg: # func: !!python/name:cooperative_cuisine.hooks.hooks_via_callback_class '' # kwargs: diff --git a/cooperative_cuisine/game_server.py b/cooperative_cuisine/game_server.py index d27db578bf1c98ddfc7fc4d9d88ddd56a58e178c..9f4271d6f083a6c9170aa9a44386552be16b5ddf 100644 --- a/cooperative_cuisine/game_server.py +++ b/cooperative_cuisine/game_server.py @@ -159,6 +159,11 @@ class EnvironmentHandler: return 1 env_id = uuid.uuid4().hex + if environment_config.number_players < 1: + raise HTTPException( + status_code=409, detail="Number players need to be positive." + ) + env = Environment( env_config=environment_config.environment_config, layout_config=environment_config.layout_config, @@ -322,10 +327,9 @@ class EnvironmentHandler: if ( manager_id in self.manager_envs and env_id in self.manager_envs[manager_id] - and self.envs[env_id].status - not in [EnvironmentStatus.STOPPED, EnvironmentStatus.PAUSED] + and self.envs[env_id].status == EnvironmentStatus.PAUSED ): - self.envs[env_id].status = EnvironmentStatus.PAUSED + self.envs[env_id].status = EnvironmentStatus.RUNNING self.envs[env_id].last_step_time = time.time_ns() def stop_env(self, manager_id: str, env_id: str, reason: str) -> int: @@ -676,7 +680,9 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul "player_hash": ws_message.player_hash, } case PlayerRequestType.ACTION: - assert ws_message.action is not None + assert ( + ws_message.action is not None + ), "websocket msg type action needs field action filled" if isinstance(ws_message.action.action_data, list): ws_message.action.action_data = np.array( ws_message.action.action_data, dtype=float @@ -714,7 +720,7 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul @app.get("/") def read_root(): - return {"OVER": "COOKED"} + return {"Cooperative": "Cuisine"} class CreateEnvironmentConfig(BaseModel): diff --git a/cooperative_cuisine/study_server.py b/cooperative_cuisine/study_server.py index ba95e28e4b783db31f7e97ecefe092a24ca5506f..2f794f0e94dc0e03476aff7fd1e5ef158a8e7bdb 100644 --- a/cooperative_cuisine/study_server.py +++ b/cooperative_cuisine/study_server.py @@ -4,7 +4,7 @@ - Run this script. Copy the manager id that is printed - Run the game_server.py script with the manager id copied from the terminal ``` -python game_server.py --manager_ids COPIED_UUID +python game_server.py --manager-ids COPIED_UUID ``` - Run 2 gui.py scripts in different terminals. For more players change `NUMBER_PLAYER_PER_ENV` and start more guis. @@ -50,6 +50,10 @@ USE_AAAMBOS_AGENT = False """Use the aaambos random agents instead of the simpler python script agents.""" +def request_game_server(game_server: str, json_data: dict): + return requests.post(game_server, json=json_data) + + class LevelConfig(BaseModel): """Configuration of a level in the study.""" @@ -190,8 +194,9 @@ class Study: seed=seed, ).model_dump(mode="json") - env_info = requests.post( - study_manager.game_server_url + "/manage/create_env/", json=creation_json + env_info = request_game_server( + study_manager.game_server_url + "/manage/create_env/", + json_data=creation_json, ) if env_info.status_code == 403: @@ -217,9 +222,9 @@ class Study: """Stops the last environment, starts the next one and remaps the participants to the new player infos. """ - requests.post( + request_game_server( f"{study_manager.game_server_url}/manage/stop_env/", - json={ + json_data={ "manager_id": study_manager.server_manager_id, "env_id": self.current_running_env["env_id"], "reason": "Next level", @@ -364,13 +369,15 @@ class StudyManager: def __init__(self): """Constructor of the StudyManager class.""" - self.game_host: str + self.game_host: str = "localhost" """Host address of the game server where the studies are running their environments.""" - self.game_port: int + self.game_port: int = 8000 """Port of the game server where the studies are running their environments.""" - self.game_server_url: str + self.game_server_url: str = "" """Combined URL of the game server where the studies are running their environments.""" - self.server_manager_id: str + self.create_game_server_url() + + self.server_manager_id: str = "" """Manager id of this manager which will be registered in the game server.""" self.running_studies: list[Study] = [] """List of currently running studies.""" @@ -381,7 +388,7 @@ class StudyManager: self.running_tutorials: dict[str, CreateEnvResult] = {} """Dict which saves currently running tutorial envs, as these do not need advanced player management.""" - self.study_config_path = ROOT_DIR / "configs" / "study" / "study_config.yml" + self.study_config_path = ROOT_DIR / "configs" / "study" / "study_config.yaml" """Path to the configuration file for the studies.""" def create_study(self): @@ -470,6 +477,9 @@ class StudyManager: """ self.game_host = game_host self.game_port = game_port + self.create_game_server_url() + + def create_game_server_url(self): self.game_server_url = f"http://{self.game_host}:{self.game_port}" def set_manager_id(self, manager_id: str): @@ -513,8 +523,9 @@ class StudyManager: seed=1234567890, ).model_dump(mode="json") # todo async - env_info = requests.post( - study_manager.game_server_url + "/manage/create_env/", json=creation_json + env_info = request_game_server( + study_manager.game_server_url + "/manage/create_env/", + json_data=creation_json, ) match env_info.status_code: case 200: @@ -534,9 +545,9 @@ class StudyManager: def end_tutorial(self, participant_id: str): env = study_manager.running_tutorials[participant_id] - answer = requests.post( + answer = request_game_server( f"{study_manager.game_server_url}/manage/stop_env/", - json={ + json_data={ "manager_id": study_manager.server_manager_id, "env_id": env["env_id"], "reason": "Finished tutorial", diff --git a/cooperative_cuisine/utils.py b/cooperative_cuisine/utils.py index 7cb4e28eab9d1929a382142c5ed3c28769ec83dd..fd1effc447897b7b5475e546f30e899b277e5022 100644 --- a/cooperative_cuisine/utils.py +++ b/cooperative_cuisine/utils.py @@ -364,7 +364,7 @@ def add_list_of_manager_ids_arguments(parser): """ parser.add_argument( "-m", - "--manager_ids", + "--manager-ids", nargs="+", type=str, default=[uuid.uuid4().hex], @@ -439,17 +439,23 @@ class NumpyAndDataclassEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, obj) -def create_layout_with_counters(w, h): +def create_layout_with_counters(w, h) -> str: """Print a layout string that has counters at the world borders. Args: w: The width of the layout. h: The height of the layout. + + Returns: + str of the layout """ + string = "" for y in range(h): for x in range(w): if x == 0 or y == 0 or x == w - 1 or y == h - 1: - print("#", end="") + string += "#" else: - print("_", end="") - print("") + string += "_" + string += "\n" + print(string) + return string diff --git a/setup.py b/setup.py index bd15680e42e22ed732f2ba8f9056aef547fb163c..bd89c44bbebaceae88a2c0b1e728eb3b95821e12 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ requirements = [ "networkx", ] -test_requirements = ["pytest>=3", "pytest-cov>=4.1"] +test_requirements = ["pytest>=3", "pytest-cov>=4.1", "httpx"] setup( author="Annika Österdiekhoff, Dominik Battefeld, Fabian Heinrich, Florian Schröder", diff --git a/tests/test_game_server.py b/tests/test_game_server.py new file mode 100644 index 0000000000000000000000000000000000000000..f072fed2dd5413ad016e64406146262c93723e53 --- /dev/null +++ b/tests/test_game_server.py @@ -0,0 +1,284 @@ +import asyncio +import json + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from cooperative_cuisine import ROOT_DIR +from cooperative_cuisine.action import ActionType +from cooperative_cuisine.game_server import ( + app, + environment_handler, + CreateEnvironmentConfig, + ManageEnv, +) +from cooperative_cuisine.server_results import CreateEnvResult +from cooperative_cuisine.state_representation import StateRepresentation + +environment_handler.extend_allowed_manager(["123"]) + + +@pytest.fixture +def create_env_config(): + layout_path = ROOT_DIR / "configs" / "layouts" / "tutorial.layout" + environment_config_path = ROOT_DIR / "configs" / "tutorial_env_config.yaml" + item_info_path = ROOT_DIR / "configs" / "item_info.yaml" + with open(item_info_path, "r") as file: + item_info = file.read() + with open(layout_path, "r") as file: + layout = file.read() + with open(environment_config_path, "r") as file: + environment_config = file.read() + + return CreateEnvironmentConfig( + manager_id="123", + number_players=1, + environment_settings={"all_player_can_pause_game": False}, + item_info_config=item_info, + layout_config=layout, + environment_config=environment_config, + seed=123, + ) + + +def test_create_env(create_env_config): + with TestClient(app) as client: + res = client.post( + "/manage/create_env/", + json=create_env_config.model_dump(mode="json"), + ) + + assert res.status_code == status.HTTP_200_OK + create_env_result = CreateEnvResult(**res.json()) + assert len(create_env_result["player_info"]) == 1 + assert isinstance(create_env_result["env_id"], str) + + +def test_invalid_manager_id_create_env(create_env_config): + create_env_config.manager_id = "!" + with TestClient(app) as client: + res = client.post( + "/manage/create_env/", + json=create_env_config.model_dump(mode="json"), + ) + assert res.status_code == status.HTTP_403_FORBIDDEN + assert res.json() == {"detail": "Manager ID not known/registered."} + + +def test_invalid_create_env_config(create_env_config): + create_env_config.number_players = -1 + with TestClient(app) as client: + res = client.post( + "/manage/create_env/", + json=create_env_config.model_dump(mode="json"), + ) + assert res.status_code == status.HTTP_409_CONFLICT + assert res.json() == {"detail": "Number players need to be positive."} + + +def test_stop_env(create_env_config): + with TestClient(app) as client: + res = client.post( + "/manage/create_env/", + json=create_env_config.model_dump(mode="json"), + ) + + with TestClient(app) as client: + res = client.post( + "/manage/stop_env/", + json=ManageEnv( + manager_id="123", env_id=res.json()["env_id"], reason="test" + ).model_dump(mode="json"), + ) + + assert res.status_code == status.HTTP_200_OK + with TestClient(app) as client: + res = client.post( + "/manage/stop_env/", + json=ManageEnv(manager_id="123", env_id="123456", reason="test").model_dump( + mode="json" + ), + ) + + assert res.status_code == status.HTTP_403_FORBIDDEN + + +def test_websocket(create_env_config): + with TestClient(app) as client: + environment_handler.envs = {} + res = client.post( + "/manage/create_env/", + json=create_env_config.model_dump(mode="json"), + ) + player_hash = res.json()["player_info"]["0"]["player_hash"] + loop = asyncio.new_event_loop() + task = loop.create_task(environment_handler.environment_steps()) + try: + with client.websocket_connect( + f"/ws/player/{res.json()['player_info']['0']['client_id']}" + ) as websocket: + assert environment_handler.check_all_players_connected( + res.json()["env_id"] + ) + websocket.send_json({"player_hash": player_hash, "type": "ready"}) + assert websocket.receive_json() == { + "request_type": "ready", + "msg": f"ready accepted", + "status": 200, + "player_hash": player_hash, + } + loop.run_until_complete(asyncio.sleep(0.001)) + websocket.send_json({"player_hash": player_hash, "type": "get_state"}) + state = websocket.receive_json() + assert state["all_players_ready"] + del state["all_players_ready"] + StateRepresentation.model_validate_json(json_data=json.dumps(state)) + + websocket.send_json( + { + "player_hash": player_hash, + "type": "action", + "action": { + "player": "0", + "action_type": ActionType.PICK_UP_DROP.value, + "action_data": None, + }, + } + ) + assert websocket.receive_json() == { + "request_type": "action", + "status": 200, + "msg": f"action accepted", + "player_hash": player_hash, + } + + assert ( + len( + environment_handler.list_not_ready_players(res.json()["env_id"]) + ) + == 0 + ) + assert ( + len( + environment_handler.list_not_connected_players( + res.json()["env_id"] + ) + ) + == 0 + ) + finally: + task.cancel() + loop.close() + + +def test_websocket_wrong_inputs(create_env_config): + with TestClient(app) as client: + environment_handler.envs = {} + res = client.post( + "/manage/create_env/", + json=create_env_config.model_dump(mode="json"), + ) + player_hash = res.json()["player_info"]["0"]["player_hash"] + wrong_player_hash = player_hash + "-------" + loop = asyncio.new_event_loop() + task = loop.create_task(environment_handler.environment_steps()) + assert ( + len(environment_handler.list_not_connected_players(res.json()["env_id"])) + == 1 + ) + try: + with client.websocket_connect( + f"/ws/player/{res.json()['player_info']['0']['client_id']}" + ) as websocket: + assert ( + len( + environment_handler.list_not_ready_players(res.json()["env_id"]) + ) + == 1 + ) + assert ( + len( + environment_handler.list_not_connected_players( + res.json()["env_id"] + ) + ) + == 0 + ) + + websocket.send_json({"player_hash": wrong_player_hash, "type": "ready"}) + assert websocket.receive_json() == { + "request_type": "ready", + "msg": f"ready not accepted", + "status": 400, + "player_hash": wrong_player_hash, + } + loop.run_until_complete(asyncio.sleep(0.001)) + websocket.send_json( + {"player_hash": wrong_player_hash, "type": "get_state"} + ) + state = websocket.receive_json() + assert state == { + "request_type": "get_state", + "status": 400, + "msg": "player hash unknown", + "player_hash": None, + } + + websocket.send_json( + { + "player_hash": wrong_player_hash, + "type": "action", + "action": { + "player": "0", + "action_type": ActionType.PICK_UP_DROP.value, + "action_data": None, + }, + } + ) + assert websocket.receive_json() == { + "request_type": "action", + "status": 400, + "msg": f"action not accepted", + "player_hash": wrong_player_hash, + } + + websocket.send_json( + { + "player_hash": wrong_player_hash, + "type": "delta_v", + "action": { + "player": "0", + "action_type": ActionType.PICK_UP_DROP.value, + "action_data": None, + }, + } + ) + assert websocket.receive_json()["status"] == 400 + + websocket.send_json( + { + "player_hash": wrong_player_hash, + "type": "action", + } + ) + assert websocket.receive_json()["status"] == 400 + + assert ( + len( + environment_handler.list_not_ready_players(res.json()["env_id"]) + ) + == 1 + ) + assert ( + len( + environment_handler.list_not_connected_players( + res.json()["env_id"] + ) + ) + == 0 + ) + + finally: + task.cancel() + loop.close() diff --git a/tests/test_study_server.py b/tests/test_study_server.py new file mode 100644 index 0000000000000000000000000000000000000000..0d979633495d970274dad50e005dfa61f4513fcc --- /dev/null +++ b/tests/test_study_server.py @@ -0,0 +1,95 @@ +import json +from unittest import mock + +from fastapi import status +from fastapi.testclient import TestClient +from requests import Response + +import cooperative_cuisine.study_server as study_server_module +from cooperative_cuisine.study_server import app + + +def test_valid_post_requests(): + test_response = Response() + test_response.status_code = status.HTTP_200_OK + test_response.encoding = "utf8" + test_response._content = json.dumps( + { + "player_info": { + "0": { + "player_id": "0", + "client_id": "ksjdhfkjsdfn", + "player_hash": "shdfbmsndfb", + } + }, + "env_id": "123456789", + "recipe_graphs": [], + } + ).encode() + with mock.patch.object( + study_server_module, "request_game_server", return_value=test_response + ) as mock_call: + with TestClient(app) as client: + res = client.post("/start_study/124/1") + + assert res.status_code == status.HTTP_200_OK + + mock_call.assert_called_once() + + with mock.patch.object( + study_server_module, "request_game_server", return_value=test_response + ) as mock_call: + with TestClient(app) as client: + res = client.post("/get_game_connection/124") + + assert res.status_code == status.HTTP_200_OK + assert res.json()["player_info"] == { + "0": { + "player_id": "0", + "client_id": "ksjdhfkjsdfn", + "player_hash": "shdfbmsndfb", + } + } + + with mock.patch.object( + study_server_module, "request_game_server", return_value=test_response + ) as mock_call: + with TestClient(app) as client: + res = client.post("/level_done/124") + + assert res.status_code == status.HTTP_200_OK + + +def test_invalid_post_requests(): + test_response = "" + with mock.patch.object( + study_server_module, "request_game_server", return_value=test_response + ) as mock_call: + with TestClient(app) as client: + res = client.post("/level_done/125") + + assert res.status_code == status.HTTP_409_CONFLICT + + with mock.patch.object( + study_server_module, "request_game_server", return_value=test_response + ) as mock_call: + with TestClient(app) as client: + res = client.post("/get_game_connection/125") + + assert res.status_code == status.HTTP_409_CONFLICT + + +def test_game_server_crashed(): + test_response = Response() + test_response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + + with mock.patch.object( + study_server_module, "request_game_server", return_value=test_response + ) as mock_call: + with TestClient(app) as client: + res = client.post("/start_study/124/1") + + assert res.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + + +# TOOD test bots diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ad9fb6b67b7af5a3a281072cfe59ea838ae9c38e --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,46 @@ +from argparse import ArgumentParser + +from cooperative_cuisine.utils import ( + url_and_port_arguments, + add_list_of_manager_ids_arguments, + disable_websocket_logging_arguments, + add_study_arguments, + add_gui_arguments, + create_layout_with_counters, + setup_logging, +) + + +def test_parser_gen(): + parser = ArgumentParser() + url_and_port_arguments(parser) + disable_websocket_logging_arguments(parser) + add_list_of_manager_ids_arguments(parser) + add_study_arguments(parser) + add_gui_arguments(parser) + + parser.parse_args( + [ + "-s", + "localhost", + "-sp", + "8000", + "-g", + "localhost", + "-gp", + "8080", + "--manager-ids", + "123", + "123123", + "--do-study", + ] + ) + + +def test_layout_creation(): + assert """###\n#_#\n###\n""" == create_layout_with_counters(3, 3) + assert """###\n#_#\n#_#\n###\n""" == create_layout_with_counters(3, 4) + + +def test_setup_logging(): + setup_logging()