From 88375c0b296560f204f32de70efdaabd6e87822e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Thu, 7 Dec 2023 11:51:29 +0100
Subject: [PATCH] style fixes and numpy array annotation

---
 overcooked_simulator/counters.py              |  7 +--
 overcooked_simulator/main.py                  |  5 +--
 .../overcooked_environment.py                 | 14 +++---
 overcooked_simulator/player.py                | 15 ++++---
 overcooked_simulator/pygame_gui/pygame_gui.py | 45 ++++++++++---------
 overcooked_simulator/simulation_runner.py     | 23 +++++-----
 tests/test_start.py                           |  2 +-
 7 files changed, 58 insertions(+), 53 deletions(-)

diff --git a/overcooked_simulator/counters.py b/overcooked_simulator/counters.py
index aa372ed9..2298c843 100644
--- a/overcooked_simulator/counters.py
+++ b/overcooked_simulator/counters.py
@@ -1,6 +1,7 @@
 from typing import Optional
 
 import numpy as np
+import numpy.typing as npt
 
 from overcooked_simulator.game_items import (
     CuttableItem,
@@ -14,8 +15,8 @@ from overcooked_simulator.game_items import (
 class Counter:
     """Simple class for a counter at a specified position (center of counter). Can hold things on top."""
 
-    def __init__(self, pos: np.ndarray):
-        self.pos: np.ndarray = pos
+    def __init__(self, pos: npt.NDArray[float]):
+        self.pos: npt.NDArray[float] = pos
         self.occupied_by: Optional[HoldableItem] = None
 
     def pick_up(self):
@@ -47,7 +48,7 @@ class Counter:
         Args:
             item: The item to be placed on the counter.
 
-        Returns: TODO Return information, wether the score is affected (Serving Window?)
+        Returns: TODO Return information, whether the score is affected (Serving Window?)
 
         """
         if self.occupied_by is None:
diff --git a/overcooked_simulator/main.py b/overcooked_simulator/main.py
index b3fa9d42..2433815d 100644
--- a/overcooked_simulator/main.py
+++ b/overcooked_simulator/main.py
@@ -5,7 +5,6 @@ import numpy as np
 import pygame
 
 from overcooked_simulator import ROOT_DIR
-from overcooked_simulator.game_items import Tomato, Plate
 from overcooked_simulator.player import Player
 from overcooked_simulator.pygame_gui.pygame_gui import PyGameGUI
 from overcooked_simulator.simulation_runner import Simulator
@@ -15,8 +14,8 @@ def main():
     simulator = Simulator(Path(ROOT_DIR, "layouts", "basic.layout"), 600)
     player_one_name = "p1"
     player_two_name = "p2"
-    simulator.register_player(Player(player_one_name, np.array([200, 200])))
-    simulator.register_player(Player(player_two_name, np.array([100, 200])))
+    simulator.register_player(Player(player_one_name, np.array([200.0, 200.0])))
+    simulator.register_player(Player(player_two_name, np.array([100.0, 200.0])))
 
     # TODO maybe read the player names and keyboard keys from config file?
     keys1 = [
diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py
index 8245fafa..12a22460 100644
--- a/overcooked_simulator/overcooked_environment.py
+++ b/overcooked_simulator/overcooked_environment.py
@@ -6,6 +6,8 @@ if TYPE_CHECKING:
     from overcooked_simulator.player import Player
 from pathlib import Path
 import numpy as np
+import numpy.typing as npt
+
 from scipy.spatial import distance_matrix
 from overcooked_simulator.counters import (
     Counter,
@@ -53,7 +55,7 @@ class Environment:
 
     def __init__(self, layout_path):
         self.players: dict[str, Player] = {}
-        self.counter_side_length: float = 40
+        self.counter_side_length: int = 40
         self.layout_path: Path = layout_path
         self.counters: list[Counter] = self.create_counters(self.layout_path)
         self.score: int = 0
@@ -118,7 +120,7 @@ class Environment:
                     elif action.action == "keyup":
                         player.perform_interact_hold_stop(counter)
 
-    def get_closest_counter(self, point: np.ndarray):
+    def get_closest_counter(self, point: npt.NDArray):
         """Determines the closest counter for a given 2d-coordinate point in the env.
 
         Args:
@@ -147,7 +149,7 @@ class Environment:
         facing_counter = self.get_closest_counter(facing_point)
         return facing_counter
 
-    def perform_movement(self, player: Player, move_vector: np.array):
+    def perform_movement(self, player: Player, move_vector: npt.NDArray[int]):
         """Moves a player in the direction specified in the action.action. If the player collides with a
         counter or other player through this movement, then they are not moved.
         (The extended code with the two ifs is for sliding movement at the counters, which feels a bit smoother.
@@ -196,9 +198,9 @@ class Environment:
         Returns: True if the player is intersecting with any object in the environment.
         """
         return (
-                self.detect_player_collision(player)
-                or self.detect_collision_counters(player)
-                or self.detect_collision_world_bounds(player)
+            self.detect_player_collision(player)
+            or self.detect_collision_counters(player)
+            or self.detect_collision_world_bounds(player)
         )
 
     def detect_player_collision(self, player: Player):
diff --git a/overcooked_simulator/player.py b/overcooked_simulator/player.py
index 971f9024..69c65e5c 100644
--- a/overcooked_simulator/player.py
+++ b/overcooked_simulator/player.py
@@ -1,6 +1,7 @@
 from typing import Optional
 
 import numpy as np
+import numpy.typing as npt
 
 from overcooked_simulator.counters import Counter, Trash
 from overcooked_simulator.game_items import HoldableItem, Plate
@@ -13,17 +14,17 @@ class Player:
 
     """
 
-    def __init__(self, name: str, pos: np.ndarray):
+    def __init__(self, name: str, pos: npt.NDArray[float]):
         self.name: str = name
-        self.pos: np.ndarray = np.array(pos, dtype=float)
+        self.pos: npt.NDArray[float] = np.array(pos, dtype=float)
         self.holding: Optional[HoldableItem] = None
 
         self.radius: int = 18
         self.move_dist: int = 5
         self.interaction_range: int = 60
-        self.facing_direction: np.ndarray = np.array([0, 1])
+        self.facing_direction: npt.NDArray[float] = np.array([0, 1])
 
-    def move(self, movement: np.ndarray):
+    def move(self, movement: npt.NDArray[float]):
         """Moves the player position by the given movement vector.
         A unit direction vector multiplied by move_dist is added to the player position.
 
@@ -34,7 +35,7 @@ class Player:
         if np.linalg.norm(movement) != 0:
             self.turn(movement)
 
-    def move_abs(self, new_pos: np.ndarray):
+    def move_abs(self, new_pos: npt.NDArray[float]):
         """Overwrites the player location by the new_pos 2d-vector. Absolute movement.
         Mostly needed for resetting the player after a collision.
 
@@ -43,7 +44,7 @@ class Player:
         """
         self.pos = new_pos
 
-    def turn(self, direction: np.ndarray):
+    def turn(self, direction: npt.NDArray[float]):
         """Turns the player in the given direction. Overwrites the facing_direction by a given 2d-vector.
         facing_direction is normalized to length 1.
 
@@ -54,7 +55,7 @@ class Player:
             self.facing_direction = direction / np.linalg.norm(direction)
 
     def can_reach(self, counter: Counter):
-        """Checks wether the player can reach the counter in question. Simple check if the distance is not larger
+        """Checks whether the player can reach the counter in question. Simple check if the distance is not larger
         than the player interaction range.
 
         Args:
diff --git a/overcooked_simulator/pygame_gui/pygame_gui.py b/overcooked_simulator/pygame_gui/pygame_gui.py
index 6151e635..be8b4238 100644
--- a/overcooked_simulator/pygame_gui/pygame_gui.py
+++ b/overcooked_simulator/pygame_gui/pygame_gui.py
@@ -11,7 +11,7 @@ from overcooked_simulator.counters import (
     PlateReturn,
     ServingWindow,
 )
-from overcooked_simulator.game_items import ProgressibleItem, Plate
+from overcooked_simulator.game_items import ProgressibleItem, Plate, HoldableItem
 from overcooked_simulator.game_items import Tomato
 from overcooked_simulator.overcooked_environment import Action
 from overcooked_simulator.simulation_runner import Simulator
@@ -19,7 +19,7 @@ from overcooked_simulator.simulation_runner import Simulator
 WHITE = (255, 255, 255)
 GREY = (190, 190, 190)
 BLACK = (0, 0, 0)
-COUNTERCOLOR = (240, 240, 240)
+COUNTER_COLOR = (240, 240, 240)
 LIGHTGREY = (220, 220, 220)
 GREEN = (0, 255, 0)
 RED = (255, 0, 0)
@@ -32,7 +32,7 @@ PLATE_RETURN_COLOR = (170, 170, 240)
 BOARD_COLOR = (239, 193, 151)
 
 
-class PlayerKeyset:
+class PlayerKeySet:
     """Set of keyboard keys for controlling a player.
     First four keys are for movement. Order: Down, Up, Left, Right.
     5th key is for interacting with counters.
@@ -41,7 +41,7 @@ class PlayerKeyset:
     """
 
     def __init__(self, player_name: str, keys: list[pygame.key]):
-        """Creates a player keyset which contains information about which keyboard keys control the player.
+        """Creates a player key set which contains information about which keyboard keys control the player.
         Movement keys in the following order: Down, Up, Left, Right
 
         Args:
@@ -59,7 +59,7 @@ class PlayerKeyset:
 
 
 class PyGameGUI:
-    """Visualisation of the overcooked environmnent and reading keyboard inputs using pygame."""
+    """Visualisation of the overcooked environment and reading keyboard inputs using pygame."""
 
     def __init__(
         self,
@@ -67,6 +67,7 @@ class PyGameGUI:
         player_names: list[str],
         player_keys: list[pygame.key],
     ):
+        self.screen = None
         self.FPS = 60
         self.simulator = simulator
         self.counter_size = self.simulator.env.counter_side_length
@@ -79,10 +80,10 @@ class PyGameGUI:
         self.player_keys = player_keys
         assert len(self.player_names) == len(
             self.player_keys
-        ), "Number of players and keysets should match."
+        ), "Number of players and key sets should match."
 
-        self.player_keysets: list[PlayerKeyset] = [
-            PlayerKeyset(player_name, keys)
+        self.player_key_sets: list[PlayerKeySet] = [
+            PlayerKeySet(player_name, keys)
             for player_name, keys in zip(self.player_names, self.player_keys)
         ]
 
@@ -101,13 +102,13 @@ class PyGameGUI:
         an action is sent in this function.
         """
         keys = pygame.key.get_pressed()
-        for player_idx, keyset in enumerate(self.player_keysets):
-            relevant_keys = [keys[k] for k in keyset.player_keys]
+        for player_idx, key_set in enumerate(self.player_key_sets):
+            relevant_keys = [keys[k] for k in key_set.player_keys]
             if any(relevant_keys[:-2]):
                 move_vec = np.zeros(2)
                 for idx, pressed in enumerate(relevant_keys[:-2]):
                     if pressed:
-                        move_vec += keyset.move_vectors[idx]
+                        move_vec += key_set.move_vectors[idx]
                 if np.linalg.norm(move_vec) != 0:
                     move_vec = move_vec / np.linalg.norm(move_vec)
 
@@ -122,17 +123,17 @@ class PyGameGUI:
         Args:
             event: Pygame event for extracting the key action.
         """
-        for keyset in self.player_keysets:
-            if event.key == keyset.pickup_key and event.type == pygame.KEYDOWN:
-                action = Action(keyset.name, "pickup", "pickup")
+        for key_set in self.player_key_sets:
+            if event.key == key_set.pickup_key and event.type == pygame.KEYDOWN:
+                action = Action(key_set.name, "pickup", "pickup")
                 self.send_action(action)
 
-            if event.key == keyset.interact_key:
+            if event.key == key_set.interact_key:
                 if event.type == pygame.KEYDOWN:
-                    action = Action(keyset.name, "interact", "keydown")
+                    action = Action(key_set.name, "interact", "keydown")
                     self.send_action(action)
                 elif event.type == pygame.KEYUP:
-                    action = Action(keyset.name, "interact", "keyup")
+                    action = Action(key_set.name, "interact", "keyup")
                     self.send_action(action)
 
     def draw_background(self):
@@ -144,7 +145,7 @@ class PyGameGUI:
                 pygame.draw.rect(self.screen, BACKGROUND_LINES_COLOR, rect, 1)
 
     def draw_players(self, state):
-        """Visualizes the players as circles with a triangle for the facing diretion.
+        """Visualizes the players as circles with a triangle for the facing direction.
         If the player holds something in their hands, it is displayed
 
         Args:
@@ -175,7 +176,7 @@ class PyGameGUI:
                 holding_item_pos = player.pos + (20 * player.facing_direction)
                 self.draw_item(holding_item_pos, player.holding)
 
-    def draw_item(self, pos, item):
+    def draw_item(self, pos, item: HoldableItem):
         """Visualisation of an item at the specified position. On a counter or in the hands of the player."""
         if isinstance(item, Tomato):
             if item.finished:
@@ -192,7 +193,7 @@ class PyGameGUI:
 
         if isinstance(item, Plate):
             image = pygame.image.load(
-                "overcooked_simulator/pygame_gui/images/plate.png"
+                self.images_path / "plate.png"
             ).convert_alpha()  # or .convert_alpha()
             rect = image.get_rect()
             rect.center = pos
@@ -229,7 +230,7 @@ class PyGameGUI:
             self.counter_size,
             self.counter_size,
         )
-        pygame.draw.rect(self.screen, COUNTERCOLOR, counter_rect_outline)
+        pygame.draw.rect(self.screen, COUNTER_COLOR, counter_rect_outline)
 
         if isinstance(counter, CuttingBoard):
             board_size = 30
@@ -309,7 +310,7 @@ class PyGameGUI:
         pygame.display.flip()
 
     def start_pygame(self):
-        """Starts pygame and the gui loop. Each frame the gamestate is visualized and keyboard inputs are read."""
+        """Starts pygame and the gui loop. Each frame the game state is visualized and keyboard inputs are read."""
         pygame.init()
         pygame.font.init()
 
diff --git a/overcooked_simulator/simulation_runner.py b/overcooked_simulator/simulation_runner.py
index 3738a239..31d9d377 100644
--- a/overcooked_simulator/simulation_runner.py
+++ b/overcooked_simulator/simulation_runner.py
@@ -6,23 +6,24 @@ from overcooked_simulator.player import Player
 
 
 class Simulator(Thread):
-    """Simulator main class which runs manages the environment and player inputs and gamestate outputs.
+    """Simulator main class which runs manages the environment and player inputs and game state outputs.
 
     Main Simulator class which runs the game environment. Players can be registered in the game.
     The simulator is run as its own thread.
 
     Typical usage example:
-
-      sim = Simulator()
-      sim.register_player(Player("p1", [x,y]))
-      sim.start()
+    ```python
+    sim = Simulator()
+    sim.register_player(Player("p1", [x,y]))
+    sim.start()
+    ```
     """
 
     def __init__(self, env_layout_path, frequency: int):
         self.finished: bool = False
 
         self.step_frequency: int = frequency
-        self.prefered_sleeptime_ns: float = 1e9 / self.step_frequency
+        self.preferred_sleep_time_ns: float = 1e9 / self.step_frequency
         self.env: Environment = Environment(env_layout_path)
 
         super().__init__()
@@ -40,19 +41,19 @@ class Simulator(Thread):
         self.env.perform_action(action)
 
     def get_state(self):
-        """Get the current gamestate as python objects.
+        """Get the current game state as python objects.
 
         Returns:
-            The current state of the game. Currently as dict with lists of environment objects.
+            The current state of the game. Currently, as dict with lists of environment objects.
         """
 
         return self.env.get_state()
 
     def get_state_json(self):
-        """Get the current gamestate in json-like dict.
+        """Get the current game state in json-like dict.
 
         Returns:
-            The gamestate encoded in a json style nested dict.
+            The gamest ate encoded in a json style nested dict.
         """
 
         return self.env.get_state_json()
@@ -86,7 +87,7 @@ class Simulator(Thread):
             self.step()
             step_duration = time.time_ns() - step_start
 
-            time_to_sleep_ns = self.prefered_sleeptime_ns - (
+            time_to_sleep_ns = self.preferred_sleep_time_ns - (
                 step_duration + overslept_in_ns
             )
 
diff --git a/tests/test_start.py b/tests/test_start.py
index 4a68443a..c618fa85 100644
--- a/tests/test_start.py
+++ b/tests/test_start.py
@@ -36,7 +36,7 @@ def test_player_registration():
     assert len(sim.env.players) == 2, "Wrong number of players"
 
     p3 = Player("player2", np.array([100, 100]))
-    sim.register_player(p2)  # same player name
+    sim.register_player(p3)  # same player name
     assert len(sim.env.players) == 2, "Wrong number of players"
 
     sim.start()
-- 
GitLab