diff --git a/cooperative_cuisine/action.py b/cooperative_cuisine/action.py index a080070b23a716556ca33841cf0226928c52334a..bb94a9369ddb0293464715fbb19f7edcbbdebb1f 100644 --- a/cooperative_cuisine/action.py +++ b/cooperative_cuisine/action.py @@ -36,7 +36,7 @@ class Action: """Id of the player.""" action_type: ActionType """Type of the action to perform. Defines what action data is valid.""" - action_data: npt.NDArray[float] | InterActionData | Literal["pickup"] + action_data: npt.NDArray[float] | list[float] | InterActionData | Literal["pickup"] """Data for the action, e.g., movement vector or start and stop interaction.""" duration: float | int = 0 """Duration of the action (relevant for movement)""" diff --git a/cooperative_cuisine/game_server.py b/cooperative_cuisine/game_server.py index 2d1262944b82542fc90831169b74dbaeca32a6ef..53205d1fd284a003086e4905ba0f8a30b999ad45 100644 --- a/cooperative_cuisine/game_server.py +++ b/cooperative_cuisine/game_server.py @@ -613,18 +613,14 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul message_dict = json.loads(message) request_type = None try: - assert "type" in message_dict, "message needs a type" - - request_type = PlayerRequestType(message_dict["type"]) - assert ( - "player_hash" in message_dict - ), "'player_hash' key not in message dictionary'" + ws_message = WebsocketMessage(**message_dict) + request_type = PlayerRequestType(ws_message.type) match request_type: case PlayerRequestType.GET_STATE: - state = environment_handler.get_state(message_dict["player_hash"]) + state = environment_handler.get_state(ws_message.player_hash) if isinstance(state, int): return { - "request_type": message_dict["type"], + "request_type": ws_message.type, "status": 400, "msg": "env id of player not in running envs" if state == 2 @@ -633,40 +629,33 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul } return state case PlayerRequestType.READY: - accepted = environment_handler.set_player_ready( - message_dict["player_hash"] - ) + accepted = environment_handler.set_player_ready(ws_message.player_hash) return { "request_type": request_type.value, "msg": f"ready{' ' if accepted else ' not '}accepted", "status": 200 if accepted else 400, - "player_hash": message_dict["player_hash"], + "player_hash": ws_message.player_hash, } case PlayerRequestType.ACTION: - assert ( - "action" in message_dict - ), "'action' key not in message dictionary of 'action' request" - assert ( - "action_data" in message_dict["action"] - ), "'action_data' key not in message dictionary['action'] of 'action' request" - if isinstance(message_dict["action"]["action_data"], list): - message_dict["action"]["action_data"] = np.array( - message_dict["action"]["action_data"], dtype=float + assert ws_message.action is not None + if isinstance(ws_message.action.action_data, list): + ws_message.action.action_data = np.array( + ws_message.action.action_data, dtype=float ) accepted = environment_handler.player_action( - message_dict["player_hash"], Action(**message_dict["action"]) + ws_message.player_hash, ws_message.action ) return { "request_type": request_type.value, "status": 200 if accepted else 400, "msg": f"action{' ' if accepted else ' not '}accepted", - "player_hash": message_dict["player_hash"], + "player_hash": ws_message.player_hash, } return { "request_type": request_type.value, "status": 400, "msg": "request not handled", - "player_hash": message_dict["player_hash"], + "player_hash": ws_message.player_hash, } except ValueError as e: return { diff --git a/cooperative_cuisine/pygame_2d_vis/gui.py b/cooperative_cuisine/pygame_2d_vis/gui.py index 7d5d3b5bb152cb57b20036cba8c4416a129c6b87..e3819454cd133759191ccc447098227f40d9f932 100644 --- a/cooperative_cuisine/pygame_2d_vis/gui.py +++ b/cooperative_cuisine/pygame_2d_vis/gui.py @@ -253,7 +253,7 @@ class PyGameGUI: action = Action( current_player_name, - ActionType.MOVEMENT, + ActionType.MOVEMENT.value, move_vec, duration=self.time_delta, ) @@ -294,7 +294,7 @@ class PyGameGUI: action = Action( current_player_name, - ActionType.MOVEMENT, + ActionType.MOVEMENT.value, move_vec, duration=self.time_delta, )