import argparse
import dataclasses
import json
import logging
import random
import sys
from enum import Enum
from subprocess import Popen

import numpy as np
import pygame
import pygame_gui
import requests
import yaml
from pygame._sdl2 import get_drivers
from websockets.sync.client import connect

from overcooked_simulator import ROOT_DIR
from overcooked_simulator.game_server import CreateEnvironmentConfig
from overcooked_simulator.gui_2d_vis.drawing import Visualizer
from overcooked_simulator.gui_2d_vis.game_colors import colors
from overcooked_simulator.overcooked_environment import (
    Action,
    ActionType,
    InterActionData,
)
from overcooked_simulator.utils import (
    custom_asdict_factory,
    url_and_port_arguments,
    disable_websocket_logging_arguments,
    add_list_of_manager_ids_arguments,
)

for driver in get_drivers():
    print(driver)


class MenuStates(Enum):
    Start = "Start"
    Game = "Game"
    End = "End"


log = logging.getLogger(__name__)


class PlayerKeySet:
    """Set of keyboard keys for controlling a player.
    First four keys are for movement. Order: Down, Up, Left, Right.    5th key is for interacting with counters.    6th key ist for picking up things or dropping them.
    """

    def __init__(
        self,
        move_keys: list[pygame.key],
        interact_key: pygame.key,
        pickup_key: pygame.key,
        switch_key: pygame.key,
        players: list[int],
    ):
        """Creates a player key set which contains information about which keyboard keys control the player.

        Movement keys in the following order: Down, Up, Left, Right

        Args:
            player_name: The name of the player to control.
            keys: The keys which control this player in the following order: Down, Up, Left, Right, Interact, Pickup.
        """
        self.move_vectors: list[list[int]] = [[-1, 0], [1, 0], [0, -1], [0, 1]]
        self.key_to_movement: dict[pygame.key, list[int]] = {
            key: vec for (key, vec) in zip(move_keys, self.move_vectors)
        }
        self.move_keys: list[pygame.key] = move_keys
        self.interact_key: pygame.key = interact_key
        self.pickup_key: pygame.key = pickup_key
        self.switch_key: pygame.key = switch_key
        self.controlled_players: list[int] = players
        self.current_player: int = players[0] if players else 0
        self.current_idx = 0
        self.other_keyset: list[PlayerKeySet] = []

    def set_controlled_players(self, controlled_players: list[int]) -> None:
        self.controlled_players = controlled_players
        self.current_player = self.controlled_players[0]
        self.current_idx = 0

    def next_player(self) -> None:
        self.current_idx = (self.current_idx + 1) % len(self.controlled_players)
        if self.other_keyset:
            for ok in self.other_keyset:
                if ok.current_idx == self.current_idx:
                    self.next_player()
                    return
        self.current_player = self.controlled_players[self.current_idx]


class PyGameGUI:
    """Visualisation of the overcooked environment and reading keyboard inputs using pygame."""

    def __init__(
        self,
        url: str,
        port: int,
        manager_ids: list[str],
    ):
        self.game_screen: pygame.Surface = None
        self.FPS = 60
        self.running = True

        self.reset_gui_values()
        self.key_sets: list[PlayerKeySet] = []

        self.websocket_url = f"ws://{url}:{port}/ws/player/"
        self.websockets = {}

        self.request_url = f"http://{url}:{port}"
        self.manager_id = random.choice(manager_ids)

        with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file:
            self.visualization_config = yaml.safe_load(file)

        self.screen_margin = self.visualization_config["GameWindow"]["screen_margin"]
        self.min_width = self.visualization_config["GameWindow"]["min_width"]
        self.min_height = self.visualization_config["GameWindow"]["min_height"]

        self.buttons_width = self.visualization_config["GameWindow"]["buttons_width"]
        self.buttons_height = self.visualization_config["GameWindow"]["buttons_height"]

        self.order_bar_height = self.visualization_config["GameWindow"][
            "order_bar_height"
        ]

        self.window_width = self.min_width
        self.window_height = self.min_height

        self.main_window = pygame.display.set_mode(
            (self.window_width, self.window_height)
        )

        # self.game_width, self.game_height = 0, 0

        self.images_path = ROOT_DIR / "pygame_gui" / "images"

        self.menu_state = MenuStates.Start
        self.manager: pygame_gui.UIManager

        self.vis = Visualizer(self.visualization_config)

        self.sub_processes = []

    def get_window_sizes(self, state: dict):
        kitchen_width = state["kitchen"]["width"]
        kitchen_height = state["kitchen"]["height"]
        if self.visualization_config["GameWindow"]["WhatIsFixed"] == "window_width":
            game_width = self.visualization_config["GameWindow"]["size"]
            kitchen_aspect_ratio = kitchen_height / kitchen_width
            game_height = int(game_width * kitchen_aspect_ratio)
            grid_size = int(game_width / (kitchen_width - 0.1))

        elif self.visualization_config["GameWindow"]["WhatIsFixed"] == "window_height":
            game_height = self.visualization_config["GameWindow"]["size"]
            kitchen_aspect_ratio = kitchen_width / kitchen_height
            game_width = int(game_height * kitchen_aspect_ratio)
            grid_size = int(game_width / (kitchen_width - 0.1))

        elif self.visualization_config["GameWindow"]["WhatIsFixed"] == "grid":
            grid_size = self.visualization_config["GameWindow"]["size"]
            game_width, game_height = (
                kitchen_width * grid_size,
                kitchen_height * grid_size,
            )

        else:
            game_width, game_height = 0, 0
            grid_size = 0

        window_width, window_height = (
            game_width + (2 * self.screen_margin),
            game_height + (2 * self.screen_margin),  # bar with orders
        )

        window_width = max(window_width, self.min_width)
        window_height = max(window_height, self.min_height)
        return (
            int(window_width),
            int(window_height),
            int(game_width),
            int(game_height),
            grid_size,
        )

    def setup_player_keys(self, n=1, disjunct=False):
        if n:
            players = list(range(self.number_humans_to_be_added))
            key_set1 = PlayerKeySet(
                move_keys=[pygame.K_a, pygame.K_d, pygame.K_w, pygame.K_s],
                interact_key=pygame.K_f,
                pickup_key=pygame.K_e,
                switch_key=pygame.K_SPACE,
                players=players,
            )
            key_set2 = PlayerKeySet(
                move_keys=[pygame.K_LEFT, pygame.K_RIGHT, pygame.K_UP, pygame.K_DOWN],
                interact_key=pygame.K_i,
                pickup_key=pygame.K_o,
                switch_key=pygame.K_p,
                players=players,
            )
            key_sets = [key_set1, key_set2]

            if disjunct:
                key_set1.set_controlled_players(players[::2])
                key_set2.set_controlled_players(players[1::2])
            elif n > 1:
                key_set1.set_controlled_players(players)
                key_set2.set_controlled_players(players)
                key_set1.other_keyset = [key_set2]
                key_set2.other_keyset = [key_set1]
                key_set2.next_player()
            return key_sets[:n]
        else:
            return []

    def handle_keys(self):
        """Handles keyboard inputs. Sends action for the respective players. When a key is held down, every frame
        an action is sent in this function.
        """

        keys = pygame.key.get_pressed()
        for key_set in self.key_sets:
            current_player_name = str(key_set.current_player)
            relevant_keys = [keys[k] for k in key_set.move_keys]
            if any(relevant_keys):
                move_vec = np.zeros(2)
                for idx, pressed in enumerate(relevant_keys):
                    if pressed:
                        move_vec += key_set.move_vectors[idx]
                if np.linalg.norm(move_vec) != 0:
                    move_vec = move_vec / np.linalg.norm(move_vec)

                action = Action(
                    current_player_name,
                    ActionType.MOVEMENT,
                    move_vec,
                    duration=1 / self.FPS,
                )
                self.send_action(action)

    def handle_key_event(self, event):
        """Handles key events for the pickup and interaction keys. Pickup is a single action,
        for interaction keydown and keyup is necessary, because the player has to be able to hold
        the key down.

        Args:
            event: Pygame event for extracting the key action.
        """

        for key_set in self.key_sets:
            current_player_name = str(key_set.current_player)
            if event.key == key_set.pickup_key and event.type == pygame.KEYDOWN:
                action = Action(current_player_name, ActionType.PUT, "pickup")
                self.send_action(action)

            if event.key == key_set.interact_key:
                if event.type == pygame.KEYDOWN:
                    action = Action(
                        current_player_name, ActionType.INTERACT, InterActionData.START
                    )
                    self.send_action(action)
                elif event.type == pygame.KEYUP:
                    action = Action(
                        current_player_name, ActionType.INTERACT, InterActionData.STOP
                    )
                    self.send_action(action)
            if event.key == key_set.switch_key:
                if event.type == pygame.KEYDOWN:
                    key_set.next_player()

    def init_ui_elements(self):
        self.manager = pygame_gui.UIManager((self.window_width, self.window_height))
        self.manager.get_theme().load_theme(ROOT_DIR / "gui_2d_vis" / "gui_theme.json")

        self.start_button = pygame_gui.elements.UIButton(
            relative_rect=pygame.Rect(
                (0, 0), (self.buttons_width, self.buttons_height)
            ),
            text="Start Game",
            manager=self.manager,
            anchors={"center": "center"},
        )
        self.start_button.can_hover()

        quit_rect = pygame.Rect(
            (
                0,
                0,
            ),
            (self.buttons_width, self.buttons_height),
        )
        quit_rect.topright = (0, 0)
        self.quit_button = pygame_gui.elements.UIButton(
            relative_rect=quit_rect,
            text="Quit Game",
            manager=self.manager,
            object_id="#quit_button",
            anchors={"right": "right", "top": "top"},
        )
        self.quit_button.can_hover()

        self.reset_button = pygame_gui.elements.UIButton(
            relative_rect=pygame.Rect(
                (
                    self.window_width - (self.screen_margin * 3 // 4),
                    self.screen_margin,
                ),
                (self.screen_margin - (self.screen_margin // 4), 50),
            ),
            text="RESET",
            manager=self.manager,
            object_id="#reset_button",
        )
        self.reset_button.can_hover()

        self.finished_button = pygame_gui.elements.UIButton(
            relative_rect=pygame.Rect(
                (
                    (self.window_width - self.buttons_width),
                    (self.window_height - self.buttons_height),
                ),
                (self.buttons_width, self.buttons_height),
            ),
            text="Finish round",
            manager=self.manager,
        )
        self.finished_button.can_hover()

        self.back_button = pygame_gui.elements.UIButton(
            relative_rect=pygame.Rect(
                (
                    (0),
                    (self.window_height - self.buttons_height),
                ),
                (self.buttons_width, self.buttons_height),
            ),
            text="Back to menu",
            manager=self.manager,
        )
        self.back_button.can_hover()

        self.score_label = pygame_gui.elements.UILabel(
            text=f"Score: _",
            relative_rect=pygame.Rect(
                (
                    (0),
                    self.window_height - self.screen_margin,
                ),
                (self.screen_margin * 2, self.screen_margin),
            ),
            manager=self.manager,
            object_id="#score_label",
        )

        self.layout_file_paths = {
            str(p.name): p
            for p in (ROOT_DIR / "game_content" / "layouts").glob("*.layout")
        }
        assert len(self.layout_file_paths) != 0, "No layout files."
        dropdown_width, dropdown_height = 200, 40
        self.layout_selection = pygame_gui.elements.UIDropDownMenu(
            relative_rect=pygame.Rect(
                (
                    0,
                    0,
                ),
                (dropdown_width, dropdown_height),
            ),
            manager=self.manager,
            options_list=list(self.layout_file_paths.keys()),
            starting_option="basic.layout"
            if "basic.layout" in self.layout_file_paths
            else random.choice(list(self.layout_file_paths.keys())),
        )
        self.timer_label = pygame_gui.elements.UILabel(
            text="GAMETIME",
            relative_rect=pygame.Rect(
                (self.screen_margin, self.window_height - self.screen_margin),
                (self.game_width, self.screen_margin),
            ),
            manager=self.manager,
            object_id="#timer_label",
        )

        self.orders_label = pygame_gui.elements.UILabel(
            text="Orders:",
            relative_rect=pygame.Rect(0, 0, self.screen_margin, self.screen_margin),
            manager=self.manager,
            object_id="#orders_label",
        )

        self.conclusion_label = pygame_gui.elements.UILabel(
            text="Your final score was _",
            relative_rect=pygame.Rect(0, 0, self.window_width, self.window_height),
            manager=self.manager,
            object_id="#score_label",
        )

        #######################

        player_selection_rect = pygame.Rect(
            (0, 0),
            (
                self.window_width * 0.9,
                (self.window_height // 3),
            ),
        )
        player_selection_rect.bottom = -10
        self.player_selection_container = pygame_gui.elements.UIPanel(
            player_selection_rect,
            manager=self.manager,
            object_id="#players",
            anchors={"bottom": "bottom", "centerx": "centerx"},
        )

        multiple_keysets_button_rect = pygame.Rect((0, 0), (190, 50))
        self.multiple_keysets_button = pygame_gui.elements.UIButton(
            relative_rect=multiple_keysets_button_rect,
            manager=self.manager,
            container=self.player_selection_container,
            text="not set",
            anchors={"left": "left", "centery": "centery"},
            object_id="#multiple_keysets_button",
        )

        split_players_button_rect = pygame.Rect((0, 0), (190, 50))
        self.split_players_button = pygame_gui.elements.UIButton(
            relative_rect=split_players_button_rect,
            manager=self.manager,
            container=self.player_selection_container,
            text="not set",
            anchors={"centerx": "centerx", "centery": "centery"},
            object_id="#split_players_button",
        )
        if self.multiple_keysets:
            self.split_players_button.show()
        else:
            self.split_players_button.hide()

        xbox_controller_button_rect = pygame.Rect((0, 0), (190, 50))
        xbox_controller_button_rect.right = 0
        self.xbox_controller_button = pygame_gui.elements.UIButton(
            relative_rect=xbox_controller_button_rect,
            manager=self.manager,
            container=self.player_selection_container,
            text="Controller?",
            anchors={"right": "right", "centery": "centery"},
            object_id="#controller_button",
        )

        ########
        #
        # panel = pygame_gui.elements.UIPanel(
        #     pygame.Rect((50, 50), (700, 500)),
        #     manager=manager,
        #     anchors={
        #         "left": "left",
        #         "right": "right",
        #         "top": "top",
        #         "bottom": "bottom",
        #     },
        # )

        players_container_rect = pygame.Rect(
            (0, 0),
            (
                self.window_width * 0.6,
                self.player_selection_container.get_abs_rect().height // 3,
            ),
        )
        self.player_number_container = pygame_gui.elements.UIPanel(
            relative_rect=players_container_rect,
            manager=self.manager,
            object_id="#players_players",
            container=self.player_selection_container,
            anchors={"top": "top", "centerx": "centerx"},
        )

        bot_container_rect = pygame.Rect(
            (0, 0),
            (
                self.window_width * 0.6,
                self.player_selection_container.get_abs_rect().height // 3,
            ),
        )
        bot_container_rect.bottom = 0
        self.bot_number_container = pygame_gui.elements.UIPanel(
            relative_rect=bot_container_rect,
            manager=self.manager,
            object_id="#players_bots",
            container=self.player_selection_container,
            anchors={"bottom": "bottom", "centerx": "centerx"},
        )

        number_players_rect = pygame.Rect((0, 0), (200, 200))
        self.added_players_label = pygame_gui.elements.UILabel(
            number_players_rect,
            manager=self.manager,
            object_id="#number_players_label",
            container=self.player_number_container,
            text=f"Humans to be added: {self.number_humans_to_be_added}",
            anchors={"center": "center"},
        )

        number_bots_rect = pygame.Rect((0, 0), (200, 200))
        self.added_bots_label = pygame_gui.elements.UILabel(
            number_bots_rect,
            manager=self.manager,
            object_id="#number_bots_label",
            container=self.bot_number_container,
            text=f"Bots to be added: {self.number_bots_to_be_added}",
            anchors={"center": "center"},
        )

        size = 50
        add_player_button_rect = pygame.Rect((0, 0), (size, size))
        add_player_button_rect.right = 0
        self.add_human_player_button = pygame_gui.elements.UIButton(
            relative_rect=add_player_button_rect,
            text="+",
            manager=self.manager,
            object_id="#quantity_button",
            container=self.player_number_container,
            anchors={"right": "right", "centery": "centery"},
        )
        self.add_human_player_button.can_hover()

        remove_player_button_rect = pygame.Rect((0, 0), (size, size))
        remove_player_button_rect.left = 0
        self.remove_human_button = pygame_gui.elements.UIButton(
            relative_rect=remove_player_button_rect,
            text="-",
            manager=self.manager,
            object_id="#quantity_button",
            container=self.player_number_container,
            anchors={"left": "left", "centery": "centery"},
        )
        self.remove_human_button.can_hover()

        add_bot_button_rect = pygame.Rect((0, 0), (size, size))
        add_bot_button_rect.right = 0
        self.add_bot_button = pygame_gui.elements.UIButton(
            relative_rect=add_bot_button_rect,
            text="+",
            manager=self.manager,
            object_id="#quantity_button",
            container=self.bot_number_container,
            anchors={"right": "right", "centery": "centery"},
        )
        self.add_bot_button.can_hover()

        remove_bot_button_rect = pygame.Rect((0, 0), (size, size))
        remove_bot_button_rect.left = 0
        self.remove_bot_button = pygame_gui.elements.UIButton(
            relative_rect=remove_bot_button_rect,
            text="-",
            manager=self.manager,
            object_id="#quantity_button",
            container=self.bot_number_container,
            anchors={"left": "left", "centery": "centery"},
        )
        self.remove_bot_button.can_hover()

    def draw(self, state):
        """Main visualization function.

        Args:            state: The game state returned by the environment."""
        self.vis.draw_gamescreen(
            self.game_screen,
            state,
            self.grid_size,
            [k.current_player for k in self.key_sets],
        )

        # self.manager.draw_ui(self.main_window)
        self.update_remaining_time(state["remaining_time"])

        self.vis.draw_orders(
            screen=self.main_window,
            state=state,
            grid_size=self.grid_size,
            width=self.game_width,
            height=self.game_height,
            screen_margin=self.screen_margin,
            config=self.visualization_config,
        )

        border = self.visualization_config["GameWindow"]["game_border_size"]
        border_rect = pygame.Rect(
            self.window_width // 2 - (self.game_width // 2) - border,
            self.window_height // 2 - (self.game_height // 2) - border,
            self.game_width + 2 * border,
            self.game_height + 2 * border,
        )
        pygame.draw.rect(
            self.main_window,
            colors[self.visualization_config["GameWindow"]["game_border_color"]],
            border_rect,
            width=border,
        )

        self.update_score_label(state)

    def set_window_size(self):
        self.game_screen = pygame.Surface(
            (
                self.game_width,
                self.game_height,
            ),
        )
        self.main_window = pygame.display.set_mode(
            (
                self.window_width,
                self.window_height,
            )
        )

    def reset_window_size(self):
        self.window_width = self.min_width
        self.window_height = self.min_height
        self.game_width = 0
        self.game_height = 0
        self.set_window_size()
        self.init_ui_elements()

    def manage_button_visibility(self):
        match self.menu_state:
            case MenuStates.Start:
                self.back_button.hide()
                self.quit_button.show()
                self.start_button.show()
                self.reset_button.hide()
                self.score_label.hide()
                self.finished_button.hide()
                self.layout_selection.show()
                self.timer_label.hide()
                self.orders_label.hide()
                self.conclusion_label.hide()

                self.player_selection_container.show()
            case MenuStates.Game:
                self.start_button.hide()
                self.back_button.hide()
                self.score_label.show()
                self.reset_button.show()
                self.score_label.show()
                self.finished_button.show()
                self.layout_selection.hide()
                self.timer_label.show()
                self.orders_label.show()
                self.conclusion_label.hide()

                self.player_selection_container.hide()

            case MenuStates.End:
                self.start_button.hide()
                self.back_button.show()
                self.score_label.hide()
                self.reset_button.hide()
                self.finished_button.hide()
                self.layout_selection.hide()
                self.timer_label.hide()
                self.orders_label.hide()
                self.conclusion_label.show()

                self.player_selection_container.hide()

    def update_score_label(self, state):
        score = state["score"]
        self.score_label.set_text(f"Score {score}")

    def update_conclusion_label(self, state):
        score = state["score"]
        self.conclusion_label.set_text(f"Your final score is {score}. Hurray!")

    def update_remaining_time(self, remaining_time: float):
        hours, rem = divmod(int(remaining_time), 3600)
        minutes, seconds = divmod(rem, 60)
        display_time = f"{minutes}:{'%02d' % seconds}"
        self.timer_label.set_text(f"Time remaining: {display_time}")

    def setup_environment(self):
        environment_config_path = ROOT_DIR / "game_content" / "environment_config.yaml"
        layout_path = self.layout_file_paths[self.layout_selection.selected_option]
        item_info_path = ROOT_DIR / "game_content" / "item_info.yaml"
        with open(item_info_path, "r") as file:
            item_info = file.read()
        with open(layout_path, "r") as file:
            layout = file.read()
        with open(environment_config_path, "r") as file:
            environment_config = file.read()

        seed = 161616161616
        creation_json = CreateEnvironmentConfig(
            manager_id=self.manager_id,
            number_players=self.number_players,
            environment_settings={"all_player_can_pause_game": False},
            item_info_config=item_info,
            environment_config=environment_config,
            layout_config=layout,
            seed=seed,
        ).model_dump(mode="json")

        # print(CreateEnvironmentConfig.model_validate_json(json_data=creation_json))
        env_info = requests.post(
            f"{self.request_url}/manage/create_env/",
            json=creation_json,
        )
        if env_info.status_code == 403:
            raise ValueError(f"Forbidden Request: {env_info.json()['detail']}")
        env_info = env_info.json()
        assert isinstance(env_info, dict), "Env info must be a dictionary"
        self.current_env_id = env_info["env_id"]
        self.player_info = env_info["player_info"]

        state = None
        for p, (player_id, player_info) in enumerate(env_info["player_info"].items()):
            if p < self.number_humans_to_be_added:
                websocket = connect(self.websocket_url + player_info["client_id"])
                websocket.send(
                    json.dumps(
                        {"type": "ready", "player_hash": player_info["player_hash"]}
                    )
                )
                assert (
                    json.loads(websocket.recv())["status"] == 200
                ), "not accepted player"
                self.websockets[player_id] = websocket
            else:
                player_hash = player_info["player_hash"]
                print(
                    f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"'
                )
                # sub = Popen(
                #     " ".join(
                #         [
                #             "exec",
                #             "aaambos",
                #             "run",
                #             "--arch_config",
                #             str(
                #                 ROOT_DIR / "game_content" / "agents" / "arch_config.yml"
                #             ),
                #             "--run_config",
                #             str(
                #                 ROOT_DIR / "game_content" / "agents" / "run_config.yml"
                #             ),
                #             f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"',
                #             f"--instance={player_hash}",
                #         ]
                #     ),
                #     shell=True,
                # )
                sub = Popen(
                    " ".join(
                        [
                            "python",
                            str(
                                ROOT_DIR / "game_content" / "agents" / "random_agent.py"
                            ),
                            f'--uri {self.websocket_url + player_info["client_id"]}',
                            f"--player_hash {player_hash}",
                            f"--player_id {player_id}",
                        ]
                    ),
                    shell=True,
                )
                self.sub_processes.append(sub)

            if p + 1 == self.number_humans_to_be_added:
                self.state_player_id = player_id
                websocket.send(
                    json.dumps(
                        {"type": "get_state", "player_hash": player_info["player_hash"]}
                    )
                )
                state = json.loads(websocket.recv())

        if not self.number_humans_to_be_added:
            player_id = "0"
            player_info = env_info["player_info"][player_id]
            websocket = connect(self.websocket_url + player_info["client_id"])
            websocket.send(
                json.dumps({"type": "ready", "player_hash": player_info["player_hash"]})
            )
            assert json.loads(websocket.recv())["status"] == 200, "not accepted player"
            self.websockets[player_id] = websocket
            self.state_player_id = player_id
            websocket.send(
                json.dumps(
                    {"type": "get_state", "player_hash": player_info["player_hash"]}
                )
            )
            state = json.loads(websocket.recv())

        (
            self.window_width,
            self.window_height,
            self.game_width,
            self.game_height,
            self.grid_size,
        ) = self.get_window_sizes(state)

    def start_button_press(self):
        self.menu_state = MenuStates.Game

        self.number_players = (
            self.number_humans_to_be_added + self.number_bots_to_be_added
        )
        self.vis.create_player_colors(self.number_players)

        if self.split_players:
            assert (
                self.number_humans_to_be_added > 1
            ), "Not enough players for key configuration."
        num_key_set = 2 if self.multiple_keysets else 1
        self.key_sets = self.setup_player_keys(
            min(self.number_humans_to_be_added, num_key_set), self.split_players
        )

        self.setup_environment()

        self.set_window_size()

        self.init_ui_elements()
        log.debug("Pressed start button")

        # self.api.set_sim(self.simulator)

    def back_button_press(self):
        self.menu_state = MenuStates.Start
        self.reset_window_size()

        self.update_selection_elements()

        log.debug("Pressed back button")

    def quit_button_press(self):
        self.running = False
        self.menu_state = MenuStates.Start
        log.debug("Pressed quit button")

    def reset_button_press(self):
        # self.reset_gui_values()

        requests.post(
            f"{self.request_url}/manage/stop_env",
            json={
                "manager_id": self.manager_id,
                "env_id": self.current_env_id,
                "reason": "reset button pressed",
            },
        )

        # self.websocket.send(json.dumps("reset_game"))
        # answer = self.websocket.recv()        log.debug("Pressed reset button")

    def finished_button_press(self):
        requests.post(
            f"{self.request_url}/manage/stop_env",
            json={
                "manager_id": self.manager_id,
                "env_id": self.current_env_id,
                "reason": "finish button pressed",
            },
        )
        self.menu_state = MenuStates.End
        self.reset_window_size()
        log.debug("Pressed finished button")

    def reset_gui_values(self):
        self.currently_controlled_player_idx = 0
        self.number_humans_to_be_added = 1
        self.number_bots_to_be_added = 0
        self.split_players = False
        self.multiple_keysets = False
        self.player_minimum = 1

    def update_selection_elements(self):
        if self.number_humans_to_be_added <= self.player_minimum:
            self.remove_human_button.disable()
            self.number_humans_to_be_added = self.player_minimum
        else:
            self.remove_human_button.enable()
        self.number_humans_to_be_added = max(
            self.player_minimum, self.number_humans_to_be_added
        )

        text = "WASD+ARROW" if self.multiple_keysets else "WASD"
        self.multiple_keysets_button.set_text(text)
        # self.split_players_button
        self.added_players_label.set_text(
            f"Humans to be added: {self.number_humans_to_be_added}"
        )
        self.added_bots_label.set_text(
            f"Bots to be added: {self.number_bots_to_be_added}"
        )
        text = "Yes" if self.split_players else "No"
        self.split_players_button.set_text(f"Split players: {text}")

        if self.multiple_keysets:
            self.split_players_button.show()
        else:
            self.split_players_button.hide()

    def send_action(self, action: Action):
        """Sends an action to the game environment.

        Args:
            action: The action to be sent. Contains the player, action type and move direction if action is a movement.
        """
        if isinstance(action.action_data, np.ndarray):
            action.action_data = [
                float(action.action_data[0]),
                float(action.action_data[1]),
            ]
        self.websockets[action.player].send(
            json.dumps(
                {
                    "type": "action",
                    "action": dataclasses.asdict(
                        action, dict_factory=custom_asdict_factory
                    ),
                    "player_hash": self.player_info[action.player]["player_hash"],
                }
            )
        )
        self.websockets[action.player].recv()

    def request_state(self):
        self.websockets[self.state_player_id].send(
            json.dumps(
                {
                    "type": "get_state",
                    "player_hash": self.player_info[str(self.key_sets[0].current_idx)][
                        "player_hash"
                    ],
                }
            )
        )
        state = json.loads(self.websockets[self.state_player_id].recv())
        return state

    def disconnect_websockets(self):
        for sub in self.sub_processes:
            try:
                sub.kill()
                # pgrp = os.getpgid(sub.pid)
                # os.killpg(pgrp, signal.SIGINT)
                # subprocess.run(
                #     "kill $(ps aux | grep 'aaambos' | awk '{print $2}')", shell=True
                # )
            except ProcessLookupError:
                pass

        self.sub_processes = []
        for websocket in self.websockets.values():
            websocket.close()

    def start_pygame(self):
        """Starts pygame and the gui loop. Each frame the game state is visualized and keyboard inputs are read."""
        log.debug(f"Starting pygame gui at {self.FPS} fps")
        pygame.init()
        pygame.font.init()

        pygame.display.set_caption("Simple Overcooked Simulator")

        clock = pygame.time.Clock()

        self.init_ui_elements()
        self.reset_window_size()
        self.manage_button_visibility()

        self.update_selection_elements()

        # Game loop
        self.running = True
        while self.running:
            try:
                time_delta = clock.tick(self.FPS) / 1000.0

                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        self.running = False

                        # UI Buttons:
                    if event.type == pygame_gui.UI_BUTTON_PRESSED:
                        match event.ui_element:
                            case self.start_button:
                                if not (
                                    self.number_humans_to_be_added
                                    + self.number_bots_to_be_added
                                ):
                                    continue
                                self.start_button_press()
                            case self.back_button:
                                self.back_button_press()
                                self.disconnect_websockets()

                            case self.finished_button:
                                self.finished_button_press()
                                self.disconnect_websockets()
                            case self.quit_button:
                                self.quit_button_press()
                                self.disconnect_websockets()
                            case self.reset_button:
                                self.reset_button_press()
                                self.disconnect_websockets()
                                self.start_button_press()

                            case self.add_human_player_button:
                                self.number_humans_to_be_added += 1
                            case self.remove_human_button:
                                self.number_humans_to_be_added = max(
                                    0, self.number_humans_to_be_added - 1
                                )
                            case self.add_bot_button:
                                self.number_bots_to_be_added += 1
                            case self.remove_bot_button:
                                self.number_bots_to_be_added = max(
                                    0, self.number_bots_to_be_added - 1
                                )
                            case self.multiple_keysets_button:
                                self.multiple_keysets = not self.multiple_keysets
                                self.split_players = False
                            case self.split_players_button:
                                self.split_players = not self.split_players
                                if self.split_players:
                                    self.player_minimum = 2
                                else:
                                    self.player_minimum = 1

                            case self.xbox_controller_button:
                                print("xbox_controller_button pressed.")

                        self.update_selection_elements()

                        self.manage_button_visibility()

                    if (
                        event.type in [pygame.KEYDOWN, pygame.KEYUP]
                        and self.menu_state == MenuStates.Game
                    ):
                        pass
                        self.handle_key_event(event)

                    self.manager.process_events(event)

                    # drawing:
                self.main_window.fill(
                    colors[self.visualization_config["GameWindow"]["background_color"]]
                )
                self.manager.draw_ui(self.main_window)

                match self.menu_state:
                    case MenuStates.Start:
                        pass

                    case MenuStates.Game:
                        state = self.request_state()

                        self.handle_keys()

                        if state["ended"]:
                            self.finished_button_press()
                            self.disconnect_websockets()
                            self.manage_button_visibility()
                        else:
                            self.draw(state)

                            game_screen_rect = self.game_screen.get_rect()

                            game_screen_rect.center = [
                                self.window_width // 2,
                                self.window_height // 2,
                            ]

                            self.main_window.blit(self.game_screen, game_screen_rect)

                    case MenuStates.End:
                        self.update_conclusion_label(state)

                self.manager.update(time_delta)
                pygame.display.flip()

            except (KeyboardInterrupt, SystemExit):
                self.running = False

        self.disconnect_websockets()
        pygame.quit()
        sys.exit()


def main(url: str, port: int, manager_ids: list[str]):
    gui = PyGameGUI(
        url=url,
        port=port,
        manager_ids=manager_ids,
    )
    gui.start_pygame()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="Overcooked Simulator 2D PyGame Visualization",
        description="PyGameGUI: a PyGame 2D Visualization window.",
        epilog="For further information, see https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/overcooked_simulator.html",
    )

    url_and_port_arguments(parser)
    disable_websocket_logging_arguments(parser)
    add_list_of_manager_ids_arguments(parser)
    args = parser.parse_args()
    main(args.url, args.port, args.manager_ids, args.enable_websocket_logging)