Skip to content
Snippets Groups Projects
  • Florian Schröder's avatar
    5c100dd0
    Add type hints and docstrings across multiple modules · 5c100dd0
    Florian Schröder authored
    Added type hints and docstrings throughout the project, notably in the effects, orders, study server, and game server modules. These additions provide better understanding and ease in navigating through the codebase. Minor code reorganization and syntax optimization were also performed.
    5c100dd0
    History
    Add type hints and docstrings across multiple modules
    Florian Schröder authored
    Added type hints and docstrings throughout the project, notably in the effects, orders, study server, and game server modules. These additions provide better understanding and ease in navigating through the codebase. Minor code reorganization and syntax optimization were also performed.
utils.py 11.16 KiB
"""
Some utility functions.
"""
from __future__ import annotations

import collections.abc
import dataclasses
import json
import logging
import os
import sys
from collections import deque
from datetime import datetime, timedelta
from enum import Enum
from typing import TYPE_CHECKING

import numpy as np
import numpy.typing as npt
import platformdirs
from scipy.spatial import distance_matrix

from cooperative_cuisine import ROOT_DIR

if TYPE_CHECKING:
    from cooperative_cuisine.counters import Counter
from cooperative_cuisine.player import Player

UUID_CUTOFF = 8
"""The cutoff length for UUIDs."""


def expand_path(path: str, env_name: str = "") -> str:
    """Expand a path with VARIABLES to the path variables based on the user's OS or installation location of the Cooperative Cuisine.
    Args:
        path: A string representing the path to be expanded. This can contain placeholders like "ROOT_DIR", "ENV_NAME", "USER_LOG_DIR", "LAYOUTS_DIR", "STUDY_DIR", and "CONFIGS_DIR" which will be replaced with their corresponding values.
        env_name (optional): A string representing the environment name to be used for expanding the path. This will be used to replace the "ENV_NAME" placeholder.

    Returns:
        A string representing the expanded path, where all placeholders have been replaced with their corresponding values.

    Example:
        expand_path("~/ROOT_DIR/ENV_NAME", "development")
        -> "/home/user/path/to/ROOT_DIR/development"

    Note:
        - The "ROOT_DIR" placeholder will be replaced with the value of the `ROOT_DIR` constant.
        - The "USER_LOG_DIR" placeholder will be replaced with the user-specific directory for log files.
        - The "LAYOUTS_DIR" placeholder will be replaced with the directory path to layouts config files.
        - The "STUDY_DIR" placeholder will be replaced with the directory path to study config files.
        - The "CONFIGS_DIR" placeholder will be replaced with the directory path to general config files.
    """
    return os.path.expanduser(
        path.replace("ROOT_DIR", str(ROOT_DIR))
        .replace("ENV_NAME", env_name)
        .replace("USER_LOG_DIR", platformdirs.user_log_dir("cooperative_cuisine"))
        .replace("LAYOUTS_DIR", str(ROOT_DIR / "configs" / "layouts"))
        .replace("STUDY_DIR", str(ROOT_DIR / "configs" / "study"))
        .replace("CONFIGS_DIR", str(ROOT_DIR / "configs"))
    )


@dataclasses.dataclass
class VectorStateGenerationData:
    """
    A class representing data used for vector state generation.

    Attributes:
        grid_base_array (numpy.ndarray): A 2D array representing the state grid.
        oh_len (int): The length of the one-hot encoding vector.
        number_normal_ingredients (int): The number of normal ingredients.
        meals (List[str]): A list of meal names.
        equipments (List[str]): A list of equipment names.
        ingredients (List[str]): A list of ingredient names.
    """

    grid_base_array: npt.NDArray[npt.NDArray[float]]
    oh_len: int

    number_normal_ingredients = 10

    meals = [
        "Chips",
        "FriedFish",
        "Burger",
        "Salad",
        "TomatoSoup",
        "OnionSoup",
        "FishAndChips",
        "Pizza",
    ]
    equipments = [
        "Pot",
        "Pan",
        "Basket",
        "Peel",
        "Plate",
        "DirtyPlate",
        "Extinguisher",
    ]
    ingredients = [
        "Tomato",
        "Lettuce",
        "Onion",
        "Meat",
        "Bun",
        "Potato",
        "Fish",
        "Dough",
        "Cheese",
        "Sausage",
    ]


@dataclasses.dataclass
class VectorStateGenerationDataSimple:
    """Relevant for reinforcment learning.

    VectorStateGenerationDataSimple class represents the data required for generating vector states. It includes the
    grid base array, the length of the one-hot encoded representations, and * other information related to meals,
    equipments, and ingredients.

    Attributes:
    - grid_base_array (numpy.ndarray): A 2D NumPy array representing the grid base.
    - oh_len (int): The length of the one-hot encoded representations.

    Constants:
    - number_normal_ingredients (int): The number of normal ingredients.
    - meals (list): A list of meal names.
    - equipments (list): A list of equipment names.
    - ingredients (list): A list of ingredient names.

    """

    grid_base_array: npt.NDArray[npt.NDArray[float]]
    oh_len: int

    number_normal_ingredients = 1

    meals = [
        "TomatoSoup",
    ]
    equipments = [
        "Pot",
        "Plate",
        "DirtyPlate",
        "Extinguisher",
    ]
    ingredients = [
        "Tomato",
    ]


def create_init_env_time():
    """Init time of the environment time, because all environments should have the same internal time."""
    return datetime(
        year=2000, month=1, day=1, hour=0, minute=0, second=0, microsecond=0
    )


def get_closest(point: npt.NDArray[float], counters: list[Counter]):
    """Determines the closest counter for a given 2d-coordinate point in the env.

    Args:
        point: The point in the env for which to find the closest counter
        counters: List of objects with a `pos` attribute to compare to.

    Returns:
        The closest counter for the given point.
    """

    return counters[
        np.argmin(distance_matrix([point], [counter.pos for counter in counters])[0])
    ]


def get_collided_players(
    player_idx, players: list[Player], player_radius: float
) -> list[Player]:
    """Filter players if they collide.

    Args:
        player_idx: The index of the player for which to find collided players.
        players: A list of Player objects representing all the players.
        player_radius: The radius of the player.

    Returns:
        A list of Player objects representing the players that have collided with the player at the given index.
    """
    player_positions = np.array([p.pos for p in players], dtype=float)
    distances = distance_matrix(player_positions, player_positions)[player_idx]
    player_radiuses = np.array([player_radius for p in players], dtype=float)
    collisions = distances <= player_radiuses + player_radius
    collisions[player_idx] = False

    return [players[idx] for idx, val in enumerate(collisions) if val]


def get_touching_counters(target: Counter, counters: list[Counter]) -> list[Counter]:
    """Filter the list of counters if they touch the target counter.

    Args:
        target: A Counter object representing the target counter.
        counters: A list of Counter objects representing the counters to be checked.

    Returns:
        A list of Counter objects that are touching the target counter.

    """
    return list(
        filter(
            lambda counter: np.linalg.norm(counter.pos - target.pos) == 1.0, counters
        )
    )


def find_item_on_counters(item_uuid: str, counters: list[Counter]) -> Counter | None:
    """This method searches for a specific item with the given UUID on a list of counters.

    It iterates through each counter and checks if it is occupied. If the counter is occupied by a deque, it further
    iterates through each item in the deque to find a match with the given UUID. If a match is found, the respective
    counter is returned. If the counter is occupied by a single, item (not a deque), it directly compares the UUID of
    the occupied item with the given UUID. If they match, the respective counter is returned. If no match is found
    for the given UUID on any counter, None is returned.

    Args:
        item_uuid (str): The UUID of the item to be searched for on counters.
        counters (list[Counter]): The list of counters to search for the item.

    Returns:
        Counter | None: The counter where the item was found, or None if the item was not found.
    """
    for counter in counters:
        if counter.occupied_by:
            if isinstance(counter.occupied_by, deque):
                for item in counter.occupied_by:
                    if item.uuid == item_uuid:
                        return counter
            else:
                if item_uuid == counter.occupied_by.uuid:
                    return counter


def custom_asdict_factory(data):
    """Converts enums to their value.

    Args:
        data: The data to be converted to a dictionary.

    Returns:
        dict: A dictionary where the values in the data are converted based on the `convert_value` function.

    """

    def convert_value(obj):
        if isinstance(obj, Enum):
            return obj.value
        return obj

    return dict((k, convert_value(v)) for k, v in data)


def setup_logging(enable_websocket_logging=False):
    """Setup logging configuration.

    Args:
        enable_websocket_logging (bool, optional): Flag to enable websocket logging. Default is False.
    """
    path_logs = ROOT_DIR.parent / "logs"
    os.makedirs(path_logs, exist_ok=True)
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s %(levelname)-8s %(name)-50s %(message)s",
        handlers=[
            logging.FileHandler(
                path_logs / f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_debug.log",
                encoding="utf-8",
            ),
            logging.StreamHandler(sys.stdout),
        ],
    )
    logging.getLogger("matplotlib").setLevel(logging.WARNING)
    if not enable_websocket_logging:
        logging.getLogger("asyncio").setLevel(logging.ERROR)
        logging.getLogger("asyncio.coroutines").setLevel(logging.ERROR)
        logging.getLogger("websockets.server").setLevel(logging.ERROR)
        logging.getLogger("websockets.protocol").setLevel(logging.ERROR)
        logging.getLogger("websockets.client").setLevel(logging.ERROR)


class NumpyAndDataclassEncoder(json.JSONEncoder):
    """Special json encoder for numpy types"""

    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, timedelta):
            return obj.total_seconds()
        elif isinstance(obj, datetime):
            return obj.isoformat()
        elif dataclasses.is_dataclass(obj):
            return dataclasses.asdict(obj, dict_factory=custom_asdict_factory)
        # elif callable(obj):
        #     return getattr(obj, "__name__", "Unknown")

        return json.JSONEncoder.default(self, obj)


def create_layout_with_counters(w, h) -> str:
    """Print a layout string that has counters at the world borders.

    Args:
        w: The width of the layout.
        h: The height of the layout.

    Returns:
        str of the layout
    """
    string = ""
    for y in range(h):
        for x in range(w):
            if x == 0 or y == 0 or x == w - 1 or y == h - 1:
                string += "#"
            else:
                string += "_"
        string += "\n"
    return string


def deep_update(d, u):
    """Deep update of a nested dictionary.

    Args:
        d: A dictionary to be updated. This dictionary will be modified in place.
        u: A dictionary containing the updates to be applied to d.

    """
    for k, v in u.items():
        if isinstance(v, collections.abc.Mapping):
            d[k] = deep_update(d.get(k, {}), v)
        else:
            d[k] = v
    return d