diff --git a/cooperative_cuisine/__main__.py b/cooperative_cuisine/__main__.py index 685d2402c807800d377b561eaef4c1010009aead..b29e85773f8c1021e5b0833bf7e63a992fa5cccd 100644 --- a/cooperative_cuisine/__main__.py +++ b/cooperative_cuisine/__main__.py @@ -2,13 +2,19 @@ import argparse import time from multiprocessing import Process -from cooperative_cuisine.utils import ( - url_and_port_arguments, +from cooperative_cuisine.argument_parser import ( + study_server_arguments, + ssl_argument, + game_server_arguments, + create_screenshot_parser, + create_study_server_parser, + create_game_server_parser, disable_websocket_logging_arguments, add_list_of_manager_ids_arguments, - add_study_arguments, add_gui_arguments, + add_study_arguments, ) +from cooperative_cuisine.pygame_2d_vis.video_replay import create_replay_parser USE_STUDY_SERVER = True @@ -54,14 +60,43 @@ def main(cli_args=None): description="Game Engine Server + PyGameGUI: Starts overcooked game engine server and a PyGame 2D Visualization window in two processes.", epilog="For further information, see https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/overcooked_simulator.html", ) - url_and_port_arguments(parser) + game_server_arguments(parser) + study_server_arguments(parser) disable_websocket_logging_arguments(parser) add_list_of_manager_ids_arguments(parser) - add_study_arguments(parser) add_gui_arguments(parser) + add_study_arguments(parser) + ssl_argument(parser) + + subparsers = parser.add_subparsers( + help="Available CLI of Cooperative Cuisine", dest="command" + ) + screenshot_parser = subparsers.add_parser( + "screenshot", help="Create a screenshot from a json state." + ) + create_screenshot_parser(screenshot_parser) + + study_server_parser = subparsers.add_parser( + "study-server", help="Start a study server." + ) + create_study_server_parser(study_server_parser) + + game_server_parser = subparsers.add_parser( + "game-server", help="Start a game server." + ) + create_game_server_parser(game_server_parser) + + replay_parser = subparsers.add_parser( + "replay", help="Create replay from json states or recordings." + ) + create_replay_parser(replay_parser) cli_args = parser.parse_args() + if cli_args.command: + print(cli_args) + return + game_server = None pygame_gui = None try: diff --git a/cooperative_cuisine/argument_parser.py b/cooperative_cuisine/argument_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..f2999fce38b4e32b8b7bc69575ffd06184be0a71 --- /dev/null +++ b/cooperative_cuisine/argument_parser.py @@ -0,0 +1,196 @@ +import uuid +from argparse import ArgumentParser, FileType + +from cooperative_cuisine import ROOT_DIR + +DEFAULT_SERVER_URL = "localhost" +"""Default server URL of game and server study.""" + +DEFAULT_SERVER_PORT = 8080 +"""Default study server port.""" + +DEFAULT_GAME_PORT = 8000 +"""Default game server port.""" + + +def study_server_arguments( + parser: ArgumentParser, + default_study_port=DEFAULT_SERVER_PORT, + default_server_url=DEFAULT_SERVER_URL, +): + parser.add_argument( + "-s", + "--study-url", + "--study-host", + type=str, + default=default_server_url, + help=f"Overcooked Study Server host url.", + ) + parser.add_argument( + "-p", + "--study-port", + type=int, + default=default_study_port, + help=f"Port number for the Study Server", + ) + + +def game_server_arguments( + parser: ArgumentParser, + default_game_port=DEFAULT_GAME_PORT, + default_server_url=DEFAULT_SERVER_URL, +): + parser.add_argument( + "-g", + "--game-url", + "--game-host", + type=str, + default=default_server_url, + help=f"Overcooked Game Server url.", + ) + parser.add_argument( + "-gp", + "--game-port", + type=int, + default=default_game_port, + help=f"Port number for the Game Server", + ) + + +def disable_websocket_logging_arguments(parser): + """Disables the logging of WebSocket arguments in the provided parser. + + Args: + parser: The argument parser object (argparse.ArgumentParser) to which the + "--enable-websocket-logging" argument will be added. + + """ + parser.add_argument( + "--enable-websocket-logging" "", action="store_true", default=True + ) + + +def add_list_of_manager_ids_arguments(parser): + """This function adds the manager ids argument to the given argument parser. + + Args: + parser: An ArgumentParser object used to parse command line arguments. + + Returns: + None + """ + parser.add_argument( + "-m", + "--manager-ids", + nargs="+", + type=str, + default=[uuid.uuid4().hex], + help="List of manager IDs that can create environments.", + ) + + +def add_study_arguments(parser): + """This function adds the study configuration argument to the given argument parser. + + Args: + parser (argparse.ArgumentParser): The argument parser object. + + + Example: + ```python + import argparse + parser = argparse.ArgumentParser() + add_study_arguments(parser) + ``` + """ + parser.add_argument( + "--study-config", + type=str, + default=ROOT_DIR / "configs" / "study" / "study_config.yaml", + help="The config of the study.", + ) + + +def add_gui_arguments(parser): + """Adds the gui debug argument to the given argument parser. + If set, additional debug / admin elements are shown. + + Args: + parser (argparse.ArgumentParser): The argument parser object. + + + Example: + ```python + import argparse + parser = argparse.ArgumentParser() + add_gui_arguments(parser) + ``` + """ + parser.add_argument( + "--do-study", + default=True, + help="Disable additional debug / admin elements.", + action="store_false", + ) + + +def ssl_argument(parser: ArgumentParser): + parser.add_argument( + "--ssl", + action="store_true", + help="Use SSL certificate. Connect to https and wss.", + ) + + +def visualization_config_argument(parser: ArgumentParser): + parser.add_argument( + "-v", + "--visualization_config", + type=FileType("r", encoding="UTF-8"), + default=ROOT_DIR / "pygame_2d_vis" / "visualization.yaml", + ) + + +def output_file_argument(parser: ArgumentParser, default_file: str): + parser.add_argument( + "-o", + "--output_file", + type=str, + default=default_file, + ) + + +def create_game_server_parser(parser: ArgumentParser): + game_server_arguments(parser) + disable_websocket_logging_arguments(parser) + add_list_of_manager_ids_arguments(parser) + ssl_argument(parser) + + +def create_study_server_parser(parser: ArgumentParser): + study_server_arguments(parser) + # TODO study server can handle several game server + game_server_arguments(parser) + add_list_of_manager_ids_arguments(parser=parser) + add_study_arguments(parser=parser) + ssl_argument(parser) + + +def create_gui_parser(parser: ArgumentParser): + study_server_arguments(parser) + disable_websocket_logging_arguments(parser) + add_list_of_manager_ids_arguments(parser) + add_gui_arguments(parser) + ssl_argument(parser) + + +def create_screenshot_parser(parser: ArgumentParser): + parser.add_argument( + "-s", + "--state", + type=FileType("r", encoding="UTF-8"), + default=ROOT_DIR / "pygame_2d_vis" / "sample_state.json", + ) + visualization_config_argument(parser) + output_file_argument(parser, ROOT_DIR / "generated" / "screenshot.jpg") + return parser diff --git a/cooperative_cuisine/game_server.py b/cooperative_cuisine/game_server.py index 01a3b02cbecd8cf515d929bca83152bc81e71cc1..081588dd81180c3b15f40789d422b9170d99ab41 100644 --- a/cooperative_cuisine/game_server.py +++ b/cooperative_cuisine/game_server.py @@ -30,6 +30,7 @@ from starlette.websockets import WebSocketDisconnect from typing_extensions import TypedDict from cooperative_cuisine.action import Action +from cooperative_cuisine.argument_parser import create_game_server_parser from cooperative_cuisine.environment import Environment from cooperative_cuisine.server_results import ( CreateEnvResult, @@ -37,9 +38,6 @@ from cooperative_cuisine.server_results import ( PlayerRequestResult, ) from cooperative_cuisine.utils import ( - url_and_port_arguments, - add_list_of_manager_ids_arguments, - disable_websocket_logging_arguments, setup_logging, UUID_CUTOFF, ) @@ -848,10 +846,7 @@ if __name__ == "__main__": epilog="For further information, see " "https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/cooperative_cuisine.html", ) - - url_and_port_arguments(parser) - disable_websocket_logging_arguments(parser) - add_list_of_manager_ids_arguments(parser) + create_game_server_parser(parser) args = parser.parse_args() main(args.game_url, args.game_port, args.manager_ids, args.enable_websocket_logging) """ diff --git a/cooperative_cuisine/pygame_2d_vis/drawing.py b/cooperative_cuisine/pygame_2d_vis/drawing.py index 2cc1f6f48309fc2e8ec0db5206d667d1d4e737ac..cea0ee330bac22d49e8b2495a4f4051732daa43e 100644 --- a/cooperative_cuisine/pygame_2d_vis/drawing.py +++ b/cooperative_cuisine/pygame_2d_vis/drawing.py @@ -13,6 +13,7 @@ import yaml from scipy.spatial import KDTree from cooperative_cuisine import ROOT_DIR +from cooperative_cuisine.argument_parser import create_screenshot_parser from cooperative_cuisine.environment import Environment from cooperative_cuisine.pygame_2d_vis.game_colors import colors, RGB from cooperative_cuisine.state_representation import ( @@ -1039,24 +1040,7 @@ def main(args): description="Generate images for a state in json.", epilog="For further information, see https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/overcooked_simulator.html", ) - parser.add_argument( - "-s", - "--state", - type=argparse.FileType("r", encoding="UTF-8"), - default=ROOT_DIR / "pygame_2d_vis" / "sample_state.json", - ) - parser.add_argument( - "-v", - "--visualization_config", - type=argparse.FileType("r", encoding="UTF-8"), - default=ROOT_DIR / "pygame_2d_vis" / "visualization.yaml", - ) - parser.add_argument( - "-o", - "--output_file", - type=str, - default=ROOT_DIR / "generated" / "screenshot.jpg", - ) + create_screenshot_parser(parser) args = parser.parse_args(args) with open(args.visualization_config, "r") as f: viz_config = yaml.safe_load(f) diff --git a/cooperative_cuisine/pygame_2d_vis/gui.py b/cooperative_cuisine/pygame_2d_vis/gui.py index df472228b29051133df2872bc8b1b28db9d3d767..f296809ace66ac58ea8897293b4c82df8413a519 100644 --- a/cooperative_cuisine/pygame_2d_vis/gui.py +++ b/cooperative_cuisine/pygame_2d_vis/gui.py @@ -21,6 +21,7 @@ from websockets.sync.client import connect from cooperative_cuisine import ROOT_DIR from cooperative_cuisine.action import ActionType, InterActionData, Action +from cooperative_cuisine.argument_parser import create_gui_parser from cooperative_cuisine.game_server import ( CreateEnvironmentConfig, WebsocketMessage, @@ -31,11 +32,7 @@ from cooperative_cuisine.pygame_2d_vis.game_colors import colors from cooperative_cuisine.server_results import PlayerInfo from cooperative_cuisine.state_representation import StateRepresentation from cooperative_cuisine.utils import ( - url_and_port_arguments, - disable_websocket_logging_arguments, - add_list_of_manager_ids_arguments, setup_logging, - add_gui_arguments, ) @@ -1472,7 +1469,7 @@ class PyGameGUI: environment_config=environment_config, layout_config=layout, seed=seed, - env_name=layout_path.stem + env_name=layout_path.stem, ).model_dump(mode="json") # print(CreateEnvironmentConfig.model_validate_json(json_data=creation_json)) @@ -2210,11 +2207,7 @@ if __name__ == "__main__": epilog="For further information, " "see https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/overcooked_simulator.html", ) - - url_and_port_arguments(parser) - disable_websocket_logging_arguments(parser) - add_list_of_manager_ids_arguments(parser) - add_gui_arguments(parser) + create_gui_parser(parser) args = parser.parse_args() main( args.study_url, diff --git a/cooperative_cuisine/pygame_2d_vis/video_replay.py b/cooperative_cuisine/pygame_2d_vis/video_replay.py index 46521e12524e1c0c2b2074cc5b1f68486d354364..4892b29f8eb3f9baff29151f6a4ff542b1ecd8e3 100644 --- a/cooperative_cuisine/pygame_2d_vis/video_replay.py +++ b/cooperative_cuisine/pygame_2d_vis/video_replay.py @@ -29,7 +29,6 @@ python video_replay.py -h # Code Documentation """ -import argparse import json import os import os.path @@ -42,8 +41,8 @@ import yaml from PIL import Image from tqdm import tqdm -from cooperative_cuisine import ROOT_DIR from cooperative_cuisine.action import Action +from cooperative_cuisine.argument_parser import visualization_config_argument from cooperative_cuisine.environment import Environment from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer from cooperative_cuisine.recording import FileRecorder @@ -297,19 +296,9 @@ def video_from_images(image_paths, video_name, fps): print("See:", video_name) -if __name__ == "__main__": - parser = ArgumentParser( - prog="Cooperative Cuisine Video Generation", - description="Generate videos from recorded data.", - epilog="For further information, see https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/overcooked_simulator.html", - ) +def create_replay_parser(parser: ArgumentParser): parser.add_argument("-j", "--json_state", help="Json states file path", type=str) - parser.add_argument( - "-v", - "--visualization_config", - type=argparse.FileType("r", encoding="UTF-8"), - default=ROOT_DIR / "pygame_2d_vis" / "visualization.yaml", - ) + visualization_config_argument(parser) parser.add_argument( "-o", "--output", @@ -380,6 +369,15 @@ if __name__ == "__main__": type=str, help="Create a video from a folder full of images.", ) + + +if __name__ == "__main__": + parser = ArgumentParser( + prog="Cooperative Cuisine Video Generation", + description="Generate videos from recorded data.", + epilog="For further information, see https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/overcooked_simulator.html", + ) + create_replay_parser(parser) args = parser.parse_args() with open(args.visualization_config, "r") as f: viz_config = yaml.safe_load(f) diff --git a/cooperative_cuisine/study_server.py b/cooperative_cuisine/study_server.py index 7988acf517e3b00fcd21055fa1edd1053e617238..9e1e87f526fc7ab89667d714c919e40c4163110b 100644 --- a/cooperative_cuisine/study_server.py +++ b/cooperative_cuisine/study_server.py @@ -10,7 +10,6 @@ python game_server.py --manager-ids COPIED_UUID The environment starts when all players connected. """ - import argparse import asyncio import json @@ -31,14 +30,12 @@ from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel from cooperative_cuisine import ROOT_DIR +from cooperative_cuisine.argument_parser import create_study_server_parser from cooperative_cuisine.environment import EnvironmentConfig from cooperative_cuisine.game_server import CreateEnvironmentConfig, EnvironmentData from cooperative_cuisine.server_results import PlayerInfo, CreateEnvResult from cooperative_cuisine.utils import ( - url_and_port_arguments, - add_list_of_manager_ids_arguments, expand_path, - add_study_arguments, deep_update, UUID_CUTOFF, ) @@ -705,15 +702,8 @@ if __name__ == "__main__": epilog="For further information, " "see https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/overcooked_simulator.html", ) - url_and_port_arguments( - parser=parser, - server_name="Study Server", - ) - add_list_of_manager_ids_arguments(parser=parser) - add_study_arguments(parser=parser) + create_study_server_parser(parser) args = parser.parse_args() - - game_server_url = f"https://{args.game_url}:{args.game_port}" main( args.study_url, args.study_port, diff --git a/cooperative_cuisine/utils.py b/cooperative_cuisine/utils.py index 88719d54d8fc2fc00cf67d464af0a8c764f9457e..601a1f3dc946b7901665e0d74f16cfc63a6df1df 100644 --- a/cooperative_cuisine/utils.py +++ b/cooperative_cuisine/utils.py @@ -9,7 +9,6 @@ import json import logging import os import sys -import uuid from collections import deque from datetime import datetime, timedelta from enum import Enum @@ -26,15 +25,6 @@ if TYPE_CHECKING: from cooperative_cuisine.counters import Counter from cooperative_cuisine.player import Player -DEFAULT_SERVER_URL = "localhost" -"""Default server URL of game and server study.""" - -DEFAULT_SERVER_PORT = 8080 -"""Default study server port.""" - -DEFAULT_GAME_PORT = 8000 -"""Default game server port.""" - UUID_CUTOFF = 8 """The cutoff length for UUIDs.""" @@ -297,131 +287,6 @@ def setup_logging(enable_websocket_logging=False): logging.getLogger("websockets.client").setLevel(logging.ERROR) -def url_and_port_arguments( - parser, - server_name="game server", - default_study_port=DEFAULT_SERVER_PORT, - default_game_port=DEFAULT_GAME_PORT, - default_server_url=DEFAULT_SERVER_URL, -): - """Adds arguments to the given parser for the URL and port configuration of a server. - - Args: - parser: The argument parser to add arguments to. - server_name: (Optional) The name of the server. Defaults to "game server". - default_study_port: (Optional) The default port number for the study URL. Defaults to 8080. - default_game_port: (Optional) The default port number for the game URL. Defaults to 8000. - default_server_url: (Optional) The default url for the server. Defaults to "localhost". - """ - parser.add_argument( - "-s", - "--study-url", - "--study-host", - type=str, - default=default_server_url, - help=f"Overcooked {server_name} study host url.", - ) - parser.add_argument( - "-sp", - "--study-port", - type=int, - default=default_study_port, - help=f"Port number for the {server_name}", - ) - parser.add_argument( - "-g", - "--game-url", - "--game-host", - type=str, - default=DEFAULT_SERVER_URL, - help=f"Overcooked {server_name} game server url.", - ) - parser.add_argument( - "-gp", - "--game-port", - type=int, - default=default_game_port, - help=f"Port number for the {server_name}", - ) - - -def disable_websocket_logging_arguments(parser): - """Disables the logging of WebSocket arguments in the provided parser. - - Args: - parser: The argument parser object (argparse.ArgumentParser) to which the - "--enable-websocket-logging" argument will be added. - - """ - parser.add_argument( - "--enable-websocket-logging" "", action="store_true", default=True - ) - - -def add_list_of_manager_ids_arguments(parser): - """This function adds the manager ids argument to the given argument parser. - - Args: - parser: An ArgumentParser object used to parse command line arguments. - - Returns: - None - """ - parser.add_argument( - "-m", - "--manager-ids", - nargs="+", - type=str, - default=[uuid.uuid4().hex], - help="List of manager IDs that can create environments.", - ) - - -def add_study_arguments(parser): - """This function adds the study configuration argument to the given argument parser. - - Args: - parser (argparse.ArgumentParser): The argument parser object. - - - Example: - ```python - import argparse - parser = argparse.ArgumentParser() - add_study_arguments(parser) - ``` - """ - parser.add_argument( - "--study-config", - type=str, - default=ROOT_DIR / "configs" / "study" / "study_config.yaml", - help="The config of the study.", - ) - - -def add_gui_arguments(parser): - """Adds the gui debug argument to the given argument parser. - If set, additional debug / admin elements are shown. - - Args: - parser (argparse.ArgumentParser): The argument parser object. - - - Example: - ```python - import argparse - parser = argparse.ArgumentParser() - add_gui_arguments(parser) - ``` - """ - parser.add_argument( - "--do-study", - default=True, - help="Disable additional debug / admin elements.", - action="store_false", - ) - - class NumpyAndDataclassEncoder(json.JSONEncoder): """Special json encoder for numpy types""" diff --git a/tests/test_utils.py b/tests/test_utils.py index 2802bb3a73f7deea7d1ee169adc0eb132b399096..a629bc5732bb525c88b21c95fe024752f29e9f40 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,6 +4,18 @@ from argparse import ArgumentParser import networkx import pytest +from cooperative_cuisine.argument_parser import ( + game_server_arguments, + study_server_arguments, + disable_websocket_logging_arguments, + add_list_of_manager_ids_arguments, + add_gui_arguments, + add_study_arguments, + ssl_argument, + create_study_server_parser, + create_game_server_parser, + create_gui_parser, +) from cooperative_cuisine.environment import Environment from cooperative_cuisine.state_representation import ( create_movement_graph, @@ -11,11 +23,6 @@ from cooperative_cuisine.state_representation import ( astar_heuristic, ) from cooperative_cuisine.utils import ( - url_and_port_arguments, - add_list_of_manager_ids_arguments, - disable_websocket_logging_arguments, - add_study_arguments, - add_gui_arguments, create_layout_with_counters, setup_logging, ) @@ -23,19 +30,20 @@ from tests.test_start import env_config_no_validation from tests.test_start import layout_empty_config, item_info -def test_parser_gen(): +def test_arguments(): parser = ArgumentParser() - url_and_port_arguments(parser) + game_server_arguments(parser) + study_server_arguments(parser) disable_websocket_logging_arguments(parser) add_list_of_manager_ids_arguments(parser) - add_study_arguments(parser) add_gui_arguments(parser) - + add_study_arguments(parser) + ssl_argument(parser) parser.parse_args( [ "-s", "localhost", - "-sp", + "-p", "8000", "-g", "localhost", @@ -49,6 +57,15 @@ def test_parser_gen(): ) +def test_parser_creation(): + parser = ArgumentParser() + create_game_server_parser(parser) + parser = ArgumentParser() + create_study_server_parser(parser) + parser = ArgumentParser() + create_gui_parser(parser) + + def test_layout_creation(): assert """###\n#_#\n###\n""" == create_layout_with_counters(3, 3) assert """###\n#_#\n#_#\n###\n""" == create_layout_with_counters(3, 4)