diff --git a/cooperative_cuisine/action.py b/cooperative_cuisine/action.py index b35819efdc2abf2c04f02413368c63ecc8176e3f..6362d08b8bd0e5afd768b49a517f80086963e77d 100644 --- a/cooperative_cuisine/action.py +++ b/cooperative_cuisine/action.py @@ -2,9 +2,10 @@ from __future__ import annotations import dataclasses from enum import Enum -from typing import Literal, TypedDict +from typing import Literal from numpy import typing as npt +from typing_extensions import TypedDict class ActionType(Enum): diff --git a/cooperative_cuisine/game_server.py b/cooperative_cuisine/game_server.py index 66922d2ec7ad5aed546efaa33af17dc61dffd02f..417ed245ef744fa50df5007b12e79925f66ef141 100644 --- a/cooperative_cuisine/game_server.py +++ b/cooperative_cuisine/game_server.py @@ -591,7 +591,7 @@ class PlayerRequestType(Enum): """Indicates a request to pass an action of a player to the environment.""" -class WebsocketMessage(TypedDict): +class WebsocketMessage(BaseModel): type: str action: None | ActionDict player_hash: str @@ -609,6 +609,7 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul """ message_dict = json.loads(message) request_type = None + # ws_message = WebsocketMessage(type=message_dict["type"], player_hash=message_dict["player_hash"]) try: assert "type" in message_dict, "message needs a type" diff --git a/cooperative_cuisine/pygame_2d_vis/gui.py b/cooperative_cuisine/pygame_2d_vis/gui.py index dee2e03776fb50e21f6c2d6149e7a0914d5d83e2..a5b3bc8a2b3ea9cacb58eea129cde1384a397605 100644 --- a/cooperative_cuisine/pygame_2d_vis/gui.py +++ b/cooperative_cuisine/pygame_2d_vis/gui.py @@ -1370,11 +1370,13 @@ class PyGameGUI: if p < self.number_humans_to_be_added: # add player websockets websocket = connect(self.websocket_url + player_info["client_id"]) - ws_message = WebsocketMessage( - type=PlayerRequestType.READY.value, - player_hash=player_info["player_hash"], - ) - websocket.send(json.dumps(ws_message)) + message_dict = { + "type": PlayerRequestType.READY.value, + "action": None, + "player_hash": player_info["player_hash"], + } + ws_message = WebsocketMessage(**message_dict) + websocket.send(ws_message.json()) assert ( json.loads(websocket.recv())["status"] == 200 ), "not accepted player" @@ -1489,20 +1491,26 @@ class PyGameGUI: float(action.action_data[1]), ] - ws_message = WebsocketMessage( - type=PlayerRequestType.ACTION.value, - action=dataclasses.asdict(action, dict_factory=custom_asdict_factory), - player_hash=self.player_info[action.player]["player_hash"], - ) - self.websockets[action.player].send(json.dumps(ws_message)) + message_dict = { + "type": PlayerRequestType.ACTION.value, + "action": dataclasses.asdict( + action, dict_factory=custom_asdict_factory + ), + "player_hash": self.player_info[action.player]["player_hash"], + } + + ws_message = WebsocketMessage(**message_dict) + self.websockets[action.player].send(ws_message.json()) self.websockets[action.player].recv() def request_state(self): - ws_message = WebsocketMessage( - type=PlayerRequestType.GET_STATE.value, - player_hash=self.player_info[self.state_player_id]["player_hash"], - ) - self.websockets[self.state_player_id].send(json.dumps(ws_message)) + message_dict = { + "type": PlayerRequestType.GET_STATE.value, + "action": None, + "player_hash": self.player_info[self.state_player_id]["player_hash"], + } + ws_message = WebsocketMessage(**message_dict) + self.websockets[self.state_player_id].send(ws_message.json()) state = json.loads(self.websockets[self.state_player_id].recv()) return state