Skip to content
Snippets Groups Projects
  • Florian Schröder's avatar
    ff409005
    Add docstrings and type hinting · ff409005
    Florian Schröder authored
    The updates include adding docstrings and type hinting to several classes and methods in multiple Python files. This commit also includes fixes for inconsistent code formatting and minor bugs. The docstrings provide essential details about the classes and methods, improving readability and understanding for other developers. Furthermore, the added type hinting will enable better IDE assistance, static analysis, and clarity on expected input and output types. Lastly, the code formatting fixes and bug fixes enhance the overall code quality and maintainability.
    ff409005
    History
    Add docstrings and type hinting
    Florian Schröder authored
    The updates include adding docstrings and type hinting to several classes and methods in multiple Python files. This commit also includes fixes for inconsistent code formatting and minor bugs. The docstrings provide essential details about the classes and methods, improving readability and understanding for other developers. Furthermore, the added type hinting will enable better IDE assistance, static analysis, and clarity on expected input and output types. Lastly, the code formatting fixes and bug fixes enhance the overall code quality and maintainability.
drawing.py 18.37 KiB
import argparse
import colorsys
import json
import math
from datetime import datetime, timedelta
from pathlib import Path

import numpy as np
import numpy.typing as npt
import pygame
import yaml
from scipy.spatial import KDTree

from overcooked_simulator import ROOT_DIR
from overcooked_simulator.gui_2d_vis.game_colors import colors
from overcooked_simulator.state_representation import (
    PlayerState,
    CookingEquipmentState,
    ItemState,
)

USE_PLAYER_COOK_SPRITES = True
SHOW_INTERACTION_RANGE = False
SHOW_COUNTER_CENTERS = False


def create_polygon(n, length):
    if n == 1:
        return np.array([0, 0])

    vector = np.array([length, 0])
    angle = (2 * np.pi) / n

    rot_matrix = np.array(
        [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]
    )

    vecs = [vector]
    for i in range(n - 1):
        vector = np.dot(rot_matrix, vector)
        vecs.append(vector)

    return vecs


class Visualizer:
    def __init__(self, config):
        self.image_cache_dict = {}
        self.player_colors = []
        self.config = config

    def create_player_colors(self, n) -> None:
        hue_values = np.linspace(0, 1, n + 1)

        colors_vec = np.array([col for col in colors.values()])

        tree = KDTree(colors_vec)

        color_names = list(colors.keys())

        self.player_colors = []
        for hue in hue_values:
            rgb = colorsys.hsv_to_rgb(hue, 1, 1)
            query_color = np.array([int(c * 255) for c in rgb])
            _, index = tree.query(query_color, k=1)
            self.player_colors.append(color_names[index])

    def draw_gamescreen(
        self,
        screen,
        state,
        grid_size,
    ):
        width = int(np.ceil(state["kitchen"]["width"] * grid_size))
        height = int(np.ceil(state["kitchen"]["height"] * grid_size))
        self.draw_background(
            surface=screen,
            width=width,
            height=height,
            grid_size=grid_size,
        )
        self.draw_counters(
            screen,
            state["counters"],
            grid_size,
        )

        self.draw_players(
            screen,
            state["players"],
            grid_size,
        )

    def draw_background(self, surface, width, height, grid_size):
        """Visualizes a game background."""
        block_size = grid_size // 2  # Set the size of the grid block
        surface.fill(colors[self.config["Kitchen"]["ground_tiles_color"]])
        for x in range(0, width, block_size):
            for y in range(0, height, block_size):
                rect = pygame.Rect(x, y, block_size, block_size)
                pygame.draw.rect(
                    surface,
                    self.config["Kitchen"]["background_lines"],
                    rect,
                    1,
                )

    def draw_image(
        self,
        screen: pygame.Surface,
        img_path: Path | str,
        size: float,
        pos: npt.NDArray,
        rot_angle=0,
    ):
        cache_entry = f"{img_path}"
        if cache_entry in self.image_cache_dict.keys():
            image = self.image_cache_dict[cache_entry]
        else:
            image = pygame.image.load(
                ROOT_DIR / "gui_2d_vis" / img_path
            ).convert_alpha()
            self.image_cache_dict[cache_entry] = image

        image = pygame.transform.scale(image, (size, size))
        if rot_angle != 0:
            image = pygame.transform.rotate(image, rot_angle)
        rect = image.get_rect()
        rect.center = pos

        screen.blit(image, rect)

    def draw_players(
        self,
        screen: pygame.Surface,
        players: dict,
        grid_size: float,
    ):
        """Visualizes the players as circles with a triangle for the facing direction.
        If the player holds something in their hands, it is displayed
        Args:            state: The game state returned by the environment.
        """
        for p_idx, player_dict in enumerate(players):
            player_dict: PlayerState
            pos = np.array(player_dict["pos"]) * grid_size
            pos += grid_size / 2  # correct for grid offset

            facing = np.array(player_dict["facing_direction"])

            if USE_PLAYER_COOK_SPRITES:
                pygame.draw.circle(
                    screen,
                    self.player_colors[p_idx],
                    pos - facing * grid_size * 0.25,
                    grid_size * 0.2,
                )

                img_path = self.config["Cook"]["parts"][0]["path"]
                rel_x, rel_y = facing
                angle = -np.rad2deg(math.atan2(rel_y, rel_x)) + 90
                size = self.config["Cook"]["parts"][0]["size"] * grid_size
                self.draw_image(screen, img_path, size, pos, angle)

            else:
                size = 0.4 * grid_size
                color1 = self.player_colors[p_idx]
                color2 = colors["white"]

                pygame.draw.circle(screen, color2, pos, size)
                pygame.draw.circle(screen, colors["blue"], pos, size, width=1)
                pygame.draw.circle(screen, colors[color1], pos, size // 2)

                pygame.draw.polygon(
                    screen,
                    colors["blue"],
                    (
                        (
                            pos[0] + (facing[1] * 0.1 * grid_size),
                            pos[1] - (facing[0] * 0.1 * grid_size),
                        ),
                        (
                            pos[0] - (facing[1] * 0.1 * grid_size),
                            pos[1] + (facing[0] * 0.1 * grid_size),
                        ),
                        pos + (facing * 0.5 * grid_size),
                    ),
                )

            if SHOW_INTERACTION_RANGE:
                pygame.draw.circle(
                    screen,
                    colors["blue"],
                    pos + (facing * grid_size * 0.4),
                    1.6 * grid_size,
                    width=1,
                )
                pygame.draw.circle(
                    screen, colors["red1"], pos + (facing * grid_size * 0.4), 4
                )

            if player_dict["holding"] is not None:
                holding_item_pos = pos + (20 * facing)
                self.draw_item(
                    pos=holding_item_pos,
                    grid_size=grid_size,
                    item=player_dict["holding"],
                    screen=screen,
                )

            if player_dict["current_nearest_counter_pos"]:
                nearest_pos = (
                    np.array(player_dict["current_nearest_counter_pos"]) * grid_size
                )

                pygame.draw.rect(
                    screen,
                    colors[self.player_colors[p_idx]],
                    rect=pygame.Rect(
                        *nearest_pos,
                        grid_size,
                        grid_size,
                    ),
                    width=2,
                )

    def draw_thing(
        self,
        screen: pygame.Surface,
        pos: npt.NDArray[float],
        grid_size: float,
        parts: list[dict[str]],
        scale: float = 1.0,
    ):
        """Draws an item, based on its visual parts specified in the visualization config.

        Args:
            screen: the game screen to draw on.
            grid_size: size of a grid cell.
            pos: Where to draw the item parts.
            parts: The visual parts to draw.
            scale: Rescale the item by this factor.
        """
        for part in parts:
            part_type = part["type"]

            draw_pos = pos.copy()
            if "center_offset" in part:
                draw_pos += np.array(part["center_offset"]) * grid_size

            match part_type:
                case "image":
                    self.draw_image(
                        screen,
                        part["path"],
                        part["size"] * scale * grid_size,
                        draw_pos,
                    )
                case "rect":
                    height = part["height"] * grid_size
                    width = part["width"] * grid_size
                    color = part["color"]
                    rect = pygame.Rect(
                        draw_pos[0] - (height / 2),
                        draw_pos[1] - (width / 2),
                        height,
                        width,
                    )
                    pygame.draw.rect(screen, color, rect)
                case "circle":
                    radius = part["radius"] * grid_size
                    color = colors[part["color"]]
                    pygame.draw.circle(screen, color, draw_pos, radius)

    def draw_item(
        self,
        pos: npt.NDArray[float] | list[float],
        grid_size: float,
        item: ItemState | CookingEquipmentState,
        scale: float = 1.0,
        plate=False,
        screen=None,
    ):
        """Visualization of an item at the specified position. On a counter or in the hands of the player.
        The visual composition of the item is read in from visualization.yaml file, where it is specified as
        different parts to be drawn.

        Args:
            grid_size: size of a grid cell.
            pos: The position of the item to draw.
            item: The item do be drawn in the game.
            scale: Rescale the item by this factor.
            screen: the pygame screen to draw on.
            plate: item is on a plate (soup are is different on a plate and pot)
        """

        if not isinstance(item, list):  # can we remove this check?
            if item["type"] in self.config:
                item_key = item["type"]
                if "Soup" in item_key and plate:
                    item_key += "Plate"
                self.draw_thing(
                    pos=pos,
                    parts=self.config[item_key]["parts"],
                    scale=scale,
                    screen=screen,
                    grid_size=grid_size,
                )
                #
        if "progress_percentage" in item and item["progress_percentage"] > 0.0:
            self.draw_progress_bar(
                screen, pos, item["progress_percentage"], grid_size=grid_size
            )

        if (
            "content_ready" in item
            and item["content_ready"]
            and item["content_ready"]["type"] in self.config
        ):
            self.draw_thing(
                pos=pos,
                parts=self.config[item["content_ready"]["type"]]["parts"],
                screen=screen,
                grid_size=grid_size,
            )
        elif "content_list" in item and item["content_list"]:
            triangle_offsets = create_polygon(len(item["content_list"]), length=10)
            scale = 1 if len(item["content_list"]) == 1 else 0.6
            for idx, o in enumerate(item["content_list"]):
                self.draw_item(
                    pos=np.array(pos) + triangle_offsets[idx],
                    item=o,
                    scale=scale,
                    plate="Plate" in item["type"],
                    screen=screen,
                    grid_size=grid_size,
                )

    @staticmethod
    def draw_progress_bar(
        screen: pygame.Surface,
        pos: npt.NDArray[float],
        percent: float,
        grid_size: float,
    ):
        """Visualize progress of progressing item as a green bar under the item."""
        pos -= grid_size / 2

        bar_height = grid_size * 0.2
        progress_width = percent * grid_size
        progress_bar = pygame.Rect(
            pos[0],
            pos[1] + grid_size - bar_height,
            progress_width,
            bar_height,
        )
        pygame.draw.rect(screen, colors["green1"], progress_bar)

    def draw_counter(
        self, screen: pygame.Surface, counter_dict: dict, grid_size: float
    ):
        """Visualization of a counter at its position. If it is occupied by an item, it is also shown.
        The visual composition of the counter is read in from visualization.yaml file, where it is specified as
        different parts to be drawn.
        Args:            counter: The counter to visualize.
        """
        pos = np.array(counter_dict["pos"], dtype=float) * grid_size
        counter_type = counter_dict["type"]

        pos += grid_size // 2  # correct for grid offset

        self.draw_thing(screen, pos, grid_size, self.config["Counter"]["parts"])
        if counter_type in self.config:
            self.draw_thing(screen, pos, grid_size, self.config[counter_type]["parts"])
        else:
            if counter_type in self.config:
                parts = self.config[counter_type]["parts"]
            elif counter_type.endswith("Dispenser"):
                parts = self.config["Dispenser"]["parts"]
            else:
                raise ValueError(f"Can not draw counter type {counter_type}")
            self.draw_thing(
                screen=screen,
                pos=pos,
                parts=parts,
                grid_size=grid_size,
            )

    def draw_counter_occupier(
        self,
        screen: pygame.Surface,
        occupied_by: dict | list,
        grid_size,
        pos: npt.NDArray[float],
    ):
        # Multiple plates on plate return:
        if isinstance(occupied_by, list):
            for i, o in enumerate(occupied_by):
                self.draw_item(
                    screen=screen,
                    pos=np.abs([pos[0], pos[1] - (i * 3)]),
                    grid_size=grid_size,
                    item=o,
                )
        # All other items:
        else:
            self.draw_item(
                pos=pos,
                grid_size=grid_size,
                item=occupied_by,
                screen=screen,
            )

    def draw_counters(self, screen: pygame, counters, grid_size):
        """Visualizes the counters in the environment.

        Args:            state: The game state returned by the environment.
        """
        for counter in counters:
            self.draw_counter(screen, counter, grid_size)

        for counter in counters:
            if counter["occupied_by"]:
                self.draw_counter_occupier(
                    screen,
                    counter["occupied_by"],
                    grid_size,
                    np.array(counter["pos"]) * grid_size + (grid_size / 2),
                )
            if SHOW_COUNTER_CENTERS:
                pygame.draw.circle(
                    screen,
                    colors["green1"],
                    np.array(counter["pos"]) * grid_size + (grid_size / 2),
                    3,
                )

    def draw_orders(
        self, screen, state, grid_size, width, height, screen_margin, config
    ):
        orders_width = width - 100
        orders_height = screen_margin
        order_screen = pygame.Surface(
            (orders_width, orders_height),
        )

        bg_color = colors[config["GameWindow"]["background_color"]]
        pygame.draw.rect(order_screen, bg_color, order_screen.get_rect())

        order_rects_start = (orders_height // 2) - (grid_size // 2)
        for idx, order in enumerate(state["orders"]):
            order_upper_left = [
                order_rects_start + idx * grid_size * 1.2,
                order_rects_start,
            ]
            pygame.draw.rect(
                order_screen,
                colors["red"],
                pygame.Rect(
                    order_upper_left[0],
                    order_upper_left[1],
                    grid_size,
                    grid_size,
                ),
                width=2,
            )
            center = np.array(order_upper_left)
            self.draw_thing(
                pos=center + (grid_size / 2),
                parts=config["Plate"]["parts"],
                screen=order_screen,
                grid_size=grid_size,
            )
            self.draw_item(
                pos=center + (grid_size / 2),
                item={"type": order["meal"]},
                plate=True,
                screen=order_screen,
                grid_size=grid_size,
            )
            order_done_seconds = (
                (
                    datetime.fromisoformat(order["start_time"])
                    + timedelta(seconds=order["max_duration"])
                )
                - datetime.fromisoformat(state["env_time"])
            ).total_seconds()

            percentage = order_done_seconds / order["max_duration"]
            self.draw_progress_bar(
                pos=center + (grid_size / 2),
                percent=percentage,
                screen=order_screen,
                grid_size=grid_size,
            )

        orders_rect = order_screen.get_rect()
        orders_rect.center = [
            screen_margin + (orders_width // 2),
            orders_height // 2,
        ]
        screen.blit(order_screen, orders_rect)

    def save_state_image(
        self, grid_size: int, state: dict, filename: str | Path
    ) -> None:
        width = int(np.ceil(state["kitchen"]["width"] * grid_size))
        height = int(np.ceil(state["kitchen"]["height"] * grid_size))

        flags = pygame.HIDDEN
        screen = pygame.display.set_mode((width, height), flags=flags)

        self.draw_gamescreen(screen, state, grid_size)
        pygame.image.save(screen, filename)


def save_screenshot(state: dict, config: dict, filename: str | Path) -> None:
    viz = Visualizer(config)
    viz.create_player_colors(len(state["players"]))
    pygame.init()
    pygame.font.init()
    viz.save_state_image(grid_size=40, state=state, filename=filename)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="Overcooked Simulator Image Generation",
        description="Generate images for a state in json.",
        epilog="For further information, see https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/overcooked_simulator.html",
    )
    parser.add_argument(
        "-s",
        "--state",
        type=argparse.FileType("r", encoding="UTF-8"),
        default=ROOT_DIR / "gui_2d_vis" / "sample_state.json",
    )
    parser.add_argument(
        "-v",
        "--visualization_config",
        type=argparse.FileType("r", encoding="UTF-8"),
        default=ROOT_DIR / "gui_2d_vis" / "visualization.yaml",
    )
    parser.add_argument(
        "-o",
        "--output_filename",
        type=str,
        default="screenshot.jpg",
    )
    args = parser.parse_args()
    with open(args.visualization_config, "r") as f:
        viz_config = yaml.safe_load(f)
    with open(args.state, "r") as f:
        state = json.load(f)
    save_screenshot(state, viz_config, args.output_filename)