From e22ce756a74db4521b7f9386478f5dd2ef8db2d8 Mon Sep 17 00:00:00 2001
From: fheinrich <fheinrich@techfak.uni-bielefeld.de>
Date: Thu, 25 Jan 2024 16:38:56 +0100
Subject: [PATCH] Moved actual pygame drawing to other file

---
 overcooked_simulator/gui_2d_vis/drawing.py    | 459 +++++++++++++++
 .../gui_2d_vis/overcooked_gui.py              | 542 ++----------------
 2 files changed, 522 insertions(+), 479 deletions(-)
 create mode 100644 overcooked_simulator/gui_2d_vis/drawing.py

diff --git a/overcooked_simulator/gui_2d_vis/drawing.py b/overcooked_simulator/gui_2d_vis/drawing.py
new file mode 100644
index 00000000..26b19a85
--- /dev/null
+++ b/overcooked_simulator/gui_2d_vis/drawing.py
@@ -0,0 +1,459 @@
+import colorsys
+import math
+from pathlib import Path
+
+import numpy as np
+import numpy.typing as npt
+import pygame
+from scipy.spatial import KDTree
+
+from overcooked_simulator import ROOT_DIR
+from overcooked_simulator.gui_2d_vis.game_colors import colors
+from overcooked_simulator.order import Order
+
+
+def create_polygon(n, length):
+    if n == 1:
+        return np.array([0, 0])
+
+    vector = np.array([length, 0])
+    angle = (2 * np.pi) / n
+
+    rot_matrix = np.array(
+        [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]
+    )
+
+    vecs = [vector]
+    for i in range(n - 1):
+        vector = np.dot(rot_matrix, vector)
+        vecs.append(vector)
+
+    return vecs
+
+
+class Visualizer:
+    def __init__(self, config):
+        self.image_cache_dict = {}
+        self.player_colors = []
+        self.config = config
+
+    def create_player_colors(self, n) -> None:
+        hue_values = np.linspace(0, 1, n + 1)
+
+        colors_vec = np.array([col for col in colors.values()])
+
+        tree = KDTree(colors_vec)
+
+        color_names = list(colors.keys())
+
+        self.player_colors = []
+        for hue in hue_values:
+            rgb = colorsys.hsv_to_rgb(hue, 1, 1)
+            query_color = np.array([int(c * 255) for c in rgb])
+            _, index = tree.query(query_color, k=1)
+            self.player_colors.append(color_names[index])
+
+    def draw_gamescreen(
+        self,
+        screen,
+        state,
+        width,
+        height,
+        grid_size,
+        SHOW_COUNTER_CENTERS=False,
+        USE_PLAYER_COOK_SPRITES=False,
+        SHOW_INTERACTION_RANGE=False,
+    ):
+        self.draw_background(
+            surface=screen,
+            width=width,
+            height=height,
+            grid_size=grid_size,
+        )
+        self.draw_counters(
+            screen,
+            state,
+            grid_size,
+            SHOW_COUNTER_CENTERS,
+        )
+
+        self.draw_players(
+            screen,
+            state,
+            grid_size,
+            USE_PLAYER_COOK_SPRITES=USE_PLAYER_COOK_SPRITES,
+            SHOW_INTERACTION_RANGE=SHOW_INTERACTION_RANGE,
+        )
+
+    def draw_background(self, surface, width, height, grid_size):
+        """Visualizes a game background."""
+        block_size = grid_size // 2  # Set the size of the grid block
+        surface.fill(colors[self.config["Kitchen"]["ground_tiles_color"]])
+        for x in range(0, width, block_size):
+            for y in range(0, height, block_size):
+                rect = pygame.Rect(x, y, block_size, block_size)
+                pygame.draw.rect(
+                    surface,
+                    self.config["Kitchen"]["background_lines"],
+                    rect,
+                    1,
+                )
+
+    def draw_image(
+        self,
+        screen: pygame.Surface,
+        img_path: Path | str,
+        size: float,
+        pos: npt.NDArray,
+        rot_angle=0,
+    ):
+        cache_entry = f"{img_path}"
+        if cache_entry in self.image_cache_dict.keys():
+            image = self.image_cache_dict[cache_entry]
+        else:
+            image = pygame.image.load(
+                ROOT_DIR / "gui_2d_vis" / img_path
+            ).convert_alpha()
+            self.image_cache_dict[cache_entry] = image
+
+        image = pygame.transform.scale(image, (size, size))
+        if rot_angle != 0:
+            image = pygame.transform.rotate(image, rot_angle)
+        rect = image.get_rect()
+        rect.center = pos
+
+        screen.blit(image, rect)
+
+    def draw_players(
+        self,
+        screen: pygame.Surface,
+        state_dict: dict,
+        grid_size: float,
+        USE_PLAYER_COOK_SPRITES: bool = True,
+        SHOW_INTERACTION_RANGE: bool = False,
+    ):
+        """Visualizes the players as circles with a triangle for the facing direction.
+        If the player holds something in their hands, it is displayed
+        Args:            state: The game state returned by the environment.
+        """
+        for p_idx, player_dict in enumerate(state_dict["players"]):
+            pos = np.array(player_dict["pos"]) * grid_size
+
+            facing = np.array(player_dict["facing"])
+
+            if USE_PLAYER_COOK_SPRITES:
+                img_path = self.config["Cook"]["parts"][0]["path"]
+                rel_x, rel_y = facing
+                angle = -np.rad2deg(math.atan2(rel_y, rel_x)) + 90
+                size = self.config["Cook"]["parts"][0]["size"] * grid_size
+                self.draw_image(screen, img_path, size, pos, angle)
+
+            else:
+                size = 0.4 * grid_size
+                color1 = self.player_colors[p_idx]
+                color2 = colors["white"]
+
+                pygame.draw.circle(screen, color2, pos, size)
+                pygame.draw.circle(screen, colors["blue"], pos, size, width=1)
+                pygame.draw.circle(screen, colors[color1], pos, size // 2)
+
+                pygame.draw.polygon(
+                    screen,
+                    colors["blue"],
+                    (
+                        (
+                            pos[0] + (facing[1] * 0.1 * grid_size),
+                            pos[1] - (facing[0] * 0.1 * grid_size),
+                        ),
+                        (
+                            pos[0] - (facing[1] * 0.1 * grid_size),
+                            pos[1] + (facing[0] * 0.1 * grid_size),
+                        ),
+                        pos + (facing * 0.5 * grid_size),
+                    ),
+                )
+
+            if SHOW_INTERACTION_RANGE:
+                facing_point = np.array(player_dict["facing"])
+
+                pygame.draw.circle(
+                    screen,
+                    colors["blue"],
+                    facing_point * grid_size,
+                    1.6 * grid_size,
+                    width=1,
+                )
+                pygame.draw.circle(
+                    screen,
+                    colors["red1"],
+                    facing * grid_size,
+                    4,
+                )
+                pygame.draw.circle(screen, colors["red1"], facing, 4)
+
+            if player_dict["holding"] is not None:
+                holding_item_pos = pos + (20 * facing)
+
+                self.draw_thing(
+                    holding_item_pos,
+                    self.config[player_dict["holding"]]["parts"],
+                )
+                
+                # TODO MAKE THIS WORK
+                # if player.current_nearest_counter:
+                #     counter: Counter = player.current_nearest_counter
+                #     pos = counter.pos * self.grid_size
+                #     pygame.draw.rect(
+                #         self.game_screen,
+                #         colors[self.player_colors[p_idx]],
+                #         rect=pygame.Rect(
+                #             pos[0] - (self.grid_size // 2),
+                #             pos[1] - (self.grid_size // 2),
+                #             self.grid_size,
+                #             self.grid_size,
+                #         ),
+                #         width=2,
+                #     )
+
+    def draw_thing(
+        self,
+        screen: pygame.Surface,
+        pos: npt.NDArray[float],
+        grid_size: float,
+        parts: list[dict[str]],
+        scale: float = 1.0,
+    ):
+        """Draws an item, based on its visual parts specified in the visualization config.
+
+        Args:
+            pos: Where to draw the item parts.
+            parts: The visual parts to draw.
+            scale: Rescale the item by this factor.
+        """
+        for part in parts:
+            part_type = part["type"]
+            match part_type:
+                case "image":
+                    if "center_offset" in part:
+                        d = np.array(part["center_offset"]) * grid_size
+                        pos += d
+
+                    self.draw_image(
+                        screen,
+                        part["path"],
+                        part["size"] * scale * grid_size,
+                        pos,
+                    )
+                case "rect":
+                    height = part["height"] * grid_size
+                    width = part["width"] * grid_size
+                    color = part["color"]
+                    if "center_offset" in part:
+                        dx, dy = np.array(part["center_offset"]) * grid_size
+                        rect = pygame.Rect(pos[0] + dx, pos[1] + dy, height, width)
+                        pygame.draw.rect(screen, color, rect)
+                    else:
+                        rect = pygame.Rect(
+                            pos[0] - (height / 2),
+                            pos[1] - (width / 2),
+                            height,
+                            width,
+                        )
+                    pygame.draw.rect(screen, color, rect)
+                case "circle":
+                    radius = part["radius"] * grid_size
+                    color = colors[part["color"]]
+                    if "center_offset" in part:
+                        pygame.draw.circle(
+                            screen,
+                            color,
+                            pos + (np.array(part["center_offset"]) * grid_size),
+                            radius,
+                        )
+                    else:
+                        pygame.draw.circle(screen, color, pos, radius)
+
+    # TODO MAKE THIS WORK
+    def draw_item(
+        self,
+        pos: npt.NDArray[float],
+        grid_size,
+        config,
+        item,
+        scale: float = 1.0,
+        plate=False,
+        screen=None,
+    ):
+        """Visualization of an item at the specified position. On a counter or in the hands of the player.
+        The visual composition of the item is read in from visualization.yaml file, where it is specified as
+        different parts to be drawn.
+
+        Args:
+            pos: The position of the item to draw.
+            item: The item do be drawn in the game.
+            scale: Rescale the item by this factor.
+            screen: the pygame screen to draw on.
+            plate: item is on a plate (soup are is different on a plate and pot)
+        """
+
+        if not isinstance(item, list):
+            if item.name in config:
+                item_key = item.name
+                if "Soup" in item.name and plate:
+                    item_key += "Plate"
+                self.draw_thing(
+                    pos,
+                    config[item_key]["parts"],
+                    scale=scale,
+                    screen=screen,
+                )
+                #
+        if isinstance(item, (Item, Plate)) and item.progress_percentage > 0.0:
+            self.draw_progress_bar(screen, pos, item.progress_percentage)
+
+        if isinstance(item, CookingEquipment) and item.content_list:
+            if item.content_ready and item.content_ready.name in config:
+                self.draw_thing(
+                    pos,
+                    config[item.content_ready.name]["parts"],
+                    screen=screen,
+                )
+            else:
+                triangle_offsets = create_polygon(len(item.content_list), length=10)
+                scale = 1 if len(item.content_list) == 1 else 0.6
+                for idx, o in enumerate(item.content_list):
+                    self.draw_item(
+                        pos + triangle_offsets[idx],
+                        o,
+                        scale=scale,
+                        plate=isinstance(item, Plate),
+                        screen=screen,
+                    )
+
+    # TODO MAKE THIS WORK
+    def draw_progress_bar(
+        self,
+        screen: pygame.Surface,
+        pos: npt.NDArray[float],
+        percent: float,
+        grid_size: float,
+    ):
+        """Visualize progress of progressing item as a green bar under the item."""
+        bar_height = grid_size * 0.2
+        progress_width = percent * grid_size
+        progress_bar = pygame.Rect(
+            pos[0] - (grid_size / 2),
+            pos[1] - (grid_size / 2) + grid_size - bar_height,
+            progress_width,
+            bar_height,
+        )
+        pygame.draw.rect(screen, colors["green1"], progress_bar)
+
+    def draw_counter(
+        self, screen: pygame.Surface, counter_dict: dict, grid_size: float
+    ):
+        """Visualization of a counter at its position. If it is occupied by an item, it is also shown.
+        The visual composition of the counter is read in from visualization.yaml file, where it is specified as
+        different parts to be drawn.
+        Args:            counter: The counter to visualize.
+        """
+        pos = np.array(counter_dict["pos"]) * grid_size
+        counter_type = counter_dict["type"]
+        self.draw_thing(screen, pos, grid_size, self.config["Counter"]["parts"])
+        if counter_type in self.config:
+            self.draw_thing(screen, pos, grid_size, self.config[counter_type]["parts"])
+        else:
+            self.draw_thing(
+                screen,
+                pos,
+                self.config[counter_type]["parts"],
+            )
+
+        occupied_by = counter_dict["occupied_by"]
+        if occupied_by is not None:
+            # Multiple plates on plate return:
+            # if isinstance(occupied_by, (list, deque)):
+            #     with self.simulator.env.lock:
+            #
+            #     for i, o in enumerate(occupied_by):
+            #             self.draw_item(np.abs([pos[0], pos[1] - (i * 3)]), o)
+            # # All other items:
+            # else:
+            self.draw_thing(
+                screen,
+                pos,
+                grid_size,
+                self.config[occupied_by]["parts"],
+            )
+
+    def draw_counters(
+        self, screen: pygame, state, grid_size, SHOW_COUNTER_CENTERS=False
+    ):
+        """Visualizes the counters in the environment.
+
+        Args:            state: The game state returned by the environment.
+        """
+        for counter in state["counters"]:
+            self.draw_counter(screen, counter, grid_size)
+            if SHOW_COUNTER_CENTERS:
+                pygame.draw.circle(screen, colors["green1"], counter.pos, 3)
+
+    # TODO MAKE THIS WORK
+    def draw_orders(
+        self, screen, state, grid_size, width, height, screen_margin, config
+    ):
+        orders_width = width - 100
+        orders_height = screen_margin
+
+        order_screen = pygame.Surface(
+            (orders_width, orders_height),
+        )
+
+        bg_color = colors[config["GameWindow"]["background_color"]]
+        pygame.draw.rect(order_screen, bg_color, order_screen.get_rect())
+
+        order_rects_start = (orders_height // 2) - (grid_size // 2)
+        for idx, order in enumerate(state["orders"]):
+            order: Order
+            order_upper_left = [
+                order_rects_start + idx * self.grid_size * 1.2,
+                order_rects_start,
+            ]
+            pygame.draw.rect(
+                order_screen,
+                colors["red"],
+                pygame.Rect(
+                    order_upper_left[0],
+                    order_upper_left[1],
+                    self.grid_size,
+                    self.grid_size,
+                ),
+                width=2,
+            )
+            center = np.array(order_upper_left) + np.array(
+                [self.grid_size / 2, self.grid_size / 2]
+            )
+            self.draw_thing(
+                center,
+                config["Plate"]["parts"],
+                screen=order_screen,
+            )
+            self.draw_item(
+                center,
+                order.meal,
+                plate=True,
+                screen=order_screen,
+            )
+            order_done_seconds = (
+                (order.start_time + order.max_duration) - state["env_time"]
+            ).total_seconds()
+
+            percentage = order_done_seconds / order.max_duration.total_seconds()
+            self.draw_progress_bar(center, percentage, screen=order_screen)
+
+        orders_rect = order_screen.get_rect()
+        orders_rect.center = [
+            screen_margin + (orders_width // 2),
+            orders_height // 2,
+        ]
+        screen.blit(order_screen, orders_rect)
diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
index c8cb0fd7..da414675 100644
--- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py
+++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
@@ -1,31 +1,21 @@
-import colorsys
 import dataclasses
 import json
 import logging
-import math
 import sys
 from datetime import timedelta
 from enum import Enum
 
 import numpy as np
-import numpy.typing as npt
 import pygame
 import pygame_gui
 import requests
 import yaml
-from scipy.spatial import KDTree
 from websockets.sync.client import connect
 
 from overcooked_simulator import ROOT_DIR
-from overcooked_simulator.game_items import (
-    Item,
-    CookingEquipment,
-    Plate,
-)
 from overcooked_simulator.game_server import CreateEnvironmentConfig
-from overcooked_simulator.gui_2d_vis.game_colors import BLUE
-from overcooked_simulator.gui_2d_vis.game_colors import colors, Color
-from overcooked_simulator.order import Order
+from overcooked_simulator.gui_2d_vis.drawing import Visualizer
+from overcooked_simulator.gui_2d_vis.game_colors import colors
 from overcooked_simulator.overcooked_environment import Action
 
 USE_PLAYER_COOK_SPRITES = True
@@ -42,43 +32,19 @@ class MenuStates(Enum):
 MANAGER_ID = "1233245425"
 
 
-def create_polygon(n, length):
-    if n == 1:
-        return np.array([0, 0])
-
-    vector = np.array([length, 0])
-    angle = (2 * np.pi) / n
-
-    rot_matrix = np.array(
-        [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]
-    )
-
-    vecs = [vector]
-    for i in range(n - 1):
-        vector = np.dot(rot_matrix, vector)
-        vecs.append(vector)
-
-    return vecs
-
-
 log = logging.getLogger(__name__)
 
 
 class PlayerKeySet:
     """Set of keyboard keys for controlling a player.
-    First four keys are for movement. Order: Down, Up, Left, Right.
-    5th key is for interacting with counters.
-    6th key ist for picking up things or dropping them.
-
+    First four keys are for movement. Order: Down, Up, Left, Right.    5th key is for interacting with counters.    6th key ist for picking up things or dropping them.
     """
 
     def __init__(self, player_name: str | int, keys: list[pygame.key]):
         """Creates a player key set which contains information about which keyboard keys control the player.
         Movement keys in the following order: Down, Up, Left, Right
-
-        Args:
-            player_name: The name of the player to control.
-            keys: The keys which control this player in the following order: Down, Up, Left, Right, Interact, Pickup.
+        Args:            player_name: The name of the player to control.
+        keys: The keys which control this player in the following order: Down, Up, Left, Right, Interact, Pickup.
         """
         self.name = player_name
         self.player_keys = keys
@@ -142,11 +108,12 @@ class PyGameGUI:
 
         self.images_path = ROOT_DIR / "pygame_gui" / "images"
 
-        self.image_cache_dict = {}
-
         self.menu_state = MenuStates.Start
         self.manager: pygame_gui.UIManager
 
+        self.vis = Visualizer(self.visualization_config)
+        self.vis.create_player_colors(len(self.player_names))
+
     def get_window_sizes(self, state: dict):
         counter_positions = np.array([c["pos"] for c in state["counters"]])
         kitchen_width = counter_positions[:, 0].max() + 0.5
@@ -184,30 +151,11 @@ class PyGameGUI:
         return (
             int(window_width),
             int(window_height),
-            game_width,
-            game_height,
+            int(game_width),
+            int(game_height),
             grid_size,
         )
 
-    def create_player_colors(self) -> list[Color]:
-        number_player = len(self.simulator.env.players)
-        hue_values = np.linspace(0, 1, number_player + 1)
-
-        colors_vec = np.array([col for col in colors.values()])
-
-        tree = KDTree(colors_vec)
-
-        color_names = list(colors.keys())
-
-        player_colors = []
-        for hue in hue_values:
-            rgb = colorsys.hsv_to_rgb(hue, 1, 1)
-            query_color = np.array([int(c * 255) for c in rgb])
-            _, index = tree.query(query_color, k=1)
-            player_colors.append(color_names[index])
-
-        return player_colors
-
     def handle_keys(self):
         """Handles keyboard inputs. Sends action for the respective players. When a key is held down, every frame
         an action is sent in this function.
@@ -232,7 +180,6 @@ class PyGameGUI:
         """Handles key events for the pickup and interaction keys. Pickup is a single action,
         for interaction keydown and keyup is necessary, because the player has to be able to hold
         the key down.
-
         Args:
             event: Pygame event for extracting the key action.
         """
@@ -249,401 +196,6 @@ class PyGameGUI:
                     action = Action(key_set.name, "interact", "keyup")
                     self.send_action(action)
 
-    def draw_background(self):
-        """Visualizes a game background."""
-        block_size = self.grid_size // 2  # Set the size of the grid block
-        self.game_screen.fill(
-            colors[self.visualization_config["Kitchen"]["ground_tiles_color"]]
-        )
-        for x in range(0, self.window_width, block_size):
-            for y in range(0, self.window_height, block_size):
-                rect = pygame.Rect(x, y, block_size, block_size)
-                pygame.draw.rect(
-                    self.game_screen,
-                    self.visualization_config["Kitchen"]["background_lines"],
-                    rect,
-                    1,
-                )
-
-    def draw_image(
-        self, img_path, size, pos, rot_angle=0, screen: pygame.Surface = None
-    ):
-        cache_entry = f"{img_path}"
-        if cache_entry in self.image_cache_dict.keys():
-            image = self.image_cache_dict[cache_entry]
-        else:
-            image = pygame.image.load(
-                ROOT_DIR / "gui_2d_vis" / img_path
-            ).convert_alpha()
-            self.image_cache_dict[cache_entry] = image
-
-        image = pygame.transform.scale(image, (size, size))
-        if rot_angle != 0:
-            image = pygame.transform.rotate(image, rot_angle)
-        rect = image.get_rect()
-        rect.center = pos
-
-        if screen is None:
-            self.game_screen.blit(image, rect)
-        else:
-            screen.blit(image, rect)
-
-    def draw_players(self, state_dict):
-        """Visualizes the players as circles with a triangle for the facing direction.
-        If the player holds something in their hands, it is displayed
-
-        Args:
-            state: The game state returned by the environment.
-        """
-        for p_idx, player_dict in enumerate(state_dict["players"]):
-            # pos = player.pos * self.grid_size
-            pos = np.array(player_dict["pos"]) * self.grid_size
-
-            facing = np.array(player_dict["facing"])
-
-            if USE_PLAYER_COOK_SPRITES:
-                img_path = self.visualization_config["Cook"]["parts"][0]["path"]
-                rel_x, rel_y = facing
-                angle = -np.rad2deg(math.atan2(rel_y, rel_x)) + 90
-                size = (
-                    self.visualization_config["Cook"]["parts"][0]["size"]
-                    * self.grid_size
-                )
-                self.draw_image(img_path, size, pos, angle)
-
-            else:
-                size = 0.4 * self.grid_size
-                color1 = self.player_colors[p_idx]
-                color2 = colors["white"]
-
-                pygame.draw.circle(self.game_screen, color2, pos, size)
-                pygame.draw.circle(self.game_screen, BLUE, pos, size, width=1)
-                pygame.draw.circle(self.game_screen, colors[color1], pos, size // 2)
-
-                pygame.draw.polygon(
-                    self.game_screen,
-                    BLUE,
-                    (
-                        (
-                            pos[0] + (facing[1] * 0.1 * self.grid_size),
-                            pos[1] - (facing[0] * 0.1 * self.grid_size),
-                        ),
-                        (
-                            pos[0] - (facing[1] * 0.1 * self.grid_size),
-                            pos[1] + (facing[0] * 0.1 * self.grid_size),
-                        ),
-                        pos + (facing * 0.5 * self.grid_size),
-                    ),
-                )
-
-            if SHOW_INTERACTION_RANGE:
-                facing_point = np.array(player_dict["facing"])
-
-                pygame.draw.circle(
-                    self.game_screen,
-                    BLUE,
-                    facing_point * self.grid_size,
-                    1.6 * self.grid_size,
-                    width=1,
-                )
-                pygame.draw.circle(
-                    self.game_screen,
-                    colors["red1"],
-                    facing * self.grid_size,
-                    4,
-                )
-                pygame.draw.circle(self.game_screen, colors["red1"], facing, 4)
-
-            if player_dict["holding"] is not None:
-                holding_item_pos = pos + (20 * facing)
-
-                self.draw_thing(
-                    holding_item_pos,
-                    self.visualization_config[player_dict["holding"]]["parts"],
-                )
-
-            # if player.current_nearest_counter:
-            #     counter: Counter = player.current_nearest_counter
-            #     pos = counter.pos * self.grid_size
-            #     pygame.draw.rect(
-            #         self.game_screen,
-            #         colors[self.player_colors[p_idx]],
-            #         rect=pygame.Rect(
-            #             pos[0] - (self.grid_size // 2),
-            #             pos[1] - (self.grid_size // 2),
-            #             self.grid_size,
-            #             self.grid_size,
-            #         ),
-            #         width=2,
-            #     )
-
-    def draw_thing(
-        self,
-        pos: npt.NDArray[float],
-        parts: list[dict[str]],
-        scale: float = 1.0,
-        screen: pygame.Surface = None,
-    ):
-        """Draws an item, based on its visual parts specified in the visualization config.
-
-        Args:
-            pos: Where to draw the item parts.
-            parts: The visual parts to draw.
-            scale: Rescale the item by this factor.
-        """
-
-        if screen is None:
-            screen = self.game_screen
-
-        for part in parts:
-            part_type = part["type"]
-            match part_type:
-                case "image":
-                    if "center_offset" in part:
-                        d = np.array(part["center_offset"]) * self.grid_size
-                        pos += d
-
-                    self.draw_image(
-                        part["path"],
-                        part["size"] * scale * self.grid_size,
-                        pos,
-                        screen=screen,
-                    )
-                case "rect":
-                    height = part["height"] * self.grid_size
-                    width = part["width"] * self.grid_size
-                    color = part["color"]
-                    if "center_offset" in part:
-                        dx, dy = np.array(part["center_offset"]) * self.grid_size
-                        rect = pygame.Rect(pos[0] + dx, pos[1] + dy, height, width)
-                        pygame.draw.rect(screen, color, rect)
-                    else:
-                        rect = pygame.Rect(
-                            pos[0] - (height / 2),
-                            pos[1] - (width / 2),
-                            height,
-                            width,
-                        )
-                    pygame.draw.rect(screen, color, rect)
-                case "circle":
-                    radius = part["radius"] * self.grid_size
-                    color = colors[part["color"]]
-                    if "center_offset" in part:
-                        pygame.draw.circle(
-                            self.game_screen,
-                            color,
-                            pos + (np.array(part["center_offset"]) * self.grid_size),
-                            radius,
-                        )
-                    else:
-                        pygame.draw.circle(screen, color, pos, radius)
-
-    def draw_item(
-        self,
-        pos: npt.NDArray[float],
-        item: Item,
-        scale: float = 1.0,
-        plate=False,
-        screen=None,
-    ):
-        """Visualization of an item at the specified position. On a counter or in the hands of the player.
-        The visual composition of the item is read in from visualization.yaml file, where it is specified as
-        different parts to be drawn.
-
-        Args:
-            pos: The position of the item to draw.
-            item: The item do be drawn in the game.
-            scale: Rescale the item by this factor.
-            screen: the pygame screen to draw on.
-            plate: item is on a plate (soup are is different on a plate and pot)
-        """
-
-        if not isinstance(item, list):
-            if item.name in self.visualization_config:
-                item_key = item.name
-                if "Soup" in item.name and plate:
-                    item_key += "Plate"
-                self.draw_thing(
-                    pos,
-                    self.visualization_config[item_key]["parts"],
-                    scale=scale,
-                    screen=screen,
-                )
-
-        if isinstance(item, (Item, Plate)) and item.progress_percentage > 0.0:
-            self.draw_progress_bar(pos, item.progress_percentage)
-
-        if isinstance(item, CookingEquipment) and item.content_list:
-            if (
-                item.content_ready
-                and item.content_ready.name in self.visualization_config
-            ):
-                self.draw_thing(
-                    pos,
-                    self.visualization_config[item.content_ready.name]["parts"],
-                    screen=screen,
-                )
-            else:
-                triangle_offsets = create_polygon(len(item.content_list), length=10)
-                scale = 1 if len(item.content_list) == 1 else 0.6
-                for idx, o in enumerate(item.content_list):
-                    self.draw_item(
-                        pos + triangle_offsets[idx],
-                        o,
-                        scale=scale,
-                        plate=isinstance(item, Plate),
-                        screen=screen,
-                    )
-
-    def draw_progress_bar(self, pos, percent, screen=None):
-        """Visualize progress of progressing item as a green bar under the item."""
-        bar_height = self.grid_size * 0.2
-        progress_width = percent * self.grid_size
-        progress_bar = pygame.Rect(
-            pos[0] - (self.grid_size / 2),
-            pos[1] - (self.grid_size / 2) + self.grid_size - bar_height,
-            progress_width,
-            bar_height,
-        )
-        if screen is None:
-            pygame.draw.rect(self.game_screen, colors["green1"], progress_bar)
-        else:
-            pygame.draw.rect(screen, colors["green1"], progress_bar)
-
-    def draw_counter(self, counter_dict):
-        """Visualization of a counter at its position. If it is occupied by an item, it is also shown.
-        The visual composition of the counter is read in from visualization.yaml file, where it is specified as
-        different parts to be drawn.
-
-        Args:
-            counter: The counter to visualize.
-        """
-
-        pos = np.array(counter_dict["pos"]) * self.grid_size
-        counter_type = counter_dict["type"]
-        self.draw_thing(pos, self.visualization_config["Counter"]["parts"])
-        if counter_type in self.visualization_config:
-            self.draw_thing(pos, self.visualization_config[counter_type]["parts"])
-        else:
-            self.draw_thing(
-                pos,
-                self.visualization_config[counter_type]["parts"],
-            )
-
-        occupied_by = counter_dict["occupied_by"]
-        if occupied_by is not None:
-            # Multiple plates on plate return:
-            # if isinstance(occupied_by, (list, deque)):
-            #     with self.simulator.env.lock:
-            #         for i, o in enumerate(occupied_by):
-            #             self.draw_item(np.abs([pos[0], pos[1] - (i * 3)]), o)
-            # # All other items:
-            # else:
-            self.draw_thing(
-                pos,
-                self.visualization_config[occupied_by]["parts"],
-            )
-
-    def draw_counters(self, state):
-        """Visualizes the counters in the environment.
-
-        Args:
-            state: The game state returned by the environment.
-        """
-        for counter in state["counters"]:
-            self.draw_counter(counter)
-            if SHOW_COUNTER_CENTERS:
-                pygame.draw.circle(self.game_screen, colors["green1"], counter.pos, 3)
-
-    def update_score_label(self, state):
-        score = state["score"]
-        self.score_label.set_text(f"Score {score}")
-
-    def update_conclusion_label(self, state):
-        score = state["score"]
-        self.conclusion_label.set_text(f"Your final score is {score}. Hurray!")
-
-    def update_remaining_time(self, remaining_time: timedelta):
-        hours, rem = divmod(remaining_time, 3600)
-        minutes, seconds = divmod(rem, 60)
-        display_time = f"{minutes}:{'%02d' % seconds}"
-        self.timer_label.set_text(f"Time remaining: {display_time}")
-
-    def draw_orders(self, state):
-        orders_width = self.game_width - 100
-        orders_height = self.screen_margin
-
-        order_screen = pygame.Surface(
-            (orders_width, orders_height),
-        )
-
-        bg_color = colors[self.visualization_config["GameWindow"]["background_color"]]
-        pygame.draw.rect(order_screen, bg_color, order_screen.get_rect())
-
-        order_rects_start = (orders_height // 2) - (self.grid_size // 2)
-        for idx, order in enumerate(state["orders"]):
-            order: Order
-            order_upper_left = [
-                order_rects_start + idx * self.grid_size * 1.2,
-                order_rects_start,
-            ]
-            pygame.draw.rect(
-                order_screen,
-                colors["red"],
-                pygame.Rect(
-                    order_upper_left[0],
-                    order_upper_left[1],
-                    self.grid_size,
-                    self.grid_size,
-                ),
-                width=2,
-            )
-            center = np.array(order_upper_left) + np.array(
-                [self.grid_size / 2, self.grid_size / 2]
-            )
-            self.draw_thing(
-                center,
-                self.visualization_config["Plate"]["parts"],
-                screen=order_screen,
-            )
-            self.draw_item(
-                center,
-                order.meal,
-                plate=True,
-                screen=order_screen,
-            )
-            order_done_seconds = (
-                (order.start_time + order.max_duration) - state["env_time"]
-            ).total_seconds()
-
-            percentage = order_done_seconds / order.max_duration.total_seconds()
-            self.draw_progress_bar(center, percentage, screen=order_screen)
-
-        orders_rect = order_screen.get_rect()
-        orders_rect.center = [
-            self.screen_margin + (orders_width // 2),
-            orders_height // 2,
-        ]
-        self.main_window.blit(order_screen, orders_rect)
-
-    def draw(self, state):
-        """Main visualization function.
-
-        Args:
-            state: The game state returned by the environment.
-        """
-
-        self.draw_background()
-
-        self.draw_counters(state)
-        self.draw_players(state)
-
-        self.manager.draw_ui(self.main_window)
-        self.update_remaining_time(state["remaining_time"])
-
-        # self.draw_orders(state)
-        self.update_score_label(state)
-
     def init_ui_elements(self):
         self.manager = pygame_gui.UIManager((self.window_width, self.window_height))
         self.manager.get_theme().load_theme(ROOT_DIR / "gui_2d_vis" / "gui_theme.json")
@@ -770,6 +322,27 @@ class PyGameGUI:
             object_id="#score_label",
         )
 
+    def draw(self, state):
+        """Main visualization function.
+
+        Args:            state: The game state returned by the environment."""
+        self.vis.draw_gamescreen(
+            self.game_screen,
+            state,
+            self.game_width,
+            self.game_height,
+            self.grid_size,
+            SHOW_COUNTER_CENTERS,
+            USE_PLAYER_COOK_SPRITES,
+            SHOW_INTERACTION_RANGE,
+        )
+
+        # self.manager.draw_ui(self.main_window)
+        self.update_remaining_time(state["remaining_time"])
+
+        # self.draw_orders(state)
+        self.update_score_label(state)
+
     def set_window_size(self):
         self.game_screen = pygame.Surface(
             (
@@ -824,6 +397,20 @@ class PyGameGUI:
                 self.orders_label.hide()
                 self.conclusion_label.show()
 
+    def update_score_label(self, state):
+        score = state["score"]
+        self.score_label.set_text(f"Score {score}")
+
+    def update_conclusion_label(self, state):
+        score = state["score"]
+        self.conclusion_label.set_text(f"Your final score is {score}. Hurray!")
+
+    def update_remaining_time(self, remaining_time: timedelta):
+        hours, rem = divmod(remaining_time, 3600)
+        minutes, seconds = divmod(rem, 60)
+        display_time = f"{minutes}:{'%02d' % seconds}"
+        self.timer_label.set_text(f"Time remaining: {display_time}")
+
     def start_button_press(self):
         self.menu_state = MenuStates.Game
 
@@ -883,10 +470,10 @@ class PyGameGUI:
 
         # self.api.set_sim(self.simulator)
 
-    # def back_button_press(self):
-    #     self.menu_state = MenuStates.Game
-    #     # self.reset_window_size()
-    #     log.debug("Pressed back button")
+    def back_button_press(self):
+        self.menu_state = MenuStates.Start
+        self.reset_window_size()
+        log.debug("Pressed back button")
 
     def quit_button_press(self):
         self.running = False
@@ -903,8 +490,7 @@ class PyGameGUI:
         )
 
         # self.websocket.send(json.dumps("reset_game"))
-        # answer = self.websocket.recv()
-        log.debug("Pressed reset button")
+        # answer = self.websocket.recv()        log.debug("Pressed reset button")
 
     def finished_button_press(self):
         requests.post(
@@ -954,6 +540,10 @@ class PyGameGUI:
         state = json.loads(self.websockets[self.state_player_id].recv())
         return state
 
+    def disconnect_websockets(self):
+        for websocket in self.websockets.values():
+            websocket.close()
+
     def start_pygame(self):
         """Starts pygame and the gui loop. Each frame the game state is visualized and keyboard inputs are read."""
         log.debug(f"Starting pygame gui at {self.FPS} fps")
@@ -978,20 +568,24 @@ class PyGameGUI:
                     if event.type == pygame.QUIT:
                         self.running = False
 
-                    # UI Buttons:
+                        # UI Buttons:
                     if event.type == pygame_gui.UI_BUTTON_PRESSED:
                         match event.ui_element:
                             case self.start_button:
                                 self.start_button_press()
                             case self.back_button:
-                                self.start_button_press()
+                                self.back_button_press()
+                                self.disconnect_websockets()
+
                             case self.finished_button:
                                 self.finished_button_press()
+                                self.disconnect_websockets()
                             case self.quit_button:
                                 self.quit_button_press()
+                                self.disconnect_websockets()
                             case self.reset_button:
                                 self.reset_button_press()
-                                self.start_button_press()
+                                self.disconnect_websockets()
 
                         self.manage_button_visibility()
 
@@ -1004,10 +598,7 @@ class PyGameGUI:
 
                     self.manager.process_events(event)
 
-                # drawing:
-
-                # state = self.simulator.get_state()
-
+                    # drawing:
                 self.main_window.fill(
                     colors[self.visualization_config["GameWindow"]["background_color"]]
                 )
@@ -1020,13 +611,8 @@ class PyGameGUI:
                     case MenuStates.Game:
                         state = self.request_state()
 
-                        self.draw_background()
-
                         self.handle_keys()
 
-                        # state = self.simulator.get_state()
-                        self.draw(state)
-
                         if state["ended"]:
                             self.finished_button_press()
                             self.manage_button_visibility()
@@ -1038,7 +624,6 @@ class PyGameGUI:
                                 self.window_width // 2,
                                 self.window_height // 2,
                             ]
-
                             self.main_window.blit(self.game_screen, game_screen_rect)
 
                     case MenuStates.End:
@@ -1048,8 +633,7 @@ class PyGameGUI:
                 pygame.display.flip()
 
             except KeyboardInterrupt:
-                pygame.quit()
-                sys.exit()
+                self.running = False
 
         pygame.quit()
         sys.exit()
-- 
GitLab