From 0b6f5f3972a82b953a5832b6cba91ae9706a5ef7 Mon Sep 17 00:00:00 2001
From: fheinrich <fheinrich@techfak.uni-bielefeld.de>
Date: Tue, 5 Mar 2024 09:36:32 +0100
Subject: [PATCH] WebsocketMessage validation in sending with base model

---
 cooperative_cuisine/action.py            |  3 +-
 cooperative_cuisine/game_server.py       |  3 +-
 cooperative_cuisine/pygame_2d_vis/gui.py | 40 ++++++++++++++----------
 3 files changed, 28 insertions(+), 18 deletions(-)

diff --git a/cooperative_cuisine/action.py b/cooperative_cuisine/action.py
index b35819ef..6362d08b 100644
--- a/cooperative_cuisine/action.py
+++ b/cooperative_cuisine/action.py
@@ -2,9 +2,10 @@ from __future__ import annotations
 
 import dataclasses
 from enum import Enum
-from typing import Literal, TypedDict
+from typing import Literal
 
 from numpy import typing as npt
+from typing_extensions import TypedDict
 
 
 class ActionType(Enum):
diff --git a/cooperative_cuisine/game_server.py b/cooperative_cuisine/game_server.py
index 66922d2e..417ed245 100644
--- a/cooperative_cuisine/game_server.py
+++ b/cooperative_cuisine/game_server.py
@@ -591,7 +591,7 @@ class PlayerRequestType(Enum):
     """Indicates a request to pass an action of a player to the environment."""
 
 
-class WebsocketMessage(TypedDict):
+class WebsocketMessage(BaseModel):
     type: str
     action: None | ActionDict
     player_hash: str
@@ -609,6 +609,7 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul
     """
     message_dict = json.loads(message)
     request_type = None
+    # ws_message = WebsocketMessage(type=message_dict["type"], player_hash=message_dict["player_hash"])
     try:
         assert "type" in message_dict, "message needs a type"
 
diff --git a/cooperative_cuisine/pygame_2d_vis/gui.py b/cooperative_cuisine/pygame_2d_vis/gui.py
index dee2e037..a5b3bc8a 100644
--- a/cooperative_cuisine/pygame_2d_vis/gui.py
+++ b/cooperative_cuisine/pygame_2d_vis/gui.py
@@ -1370,11 +1370,13 @@ class PyGameGUI:
             if p < self.number_humans_to_be_added:
                 # add player websockets
                 websocket = connect(self.websocket_url + player_info["client_id"])
-                ws_message = WebsocketMessage(
-                    type=PlayerRequestType.READY.value,
-                    player_hash=player_info["player_hash"],
-                )
-                websocket.send(json.dumps(ws_message))
+                message_dict = {
+                    "type": PlayerRequestType.READY.value,
+                    "action": None,
+                    "player_hash": player_info["player_hash"],
+                }
+                ws_message = WebsocketMessage(**message_dict)
+                websocket.send(ws_message.json())
                 assert (
                     json.loads(websocket.recv())["status"] == 200
                 ), "not accepted player"
@@ -1489,20 +1491,26 @@ class PyGameGUI:
                 float(action.action_data[1]),
             ]
 
-        ws_message = WebsocketMessage(
-            type=PlayerRequestType.ACTION.value,
-            action=dataclasses.asdict(action, dict_factory=custom_asdict_factory),
-            player_hash=self.player_info[action.player]["player_hash"],
-        )
-        self.websockets[action.player].send(json.dumps(ws_message))
+            message_dict = {
+                "type": PlayerRequestType.ACTION.value,
+                "action": dataclasses.asdict(
+                    action, dict_factory=custom_asdict_factory
+                ),
+                "player_hash": self.player_info[action.player]["player_hash"],
+            }
+
+            ws_message = WebsocketMessage(**message_dict)
+        self.websockets[action.player].send(ws_message.json())
         self.websockets[action.player].recv()
 
     def request_state(self):
-        ws_message = WebsocketMessage(
-            type=PlayerRequestType.GET_STATE.value,
-            player_hash=self.player_info[self.state_player_id]["player_hash"],
-        )
-        self.websockets[self.state_player_id].send(json.dumps(ws_message))
+        message_dict = {
+            "type": PlayerRequestType.GET_STATE.value,
+            "action": None,
+            "player_hash": self.player_info[self.state_player_id]["player_hash"],
+        }
+        ws_message = WebsocketMessage(**message_dict)
+        self.websockets[self.state_player_id].send(ws_message.json())
         state = json.loads(self.websockets[self.state_player_id].recv())
         return state
 
-- 
GitLab