import pygame
import numpy as np
from overcooked_simulator.overcooked_environment import Action
from overcooked_simulator.simulation_runner import Simulator


WHITE = (255, 255, 255)
GREY = (190, 190, 190)
BLACK = (0, 0, 0)
COUNTERCOLOR = (240, 240, 240)
LIGHTGREY = (220, 220, 220)
GREEN = (0, 255, 0)
RED = (255, 0, 0)
BLUE = (0, 0, 255)
YELLOW = (255, 255, 0)
BACKGROUND_COLOR = GREY


class PlayerKeyset:
    """Set of keyboard keys for controlling a player.
    First four keys are for movement,
    5th key is for interacting with counters.
    6th key ist for picking up things or dropping them.
    """
    def __init__(self, keys: list[pygame.key]):
        self.player_keys = keys
        self.move_vectors = [[-1, 0], [1, 0], [0, -1], [0, 1]]
        self.key_to_movement = {key: vec for (key, vec) in zip(self.player_keys[:-2], self.move_vectors)}
        self.interact_key = self.player_keys[-2]
        self.pickup_key = self.player_keys[-1]


class PyGameGUI:
    """Visualisation of the overcooked environmnent and reading keyboard inputs using pygame.
    """
    def __init__(self, simulator: Simulator):
        self.FPS = 60
        self.simulator = simulator
        self.counter_size = self.simulator.env.counter_side_length
        self.window_width, self.window_height = simulator.env.world_width, simulator.env.world_height

        self.GET_CONTINOUS_INTERACT_AND_PICKUP = False

        keys1 = [pygame.K_LEFT, pygame.K_RIGHT, pygame.K_UP, pygame.K_DOWN, pygame.K_SPACE, pygame.K_i]
        keys2 = [pygame.K_a, pygame.K_d, pygame.K_w, pygame.K_s, pygame.K_f, pygame.K_e]
        self.player_keysets: list[PlayerKeyset] = [PlayerKeyset(keys1), PlayerKeyset(keys2)]

    def send_action(self, action: Action):
        """Sends an action to the game environment.

        Args:
            action: The action to be sent. Contains the player, action type and move direction if action is a movement.
        """
        self.simulator.enter_action(action)

    def handle_keys(self):
        """Handles keyboard inputs. Sends action for the respective players. When a key is held down, every frame
        an action is sent in this function.
        """
        keys = pygame.key.get_pressed()
        for player_idx, keyset in enumerate(self.player_keysets):
            relevant_keys = [keys[k] for k in keyset.player_keys]
            if any(relevant_keys[:-2]):
                move_vec = np.zeros(2)
                for idx, pressed in enumerate(relevant_keys[:-2]):
                    if pressed:
                        move_vec += keyset.move_vectors[idx]
                if np.linalg.norm(move_vec) != 0:
                    move_vec = move_vec / np.linalg.norm(move_vec)

                action = Action(f"p{player_idx+1}", "movement", move_vec)
                self.send_action(action)
            if self.GET_CONTINOUS_INTERACT_AND_PICKUP:
                if relevant_keys[-2]:
                    action = Action(f"p{player_idx+1}", "interact", "interact")
                    self.send_action(action)
                if relevant_keys[-1]:
                    action = Action(f"p{player_idx+1}", "pickup", "pickup")
                    self.send_action(action)

    def handle_interact_single_send(self, event):
        """Handles key events. Here when a key is held down, only one action is sent. (This can be
        switched by the GET_CONTINOUS_INTERACT_AND_PICKUP flag)

        Args:
            event: Pygame event for extracting the key.
        """
        for player_idx, keyset in enumerate(self.player_keysets):
            if event.key == keyset.pickup_key:
                action = Action(f"p{player_idx + 1}", "pickup", "pickup")
                self.send_action(action)
            elif event.key == keyset.interact_key:
                action = Action(f"p{player_idx + 1}", "interact", "interact")
                self.send_action(action)

    def draw_background(self):
        """Visualizes a game background.
        """
        BACKGROUND_LINES = (200, 200, 200)
        block_size = self.counter_size//2  # Set the size of the grid block
        for x in range(0, self.window_width, block_size):
            for y in range(0, self.window_height, block_size):
                rect = pygame.Rect(x, y, block_size, block_size)
                pygame.draw.rect(self.screen, BACKGROUND_LINES, rect, 1)

    def draw_players(self, state):
        """Visualizes the players as circles with an triangle for the facing diretion.

        Args:
            state: The game state returned by the environment.
        """
        for player in state["players"].values():
            pos = player.pos
            size = player.radius
            color1 = RED if player.name == "p1" else GREEN
            color2 = WHITE

            rect = pygame.Rect(pos[0] - (size / 2), pos[1] - (size / 2), size, size)
            pygame.draw.circle(self.screen, color2, pos, size)
            pygame.draw.rect(self.screen, color1, rect)

            facing = player.facing_direction

            pygame.draw.polygon(self.screen, BLUE,
                            ((pos[0]+(facing[1]*5), pos[1]-(facing[0]*5)),
                             (pos[0]-(facing[1]*5), pos[1]+(facing[0]*5)), player.pos + (facing * 20)))

    def draw_counters(self, state):
        """Visualizes the counters in the environment.

        Args:
            state: The game state returned by the environment.
        """
        for idx, counter in enumerate(state["counters"]):
            counter_rect_outline = pygame.Rect(counter.pos[0] - (self.counter_size / 2),
                                       counter.pos[1] - (self.counter_size / 2), self.counter_size,
                                       self.counter_size)

            pygame.draw.rect(self.screen, COUNTERCOLOR, counter_rect_outline)

    def draw(self, state):
        """Main visualization function.

        Args:
            state: The game state returned by the environment.
        """
        self.screen.fill(BACKGROUND_COLOR)
        self.draw_background()

        self.draw_counters(state)
        self.draw_players(state)

        pygame.display.flip()

    def start_pygame(self):
        """Starts pygame and the gui loop. Each frame the gamestate is visualized and keyboard inputs are read.
        """
        pygame.init()
        pygame.font.init()

        self.screen = pygame.display.set_mode((self.window_width, self.window_height))
        pygame.display.set_caption("Simple Overcooked Simulator")
        self.screen.fill(BACKGROUND_COLOR)

        clock = pygame.time.Clock()

        # Game loop
        running = True
        while running:

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    running = False
                if not self.GET_CONTINOUS_INTERACT_AND_PICKUP:
                    if event.type == pygame.KEYDOWN:
                        self.handle_interact_single_send(event)

            self.handle_keys()
            clock.tick(self.FPS)
            state = self.simulator.get_state()
            self.draw(state)

        pygame.quit()