From 4569ab54f079481f186919034d915b164b3250e1 Mon Sep 17 00:00:00 2001
From: fheinrich <fheinrich@techfak.uni-bielefeld.de>
Date: Wed, 7 Feb 2024 11:04:12 +0100
Subject: [PATCH] View restriction as BaseModel, can be switched with players

---
 overcooked_simulator/gui_2d_vis/drawing.py        |  6 +++---
 overcooked_simulator/gui_2d_vis/overcooked_gui.py |  3 +--
 overcooked_simulator/overcooked_environment.py    | 11 ++++++-----
 overcooked_simulator/state_representation.py      |  9 ++++++++-
 4 files changed, 18 insertions(+), 11 deletions(-)

diff --git a/overcooked_simulator/gui_2d_vis/drawing.py b/overcooked_simulator/gui_2d_vis/drawing.py
index 5472397e..7730b075 100644
--- a/overcooked_simulator/gui_2d_vis/drawing.py
+++ b/overcooked_simulator/gui_2d_vis/drawing.py
@@ -150,9 +150,9 @@ class Visualizer:
             # rotate direction vector in both direction with the angel
             # draw 2 large rect which are rotated so that one edge is the viewing border
 
-            direction = pygame.math.Vector2(state["view_restriction"][0])
-            pos = pygame.math.Vector2(state["view_restriction"][1])
-            angle = state["view_restriction"][2]
+            direction = pygame.math.Vector2(state["view_restriction"]["direction"])
+            pos = pygame.math.Vector2(state["view_restriction"]["position"])
+            angle = state["view_restriction"]["angle"]
 
             pos = pos * grid_size + pygame.math.Vector2([grid_size / 2, grid_size / 2])
 
diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
index 9f663325..a4d28c5c 100644
--- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py
+++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
@@ -28,7 +28,6 @@ from overcooked_simulator.utils import (
     url_and_port_arguments,
     disable_websocket_logging_arguments,
     add_list_of_manager_ids_arguments,
-    setup_logging,
 )
 
 
@@ -936,7 +935,7 @@ class PyGameGUI:
             json.dumps(
                 {
                     "type": "get_state",
-                    "player_hash": self.player_info[self.state_player_id][
+                    "player_hash": self.player_info[str(self.key_sets[0].current_idx)][
                         "player_hash"
                     ],
                 }
diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py
index 444163d5..40aae7a9 100644
--- a/overcooked_simulator/overcooked_environment.py
+++ b/overcooked_simulator/overcooked_environment.py
@@ -766,11 +766,12 @@ class Environment:
                 "remaining_time": max(
                     self.env_time_end - self.env_time, timedelta(0)
                 ).total_seconds(),
-                "view_restriction": [
-                    self.players[player_id].facing_direction.tolist(),
-                    self.players[player_id].pos.tolist(),
-                    35.0,
-                ]
+                "view_restriction": {
+                    "direction": self.players[player_id].facing_direction.tolist(),
+                    "position": self.players[player_id].pos.tolist(),
+                    "angle": 35.0,
+                    "counter_mask": None,
+                }
                 if FOG_OF_WAR
                 else None,
             }
diff --git a/overcooked_simulator/state_representation.py b/overcooked_simulator/state_representation.py
index c7315af9..9d4e7427 100644
--- a/overcooked_simulator/state_representation.py
+++ b/overcooked_simulator/state_representation.py
@@ -66,6 +66,13 @@ class KitchenInfo(BaseModel):
     height: float
 
 
+class ViewRestriction(BaseModel):
+    direction: list[float]
+    position: list[float]
+    angle: int  # degrees
+    counter_mask: None | list[bool]
+
+
 class StateRepresentation(BaseModel):
     """The format of the returned state representation."""
 
@@ -77,7 +84,7 @@ class StateRepresentation(BaseModel):
     ended: bool
     env_time: datetime  # isoformat str
     remaining_time: float
-    view_restriction: None | list[list[float] | float]
+    view_restriction: None | ViewRestriction
 
 
 def create_json_schema():
-- 
GitLab