Skip to content
Snippets Groups Projects
overcooked_gui.py 59.01 KiB
import argparse
import dataclasses
import json
import logging
import os
import random
import signal
import subprocess
import sys
import uuid
from enum import Enum
from subprocess import Popen

import numpy as np
import pygame
import pygame_gui
import requests
import yaml
from pygame import mixer
from websockets.sync.client import connect

from overcooked_simulator import ROOT_DIR
from overcooked_simulator.game_server import CreateEnvironmentConfig
from overcooked_simulator.gui_2d_vis.drawing import Visualizer
from overcooked_simulator.gui_2d_vis.game_colors import colors
from overcooked_simulator.overcooked_environment import (
    Action,
    ActionType,
    InterActionData,
)
from overcooked_simulator.utils import (
    custom_asdict_factory,
    url_and_port_arguments,
    disable_websocket_logging_arguments,
    add_list_of_manager_ids_arguments,
)


class MenuStates(Enum):
    Start = "Start"
    ControllerTutorial = "ControllerTutorial"
    PreGame = "PreGame"
    Game = "Game"
    PostGame = "PostGame"
    End = "End"


log = logging.getLogger(__name__)


class PlayerKeySet:
    """Set of keyboard keys for controlling a player.
    First four keys are for movement. Order: Down, Up, Left, Right.    5th key is for interacting with counters.    6th key ist for picking up things or dropping them.
    """

    def __init__(
        self,
        move_keys: list[pygame.key],
        interact_key: pygame.key,
        pickup_key: pygame.key,
        switch_key: pygame.key,
        players: list[int],
        joystick: int,
    ):
        """Creates a player key set which contains information about which keyboard keys control the player.

        Movement keys in the following order: Down, Up, Left, Right

        Args:
            move_keys: The keys which control this players movement in the following order: Down, Up, Left, Right.
            interact_key: The key to interact with objects in the game.
            pickup_key: The key to pick items up or put them down.
            switch_key: The key for switching through controllable players.
            players: The player indices which this keyset can control.
            joystick: number of joystick (later check if available)
        """
        self.move_vectors: list[list[int]] = [[-1, 0], [1, 0], [0, -1], [0, 1]]
        self.key_to_movement: dict[pygame.key, list[int]] = {
            key: vec for (key, vec) in zip(move_keys, self.move_vectors)
        }
        self.move_keys: list[pygame.key] = move_keys
        self.interact_key: pygame.key = interact_key
        self.pickup_key: pygame.key = pickup_key
        self.switch_key: pygame.key = switch_key
        self.controlled_players: list[int] = players
        self.current_player: int = players[0] if players else 0
        self.current_idx = 0
        self.other_keyset: list[PlayerKeySet] = []
        self.joystick = joystick

    def set_controlled_players(self, controlled_players: list[int]) -> None:
        self.controlled_players = controlled_players
        self.current_player = self.controlled_players[0]
        self.current_idx = 0

    def next_player(self) -> None:
        self.current_idx = (self.current_idx + 1) % len(self.controlled_players)
        if self.other_keyset:
            for ok in self.other_keyset:
                if ok.current_idx == self.current_idx:
                    self.next_player()
                    return
        self.current_player = self.controlled_players[self.current_idx]


class PyGameGUI:
    """Visualisation of the overcooked environment and reading keyboard inputs using pygame."""

    def __init__(
        self,
        url: str,
        port: int,
        manager_ids: list[str],
        CONNECT_WITH_STUDY_SERVER: bool,
        USE_AAAMBOS_AGENT: bool,
    ):
        self.CONNECT_WITH_STUDY_SERVER = CONNECT_WITH_STUDY_SERVER
        self.USE_AAAMBOS_AGENT = USE_AAAMBOS_AGENT

        pygame.init()
        pygame.display.set_icon(
            pygame.image.load(ROOT_DIR / "gui_2d_vis" / "images" / "fish3.png")
        )

        self.participant_id = uuid.uuid4().hex

        self.game_screen: pygame.Surface = None
        self.FPS = 60
        self.running = True

        self.key_sets: list[PlayerKeySet] = []

        self.websocket_url = f"ws://{url}:{port}/ws/player/"
        self.websockets = {}

        self.request_url = f"http://{url}:{port}"
        self.manager_id = random.choice(manager_ids)

        with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file:
            self.visualization_config = yaml.safe_load(file)

        self.screen_margin = self.visualization_config["GameWindow"]["screen_margin"]
        self.min_width = self.visualization_config["GameWindow"]["min_width"]
        self.min_height = self.visualization_config["GameWindow"]["min_height"]
        self.buttons_width = self.visualization_config["GameWindow"]["buttons_width"]
        self.buttons_height = self.visualization_config["GameWindow"]["buttons_height"]
        self.order_bar_height = self.visualization_config["GameWindow"][
            "order_bar_height"
        ]
        (
            self.window_width_fullscreen,
            self.window_height_fullscreen,
        ) = pygame.display.get_desktop_sizes()[0]
        self.window_width_windowed = self.min_width
        self.window_height_windowed = self.min_height
        self.kitchen_width = 1
        self.kitchen_height = 1
        self.kitchen_aspect_ratio = 1
        self.images_path = ROOT_DIR / "pygame_gui" / "images"
        self.vis = Visualizer(self.visualization_config)

        self.fullscreen = False

        self.menu_state = MenuStates.Start
        self.manager: pygame_gui.UIManager

        self.sub_processes = []

        self.layout_file_paths = sorted(
            (ROOT_DIR / "game_content" / "layouts").rglob("*.layout")
        )
        self.current_layout_idx = 0

        self.last_level = False

        self.beeped_once = False

    def setup_player_keys(self, number_players, number_key_sets=1, disjunct=False):
        # First four keys are for movement. Order: Down, Up, Left, Right.
        # 5th key is for interacting with counters.
        # 6th key ist for picking up things or dropping them.
        if number_key_sets:
            players = list(range(number_players))
            key_set1 = PlayerKeySet(
                move_keys=[pygame.K_a, pygame.K_d, pygame.K_w, pygame.K_s],
                interact_key=pygame.K_f,
                pickup_key=pygame.K_e,
                switch_key=pygame.K_SPACE,
                players=players,
                joystick=0,
            )
            key_set2 = PlayerKeySet(
                move_keys=[pygame.K_LEFT, pygame.K_RIGHT, pygame.K_UP, pygame.K_DOWN],
                interact_key=pygame.K_i,
                pickup_key=pygame.K_o,
                switch_key=pygame.K_p,
                players=players,
                joystick=1,
            )
            key_sets = [key_set1, key_set2]

            if disjunct:
                key_set1.set_controlled_players(players[::2])
                key_set2.set_controlled_players(players[1::2])
            elif number_key_sets > 1:
                key_set1.set_controlled_players(players)
                key_set2.set_controlled_players(players)
                key_set1.other_keyset = [key_set2]
                key_set2.other_keyset = [key_set1]
                key_set2.next_player()
            return key_sets[:number_key_sets]
        else:
            return []

    def handle_keys(self):
        """Handles keyboard inputs. Sends action for the respective players. When a key is held down, every frame
        an action is sent in this function.
        """

        keys = pygame.key.get_pressed()
        for key_set in self.key_sets:
            current_player_name = str(key_set.current_player)
            relevant_keys = [keys[k] for k in key_set.move_keys]
            if any(relevant_keys):
                move_vec = np.zeros(2)
                for idx, pressed in enumerate(relevant_keys):
                    if pressed:
                        move_vec += key_set.move_vectors[idx]
                if np.linalg.norm(move_vec) != 0:
                    move_vec = move_vec / np.linalg.norm(move_vec)

                action = Action(
                    current_player_name,
                    ActionType.MOVEMENT,
                    move_vec,
                    duration=self.time_delta,
                )
                self.send_action(action)

    def handle_joy_stick_input(self, joysticks):
        """Handles joystick inputs for movement every frame
        Args:
            joysticks: list of joysticks
        """
        # Axis 0: joy stick left: -1 = left, ~0 = center, 1 = right
        # Axis 1: joy stick left: -1 = up, ~0 = center, 1 = down
        # see control stuff here (at the end of the page): https://www.pygame.org/docs/ref/joystick.html
        for key_set in self.key_sets:
            current_player_name = str(key_set.current_player)
            # if a joystick is connected for current player
            if key_set.joystick in joysticks:
                # Usually axis run in pairs, up/down for one, and left/right for the other. Triggers count as axes.
                # You may want to take into account some tolerance to handle jitter, and
                # joystick drift may keep the joystick from centering at 0 or using the full range of position values.
                tolerance_threshold = 0.2
                # axis 0 = joy stick left --> left & right
                axis_left_right = joysticks[key_set.joystick].get_axis(0)
                axis_up_down = joysticks[key_set.joystick].get_axis(1)
                if (
                    abs(axis_left_right) > tolerance_threshold
                    or abs(axis_up_down) > tolerance_threshold
                ):
                    move_vec = np.zeros(2)
                    if abs(axis_left_right) > tolerance_threshold:
                        move_vec[0] += axis_left_right
                    # axis 1 = joy stick right --> up & down
                    if abs(axis_up_down) > tolerance_threshold:
                        move_vec[1] += axis_up_down

                    # if np.linalg.norm(move_vec) != 0:
                    #     move_vec = move_vec / np.linalg.norm(move_vec)

                    action = Action(
                        current_player_name,
                        ActionType.MOVEMENT,
                        move_vec,
                        duration=self.time_delta,
                    )
                    self.send_action(action)

    def handle_key_event(self, event):
        """Handles key events for the pickup and interaction keys. Pickup is a single action,
        for interaction keydown and keyup is necessary, because the player has to be able to hold
        the key down.

        Args:
            event: Pygame event for extracting the key action.
        """

        for key_set in self.key_sets:
            current_player_name = str(key_set.current_player)
            if event.key == key_set.pickup_key and event.type == pygame.KEYDOWN:
                action = Action(self.player_id, ActionType.PUT, "pickup")
                self.send_action(action)

            if event.key == key_set.interact_key:
                if event.type == pygame.KEYDOWN:
                    action = Action(
                        current_player_name, ActionType.INTERACT, InterActionData.START
                    )
                    self.send_action(action)
                elif event.type == pygame.KEYUP:
                    action = Action(
                        self.player_id, ActionType.INTERACT, InterActionData.STOP
                    )
                    self.send_action(action)
            if event.key == key_set.switch_key and not self.CONNECT_WITH_STUDY_SERVER:
                if event.type == pygame.KEYDOWN:
                    key_set.next_player()

    def handle_joy_stick_event(self, event, joysticks):
        """Handles joy stick events for the pickup and interaction keys. Pickup is a single action,
        for interaction buttondown and buttonup is necessary, because the player has to be able to hold
        the button down.

        Args:
            event: Pygame event for extracting the button action.
            joysticks: list of joysticks
        """
        for key_set in self.key_sets:
            current_player_name = str(key_set.current_player)
            # if a joystick is connected for current player
            if key_set.joystick in joysticks:
                # pickup = Button A <-> 0
                if (
                    joysticks[key_set.joystick].get_button(0)
                    and event.type == pygame.JOYBUTTONDOWN
                ):
                    action = Action(current_player_name, ActionType.PUT, "pickup")
                    self.send_action(action)

                # interact = Button X <-> 2
                if (
                    joysticks[key_set.joystick].get_button(2)
                    and event.type == pygame.JOYBUTTONDOWN
                ):
                    action = Action(
                        current_player_name, ActionType.INTERACT, InterActionData.START
                    )
                    self.send_action(action)
                    # stop interaction if last pressed button was X <-> 2
                if event.button == 2 and event.type == pygame.JOYBUTTONUP:
                    action = Action(
                        current_player_name, ActionType.INTERACT, InterActionData.STOP
                    )
                    self.send_action(action)
                # switch button Y <-> 3
                if (
                    joysticks[key_set.joystick].get_button(3)
                    and not self.CONNECT_WITH_STUDY_SERVER
                ):
                    if event.type == pygame.JOYBUTTONDOWN:
                        key_set.next_player()

    def set_window_size(self):
        if self.fullscreen:
            flags = pygame.FULLSCREEN
            self.window_width = self.window_width_fullscreen
            self.window_height = self.window_height_fullscreen
        else:
            flags = 0
            self.window_width = self.window_width_windowed
            self.window_height = self.window_height_windowed

        self.main_window = pygame.display.set_mode(
            (
                self.window_width,
                self.window_height,
            ),
            flags=flags,
        )

    def reset_window_size(self):
        self.game_width = 0
        self.game_height = 0
        self.set_window_size()

    def set_game_size(self, max_width=None, max_height=None):
        if max_width is None:
            max_width = self.window_width - (2 * self.screen_margin)
        if max_height is None:
            max_height = self.window_height - (2 * self.screen_margin)

        self.kitchen_aspect_ratio = self.kitchen_height / self.kitchen_width
        if self.kitchen_width > self.kitchen_height:
            self.game_width = max_width
            self.game_height = self.game_width * self.kitchen_aspect_ratio

            if self.game_height > max_height:
                self.game_height = max_height
                self.game_width = self.game_height / self.kitchen_aspect_ratio
        else:
            self.game_height = max_height
            self.game_width = self.game_height / self.kitchen_aspect_ratio

            if self.game_width > max_width:
                self.game_width = max_width
                self.game_height = self.game_width * self.kitchen_aspect_ratio

        self.grid_size = int(self.game_width / self.kitchen_width)

        self.game_width = max(self.game_width, 100)
        self.game_height = max(self.game_height, 100)
        self.grid_size = max(self.grid_size, 1)

        residual_x = self.game_width - (self.kitchen_width * self.grid_size)
        residual_y = self.game_height - (self.kitchen_height * self.grid_size)
        self.game_width -= residual_x
        self.game_height -= residual_y

        self.game_screen = pygame.Surface(
            (
                self.game_width,
                self.game_height,
            )
        )

    def init_ui_elements(self):
        self.manager = pygame_gui.UIManager((self.window_width, self.window_height))
        self.manager.get_theme().load_theme(ROOT_DIR / "gui_2d_vis" / "gui_theme.json")

        ########################################################################
        # Start screen
        ########################################################################
        self.start_button = pygame_gui.elements.UIButton(
            relative_rect=pygame.Rect(
                (0, 0), (self.buttons_width, self.buttons_height)
            ),
            text="Start Game",
            manager=self.manager,
            anchors={"center": "center"},
            object_id="#start_button",
        )

        img = pygame.image.load(
            ROOT_DIR / "gui_2d_vis" / "press_a.drawio.png"
        ).convert_alpha()

        image_rect = img.get_rect()
        image_rect.centery += 60
        self.press_a_image = pygame_gui.elements.UIImage(
            image_rect,
            img,
            manager=self.manager,
            anchors={"centerx": "centerx", "centery": "centery"},
        )
        img_width = self.buttons_width
        img_height = img_width * (image_rect.height / image_rect.width)
        new_dims = (img_width, img_height)
        self.press_a_image.set_dimensions(new_dims)

        rect = pygame.Rect((0, 0), (self.buttons_width, self.buttons_height))
        rect.topright = (0, 0)
        self.quit_button = pygame_gui.elements.UIButton(
            relative_rect=rect,
            text="Quit Game",
            manager=self.manager,
            object_id="#quit_button",
            anchors={"right": "right", "top": "top"},
        )

        player_selection_rect = pygame.Rect(
            (0, 0),
            (
                self.window_width * 0.9,
                (self.window_height // 3),
            ),
        )
        player_selection_rect.bottom = -10
        self.player_selection_container = pygame_gui.elements.UIPanel(
            player_selection_rect,
            manager=self.manager,
            object_id="#players",
            anchors={"bottom": "bottom", "centerx": "centerx"},
        )

        multiple_keysets_button_rect = pygame.Rect((0, 0), (190, 50))
        self.multiple_keysets_button = pygame_gui.elements.UIButton(
            relative_rect=multiple_keysets_button_rect,
            manager=self.manager,
            container=self.player_selection_container,
            text="not set",
            anchors={"left": "left", "centery": "centery"},
            object_id="#multiple_keysets_button",
        )

        split_players_button_rect = pygame.Rect((0, 0), (190, 50))
        self.split_players_button = pygame_gui.elements.UIButton(
            relative_rect=split_players_button_rect,
            manager=self.manager,
            container=self.player_selection_container,
            text="not set",
            anchors={"centerx": "centerx", "centery": "centery"},
            object_id="#split_players_button",
        )

        players_container_rect = pygame.Rect(
            (0, 0),
            (
                self.window_width * 0.6,
                self.player_selection_container.get_abs_rect().height // 3,
            ),
        )
        self.player_number_container = pygame_gui.elements.UIPanel(
            relative_rect=players_container_rect,
            manager=self.manager,
            object_id="#players_players",
            container=self.player_selection_container,
            anchors={"top": "top", "centerx": "centerx"},
        )

        bot_container_rect = pygame.Rect(
            (0, 0),
            (
                self.window_width * 0.6,
                self.player_selection_container.get_abs_rect().height // 3,
            ),
        )
        bot_container_rect.bottom = 0
        self.bot_number_container = pygame_gui.elements.UIPanel(
            relative_rect=bot_container_rect,
            manager=self.manager,
            object_id="#players_bots",
            container=self.player_selection_container,
            anchors={"bottom": "bottom", "centerx": "centerx"},
        )

        number_players_rect = pygame.Rect((0, 0), (200, 200))
        self.added_players_label = pygame_gui.elements.UILabel(
            number_players_rect,
            manager=self.manager,
            object_id="#number_players_label",
            container=self.player_number_container,
            text=f"Humans to be added: -",
            anchors={"center": "center"},
        )

        number_bots_rect = pygame.Rect((0, 0), (200, 200))
        self.added_bots_label = pygame_gui.elements.UILabel(
            number_bots_rect,
            manager=self.manager,
            object_id="#number_bots_label",
            container=self.bot_number_container,
            text=f"Bots to be added: -",
            anchors={"center": "center"},
        )

        size = 50
        add_player_button_rect = pygame.Rect((0, 0), (size, size))
        add_player_button_rect.right = 0
        self.add_human_player_button = pygame_gui.elements.UIButton(
            relative_rect=add_player_button_rect,
            text="+",
            manager=self.manager,
            object_id="#quantity_button",
            container=self.player_number_container,
            anchors={"right": "right", "centery": "centery"},
        )

        remove_player_button_rect = pygame.Rect((0, 0), (size, size))
        remove_player_button_rect.left = 0
        self.remove_human_button = pygame_gui.elements.UIButton(
            relative_rect=remove_player_button_rect,
            text="-",
            manager=self.manager,
            object_id="#quantity_button",
            container=self.player_number_container,
            anchors={"left": "left", "centery": "centery"},
        )

        add_bot_button_rect = pygame.Rect((0, 0), (size, size))
        add_bot_button_rect.right = 0
        self.add_bot_button = pygame_gui.elements.UIButton(
            relative_rect=add_bot_button_rect,
            text="+",
            manager=self.manager,
            object_id="#quantity_button",
            container=self.bot_number_container,
            anchors={"right": "right", "centery": "centery"},
        )

        remove_bot_button_rect = pygame.Rect((0, 0), (size, size))
        remove_bot_button_rect.left = 0
        self.remove_bot_button = pygame_gui.elements.UIButton(
            relative_rect=remove_bot_button_rect,
            text="-",
            manager=self.manager,
            object_id="#quantity_button",
            container=self.bot_number_container,
            anchors={"left": "left", "centery": "centery"},
        )

        ########################################################################
        # Tutorial screen
        ########################################################################

        image = pygame.image.load(
            ROOT_DIR / "gui_2d_vis" / "tutorial_files" / "tutorial.drawio.png"
        ).convert_alpha()
        image_rect = image.get_rect()
        image_rect.topleft = (20, self.buttons_height)
        self.tutorial_image = pygame_gui.elements.UIImage(
            image_rect,
            image,
            manager=self.manager,
            anchors={"top": "top", "left": "left"},
        )
        img_width = self.window_width * 0.8
        # img_width = img_height * (image_rect.width / image_rect.height)
        img_height = img_width * (image_rect.height / image_rect.width)
        new_dims = (img_width, img_height)
        self.tutorial_image.set_dimensions(new_dims)

        button_rect = pygame.Rect((0, 0), (220, 80))
        button_rect.bottom = -20
        self.continue_button = pygame_gui.elements.UIButton(
            relative_rect=button_rect,
            text="Continue",
            manager=self.manager,
            anchors={"centerx": "centerx", "bottom": "bottom"},
        )

        fullscreen_button_rect = pygame.Rect(
            (0, 0), (self.buttons_width * 0.7, self.buttons_height)
        )
        fullscreen_button_rect.topright = (-self.buttons_width, 0)
        self.fullscreen_button = pygame_gui.elements.UIButton(
            relative_rect=fullscreen_button_rect,
            text="Fullscreen",
            manager=self.manager,
            object_id="#fullscreen_button",
            anchors={"right": "right", "top": "top"},
        )

        ########################################################################
        # PreGame screen
        ########################################################################

        image = pygame.image.load(
            ROOT_DIR / "gui_2d_vis" / "tutorial_files" / "recipe_mock.png"
        ).convert_alpha()
        image_rect = image.get_rect()
        image_rect.top = 50
        self.recipe_image = pygame_gui.elements.UIImage(
            image_rect,
            image,
            manager=self.manager,
            anchors={"centerx": "centerx", "top": "top"},
        )
        img_height = self.window_height * 0.7
        img_width = img_height * (image_rect.width / image_rect.height)
        new_dims = (img_width, img_height)
        self.recipe_image.set_dimensions(new_dims)

        self.level_name = pygame_gui.elements.UILabel(
            text=f"Next level: {self.layout_file_paths[self.current_layout_idx].stem}",
            relative_rect=pygame.Rect(
                (0, 0),
                (self.window_width * 0.7, self.window_height * 0.2),
            ),
            manager=self.manager,
            object_id="#score_label",
            anchors={"centerx": "centerx", "top": "top"},
        )

        ########################################################################
        # Game screen
        ########################################################################

        self.finished_button = pygame_gui.elements.UIButton(
            relative_rect=pygame.Rect(
                (
                    (self.window_width - self.buttons_width),
                    (self.window_height - self.buttons_height),
                ),
                (self.buttons_width, self.buttons_height),
            ),
            text="Finish round",
            manager=self.manager,
        )

        self.orders_label = pygame_gui.elements.UILabel(
            text="Orders:",
            relative_rect=pygame.Rect(0, 0, self.screen_margin, self.screen_margin),
            manager=self.manager,
            object_id="#orders_label",
        )

        rect = pygame.Rect(
            (0, 0),
            (self.window_width * 0.2, self.buttons_height),
        )
        rect.bottomleft = (0, 0)
        self.score_label = pygame_gui.elements.UILabel(
            text=f"Score not set",
            relative_rect=rect,
            manager=self.manager,
            object_id="#score_label",
            anchors={"bottom": "bottom", "left": "left"},
        )

        rect = pygame.Rect(
            (0, 0),
            (self.window_width * 0.4, self.buttons_height),
        )
        rect.bottom = 0
        self.timer_label = pygame_gui.elements.UILabel(
            text="GAMETIME not set",
            relative_rect=rect,
            manager=self.manager,
            object_id="#timer_label",
            anchors={"bottom": "bottom", "centerx": "centerx"},
        )

        rect = pygame.Rect(
            (0, 0),
            (self.window_width, self.screen_margin),
        )
        rect.right = 20
        self.wait_players_label = pygame_gui.elements.UILabel(
            text="WAITING FOR OTHER PLAYERS",
            relative_rect=rect,
            manager=self.manager,
            object_id="#wait_players_label",
            anchors={"centery": "centery", "right": "right"},
        )

        ########################################################################
        # PostGame screen
        ########################################################################

        conclusion_rect = pygame.Rect(0, 0, self.window_width, self.window_height * 0.4)
        conclusion_rect.top = 50
        self.conclusion_label = pygame_gui.elements.UILabel(
            text="not set",
            relative_rect=conclusion_rect,
            manager=self.manager,
            object_id="#score_label",
            anchors={"centerx": "centerx", "top": "top"},
        )

        next_game_button_rect = pygame.Rect((0, 0), (190, 50))
        next_game_button_rect.center = (self.buttons_width // 2, 200)
        self.next_game_button = pygame_gui.elements.UIButton(
            relative_rect=next_game_button_rect,
            manager=self.manager,
            text="Next game",
            anchors={"centerx": "centerx", "centery": "centery"},
            object_id="#split_players_button",
        )

        retry_button_rect = pygame.Rect((0, 0), (190, 50))
        retry_button_rect.center = (self.buttons_width // 2 - 200, 200)
        self.retry_button = pygame_gui.elements.UIButton(
            relative_rect=retry_button_rect,
            manager=self.manager,
            text="Retry last game",
            anchors={"center": "center"},
            object_id="#split_players_button",
        )

        finish_study_rect = pygame.Rect((0, 0), (190, 50))
        finish_study_rect.center = (self.buttons_width // 2 + 200, 200)
        self.finish_study_button = pygame_gui.elements.UIButton(
            relative_rect=finish_study_rect,
            manager=self.manager,
            text="Finish study",
            anchors={"center": "center"},
            object_id="#split_players_button",
        )

        ########################################################################
        # End screen
        ########################################################################

        conclusion_rect = pygame.Rect(
            0, 0, self.window_width * 0.6, self.window_height * 0.4
        )
        self.thank_you_label = pygame_gui.elements.UILabel(
            text="Thank you!",
            relative_rect=conclusion_rect,
            manager=self.manager,
            object_id="#score_label",
            anchors={"center": "center"},
        )

        ########################################################################

        self.start_screen_elements = [
            self.start_button,
            self.quit_button,
            self.fullscreen_button,
            self.player_selection_container,
            self.press_a_image,
        ]

        self.tutorial_screen_elements = [
            self.tutorial_image,
            self.continue_button,
        ]

        self.pregame_screen_elements = [
            self.recipe_image,
            self.level_name,
            self.press_a_image,
            self.continue_button,
        ]

        self.game_screen_elements = [
            self.orders_label,
            self.score_label,
            self.timer_label,
            self.wait_players_label,
        ]

        self.postgame_screen_elements = [
            self.conclusion_label,
            self.next_game_button,
        ]

        self.end_screen_elements = [
            self.fullscreen_button,
            self.quit_button,
            self.thank_you_label,
        ]

        self.rest = [
            self.fullscreen_button,
            self.quit_button,
            self.retry_button,
            self.finish_study_button,
            self.finished_button,
        ]

    def show_screen_elements(self, elements: list):
        for element in (
            self.start_screen_elements
            + self.tutorial_screen_elements
            + self.pregame_screen_elements
            + self.game_screen_elements
            + self.postgame_screen_elements
            + self.end_screen_elements
            + self.rest
        ):
            element.hide()
        for element in elements:
            element.show()

    def update_screen_elements(self):
        match self.menu_state:
            case MenuStates.Start:
                self.show_screen_elements(self.start_screen_elements)

                if self.CONNECT_WITH_STUDY_SERVER:
                    self.player_selection_container.hide()

                self.update_selection_elements()
            case MenuStates.ControllerTutorial:
                self.show_screen_elements(self.tutorial_screen_elements)
                self.setup_game(tutorial=True)
                self.set_game_size(
                    max_height=self.window_height * 0.3,
                    max_width=self.window_width * 0.3,
                )
                self.set_window_size()
                self.game_center = (
                    self.window_width - self.game_width / 2 - 20,
                    self.window_height - self.game_height / 2 - 20,
                )
            case MenuStates.PreGame:
                self.show_screen_elements(self.pregame_screen_elements)
            case MenuStates.Game:
                self.show_screen_elements(self.game_screen_elements)
            case MenuStates.PostGame:
                self.show_screen_elements(self.postgame_screen_elements)
                if self.last_level:
                    self.next_game_button.hide()
                    self.finish_study_button.show()
                else:
                    self.next_game_button.show()
                    self.finish_study_button.hide()
            case MenuStates.End:
                self.show_screen_elements(self.end_screen_elements)

    def draw_main_window(self):
        self.main_window.fill(
            colors[self.visualization_config["GameWindow"]["background_color"]]
        )

        match self.menu_state:
            case MenuStates.Start:
                pass
            case MenuStates.ControllerTutorial:
                self.draw_tutorial_screen_frame()
            case MenuStates.Game:
                self.draw_game_screen_frame()
            case MenuStates.PostGame:
                self.update_conclusion_label(self.last_state)

        self.manager.draw_ui(self.main_window)
        self.manager.update(self.time_delta)
        pygame.display.flip()

    def draw_tutorial_screen_frame(self):
        self.handle_keys()
        self.handle_joy_stick_input(joysticks=self.joysticks)

        state = self.request_state()
        self.vis.draw_gamescreen(
            self.game_screen,
            state,
            self.grid_size,
            [k.current_player for k in self.key_sets],
        )

        game_screen_rect = self.game_screen.get_rect()
        game_screen_rect.center = self.game_center
        self.main_window.blit(self.game_screen, game_screen_rect)

    def draw_game_screen_frame(self):
        self.last_state = self.request_state()

        self.handle_keys()
        self.handle_joy_stick_input(joysticks=self.joysticks)

        if not self.beeped_once and self.last_state["all_players_ready"]:
            self.beeped_once = True
            self.play_bell_sound()

        if self.last_state["ended"]:
            self.menu_state = MenuStates.PostGame
            self.finished_button_press()
            self.disconnect_websockets()

            if self.CONNECT_WITH_STUDY_SERVER:
                self.send_level_done()

            self.update_screen_elements()
            self.beeped_once = False

        else:
            self.draw_game(self.last_state)

            game_screen_rect = self.game_screen.get_rect()

            game_screen_rect.center = [
                self.window_width // 2,
                self.window_height // 2,
            ]

            self.main_window.blit(self.game_screen, game_screen_rect)

            if not self.last_state["all_players_ready"]:
                self.wait_players_label.show()
            else:
                self.wait_players_label.hide()

    def draw_game(self, state):
        """Main visualization function.

        Args:            state: The game state returned by the environment."""
        self.vis.draw_gamescreen(
            self.game_screen,
            state,
            self.grid_size,
            [k.current_player for k in self.key_sets],
        )

        self.vis.draw_orders(
            screen=self.main_window,
            state=state,
            grid_size=self.buttons_height,
            width=self.window_width - self.buttons_width - (self.buttons_width * 0.7),
            height=self.game_height,
            screen_margin=self.screen_margin,
            config=self.visualization_config,
        )

        border = self.visualization_config["GameWindow"]["game_border_size"]
        border_rect = pygame.Rect(
            self.window_width // 2 - (self.game_width // 2) - border,
            self.window_height // 2 - (self.game_height // 2) - border,
            self.game_width + 2 * border,
            self.game_height + 2 * border,
        )
        pygame.draw.rect(
            self.main_window,
            colors[self.visualization_config["GameWindow"]["game_border_color"]],
            border_rect,
            width=border,
        )

        self.update_score_label(state)
        self.update_remaining_time(state["remaining_time"])

        if state["info_msg"]:
            for idx, msg in enumerate(reversed(state["info_msg"])):
                text_surface = self.comic_sans.render(
                    msg[0],
                    antialias=True,
                    color=(0, 0, 0)
                    if msg[1] == "Normal"
                    else ((255, 0, 0) if msg[1] == "Warning" else (0, 255, 0)),
                    # bgcolor=(255, 255, 255),
                )
                self.main_window.blit(
                    text_surface,
                    (
                        self.window_width / 4,
                        self.window_height - self.screen_margin + 5 + (20 * idx),
                    ),
                )

    def update_score_label(self, state):
        score = state["score"]
        self.score_label.set_text(f"Score {score}")

    def update_conclusion_label(self, state):
        score = state["score"]
        self.conclusion_label.set_text(f"Your final score is {score}. Hurray!")

    def update_remaining_time(self, remaining_time: float):
        hours, rem = divmod(int(remaining_time), 3600)
        minutes, seconds = divmod(rem, 60)
        display_time = f"{minutes}:{'%02d' % seconds}"
        self.timer_label.set_text(f"Time remaining: {display_time}")

    def create_env_on_game_server(self, tutorial):
        if tutorial:
            layout_path = ROOT_DIR / "game_content" / "tutorial" / "tutorial.layout"
            environment_config_path = (
                ROOT_DIR / "game_content" / "tutorial" / "tutorial_env_config.yaml"
            )
        else:
            environment_config_path = (
                ROOT_DIR / "game_content" / "environment_config.yaml"
            )
            layout_path = self.layout_file_paths[self.current_layout_idx]

        item_info_path = ROOT_DIR / "game_content" / "item_info.yaml"
        with open(item_info_path, "r") as file:
            item_info = file.read()
        with open(layout_path, "r") as file:
            layout = file.read()
        with open(environment_config_path, "r") as file:
            environment_config = file.read()

        num_players = 1 if tutorial else self.number_players
        seed = 161616161616
        creation_json = CreateEnvironmentConfig(
            manager_id=self.manager_id,
            number_players=num_players,
            environment_settings={"all_player_can_pause_game": False},
            item_info_config=item_info,
            environment_config=environment_config,
            layout_config=layout,
            seed=seed,
        ).model_dump(mode="json")

        # print(CreateEnvironmentConfig.model_validate_json(json_data=creation_json))
        env_info = requests.post(
            f"{self.request_url}/manage/create_env/",
            json=creation_json,
        )
        if env_info.status_code == 403:
            raise ValueError(f"Forbidden Request: {env_info.json()['detail']}")
        env_info = env_info.json()
        assert isinstance(env_info, dict), "Env info must be a dictionary"
        self.current_env_id = env_info["env_id"]
        self.player_info = env_info["player_info"]
        if tutorial:
            self.player_id = str(list(self.player_info.keys())[0])

    def get_game_connection(self):
        if self.menu_state == MenuStates.ControllerTutorial:
            self.player_info = requests.post(
                f"http://localhost:8080/connect_to_tutorial/{self.participant_id}"
            ).json()

            self.key_sets[0].current_player = int(self.player_info["player_id"])
            self.player_id = self.player_info["player_id"]
            self.player_info = {self.player_info["player_id"]: self.player_info}
        else:
            answer = requests.post(
                f"http://localhost:8080/get_game_connection/{self.participant_id}"
            ).json()
            self.player_info = answer["player_info"]
            self.last_level = answer["last_level"]

            print("LAST LEVEL", self.last_level)

            self.key_sets[0].current_player = int(self.player_info["player_id"])
            self.player_id = self.player_info["player_id"]
            self.player_info = {self.player_info["player_id"]: self.player_info}

    def create_and_connect_bot(self, player_id, player_info):
        player_hash = player_info["player_hash"]
        print(
            f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"'
        )
        if self.USE_AAAMBOS_AGENT:
            sub = Popen(
                " ".join(
                    [
                        "exec",
                        "aaambos",
                        "run",
                        "--arch_config",
                        str(ROOT_DIR / "game_content" / "agents" / "arch_config.yml"),
                        "--run_config",
                        str(ROOT_DIR / "game_content" / "agents" / "run_config.yml"),
                        f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"',
                        f"--instance={player_hash}",
                    ]
                ),
                shell=True,
            )
        else:
            sub = Popen(
                " ".join(
                    [
                        "python",
                        str(ROOT_DIR / "game_content" / "agents" / "random_agent.py"),
                        f'--uri {self.websocket_url + player_info["client_id"]}',
                        f"--player_hash {player_hash}",
                        f"--player_id {player_id}",
                    ]
                ),
                shell=True,
            )
        self.sub_processes.append(sub)

    def connect_websockets(self):
        for p, (player_id, player_info) in enumerate(self.player_info.items()):
            if p < self.number_humans_to_be_added:
                # add player websockets
                websocket = connect(self.websocket_url + player_info["client_id"])
                websocket.send(
                    json.dumps(
                        {"type": "ready", "player_hash": player_info["player_hash"]}
                    )
                )
                assert (
                    json.loads(websocket.recv())["status"] == 200
                ), "not accepted player"
                self.websockets[player_id] = websocket

            else:
                # create bots and add bot websockets
                self.create_and_connect_bot(player_id, player_info)

            if p == 0:
                self.state_player_id = player_id

    def setup_game(self, tutorial=False):
        if tutorial:
            self.key_sets = self.setup_player_keys(1, 1, False)
            self.vis.create_player_colors(1)
        else:
            self.number_players = (
                self.number_humans_to_be_added + self.number_bots_to_be_added
            )

            if self.split_players:
                assert (
                    self.number_humans_to_be_added > 1
                ), "Not enough players for key configuration."
            num_key_set = 2 if self.multiple_keysets else 1
            self.key_sets = self.setup_player_keys(
                self.number_humans_to_be_added,
                min(self.number_humans_to_be_added, num_key_set),
                self.split_players,
            )

        if self.CONNECT_WITH_STUDY_SERVER:
            self.get_game_connection()
        else:
            self.create_env_on_game_server(tutorial)

        self.connect_websockets()

        state = self.request_state()

        self.vis.create_player_colors(len(state["players"]))

        self.kitchen_width = state["kitchen"]["width"]
        self.kitchen_height = state["kitchen"]["height"]

    def stop_game(self, reason: str) -> None:
        log.debug(f"Stopping game: {reason}")
        if not self.CONNECT_WITH_STUDY_SERVER:
            requests.post(
                f"{self.request_url}/manage/stop_env/",
                json={
                    "manager_id": self.manager_id,
                    "env_id": self.current_env_id,
                    "reason": reason,
                },
            )

    def send_tutorial_finished(self):
        requests.post(
            f"http://localhost:8080/disconnect_from_tutorial/{self.participant_id}",
        )

    def finished_button_press(self):
        if not self.CONNECT_WITH_STUDY_SERVER:
            self.stop_game("finished_button_pressed")
        self.menu_state = MenuStates.PostGame
        self.reset_window_size()
        log.debug("Pressed finished button")
        self.update_screen_elements()

    def fullscreen_button_press(self):
        self.fullscreen = not self.fullscreen
        self.set_window_size()
        self.init_ui_elements()
        self.set_game_size()

    def reset_gui_values(self):
        self.currently_controlled_player_idx = 0
        self.number_humans_to_be_added = 1
        self.number_bots_to_be_added = 0
        self.number_players = (
            self.number_humans_to_be_added + self.number_bots_to_be_added
        )
        self.split_players = False
        self.multiple_keysets = False
        self.player_minimum = 1

    def update_selection_elements(self):
        if self.number_humans_to_be_added <= self.player_minimum:
            self.remove_human_button.disable()
            self.number_humans_to_be_added = self.player_minimum
        else:
            self.remove_human_button.enable()

        self.number_humans_to_be_added = max(
            self.player_minimum, self.number_humans_to_be_added
        )

        self.number_players = (
            self.number_humans_to_be_added + self.number_bots_to_be_added
        )

        text = "WASD+ARROW" if self.multiple_keysets else "WASD"
        self.multiple_keysets_button.set_text(text)
        self.added_players_label.set_text(
            f"Humans to be added: {self.number_humans_to_be_added}"
        )
        self.added_bots_label.set_text(
            f"Bots to be added: {self.number_bots_to_be_added}"
        )
        text = "Yes" if self.split_players else "No"
        self.split_players_button.set_text(f"Split players: {text}")

        if self.multiple_keysets:
            self.split_players_button.show()
        else:
            self.split_players_button.hide()

        if self.number_players == 0:
            self.start_button.disable()
        else:
            self.start_button.enable()

    def send_action(self, action: Action):
        """Sends an action to the game environment.

        Args:
            action: The action to be sent. Contains the player, action type and move direction if action is a movement.
        """
        if isinstance(action.action_data, np.ndarray):
            action.action_data = [
                float(action.action_data[0]),
                float(action.action_data[1]),
            ]
        self.websockets[self.player_id].send(
            json.dumps(
                {
                    "type": "action",
                    "action": dataclasses.asdict(
                        action, dict_factory=custom_asdict_factory
                    ),
                    "player_hash": self.player_info[action.player]["player_hash"],
                }
            )
        )
        self.websockets[self.player_id].recv()

    def request_state(self):
        self.websockets[self.state_player_id].send(
            json.dumps(
                {
                    "type": "get_state",
                    "player_hash": self.player_info[self.state_player_id][
                        "player_hash"
                    ],
                }
            )
        )
        state = json.loads(self.websockets[self.state_player_id].recv())
        return state

    def disconnect_websockets(self):
        for sub in self.sub_processes:
            try:
                if self.USE_AAAMBOS_AGENT:
                    pgrp = os.getpgid(sub.pid)
                    os.killpg(pgrp, signal.SIGINT)
                    subprocess.run(
                        "kill $(ps aux | grep 'aaambos' | awk '{print $2}')", shell=True
                    )
                else:
                    sub.kill()

            except ProcessLookupError:
                pass

        self.sub_processes = []
        for websocket in self.websockets.values():
            websocket.close()

    def play_bell_sound(self):
        bell_path = str(ROOT_DIR / "gui_2d_vis" / "sync_bell.wav")
        mixer.init()
        mixer.music.load(bell_path)
        mixer.music.set_volume(0.9)
        mixer.music.play()
        log.log(logging.INFO, "Started game, played bell sound")

    def start_study(self):
        self.player_info = requests.post(
            f"http://localhost:8080/start_study/{self.participant_id}"
        ).json()
        self.last_level = False

    def send_level_done(self):
        answer = requests.post(
            f"http://localhost:8080/level_done/{self.participant_id}"
        ).json()
        # self.last_level = answer["last_level"]
        # print("\nAT LAST LEVEL:", self.last_level, "\n")

    def manage_button_event(self, event):
        if event.ui_element == self.quit_button:
            self.running = False
            self.disconnect_websockets()
            self.stop_game("Quit button")
            self.menu_state = MenuStates.Start
            log.debug("Pressed quit button")
            return

        elif event.ui_element == self.fullscreen_button:
            self.fullscreen_button_press()
            log.debug("Pressed fullscreen button")
            return

        # Filter by shown screen page
        match self.menu_state:
            ############################################
            case MenuStates.Start:
                match event.ui_element:
                    case self.start_button:
                        if not (
                            self.number_humans_to_be_added
                            + self.number_bots_to_be_added
                        ):
                            pass
                        else:
                            self.menu_state = MenuStates.ControllerTutorial

                    case self.add_human_player_button:
                        self.number_humans_to_be_added += 1
                    case self.add_bot_button:
                        self.number_bots_to_be_added += 1

                    case self.remove_human_button:
                        self.number_humans_to_be_added = max(
                            self.player_minimum, self.number_humans_to_be_added - 1
                        )
                    case self.remove_bot_button:
                        self.number_bots_to_be_added = max(
                            0, self.number_bots_to_be_added - 1
                        )
                    case self.multiple_keysets_button:
                        self.multiple_keysets = not self.multiple_keysets
                        self.split_players = False
                    case self.split_players_button:
                        self.split_players = not self.split_players
                        if self.split_players:
                            self.player_minimum = 2
                        else:
                            self.player_minimum = 0

            ############################################

            case MenuStates.ControllerTutorial:
                match event.ui_element:
                    case self.continue_button:
                        self.menu_state = MenuStates.PreGame
                        if self.CONNECT_WITH_STUDY_SERVER:
                            self.send_tutorial_finished()
                            self.start_study()
                        else:
                            self.stop_game("tutorial_finished")
                        self.disconnect_websockets()

            ############################################

            case MenuStates.PreGame:
                match event.ui_element:
                    case self.continue_button:
                        self.setup_game()
                        self.set_game_size()
                        self.menu_state = MenuStates.Game

            ############################################

            case MenuStates.Game:
                match event.ui_element:
                    case self.finished_button:
                        self.menu_state = MenuStates.PostGame
                        self.disconnect_websockets()
                        self.finished_button_press()
                        self.handle_joy_stick_input(joysticks=self.joysticks)

                        if self.CONNECT_WITH_STUDY_SERVER:
                            self.send_level_done()

            ############################################

            case MenuStates.PostGame:
                match event.ui_element:
                    case self.retry_button:
                        if not self.CONNECT_WITH_STUDY_SERVER:
                            self.stop_game("Retry button")
                        self.menu_state = MenuStates.PreGame

                    case self.next_game_button:
                        if not self.CONNECT_WITH_STUDY_SERVER:
                            self.current_layout_idx += 1
                            if self.current_layout_idx == len(self.layout_file_paths) - 1:
                                self.last_level = True
                            else:
                                log.debug(
                                    f"LEVEL: {self.layout_file_paths[self.current_layout_idx]}"
                                )
                        self.menu_state = MenuStates.PreGame

                    case self.finish_study_button:
                        self.menu_state = MenuStates.End

            ############################################

            case MenuStates.End:
                match event.ui_element:
                    case other:
                        pass

    def start_pygame(self):
        """Starts pygame and the gui loop. Each frame the game state is visualized and keyboard inputs are read."""
        log.debug(f"Starting pygame gui at {self.FPS} fps")
        pygame.font.init()
        self.comic_sans = pygame.font.SysFont("Comic Sans MS", 30)

        pygame.display.set_caption("Simple Overcooked Simulator")

        clock = pygame.time.Clock()

        self.reset_window_size()
        self.init_ui_elements()
        self.reset_gui_values()
        self.update_screen_elements()

        # Game loop
        self.running = True
        # This dict can be left as-is, since pygame will generate a
        # pygame.JOYDEVICEADDED event for every joystick connected
        # at the start of the program.
        self.joysticks = {}

        while self.running:
            try:
                self.time_delta = clock.tick(self.FPS) / 1000

                # PROCESSING EVENTS
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        self.disconnect_websockets()
                        self.running = False

                    # connect joystick
                    if (
                        pygame.joystick.get_count() > 0
                        and event.type == pygame.JOYDEVICEADDED
                    ):
                        # This event will be generated when the program starts for every
                        # joystick, filling up the list without needing to create them manually.
                        joy = pygame.joystick.Joystick(event.device_index)
                        self.joysticks[joy.get_instance_id()] = joy
                        print(f"Joystick {joy.get_instance_id()} connected")

                    # disconnect joystick
                    if event.type == pygame.JOYDEVICEREMOVED:
                        del self.joysticks[event.instance_id]
                        print(f"Joystick {event.instance_id} disconnected")
                        print("Number of joysticks:", pygame.joystick.get_count())

                    # Press key instead of mouse button press
                    if self.menu_state == MenuStates.Start:
                        if event.type in [pygame.JOYBUTTONDOWN, pygame.KEYDOWN]:
                            self.menu_state = MenuStates.ControllerTutorial
                            self.update_screen_elements()
                    elif self.menu_state == MenuStates.PreGame:
                        if event.type in [pygame.JOYBUTTONDOWN, pygame.KEYDOWN]:
                            self.setup_game()
                            self.set_game_size()
                            self.menu_state = MenuStates.Game
                            self.update_screen_elements()

                    if event.type == pygame_gui.UI_BUTTON_PRESSED:
                        self.manage_button_event(event)
                        self.update_screen_elements()

                    if event.type in [
                        pygame.KEYDOWN,
                        pygame.KEYUP,
                    ] and self.menu_state in [
                        MenuStates.Game,
                        MenuStates.ControllerTutorial,
                    ]:
                        self.handle_key_event(event)

                    if event.type in [
                        pygame.JOYBUTTONDOWN,
                        pygame.JOYBUTTONUP,
                    ] and self.menu_state in [
                        MenuStates.Game,
                        MenuStates.ControllerTutorial,
                    ]:
                        self.handle_joy_stick_event(event, joysticks=self.joysticks)

                    self.manager.process_events(event)

                # DRAWING
                self.draw_main_window()

            except (KeyboardInterrupt, SystemExit):
                self.running = False
                self.disconnect_websockets()
                self.stop_game("Program exited.")

        self.disconnect_websockets()
        self.stop_game("Program exited")
        pygame.quit()
        sys.exit()


def main(
    url: str,
    port: int,
    manager_ids: list[str],
    CONNECT_WITH_STUDY_SERVER=False,
    USE_AAAMBOS_AGENT=False,
):
    manager_ids = ["1234"]

    # setup_logging()
    gui = PyGameGUI(
        url=url,
        port=port,
        manager_ids=manager_ids,
        CONNECT_WITH_STUDY_SERVER=CONNECT_WITH_STUDY_SERVER,
        USE_AAAMBOS_AGENT=CONNECT_WITH_STUDY_SERVER,
    )
    gui.start_pygame()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="Overcooked Simulator 2D PyGame Visualization",
        description="PyGameGUI: a PyGame 2D Visualization window.",
        epilog="For further information, see https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/overcooked_simulator.html",
    )

    url_and_port_arguments(parser)
    disable_websocket_logging_arguments(parser)
    add_list_of_manager_ids_arguments(parser)
    args = parser.parse_args()
    main(args.url, args.port, args.manager_ids)