import argparse
import colorsys
import json
import os
import time
from datetime import datetime, timedelta
from enum import Flag, auto
from functools import reduce
from pathlib import Path
from threading import Lock

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pygame
import yaml
from scipy.spatial import KDTree

from cooperative_cuisine import ROOT_DIR
from cooperative_cuisine.argument_parser import create_screenshot_parser
from cooperative_cuisine.environment import Environment
from cooperative_cuisine.pygame_2d_vis.game_colors import colors, RGB
from cooperative_cuisine.state_representation import (
    PlayerState,
    CookingEquipmentState,
    ItemState,
    EffectState,
)


class CacheFlags(Flag):
    """Do caching of background and or counters.

    You can combine them via `|` operator: `CacheFlags.BACKGROUND | CacheFlags.COUNTER`.
    You can specify the cache flags via the visualization config under GameWindow.cache_flags (just name the flags in a list):
    ```yaml
    GameWindow:
      # optimization
      cache_flags: [Counters, Background]  # [None]
      reduced_background: true
    """

    NONE = 0
    """No caching."""
    BACKGROUND = auto()
    """Cache the Background lines / texture."""
    COUNTERS = auto()
    """Cache the basic counter drawing. If used without Background flag, you need to set the config attribute GameWindow.reduced_background to true (or the class attribute reduced_background of the Visualizer)."""


def calc_angle(vec_a: list[float], vec_b: list[float]) -> float:
    a = pygame.math.Vector2(vec_a)
    b = pygame.math.Vector2(vec_b)
    return a.angle_to(b)


def create_polygon(n, start_vec):
    if n == 1:
        return np.array([0, 0])

    vector = start_vec.copy()

    angle = (2 * np.pi) / n

    rot_matrix = np.array(
        [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]
    )

    vecs = [vector]
    for i in range(n - 1):
        vector = np.dot(rot_matrix, vector)
        vecs.append(vector)

    return vecs


class Visualizer:
    """Class for visualizing the game state retrieved from the gameserver.
    2D game screen is drawn with pygame shapes and images.

    Args:
        config: Visualization configuration (loaded from yaml file) given as a dict.

    """

    def __init__(self, config):
        self.image_cache_dict = {}
        self.surface_cache_dict = {}

        self.player_colors = []
        self.config = config

        self.fire_state = 0
        self.fire_time_steps = 8
        self.observation_screen = None

        self.USE_PLAYER_COOK_SPRITES = (
            config["Gui"]["use_player_cook_sprites"]
            if "Gui" in config and "use_player_cook_sprites" in config["Gui"]
            else True
        )
        self.SHOW_INTERACTION_RANGE = (
            config["Gui"]["show_interaction_range"]
            if "Gui" in config and "show_interaction_range" in config["Gui"]
            else False
        )
        self.SHOW_COUNTER_CENTERS = (
            config["Gui"]["show_counter_centers"]
            if "Gui" in config and "show_counter_centers" in config["Gui"]
            else False
        )
        pygame.font.init()
        self.font = pygame.font.SysFont("Arial", 20)
        self.init_get_state_image = False

        self.observation_screen = pygame.display.set_mode(
            (100, 100), flags=pygame.HIDDEN
        )

        self.grid_size = 48

        configs_cache_flags = [CacheFlags[c.upper()] for c in (
            config["GameWindow"]["cache_flags"] if "GameWindow" in config and "cache_flags" in config[
                "GameWindow"] else ["Counters", "Background"])]
        self.cache_flags = reduce(lambda s, x: s | x, configs_cache_flags, CacheFlags.NONE)

        self.grid_size_lock = Lock()
        self.reduced_background = config["GameWindow"][
            "reduced_background"] if "GameWindow" in config and "reduced_background" in config["GameWindow"] else True

    def invalidate_surface_cache(self):
        self.surface_cache_dict = {}

    def create_player_colors(self, n) -> None:
        """Create different colors for the players. The color hues are sampled uniformly in HSV-Space,
        then the corresponding colors from the defined colors list are looked up.

        Args:
            n: Number of players to create colors for.
        """
        hue_values = np.linspace(0, 1, n + 1)

        colors_vec = np.array([col for col in colors.values()])

        tree = KDTree(colors_vec)

        color_names = list(colors.keys())

        self.player_colors = []
        for hue in hue_values:
            rgb = colorsys.hsv_to_rgb(hue, 1, 1)
            query_color = np.array([int(c * 255) for c in rgb])
            _, index = tree.query(query_color, k=1)
            self.player_colors.append(color_names[index])

    def set_grid_size(self, grid_size: float):
        self.grid_size = grid_size

    def model_to_world_coords(self, pos):
        return (np.array(pos) + 0.5) * self.grid_size

    def draw_gamescreen(
        self,
        state: dict,
        controlled_player_idxs: list[int],
        env_id_ref=None,
    ):
        """Draws the game state on the given surface.

        Args:
            screen: The pygame surface to draw the game on.
            state: The gamestate retrieved from the environment.
            grid_size: The gridsize to base every object size in the game on.
        """

        width = int(np.ceil(state["kitchen"]["width"] * self.grid_size))
        height = int(np.ceil(state["kitchen"]["height"] * self.grid_size))
        screen = pygame.Surface((width, height), pygame.SRCALPHA)
        if env_id_ref in self.surface_cache_dict:
            screen.blit(self.surface_cache_dict[env_id_ref], (0, 0))
        else:
            if CacheFlags.BACKGROUND in self.cache_flags:
                self.draw_background(
                    surface=screen,
                    width=width,
                    height=height,
                    fixed_counter_pos=set([(int(c["pos"][0]), int(c["pos"][1])) for c in state["counters"]])
                )
            if CacheFlags.COUNTERS in self.cache_flags:
                self.draw_counters(
                    screen,
                    state["counters"],
                )
            self.surface_cache_dict[env_id_ref] = screen.copy()

        if CacheFlags.BACKGROUND not in self.cache_flags:
            self.draw_background(
                surface=screen,
                width=width,
                height=height,
                fixed_counter_pos=set([(int(c["pos"][0]), int(c["pos"][1])) for c in state["counters"]]),
            )
        if CacheFlags.COUNTERS not in self.cache_flags:
            self.draw_counters(
                screen,
                state["counters"],
            )

        self.draw_counter_contents(
            screen,
            state["counters"],
        )

        for idx, col in zip(controlled_player_idxs, [colors["red"], colors["blue"]]):
            pygame.draw.circle(
                screen,
                col,
                (np.array(state["players"][int(idx)]["pos"]) + 0.5) * self.grid_size,
                (self.grid_size / 2),
            )

        self.draw_players(
            screen,
            state["players"],
        )

        if "view_restrictions" in state and state["view_restrictions"]:
            self.draw_lightcones(screen, width, height, state)

        # for x in range(0, width, int(np.round(self.grid_size))):
        #     for y in range(0, height, int(np.round(self.grid_size))):
        #         rect = pygame.Rect(x, y, self.grid_size, self.grid_size)
        #         pygame.draw.rect(
        #             screen,
        #             colors["green"],
        #             rect,
        #             1,
        #         )

        return screen

    def draw_lightcones(self, screen, width, height, state):
        view_restrictions = state["view_restrictions"]
        mask = pygame.Surface(screen.get_size(), pygame.SRCALPHA)
        mask.fill((0, 0, 0, 0))
        mask_color = (0, 0, 0, 0)
        for idx, restriction in enumerate(view_restrictions):
            direction = pygame.math.Vector2(restriction["direction"])
            pos = pygame.math.Vector2(restriction["position"])
            angle = restriction["angle"] / 2
            view_range = restriction["range"]

            angle = min(angle, 180)

            # pos = pos * self.grid_size + pygame.math.Vector2([self.grid_size / 2, self.grid_size / 2])
            pos = pygame.math.Vector2(self.model_to_world_coords(pos).tolist())

            rect_scale = max(width, height) * 2
            # rect_scale = 2 * grid_size

            left_beam = pos + (direction.rotate(angle) * rect_scale * 2)
            right_beam = pos + (direction.rotate(-angle) * rect_scale * 2)

            cone_mask = pygame.surface.Surface(screen.get_size(), pygame.SRCALPHA)
            cone_mask.fill((255, 255, 255, 255))

            offset_front = direction * self.grid_size * 0.7
            if angle != 180:
                shadow_cone_points = [
                    pos - offset_front,
                    left_beam - offset_front,
                    left_beam + (direction.rotate(90) * rect_scale),
                    pos
                    - (direction * rect_scale * 2)
                    + (direction.rotate(90) * rect_scale),
                    pos
                    - (direction * rect_scale * 2)
                    + (direction.rotate(-90) * rect_scale),
                    right_beam + (direction.rotate(-90) * rect_scale),
                    right_beam - offset_front,
                ]
                light_cone_points = [pos - offset_front, left_beam, right_beam]
                pygame.draw.polygon(
                    cone_mask,
                    mask_color,
                    shadow_cone_points,
                )

            if view_range:
                n_circle_points = 40

                start_vec = np.array(-direction * view_range)
                points = (
                             np.array(create_polygon(n_circle_points, start_vec)) * self.grid_size
                         ) + pos

                circle_closed = np.concatenate([points, points[0:1]], axis=0)

                corners = [
                    pos - (direction * rect_scale),
                    *circle_closed,
                    pos - (direction * rect_scale),
                    pos
                    - (direction * rect_scale)
                    + (direction.rotate(90) * rect_scale),
                    pos
                    + (direction * rect_scale)
                    + (direction.rotate(90) * rect_scale),
                    pos
                    + (direction * rect_scale)
                    + (direction.rotate(-90) * rect_scale),
                    pos
                    - (direction * rect_scale)
                    + (direction.rotate(-90) * rect_scale),
                ]

                pygame.draw.polygon(cone_mask, mask_color, corners)

            mask.blit(cone_mask, (0, 0), special_flags=pygame.BLEND_MAX)

        screen.blit(
            mask,
            mask.get_rect(),
            special_flags=pygame.BLEND_RGBA_MULT,
        )

    def draw_background(
        self, surface: pygame.Surface, width: int, height: int, fixed_counter_pos: set[tuple[int, int]] | None,
    ):
        """Visualizes a game background.

        Args:
            surface: The pygame surface to draw the background on.
            width: The kitchen width.
            height: The kitchen height.
            fixed_counter_pos: Set of counter positions.
        """
        block_size = int(np.ceil(self.grid_size / 2))  # Set the size of the grid block

        if self.reduced_background:
            for x_idx, x in enumerate(np.arange(0, width, self.grid_size / 2)):
                for y_idx, y in enumerate(np.arange(0, height, self.grid_size / 2)):
                    if (x_idx // 2, y_idx // 2) not in fixed_counter_pos:
                        rect = pygame.Rect(np.round(x), np.round(y), block_size, block_size)
                        surface.fill(colors[self.config["Kitchen"]["ground_tiles_color"]], rect)
                        pygame.draw.rect(
                            surface,
                            self.config["Kitchen"]["background_lines"],
                            rect,
                            1,
                        )
        else:
            surface.fill(colors[self.config["Kitchen"]["ground_tiles_color"]])
            for x in range(0, width, block_size):
                for y in range(0, height, block_size):
                    rect = pygame.Rect(x, y, block_size, block_size)
                    pygame.draw.rect(
                        surface,
                        self.config["Kitchen"]["background_lines"],
                        rect,
                        1,
                    )

    def draw_image(
        self,
        screen: pygame.Surface,
        img_path: Path | str,
        size: float,
        pos: npt.NDArray,
        rot_angle=0,
        burnt: bool = False,
    ):
        """Draws an image on the given screen.

        Args:
            screen: The pygame surface to draw the image on.
            img_path: The path to the image file, given relative to the pygame_2d_vis directory.
            size: The size of the image, given in pixels.
            pos: The position of the center of the image, given in pixels.
            rot_angle: Optional angle to rotate the image around.
        """

        size = int(np.round(size))
        rot_angle = int(np.round(rot_angle))

        cache_entry = f"{img_path}-{size}-{rot_angle}{'-burnt' if burnt else ''}"
        if cache_entry in self.image_cache_dict:
            image = self.image_cache_dict[cache_entry]
        else:

            image = pygame.image.load(
                ROOT_DIR / "pygame_2d_vis" / img_path
            ).convert_alpha()
            if burnt:
                image = pygame.transform.grayscale(image)

            # TODO: smoothscale or not???
            # image = pygame.transform.smoothscale(image, (size, size))
            image = pygame.transform.scale(image, (size, size))
            if rot_angle != 0:
                image = pygame.transform.rotate(image, rot_angle)

            self.image_cache_dict[cache_entry] = image

        rect = image.get_rect()
        rect.center = np.round(pos)
        screen.blit(image, rect)

    def draw_cook(
        self,
        screen: pygame.Surface,
        pos: npt.NDArray[float] | list[float],
        color: RGB,
        facing: npt.NDArray[float] | list[float],
    ):
        pygame.draw.circle(
            screen,
            color,
            self.model_to_world_coords(pos - facing * 0.25),
            self.grid_size * 0.2,
        )
        self.draw_thing(screen, pos, self.config["Cook"]["parts"], scale=1.0, orientation=facing.tolist())

    def draw_players(
        self,
        screen: pygame.Surface,
        players: dict,
    ):
        """Visualizes the players as circles with a triangle for the facing direction.
        If the player holds something in their hands, it is displayed

        Args:
            screen: The pygame surface to draw the players on.
            players: The state of the players returned by the environment.
        """
        for p_idx, player_dict in enumerate(players):
            player_dict: PlayerState

            pos = player_dict["pos"]
            facing = np.array(player_dict["facing_direction"], dtype=float)

            if self.USE_PLAYER_COOK_SPRITES:
                self.draw_cook(
                    screen, pos, colors[self.player_colors[p_idx]], facing
                )
            else:
                player_radius = 0.4
                size = player_radius * self.grid_size
                color1 = self.player_colors[p_idx]
                color2 = colors["white"]

                pygame.draw.circle(screen, color2, pos, size)
                pygame.draw.circle(screen, colors["blue"], pos, size, width=1)
                pygame.draw.circle(screen, colors[color1], pos, size // 2)

                pygame.draw.polygon(
                    screen,
                    colors["blue"],
                    (
                        (
                            pos[0] + (facing[1] * 0.1 * self.grid_size),
                            pos[1] - (facing[0] * 0.1 * self.grid_size),
                        ),
                        (
                            pos[0] - (facing[1] * 0.1 * self.grid_size),
                            pos[1] + (facing[0] * 0.1 * self.grid_size),
                        ),
                        pos + (facing * 0.5 * self.grid_size),
                    ),
                )

            if player_dict["holding"] is not None:
                holding_item_pos = pos + (facing * 0.5)
                self.draw_item(
                    pos=holding_item_pos,
                    item=player_dict["holding"],
                    screen=screen,
                )

            if player_dict["current_nearest_counter_pos"]:
                nearest_pos = np.round(np.array(player_dict["current_nearest_counter_pos"]) * self.grid_size)
                interaction_marker_width = max(1, int(self.grid_size / 15))
                pygame.draw.rect(
                    screen,
                    colors[self.player_colors[p_idx]],
                    rect=pygame.Rect(
                        *nearest_pos,
                        self.grid_size,
                        self.grid_size,
                    ),
                    width=interaction_marker_width,
                )

            if self.SHOW_INTERACTION_RANGE:
                pygame.draw.circle(
                    screen,
                    colors["blue"],
                    pos + (facing * self.grid_size * 0.4),
                    1.6 * self.grid_size,
                    width=1,
                )
                pygame.draw.circle(
                    screen, colors["red1"], pos + (facing * self.grid_size * 0.4), 4
                )

    def draw_thing(
        self,
        screen: pygame.Surface,
        pos: npt.NDArray[float],
        parts: list[dict[str]],
        scale: float = 1.0,
        burnt: bool = False,
        orientation: list[float] | None = None,
        absolute_size=None,
        absolute=False,
    ):
        """Draws an item, based on its visual parts specified in the visualization config.

        Args:
            screen: the game screen to draw on.
            grid_size: size of a grid cell.
            pos: Where to draw the item parts.
            parts: The visual parts to draw.
            scale: Rescale the item by this factor.
            orientation: Rotate the item to face this direction.
        """
        for part in parts:
            part_type = part["type"]
            angle, angle_offset = 0, 0

            if absolute:
                draw_pos = pos
            else:
                draw_pos = self.model_to_world_coords(pos)

            if orientation is not None:
                angle_offset = calc_angle(orientation, [0, 1])
                if "rotate_image" in part.keys():
                    if part["rotate_image"]:
                        angle = calc_angle(orientation, [0, 1])
                else:
                    angle = angle_offset
            if "rotate_offset" in part.keys():
                angle_offset = 0

            match part_type:
                case "image":
                    if "center_offset" in part:
                        d = pygame.math.Vector2(part["center_offset"]) * self.grid_size
                        d.rotate_ip(angle_offset)
                        d[0] = -d[0]
                        draw_pos += np.array(d)
                    size = (
                        absolute_size
                        if absolute_size is not None
                        else part["size"] * scale
                    )
                    if not absolute:
                        size *= self.grid_size

                    self.draw_image(
                        screen,
                        part["path"],
                        size,
                        draw_pos,
                        burnt=burnt,
                        rot_angle=angle,
                    )

                case "rect":
                    if "center_offset" in part:
                        d = pygame.math.Vector2(part["center_offset"]) * self.grid_size
                        d.rotate_ip(angle_offset)
                        d[0] = -d[0]

                        draw_pos += np.array(d)
                    height = part["height"] * self.grid_size
                    width = part["width"] * self.grid_size
                    color = part["color"]
                    rect = pygame.Rect(
                        draw_pos[0] - (height / 2),
                        draw_pos[1] - (width / 2),
                        height,
                        width,
                    )
                    pygame.draw.rect(screen, color, rect)

                case "circle":
                    if "center_offset" in part:
                        d = pygame.math.Vector2(part["center_offset"]) * self.grid_size
                        d.rotate_ip(-angle_offset)
                        draw_pos += np.array(d)
                    radius = part["radius"] * self.grid_size
                    color = colors[part["color"]]

                    pygame.draw.circle(screen, color, draw_pos, radius)

    def draw_item(
        self,
        pos: npt.NDArray[float] | list[float],
        item: ItemState | CookingEquipmentState | EffectState,
        scale: float = 1.0,
        plate=False,
        screen=None,
        absolute=False,
    ):
        """Visualization of an item at the specified position. On a counter or in the hands of the player.
        The visual composition of the item is read in from visualization.yaml file, where it is specified as
        different parts to be drawn.

        Args:
            grid_size: size of a grid cell.
            pos: The position of the item to draw.
            item: The item do be drawn in the game.
            scale: Rescale the item by this factor.
            screen: the pygame screen to draw on.
            plate: item is on a plate (soup are is different on a plate and pot)
        """

        if not isinstance(item, list):  # can we remove this check?
            if item["type"] in self.config or (
                item["type"].startswith("Burnt")
                and item["type"].replace("Burnt", "") in self.config
            ):
                item_key = item["type"]
                if "Soup" in item_key and plate:
                    item_key += "Plate"
                elif item_key.startswith("Burnt"):
                    item_key = item_key.replace("Burnt", "")

                elif item_key == "Fire":
                    item_key = (
                        f"{item_key}{int(self.fire_state / self.fire_time_steps) + 1}"
                    )

                self.draw_thing(
                    pos=pos,
                    parts=self.config[item_key]["parts"],
                    scale=scale,
                    screen=screen,
                    burnt=item["type"].startswith("Burnt"),
                    absolute=absolute,
                )

        if "progress_percentage" in item and item["progress_percentage"] > 0.0:
            if item["inverse_progress"]:
                percentage = 1 - item["progress_percentage"]
            else:
                percentage = item["progress_percentage"]
            self.draw_progress_bar(
                screen,
                pos,
                percentage,
                attention=item["inverse_progress"],
                absolute=absolute,
            )

        if (
            "content_ready" in item
            and item["content_ready"]
            and (
            item["content_ready"]["type"] in self.config
            or (
                item["content_ready"]["type"].startswith("Burnt")
                and item["content_ready"]["type"].replace("Burnt", "")
                in self.config
            )
        )
        ):
            self.draw_thing(
                pos=pos,
                parts=self.config[item["content_ready"]["type"].replace("Burnt", "")][
                    "parts"
                ],
                screen=screen,
                burnt=item["type"].startswith("Burnt"),
            )
        elif "content_list" in item and item["content_list"]:
            triangle_offsets = create_polygon(
                len(item["content_list"]), np.array([0, 0.15])
            )
            scale = 1 if len(item["content_list"]) == 1 else 0.6
            for idx, o in enumerate(item["content_list"]):
                self.draw_item(
                    pos=np.array(pos) + triangle_offsets[idx],
                    item=o,
                    scale=scale,
                    plate="Plate" in item["type"],
                    screen=screen,
                )
        if "active_effects" in item and item["active_effects"]:
            for effect in item["active_effects"]:
                self.draw_item(pos=pos, item=effect, screen=screen)

    def draw_progress_bar(
        self,
        screen: pygame.Surface,
        pos: npt.NDArray[float],
        percent: float,
        attention: bool = False,
        absolute: bool = False,
        size: float = None
    ):
        """Visualize progress of progressing item as a green bar under the item.

        Args:
            screen: The pygame surface to draw the progress bar on.
            pos: The center position of a tile to draw the progress bar under.
            percent: Progressed percent of the progress bar.
            grid_size: Scaling of the progress bar given in pixels.
        """
        if absolute:
            assert size, "Size must be given if absolute is True."
            bar_pos = pos - (size / 2)
        else:
            size = self.grid_size
            bar_pos = self.model_to_world_coords(pos - 0.5)

        bar_height = size * 0.2
        progress_width = percent * size
        progress_bar = pygame.Rect(
            bar_pos[0],
            bar_pos[1] + size - bar_height,
            progress_width,
            bar_height,
        )
        pygame.draw.rect(screen, colors["red" if attention else "green1"], progress_bar)

    def draw_counter(
        self, screen: pygame.Surface, counter_dict: dict
    ):
        """Visualization of a counter at its position. If it is occupied by an item, it is also shown.
        The visual composition of the counter is read in from visualization.yaml file, where it is specified as
        different parts to be drawn.
        Args:
            screen: The pygame surface to draw the counter on.
            counter_dict: The counter to visualize, given as a dict from the game state.
            grid_size: Scaling of the counter given in pixels.
        """
        pos = np.array(counter_dict["pos"], dtype=float)
        counter_type = counter_dict["type"]

        self.draw_thing(
            screen,
            pos,
            self.config["Counter"]["parts"],
            orientation=counter_dict["orientation"]
            if "orientation" in counter_dict
            else None,
        )
        if counter_type in self.config:
            self.draw_thing(
                screen,
                pos,
                self.config[counter_type]["parts"],
                orientation=counter_dict["orientation"],
            )
        else:
            if counter_type in self.config:
                parts = self.config[counter_type]["parts"]
            elif counter_type.endswith("Dispenser"):
                parts = self.config["Dispenser"]["parts"]
            else:
                raise ValueError(f"Can not draw counter type {counter_type}")
            self.draw_thing(
                screen=screen,
                pos=pos,
                parts=parts,
                orientation=counter_dict["orientation"],
            )

    def draw_counter_occupier(
        self,
        screen: pygame.Surface,
        occupied_by: dict | list,
        pos: npt.NDArray[float],
        item_scale: float,
    ):
        """Visualization of a thing lying on a counter.
        Args:
            screen: The pygame surface to draw the item on the counter on.
            occupied_by: The thing that occupies the counter.
            grid_size: Scaling of the object given in pixels.
            pos: The position of the counter which the thing lies on.
            item_scale: Relative scaling of the item.
        """
        # Multiple plates on plate return:
        if isinstance(occupied_by, list):
            for i, o in enumerate(occupied_by):
                stack_pos = np.abs([pos[0], pos[1] - (i * 0.075)])
                self.draw_item(
                    screen=screen,
                    pos=stack_pos,
                    item=o,
                    scale=item_scale,
                )
        # All other items:
        else:
            self.draw_item(
                pos=pos,
                item=occupied_by,
                screen=screen,
                scale=item_scale,
            )

    def draw_counters(self, screen: pygame, counters: dict):
        """Visualizes the counters in the environment.
        Args:
            screen: The pygame surface to draw the counters on.
            counters: The counter state returned by the environment.
            grid_size: Scaling of the object given in pixels.
        """
        for counter in counters:
            self.draw_counter(screen, counter)

    def draw_counter_contents(self, screen: pygame, counters: dict):
        """Visualizes the contents of the counters in the environment."""
        for counter in counters:
            if counter["occupied_by"]:
                item_pos = np.array(counter["pos"])
                item_scale = 1.0

                counter_type = counter["type"]

                if counter_type.endswith("Dispenser") and "Plate" not in counter_type:
                    if "item_offset" in self.config["Dispenser"].keys():
                        offset_vec = pygame.math.Vector2(
                            self.config["Dispenser"]["item_offset"]
                        )
                        offset_vec.rotate_ip(
                            offset_vec.angle_to(
                                pygame.math.Vector2(counter["orientation"])
                            )
                            + 180
                        )
                        item_pos += offset_vec
                    if "item_scale" in self.config["Dispenser"].keys():
                        item_scale = self.config["Dispenser"]["item_scale"]

                self.draw_counter_occupier(
                    screen=screen,
                    occupied_by=counter["occupied_by"],
                    pos=item_pos,
                    item_scale=item_scale,
                )
            if counter["active_effects"]:
                for effect in counter["active_effects"]:
                    self.draw_item(
                        pos=np.array(counter["pos"]),
                        screen=screen,
                        item=effect,
                    )

            if self.SHOW_COUNTER_CENTERS:
                pos = (np.array(counter["pos"]) + 0.5) * self.grid_size
                pygame.draw.circle(screen, colors["green1"], pos, 1)
                pygame.draw.circle(screen, colors["green1"], pos, 1)
                facing = np.array(counter["orientation"])
                # pygame.draw.polygon(
                #     screen,
                #     colors["red"],
                #     (
                #         (
                #             pos[0] + (facing[1] * 0.1 * grid_size),
                #             pos[1] - (facing[0] * 0.1 * grid_size),
                #         ),
                #         (
                #             pos[0] - (facing[1] * 0.1 * grid_size),
                #             pos[1] + (facing[0] * 0.1 * grid_size),
                #         ),
                #         pos + (facing * 0.5 * grid_size),
                #     ),
                # )

        self.fire_state = (self.fire_state + 1) % (3 * self.fire_time_steps)

    def draw_orders(
        self,
        screen: pygame.surface,
        state: dict,
        grid_size: int,
        width: int,
        height: int,
        config: dict,
    ):
        """Visualization of the current orders.

        Args:
            screen: pygame surface to draw the orders on, probably not the game screen itself.
            state: The game state returned by the environment.
            grid_size: Scaling of the drawn orders, given in pixels.
            width: Width of the pygame window
            height: Height of the pygame window.
            config: Visualization configuration (loaded from yaml file) given as a dict.

        """
        orders_height = height

        bg_color = colors[config["GameWindow"]["background_color"]]
        pygame.draw.rect(screen, bg_color, screen.get_rect())

        order_rects_start = (orders_height // 2) - (grid_size // 2)
        for idx, order in enumerate(state["orders"]):
            order_upper_left = [
                order_rects_start + idx * grid_size * 1.2,
                order_rects_start,
            ]
            pygame.draw.rect(
                screen,
                colors["red"],
                pygame.Rect(
                    order_upper_left[0],
                    order_upper_left[1],
                    grid_size,
                    grid_size,
                ),
                width=2,
            )
            center = np.array(order_upper_left)
            self.draw_thing(
                pos=center + (grid_size / 2),
                parts=config["Plate"]["parts"],
                screen=screen,
                scale=grid_size,
                absolute=True
            )
            self.draw_item(
                pos=center + (grid_size / 2),
                item={"type": order["meal"]},
                plate=True,
                scale=grid_size,
                screen=screen,
                absolute=True
            )
            order_done_seconds = (
                (
                    datetime.fromisoformat(order["start_time"])
                    + timedelta(seconds=order["max_duration"])
                )
                - datetime.fromisoformat(state["env_time"])
            ).total_seconds()

            percentage = order_done_seconds / order["max_duration"]
            self.draw_progress_bar(
                pos=center + (grid_size / 2),
                percent=percentage,
                screen=screen,
                attention=percentage < 0.25,
                size=grid_size,
                absolute=True
            )

            self.font.set_bold(True)

            text_surface = self.font.render(str(order["score"]), True, (0, 0, 0))
            screen.blit(text_surface, center)

    def get_state_image(self, state: dict,
                        controlled_players: list[int] = None,
                        grid_size: int | None = None,
                        env_id_ref=None,
                        ) -> npt.NDArray[np.uint8]:
        if grid_size is None:
            return pygame.surfarray.pixels3d(
                self.draw_gamescreen(state,
                                     [0] if controlled_players is None else controlled_players,
                                     env_id_ref=env_id_ref
                                     )).transpose((1, 0, 2))
        with self.grid_size_lock:
            pre_grid_size = self.grid_size
            try:
                self.set_grid_size(grid_size)
                screen = self.draw_gamescreen(state, [0] if controlled_players is None else controlled_players,
                                              env_id_ref=env_id_ref)
            finally:
                self.set_grid_size(pre_grid_size)
            return pygame.surfarray.pixels3d(screen).transpose((1, 0, 2))

    def get_state_image_by_size(self, state: dict,
                                max_size: int,
                                controlled_players: list[int] = None,
                                env_id_ref=None):
        grid_size = max_size / max(state["kitchen"]["width"], state["kitchen"]["height"])
        image = self.get_state_image(state, controlled_players, grid_size, env_id_ref=env_id_ref)
        if state["kitchen"]["width"] == state["kitchen"]["height"]:
            return image
        squared = np.zeros((max_size, max_size, 3), dtype=np.uint8)
        squared[:image.shape[0], :image.shape[1], :] = image
        return squared

    def draw_recipe_image(
        self, screen: pygame.Surface, graph_dict, width, height, node_size
    ) -> None:
        # screen.fill(self.config["GameWindow"]["background_color"])

        positions_dict = graph_dict["layout"]

        positions = np.array(list(positions_dict.values()))
        unique_x_vals = np.unique(positions[:, 0])
        new_positions_unique = np.linspace(
            start=0,
            stop=np.max(positions[:, 0]),
            num=len(unique_x_vals),
        )
        replace_map = {
            unique_x_vals[i]: new_positions_unique[i] for i in range(len(unique_x_vals))
        }
        for k, v in graph_dict["layout"].items():
            graph_dict["layout"][k] = (replace_map[v[0]], v[1])

        positions = np.array(list(positions_dict.values()))
        positions = positions - positions.min(axis=0)
        positions[positions == 0] = 0.000001
        positions = (
            positions / positions.max(axis=0) * (np.array([width, height]) - node_size)
        )
        positions += node_size / 2

        positions_dict = {
            name: pos for name, pos in zip(positions_dict.keys(), positions)
        }

        for start, end in graph_dict["edges"]:
            pygame.draw.line(
                screen,
                "black",
                positions_dict[start],
                positions_dict[end],
                width=5,
            )
        for name, pos in positions_dict.items():
            key = name.split("_")[0]
            if key in [
                "Chips",
                "FriedFish",
                "Burger",
                "Salad",
                "TomatoSoup",
                "OnionSoup",
                "FishAndChips",
                "Pizza",
            ]:
                self.draw_thing(
                    screen,
                    np.array(pos),
                    self.config["Plate"]["parts"],
                    absolute_size=node_size,
                    absolute=True,
                )
            if "Soup" in key:
                self.draw_thing(
                    screen,
                    np.array(pos),
                    self.config[key + "Plate"]["parts"],
                    absolute=True,
                )
            else:
                viz = self.config[key]["parts"]
                self.draw_thing(screen, np.array(pos), viz, absolute_size=node_size, absolute=True)


def save_screenshot(state: dict, config: dict, filename: str | Path) -> None:
    """Standalone function to save a screenshot. Creates a visualizer from the config and visualizes
    the game state, saves it to the given filename.

    Args:
        state: The gamestate to visualize.
        config: Visualization config for the visualizer.
        filename: Filename to save the image to.

    """
    vis = Visualizer(config)
    vis.create_player_colors(len(state["players"]))
    vis.set_grid_size(40)

    pygame.init()
    pygame.font.init()

    screen = vis.draw_gamescreen(state, [0])
    pygame.image.save(screen, filename)


def generate_recipe_images(config: dict, folder_path: str | Path):
    os.makedirs(ROOT_DIR / folder_path, exist_ok=True)
    env = Environment(
        env_config=str(ROOT_DIR / "configs" / "environment_config.yaml"),
        layout_config=str(ROOT_DIR / "configs" / "layouts" / "basic.layout"),
        item_info=str(ROOT_DIR / "configs" / "item_info.yaml"),
        as_files=True,
        env_name="0",
    )

    viz = Visualizer(config)
    pygame.init()
    pygame.font.init()

    graph_dicts = env.recipe_validation.get_recipe_graphs()
    width, height, node_size = 700, 400, 80
    flags = pygame.HIDDEN
    for graph_dict in graph_dicts:
        screen = pygame.display.set_mode((width, height), flags=flags)
        viz.draw_recipe_image(screen, graph_dict, width, height, node_size)
        pygame.image.save(screen, f"{folder_path}/{graph_dict['meal']}.png")


def main(cli_args):
    """Runs the Cooperative Cuisine Image Generation process.

    This method takes command line arguments to specify the state file, visualization configuration file, and output
    file for the generated image. It then reads the visualization configuration file and state file, and calls the
    'save_screenshot' and 'generate_recipe_images' methods to generate the image.

    Args:
        -s, --state: A command line argument of type `argparse.FileType("r", encoding="UTF-8")`. Specifies the state file to use for image generation. If not provided, the default value is 'ROOT_DIR / "pygame_2d_vis" / "sample_state.json"'.

        -v, --visualization_config: A command line argument of type `argparse.FileType("r", encoding="UTF-8")`. Specifies the visualization configuration file to use for image generation. If not provided, the default value is 'ROOT_DIR / "pygame_2d_vis" / "visualization.yaml"'.

        -o, --output_file: A command line argument of type `str`. Specifies the output file path for the generated image. If not provided, the default value is 'ROOT_DIR / "generated" / "screenshot.jpg"'.
    """

    with open(cli_args.visualization_config, "r") as f:
        viz_config = yaml.safe_load(f)
    with open(cli_args.state, "r") as f:
        state = json.load(f)
    save_screenshot(state, viz_config, cli_args.output_file)
    generate_recipe_images(viz_config, cli_args.output_file.parent)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="Cooperative Cuisine Image Generation",
        description="Generate images for a state in json.",
        epilog="For further information, see https://scs.pages.ub.uni-bielefeld.de/cocosy/cooperative-cuisine/overcooked_simulator.html",
    )
    create_screenshot_parser(parser)
    args = parser.parse_args()
    main(args)