From 1b72bbdf8bfb05db54260857dc7aa05c4fd9854a Mon Sep 17 00:00:00 2001 From: fheinrich <fheinrich@techfak.uni-bielefeld.de> Date: Tue, 5 Mar 2024 11:36:30 +0100 Subject: [PATCH] Validation of incoming and outgoing websocket messages --- cooperative_cuisine/action.py | 2 +- cooperative_cuisine/game_server.py | 37 +++++++++--------------- cooperative_cuisine/pygame_2d_vis/gui.py | 4 +-- 3 files changed, 16 insertions(+), 27 deletions(-) diff --git a/cooperative_cuisine/action.py b/cooperative_cuisine/action.py index a080070b..bb94a936 100644 --- a/cooperative_cuisine/action.py +++ b/cooperative_cuisine/action.py @@ -36,7 +36,7 @@ class Action: """Id of the player.""" action_type: ActionType """Type of the action to perform. Defines what action data is valid.""" - action_data: npt.NDArray[float] | InterActionData | Literal["pickup"] + action_data: npt.NDArray[float] | list[float] | InterActionData | Literal["pickup"] """Data for the action, e.g., movement vector or start and stop interaction.""" duration: float | int = 0 """Duration of the action (relevant for movement)""" diff --git a/cooperative_cuisine/game_server.py b/cooperative_cuisine/game_server.py index 2d126294..53205d1f 100644 --- a/cooperative_cuisine/game_server.py +++ b/cooperative_cuisine/game_server.py @@ -613,18 +613,14 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul message_dict = json.loads(message) request_type = None try: - assert "type" in message_dict, "message needs a type" - - request_type = PlayerRequestType(message_dict["type"]) - assert ( - "player_hash" in message_dict - ), "'player_hash' key not in message dictionary'" + ws_message = WebsocketMessage(**message_dict) + request_type = PlayerRequestType(ws_message.type) match request_type: case PlayerRequestType.GET_STATE: - state = environment_handler.get_state(message_dict["player_hash"]) + state = environment_handler.get_state(ws_message.player_hash) if isinstance(state, int): return { - "request_type": message_dict["type"], + "request_type": ws_message.type, "status": 400, "msg": "env id of player not in running envs" if state == 2 @@ -633,40 +629,33 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul } return state case PlayerRequestType.READY: - accepted = environment_handler.set_player_ready( - message_dict["player_hash"] - ) + accepted = environment_handler.set_player_ready(ws_message.player_hash) return { "request_type": request_type.value, "msg": f"ready{' ' if accepted else ' not '}accepted", "status": 200 if accepted else 400, - "player_hash": message_dict["player_hash"], + "player_hash": ws_message.player_hash, } case PlayerRequestType.ACTION: - assert ( - "action" in message_dict - ), "'action' key not in message dictionary of 'action' request" - assert ( - "action_data" in message_dict["action"] - ), "'action_data' key not in message dictionary['action'] of 'action' request" - if isinstance(message_dict["action"]["action_data"], list): - message_dict["action"]["action_data"] = np.array( - message_dict["action"]["action_data"], dtype=float + assert ws_message.action is not None + if isinstance(ws_message.action.action_data, list): + ws_message.action.action_data = np.array( + ws_message.action.action_data, dtype=float ) accepted = environment_handler.player_action( - message_dict["player_hash"], Action(**message_dict["action"]) + ws_message.player_hash, ws_message.action ) return { "request_type": request_type.value, "status": 200 if accepted else 400, "msg": f"action{' ' if accepted else ' not '}accepted", - "player_hash": message_dict["player_hash"], + "player_hash": ws_message.player_hash, } return { "request_type": request_type.value, "status": 400, "msg": "request not handled", - "player_hash": message_dict["player_hash"], + "player_hash": ws_message.player_hash, } except ValueError as e: return { diff --git a/cooperative_cuisine/pygame_2d_vis/gui.py b/cooperative_cuisine/pygame_2d_vis/gui.py index 7d5d3b5b..e3819454 100644 --- a/cooperative_cuisine/pygame_2d_vis/gui.py +++ b/cooperative_cuisine/pygame_2d_vis/gui.py @@ -253,7 +253,7 @@ class PyGameGUI: action = Action( current_player_name, - ActionType.MOVEMENT, + ActionType.MOVEMENT.value, move_vec, duration=self.time_delta, ) @@ -294,7 +294,7 @@ class PyGameGUI: action = Action( current_player_name, - ActionType.MOVEMENT, + ActionType.MOVEMENT.value, move_vec, duration=self.time_delta, ) -- GitLab