Skip to content
Snippets Groups Projects
Commit dc8d7e4f authored by Florian Schröder's avatar Florian Schröder
Browse files

Update player speed attribute and handle Enum serialization

Renamed `move_dist` attribute to `player_speed_units_per_seconds` in player configuration. Also implemented `__post_init__` method in Action class to handle conversion from str to Enum type. Added a utility function `custom_asdict_factory` to handle Enum serialization for actions. Various code updates are made to accommodate these changes. Fixes merge bugs.
parent 4f4c2c15
No related branches found
No related tags found
1 merge request!26Resolve "api"
Pipeline #44753 failed
...@@ -589,9 +589,12 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul ...@@ -589,9 +589,12 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul
assert ( assert (
"action" in message_dict "action" in message_dict
), "'action' key not in message dictionary of 'action' request" ), "'action' key not in message dictionary of 'action' request"
if isinstance(message_dict["action"]["action"], list): assert (
message_dict["action"]["action"] = np.array( "action_data" in message_dict["action"]
message_dict["action"]["action"], dtype=float ), "'action_data' key not in message dictionary['action'] of 'action' request"
if isinstance(message_dict["action"]["action_data"], list):
message_dict["action"]["action_data"] = np.array(
message_dict["action"]["action_data"], dtype=float
) )
accepted = environment_handler.player_action( accepted = environment_handler.player_action(
message_dict["player_hash"], Action(**message_dict["action"]) message_dict["player_hash"], Action(**message_dict["action"])
...@@ -608,11 +611,11 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul ...@@ -608,11 +611,11 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul
"msg": "request not handled", "msg": "request not handled",
"player_hash": message_dict["player_hash"], "player_hash": message_dict["player_hash"],
} }
except ValueError: except ValueError as e:
return { return {
"request_type": message_dict["type"], "request_type": message_dict["type"],
"status": 400, "status": 400,
"msg": "Invalid request type", "msg": e.args[0],
"player_hash": None, "player_hash": None,
} }
except AssertionError as e: except AssertionError as e:
......
...@@ -21,6 +21,7 @@ from overcooked_simulator.overcooked_environment import ( ...@@ -21,6 +21,7 @@ from overcooked_simulator.overcooked_environment import (
ActionType, ActionType,
InterActionData, InterActionData,
) )
from overcooked_simulator.utils import custom_asdict_factory
class MenuStates(Enum): class MenuStates(Enum):
...@@ -174,7 +175,9 @@ class PyGameGUI: ...@@ -174,7 +175,9 @@ class PyGameGUI:
if np.linalg.norm(move_vec) != 0: if np.linalg.norm(move_vec) != 0:
move_vec = move_vec / np.linalg.norm(move_vec) move_vec = move_vec / np.linalg.norm(move_vec)
action = Action(key_set.name, ActionType.MOVEMENT, move_vec) action = Action(
key_set.name, ActionType.MOVEMENT, move_vec, duration=1 / self.FPS
)
self.send_action(action) self.send_action(action)
def handle_key_event(self, event): def handle_key_event(self, event):
...@@ -521,18 +524,24 @@ class PyGameGUI: ...@@ -521,18 +524,24 @@ class PyGameGUI:
Args: Args:
action: The action to be sent. Contains the player, action type and move direction if action is a movement. action: The action to be sent. Contains the player, action type and move direction if action is a movement.
""" """
if isinstance(action.action, np.ndarray): if isinstance(action.action_data, np.ndarray):
action.action = [float(action.action[0]), float(action.action[1])] action.action_data = [
float(action.action_data[0]),
float(action.action_data[1]),
]
self.websockets[action.player].send( self.websockets[action.player].send(
json.dumps( json.dumps(
{ {
"type": "action", "type": "action",
"action": dataclasses.asdict(action), "action": dataclasses.asdict(
action, dict_factory=custom_asdict_factory
),
"player_hash": self.player_info[action.player]["player_hash"], "player_hash": self.player_info[action.player]["player_hash"],
} }
) )
) )
self.websockets[action.player].recv() a = self.websockets[action.player].recv()
print(a)
def request_state(self): def request_state(self):
self.websockets[self.state_player_id].send( self.websockets[self.state_player_id].send(
......
from __future__ import annotations from __future__ import annotations
import dataclasses
import datetime import datetime
import json import json
import dataclasses
import logging import logging
import random import random
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from threading import Lock from threading import Lock
from typing import Literal, Any from typing import Literal
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
...@@ -78,6 +78,12 @@ class Action: ...@@ -78,6 +78,12 @@ class Action:
def __repr__(self): def __repr__(self):
return f"Action({self.player},{self.action_type.value},{self.action_data},{self.duration})" return f"Action({self.player},{self.action_type.value},{self.action_data},{self.duration})"
def __post_init__(self):
if isinstance(self.action_type, str):
self.action_type = ActionType(self.action_type)
if isinstance(self.action_data, str) and self.action_data != "pickup":
self.action_data = InterActionData(self.action_data)
# TODO Abstract base class for different environments # TODO Abstract base class for different environments
...@@ -466,7 +472,7 @@ class Environment: ...@@ -466,7 +472,7 @@ class Environment:
move_vector = player.current_movement move_vector = player.current_movement
d_time = timedelta.total_seconds() d_time = timedelta.total_seconds()
step = move_vector * (player.move_speed * d_time) step = move_vector * (player.player_speed_units_per_seconds * d_time)
player.move(step) player.move(step)
if self.detect_collision(player): if self.detect_collision(player):
......
...@@ -28,7 +28,7 @@ class PlayerConfig: ...@@ -28,7 +28,7 @@ class PlayerConfig:
radius: float = 0.4 radius: float = 0.4
"""The size of the player. The size of a counter is 1""" """The size of the player. The size of a counter is 1"""
move_dist: float = 0.15 player_speed_units_per_seconds: float | int = 8
"""The move distance/speed of the player per action call.""" """The move distance/speed of the player per action call."""
interaction_range: float = 1.6 interaction_range: float = 1.6
"""How far player can interact with counters.""" """How far player can interact with counters."""
...@@ -49,8 +49,6 @@ class Player: ...@@ -49,8 +49,6 @@ class Player:
): ):
self.name: str = name self.name: str = name
"""Reference for the player""" """Reference for the player"""
self.player_config = player_config
"""Player configuration from the `environment.yml`"""
self.pos: npt.NDArray[float] | None = None self.pos: npt.NDArray[float] | None = None
"""The initial/suggested position of the player.""" """The initial/suggested position of the player."""
if pos is not None: if pos is not None:
...@@ -59,11 +57,13 @@ class Player: ...@@ -59,11 +57,13 @@ class Player:
self.holding: Optional[Item] = None self.holding: Optional[Item] = None
"""What item the player is holding.""" """What item the player is holding."""
self.radius: float = self.player_config["radius"] self.radius: float = player_config.radius
"""See `PlayerConfig.radius`.""" """See `PlayerConfig.radius`."""
self.move_speed: int = self.player_config["player_speed_units_per_seconds"] self.player_speed_units_per_seconds: float | int = (
player_config.player_speed_units_per_seconds
)
"""See `PlayerConfig.move_dist`.""" """See `PlayerConfig.move_dist`."""
self.interaction_range: int = self.player_config["interaction_range"] self.interaction_range: float = player_config.interaction_range
"""See `PlayerConfig.interaction_range`.""" """See `PlayerConfig.interaction_range`."""
self.facing_direction: npt.NDArray[float] = np.array([0, 1]) self.facing_direction: npt.NDArray[float] = np.array([0, 1])
"""Current direction the player looks.""" """Current direction the player looks."""
......
from datetime import datetime from datetime import datetime
from enum import Enum
def create_init_env_time(): def create_init_env_time():
...@@ -6,3 +7,12 @@ def create_init_env_time(): ...@@ -6,3 +7,12 @@ def create_init_env_time():
return datetime( return datetime(
year=2000, month=1, day=1, hour=0, minute=0, second=0, microsecond=0 year=2000, month=1, day=1, hour=0, minute=0, second=0, microsecond=0
) )
def custom_asdict_factory(data):
def convert_value(obj):
if isinstance(obj, Enum):
return obj.value
return obj
return dict((k, convert_value(v)) for k, v in data)
...@@ -100,7 +100,7 @@ def test_movement(): ...@@ -100,7 +100,7 @@ def test_movement():
sim.enter_action(move_action) sim.enter_action(move_action)
expected = start_pos + do_moves_number * ( expected = start_pos + do_moves_number * (
move_direction * sim.env.players[player_name].move_speed move_direction * sim.env.players[player_name].player_speed_units_per_seconds
) )
assert np.isclose( assert np.isclose(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment