From 0b6f5f3972a82b953a5832b6cba91ae9706a5ef7 Mon Sep 17 00:00:00 2001 From: fheinrich <fheinrich@techfak.uni-bielefeld.de> Date: Tue, 5 Mar 2024 09:36:32 +0100 Subject: [PATCH] WebsocketMessage validation in sending with base model --- cooperative_cuisine/action.py | 3 +- cooperative_cuisine/game_server.py | 3 +- cooperative_cuisine/pygame_2d_vis/gui.py | 40 ++++++++++++++---------- 3 files changed, 28 insertions(+), 18 deletions(-) diff --git a/cooperative_cuisine/action.py b/cooperative_cuisine/action.py index b35819ef..6362d08b 100644 --- a/cooperative_cuisine/action.py +++ b/cooperative_cuisine/action.py @@ -2,9 +2,10 @@ from __future__ import annotations import dataclasses from enum import Enum -from typing import Literal, TypedDict +from typing import Literal from numpy import typing as npt +from typing_extensions import TypedDict class ActionType(Enum): diff --git a/cooperative_cuisine/game_server.py b/cooperative_cuisine/game_server.py index 66922d2e..417ed245 100644 --- a/cooperative_cuisine/game_server.py +++ b/cooperative_cuisine/game_server.py @@ -591,7 +591,7 @@ class PlayerRequestType(Enum): """Indicates a request to pass an action of a player to the environment.""" -class WebsocketMessage(TypedDict): +class WebsocketMessage(BaseModel): type: str action: None | ActionDict player_hash: str @@ -609,6 +609,7 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul """ message_dict = json.loads(message) request_type = None + # ws_message = WebsocketMessage(type=message_dict["type"], player_hash=message_dict["player_hash"]) try: assert "type" in message_dict, "message needs a type" diff --git a/cooperative_cuisine/pygame_2d_vis/gui.py b/cooperative_cuisine/pygame_2d_vis/gui.py index dee2e037..a5b3bc8a 100644 --- a/cooperative_cuisine/pygame_2d_vis/gui.py +++ b/cooperative_cuisine/pygame_2d_vis/gui.py @@ -1370,11 +1370,13 @@ class PyGameGUI: if p < self.number_humans_to_be_added: # add player websockets websocket = connect(self.websocket_url + player_info["client_id"]) - ws_message = WebsocketMessage( - type=PlayerRequestType.READY.value, - player_hash=player_info["player_hash"], - ) - websocket.send(json.dumps(ws_message)) + message_dict = { + "type": PlayerRequestType.READY.value, + "action": None, + "player_hash": player_info["player_hash"], + } + ws_message = WebsocketMessage(**message_dict) + websocket.send(ws_message.json()) assert ( json.loads(websocket.recv())["status"] == 200 ), "not accepted player" @@ -1489,20 +1491,26 @@ class PyGameGUI: float(action.action_data[1]), ] - ws_message = WebsocketMessage( - type=PlayerRequestType.ACTION.value, - action=dataclasses.asdict(action, dict_factory=custom_asdict_factory), - player_hash=self.player_info[action.player]["player_hash"], - ) - self.websockets[action.player].send(json.dumps(ws_message)) + message_dict = { + "type": PlayerRequestType.ACTION.value, + "action": dataclasses.asdict( + action, dict_factory=custom_asdict_factory + ), + "player_hash": self.player_info[action.player]["player_hash"], + } + + ws_message = WebsocketMessage(**message_dict) + self.websockets[action.player].send(ws_message.json()) self.websockets[action.player].recv() def request_state(self): - ws_message = WebsocketMessage( - type=PlayerRequestType.GET_STATE.value, - player_hash=self.player_info[self.state_player_id]["player_hash"], - ) - self.websockets[self.state_player_id].send(json.dumps(ws_message)) + message_dict = { + "type": PlayerRequestType.GET_STATE.value, + "action": None, + "player_hash": self.player_info[self.state_player_id]["player_hash"], + } + ws_message = WebsocketMessage(**message_dict) + self.websockets[self.state_player_id].send(ws_message.json()) state = json.loads(self.websockets[self.state_player_id].recv()) return state -- GitLab