diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py index cb118cb2e11b50fa129397cab974c8f172f77063..eee0de2adf5e018306a5f90865be4a9840353a14 100644 --- a/overcooked_simulator/game_server.py +++ b/overcooked_simulator/game_server.py @@ -589,9 +589,12 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul assert ( "action" in message_dict ), "'action' key not in message dictionary of 'action' request" - if isinstance(message_dict["action"]["action"], list): - message_dict["action"]["action"] = np.array( - message_dict["action"]["action"], dtype=float + assert ( + "action_data" in message_dict["action"] + ), "'action_data' key not in message dictionary['action'] of 'action' request" + if isinstance(message_dict["action"]["action_data"], list): + message_dict["action"]["action_data"] = np.array( + message_dict["action"]["action_data"], dtype=float ) accepted = environment_handler.player_action( message_dict["player_hash"], Action(**message_dict["action"]) @@ -608,11 +611,11 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul "msg": "request not handled", "player_hash": message_dict["player_hash"], } - except ValueError: + except ValueError as e: return { "request_type": message_dict["type"], "status": 400, - "msg": "Invalid request type", + "msg": e.args[0], "player_hash": None, } except AssertionError as e: diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py index b47875f5598276c30e2621c049f41f42d9ba4177..b88d294e5eae744606c763accae35ac50afab440 100644 --- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py +++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py @@ -21,6 +21,7 @@ from overcooked_simulator.overcooked_environment import ( ActionType, InterActionData, ) +from overcooked_simulator.utils import custom_asdict_factory class MenuStates(Enum): @@ -174,7 +175,9 @@ class PyGameGUI: if np.linalg.norm(move_vec) != 0: move_vec = move_vec / np.linalg.norm(move_vec) - action = Action(key_set.name, ActionType.MOVEMENT, move_vec) + action = Action( + key_set.name, ActionType.MOVEMENT, move_vec, duration=1 / self.FPS + ) self.send_action(action) def handle_key_event(self, event): @@ -521,18 +524,24 @@ class PyGameGUI: Args: action: The action to be sent. Contains the player, action type and move direction if action is a movement. """ - if isinstance(action.action, np.ndarray): - action.action = [float(action.action[0]), float(action.action[1])] + if isinstance(action.action_data, np.ndarray): + action.action_data = [ + float(action.action_data[0]), + float(action.action_data[1]), + ] self.websockets[action.player].send( json.dumps( { "type": "action", - "action": dataclasses.asdict(action), + "action": dataclasses.asdict( + action, dict_factory=custom_asdict_factory + ), "player_hash": self.player_info[action.player]["player_hash"], } ) ) - self.websockets[action.player].recv() + a = self.websockets[action.player].recv() + print(a) def request_state(self): self.websockets[self.state_player_id].send( diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index 9154646dba08e8775fd07a4256ca92a4728ed5e1..15e8751645f4346add67f4fd287f07b1eabae595 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -1,15 +1,15 @@ from __future__ import annotations +import dataclasses import datetime import json -import dataclasses import logging import random from datetime import timedelta from enum import Enum from pathlib import Path from threading import Lock -from typing import Literal, Any +from typing import Literal import numpy as np import numpy.typing as npt @@ -78,6 +78,12 @@ class Action: def __repr__(self): return f"Action({self.player},{self.action_type.value},{self.action_data},{self.duration})" + def __post_init__(self): + if isinstance(self.action_type, str): + self.action_type = ActionType(self.action_type) + if isinstance(self.action_data, str) and self.action_data != "pickup": + self.action_data = InterActionData(self.action_data) + # TODO Abstract base class for different environments @@ -466,7 +472,7 @@ class Environment: move_vector = player.current_movement d_time = timedelta.total_seconds() - step = move_vector * (player.move_speed * d_time) + step = move_vector * (player.player_speed_units_per_seconds * d_time) player.move(step) if self.detect_collision(player): diff --git a/overcooked_simulator/player.py b/overcooked_simulator/player.py index da4c9a8c87bd300f911db048efa9045801181903..0f5f25424210d595d8290c61231d641b16d934a7 100644 --- a/overcooked_simulator/player.py +++ b/overcooked_simulator/player.py @@ -28,7 +28,7 @@ class PlayerConfig: radius: float = 0.4 """The size of the player. The size of a counter is 1""" - move_dist: float = 0.15 + player_speed_units_per_seconds: float | int = 8 """The move distance/speed of the player per action call.""" interaction_range: float = 1.6 """How far player can interact with counters.""" @@ -49,8 +49,6 @@ class Player: ): self.name: str = name """Reference for the player""" - self.player_config = player_config - """Player configuration from the `environment.yml`""" self.pos: npt.NDArray[float] | None = None """The initial/suggested position of the player.""" if pos is not None: @@ -59,11 +57,13 @@ class Player: self.holding: Optional[Item] = None """What item the player is holding.""" - self.radius: float = self.player_config["radius"] + self.radius: float = player_config.radius """See `PlayerConfig.radius`.""" - self.move_speed: int = self.player_config["player_speed_units_per_seconds"] + self.player_speed_units_per_seconds: float | int = ( + player_config.player_speed_units_per_seconds + ) """See `PlayerConfig.move_dist`.""" - self.interaction_range: int = self.player_config["interaction_range"] + self.interaction_range: float = player_config.interaction_range """See `PlayerConfig.interaction_range`.""" self.facing_direction: npt.NDArray[float] = np.array([0, 1]) """Current direction the player looks.""" diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py index dfb5da0068a533134ad72e85bf14289d295cc585..168d46b8b52f7f623a0147844039ffdcdb77cecd 100644 --- a/overcooked_simulator/utils.py +++ b/overcooked_simulator/utils.py @@ -1,4 +1,5 @@ from datetime import datetime +from enum import Enum def create_init_env_time(): @@ -6,3 +7,12 @@ def create_init_env_time(): return datetime( year=2000, month=1, day=1, hour=0, minute=0, second=0, microsecond=0 ) + + +def custom_asdict_factory(data): + def convert_value(obj): + if isinstance(obj, Enum): + return obj.value + return obj + + return dict((k, convert_value(v)) for k, v in data) diff --git a/tests/test_start.py b/tests/test_start.py index 059fc4a9dddefba6c982ba14119bc12c6820c32b..a9e0606797bc04ea8d01beb75aa76ec2f35217f3 100644 --- a/tests/test_start.py +++ b/tests/test_start.py @@ -100,7 +100,7 @@ def test_movement(): sim.enter_action(move_action) expected = start_pos + do_moves_number * ( - move_direction * sim.env.players[player_name].move_speed + move_direction * sim.env.players[player_name].player_speed_units_per_seconds ) assert np.isclose(