Skip to content
Snippets Groups Projects
Commit dc8d7e4f authored by Florian Schröder's avatar Florian Schröder
Browse files

Update player speed attribute and handle Enum serialization

Renamed `move_dist` attribute to `player_speed_units_per_seconds` in player configuration. Also implemented `__post_init__` method in Action class to handle conversion from str to Enum type. Added a utility function `custom_asdict_factory` to handle Enum serialization for actions. Various code updates are made to accommodate these changes. Fixes merge bugs.
parent 4f4c2c15
No related branches found
No related tags found
1 merge request!26Resolve "api"
Pipeline #44753 failed
......@@ -589,9 +589,12 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul
assert (
"action" in message_dict
), "'action' key not in message dictionary of 'action' request"
if isinstance(message_dict["action"]["action"], list):
message_dict["action"]["action"] = np.array(
message_dict["action"]["action"], dtype=float
assert (
"action_data" in message_dict["action"]
), "'action_data' key not in message dictionary['action'] of 'action' request"
if isinstance(message_dict["action"]["action_data"], list):
message_dict["action"]["action_data"] = np.array(
message_dict["action"]["action_data"], dtype=float
)
accepted = environment_handler.player_action(
message_dict["player_hash"], Action(**message_dict["action"])
......@@ -608,11 +611,11 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul
"msg": "request not handled",
"player_hash": message_dict["player_hash"],
}
except ValueError:
except ValueError as e:
return {
"request_type": message_dict["type"],
"status": 400,
"msg": "Invalid request type",
"msg": e.args[0],
"player_hash": None,
}
except AssertionError as e:
......
......@@ -21,6 +21,7 @@ from overcooked_simulator.overcooked_environment import (
ActionType,
InterActionData,
)
from overcooked_simulator.utils import custom_asdict_factory
class MenuStates(Enum):
......@@ -174,7 +175,9 @@ class PyGameGUI:
if np.linalg.norm(move_vec) != 0:
move_vec = move_vec / np.linalg.norm(move_vec)
action = Action(key_set.name, ActionType.MOVEMENT, move_vec)
action = Action(
key_set.name, ActionType.MOVEMENT, move_vec, duration=1 / self.FPS
)
self.send_action(action)
def handle_key_event(self, event):
......@@ -521,18 +524,24 @@ class PyGameGUI:
Args:
action: The action to be sent. Contains the player, action type and move direction if action is a movement.
"""
if isinstance(action.action, np.ndarray):
action.action = [float(action.action[0]), float(action.action[1])]
if isinstance(action.action_data, np.ndarray):
action.action_data = [
float(action.action_data[0]),
float(action.action_data[1]),
]
self.websockets[action.player].send(
json.dumps(
{
"type": "action",
"action": dataclasses.asdict(action),
"action": dataclasses.asdict(
action, dict_factory=custom_asdict_factory
),
"player_hash": self.player_info[action.player]["player_hash"],
}
)
)
self.websockets[action.player].recv()
a = self.websockets[action.player].recv()
print(a)
def request_state(self):
self.websockets[self.state_player_id].send(
......
from __future__ import annotations
import dataclasses
import datetime
import json
import dataclasses
import logging
import random
from datetime import timedelta
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import Literal, Any
from typing import Literal
import numpy as np
import numpy.typing as npt
......@@ -78,6 +78,12 @@ class Action:
def __repr__(self):
return f"Action({self.player},{self.action_type.value},{self.action_data},{self.duration})"
def __post_init__(self):
if isinstance(self.action_type, str):
self.action_type = ActionType(self.action_type)
if isinstance(self.action_data, str) and self.action_data != "pickup":
self.action_data = InterActionData(self.action_data)
# TODO Abstract base class for different environments
......@@ -466,7 +472,7 @@ class Environment:
move_vector = player.current_movement
d_time = timedelta.total_seconds()
step = move_vector * (player.move_speed * d_time)
step = move_vector * (player.player_speed_units_per_seconds * d_time)
player.move(step)
if self.detect_collision(player):
......
......@@ -28,7 +28,7 @@ class PlayerConfig:
radius: float = 0.4
"""The size of the player. The size of a counter is 1"""
move_dist: float = 0.15
player_speed_units_per_seconds: float | int = 8
"""The move distance/speed of the player per action call."""
interaction_range: float = 1.6
"""How far player can interact with counters."""
......@@ -49,8 +49,6 @@ class Player:
):
self.name: str = name
"""Reference for the player"""
self.player_config = player_config
"""Player configuration from the `environment.yml`"""
self.pos: npt.NDArray[float] | None = None
"""The initial/suggested position of the player."""
if pos is not None:
......@@ -59,11 +57,13 @@ class Player:
self.holding: Optional[Item] = None
"""What item the player is holding."""
self.radius: float = self.player_config["radius"]
self.radius: float = player_config.radius
"""See `PlayerConfig.radius`."""
self.move_speed: int = self.player_config["player_speed_units_per_seconds"]
self.player_speed_units_per_seconds: float | int = (
player_config.player_speed_units_per_seconds
)
"""See `PlayerConfig.move_dist`."""
self.interaction_range: int = self.player_config["interaction_range"]
self.interaction_range: float = player_config.interaction_range
"""See `PlayerConfig.interaction_range`."""
self.facing_direction: npt.NDArray[float] = np.array([0, 1])
"""Current direction the player looks."""
......
from datetime import datetime
from enum import Enum
def create_init_env_time():
......@@ -6,3 +7,12 @@ def create_init_env_time():
return datetime(
year=2000, month=1, day=1, hour=0, minute=0, second=0, microsecond=0
)
def custom_asdict_factory(data):
def convert_value(obj):
if isinstance(obj, Enum):
return obj.value
return obj
return dict((k, convert_value(v)) for k, v in data)
......@@ -100,7 +100,7 @@ def test_movement():
sim.enter_action(move_action)
expected = start_pos + do_moves_number * (
move_direction * sim.env.players[player_name].move_speed
move_direction * sim.env.players[player_name].player_speed_units_per_seconds
)
assert np.isclose(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment