diff --git a/cooperative_cuisine/__init__.py b/cooperative_cuisine/__init__.py index 765bf5cb4b978abc408ecd3cdec7acecd9dbc923..e3bd8db959d95c7964ed58ec041e2577f2561113 100644 --- a/cooperative_cuisine/__init__.py +++ b/cooperative_cuisine/__init__.py @@ -46,7 +46,7 @@ cooperative_cuisine -s localhost -sp 8080 -g localhost -gp 8000 *The arguments shown are the defaults.* -You can also start the **Game Server**m **Study Server** (Matchmaking),and the **PyGame GUI** individually in different terminals. +You can also start the **Game Server**, **Study Server** (Matchmaking),and the **PyGame GUI** individually in different terminals. ```bash python3 cooperative_cuisine/game_server.py -g localhost -gp 8000 --manager_ids SECRETKEY1 SECRETKEY2 diff --git a/cooperative_cuisine/recording.py b/cooperative_cuisine/recording.py index 10597e8040674dcb85008c4e6a82a26b56613f31..7f540b055f933a5b6410afc12ec785a337b00d4f 100644 --- a/cooperative_cuisine/recording.py +++ b/cooperative_cuisine/recording.py @@ -55,18 +55,7 @@ log = logging.getLogger(__name__) class FileRecorder(HookCallbackClass): - """ - Class: FileRecorder - - This class is responsible for recording data to a file. - - Attributes: - name (str): The name of the recorder. - env (Environment): The environment instance. - log_path (str): The path to the log file. Default value is "USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl". - add_hook_ref (bool): Indicates whether to add a hook reference to the recorded data. Default value is False. - - """ + """This class is responsible for recording data to a file.""" def __init__( self, @@ -78,11 +67,11 @@ class FileRecorder(HookCallbackClass): ): super().__init__(name, env, **kwargs) self.add_hook_ref = add_hook_ref - """If the name of the hook (the reference) should be included in the recording.""" + """Indicates whether to add a hook reference to the recorded data. Default value is False.""" log_path = log_path.replace("LOG_RECORD_NAME", name) log_path = Path(expand_path(log_path, env_name=env.env_name)) self.log_path = log_path - """The path to the recording file.""" + """The path to the log file. Default value is "USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl".""" log.info(f"Recorder record for {name} in file://{log_path}") os.makedirs(log_path.parent, exist_ok=True) diff --git a/cooperative_cuisine/study_server.py b/cooperative_cuisine/study_server.py index 50d233f108afb889da416f91d432459aa9240b67..7db833228357f9bff8cecf4e089ca3ed1a15770a 100644 --- a/cooperative_cuisine/study_server.py +++ b/cooperative_cuisine/study_server.py @@ -439,8 +439,6 @@ if __name__ == "__main__": url_and_port_arguments( parser=parser, server_name="Study Server", - default_study_port=8080, - default_game_port=8000, ) add_list_of_manager_ids_arguments(parser=parser) add_study_arguments(parser=parser) diff --git a/cooperative_cuisine/utils.py b/cooperative_cuisine/utils.py index d6e0bfdb341d1686cc99a2fb94d8b04e6952745a..1b28523bbc3a14f059ffcd15fa1dc3a0bce4269e 100644 --- a/cooperative_cuisine/utils.py +++ b/cooperative_cuisine/utils.py @@ -26,9 +26,35 @@ if TYPE_CHECKING: from cooperative_cuisine.player import Player DEFAULT_SERVER_URL = "localhost" +"""Default server URL of game and server study.""" + +DEFAULT_SERVER_PORT = 8080 +"""Default study server port.""" + +DEFAULT_GAME_PORT = 8000 +"""Default game server port.""" def expand_path(path: str, env_name: str = "") -> str: + """Expand a path with VARIABLES to the path variables based on the user's OS or installation location of the Cooperative Cuisine. + Args: + path: A string representing the path to be expanded. This can contain placeholders like "ROOT_DIR", "ENV_NAME", "USER_LOG_DIR", "LAYOUTS_DIR", "STUDY_DIR", and "CONFIGS_DIR" which will be replaced with their corresponding values. + env_name (optional): A string representing the environment name to be used for expanding the path. This will be used to replace the "ENV_NAME" placeholder. + + Returns: + A string representing the expanded path, where all placeholders have been replaced with their corresponding values. + + Example: + expand_path("~/ROOT_DIR/ENV_NAME", "development") + -> "/home/user/path/to/ROOT_DIR/development" + + Note: + - The "ROOT_DIR" placeholder will be replaced with the value of the `ROOT_DIR` constant. + - The "USER_LOG_DIR" placeholder will be replaced with the user-specific directory for log files. + - The "LAYOUTS_DIR" placeholder will be replaced with the directory path to layouts config files. + - The "STUDY_DIR" placeholder will be replaced with the directory path to study config files. + - The "CONFIGS_DIR" placeholder will be replaced with the directory path to general config files. + """ return os.path.expanduser( path.replace("ROOT_DIR", str(ROOT_DIR)) .replace("ENV_NAME", env_name) @@ -41,6 +67,18 @@ def expand_path(path: str, env_name: str = "") -> str: @dataclasses.dataclass class VectorStateGenerationData: + """ + A class representing data used for vector state generation. + + Attributes: + grid_base_array (numpy.ndarray): A 2D array representing the state grid. + oh_len (int): The length of the one-hot encoding vector. + number_normal_ingredients (int): The number of normal ingredients. + meals (List[str]): A list of meal names. + equipments (List[str]): A list of equipment names. + ingredients (List[str]): A list of ingredient names. + """ + grid_base_array: npt.NDArray[npt.NDArray[float]] oh_len: int @@ -81,6 +119,24 @@ class VectorStateGenerationData: @dataclasses.dataclass class VectorStateGenerationDataSimple: + """Relevant for reinforcment learning. + + VectorStateGenerationDataSimple class represents the data required for generating vector states. It includes the + grid base array, the length of the one-hot encoded representations, and * other information related to meals, + equipments, and ingredients. + + Attributes: + - grid_base_array (numpy.ndarray): A 2D NumPy array representing the grid base. + - oh_len (int): The length of the one-hot encoded representations. + + Constants: + - number_normal_ingredients (int): The number of normal ingredients. + - meals (list): A list of meal names. + - equipments (list): A list of equipment names. + - ingredients (list): A list of ingredient names. + + """ + grid_base_array: npt.NDArray[npt.NDArray[float]] oh_len: int @@ -125,6 +181,16 @@ def get_closest(point: npt.NDArray[float], counters: list[Counter]): def get_collided_players( player_idx, players: list[Player], player_radius: float ) -> list[Player]: + """Filter players if they collide. + + Args: + player_idx: The index of the player for which to find collided players. + players: A list of Player objects representing all the players. + player_radius: The radius of the player. + + Returns: + A list of Player objects representing the players that have collided with the player at the given index. + """ player_positions = np.array([p.pos for p in players], dtype=float) distances = distance_matrix(player_positions, player_positions)[player_idx] player_radiuses = np.array([player_radius for p in players], dtype=float) @@ -135,6 +201,16 @@ def get_collided_players( def get_touching_counters(target: Counter, counters: list[Counter]) -> list[Counter]: + """Filter the list of counters if they touch the target counter. + + Args: + target: A Counter object representing the target counter. + counters: A list of Counter objects representing the counters to be checked. + + Returns: + A list of Counter objects that are touching the target counter. + + """ return list( filter( lambda counter: np.linalg.norm(counter.pos - target.pos) == 1.0, counters @@ -143,6 +219,21 @@ def get_touching_counters(target: Counter, counters: list[Counter]) -> list[Coun def find_item_on_counters(item_uuid: str, counters: list[Counter]) -> Counter | None: + """This method searches for a specific item with the given UUID on a list of counters. + + It iterates through each counter and checks if it is occupied. If the counter is occupied by a deque, it further + iterates through each item in the deque to find a match with the given UUID. If a match is found, the respective + counter is returned. If the counter is occupied by a single, item (not a deque), it directly compares the UUID of + the occupied item with the given UUID. If they match, the respective counter is returned. If no match is found + for the given UUID on any counter, None is returned. + + Args: + item_uuid (str): The UUID of the item to be searched for on counters. + counters (list[Counter]): The list of counters to search for the item. + + Returns: + Counter | None: The counter where the item was found, or None if the item was not found. + """ for counter in counters: if counter.occupied_by: if isinstance(counter.occupied_by, deque): @@ -155,7 +246,15 @@ def find_item_on_counters(item_uuid: str, counters: list[Counter]) -> Counter | def custom_asdict_factory(data): - """Convert enums to their value.""" + """Converts enums to their value. + + Args: + data: The data to be converted to a dictionary. + + Returns: + dict: A dictionary where the values in the data are converted based on the `convert_value` function. + + """ def convert_value(obj): if isinstance(obj, Enum): @@ -166,6 +265,11 @@ def custom_asdict_factory(data): def setup_logging(enable_websocket_logging=False): + """Setup logging configuration. + + Args: + enable_websocket_logging (bool, optional): Flag to enable websocket logging. Default is False. + """ path_logs = ROOT_DIR.parent / "logs" os.makedirs(path_logs, exist_ok=True) logging.basicConfig( @@ -189,14 +293,27 @@ def setup_logging(enable_websocket_logging=False): def url_and_port_arguments( - parser, server_name="game server", default_study_port=8080, default_game_port=8000 + parser, + server_name="game server", + default_study_port=DEFAULT_SERVER_PORT, + default_game_port=DEFAULT_GAME_PORT, + default_server_url=DEFAULT_SERVER_URL, ): + """Adds arguments to the given parser for the URL and port configuration of a server. + + Args: + parser: The argument parser to add arguments to. + server_name: (Optional) The name of the server. Defaults to "game server". + default_study_port: (Optional) The default port number for the study URL. Defaults to 8080. + default_game_port: (Optional) The default port number for the game URL. Defaults to 8000. + default_server_url: (Optional) The default url for the server. Defaults to "localhost". + """ parser.add_argument( "-s", "--study-url", "--study-host", type=str, - default=DEFAULT_SERVER_URL, + default=default_server_url, help=f"Overcooked {server_name} study host url.", ) parser.add_argument( @@ -224,12 +341,27 @@ def url_and_port_arguments( def disable_websocket_logging_arguments(parser): + """Disables the logging of WebSocket arguments in the provided parser. + + Args: + parser: The argument parser object (argparse.ArgumentParser) to which the + "--enable-websocket-logging" argument will be added. + + """ parser.add_argument( "--enable-websocket-logging" "", action="store_true", default=True ) def add_list_of_manager_ids_arguments(parser): + """This function adds the manager ids argument to the given argument parser. + + Args: + parser: An ArgumentParser object used to parse command line arguments. + + Returns: + None + """ parser.add_argument( "-m", "--manager_ids", @@ -241,6 +373,19 @@ def add_list_of_manager_ids_arguments(parser): def add_study_arguments(parser): + """This function adds the study configuration argument to the given argument parser. + + Args: + parser (argparse.ArgumentParser): The argument parser object. + + + Example: + ```python + import argparse + parser = argparse.ArgumentParser() + add_study_arguments(parser) + ``` + """ parser.add_argument( "--study-config", type=str, @@ -271,7 +416,13 @@ class NumpyAndDataclassEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, obj) -def create_layout(w, h): +def create_layout_with_counters(w, h): + """Print a layout string that has counters at the world borders. + + Args: + w: The width of the layout. + h: The height of the layout. + """ for y in range(h): for x in range(w): if x == 0 or y == 0 or x == w - 1 or y == h - 1: