Skip to content
Snippets Groups Projects
Commit 9e58ecdf authored by Fabian Heinrich's avatar Fabian Heinrich
Browse files

Merge branch '81-gui-player-management' into 'main'

Resolve "GUI Player Management"

Closes #81

See merge request scs/cocosy/overcooked-simulator!44
parents a9290be8 b6b3e663
No related branches found
No related tags found
1 merge request!44Resolve "GUI Player Management"
Pipeline #45579 passed
Showing
with 990 additions and 265 deletions
concurrency: MultiProcessing
communication:
communication_prefs:
- !name:ipaacar_com_service.communications.ipaacar_com.IPAACARInfo
modules:
connection:
module_info: !name:cocosy_agent.modules.connection_module.ConnectionModule
mean_frequency_step: 2 # 2: every 0.5 seconds
working_memory:
module_info: !name:cocosy_agent.modules.working_memory_module.WorkingMemoryModule
subtask_selection:
module_info: !name:cocosy_agent.modules.random_subtask_module.RandomSubtaskModule
action_execution:
module_info: !name:cocosy_agent.modules.action_execution_module.ActionExecutionModule
mean_frequency_step: 10 # 2: every 0.5 seconds
# gui:
# module_info: !name:aaambos.std.guis.pysimplegui.pysimplegui_window.PySimpleGUIWindowModule
# window_title: Counting GUI
# topics_to_show: [["SubtaskDecision", "cocosy_agent.conventions.communication.SubtaskDecision", ["task_type"]], ["ActionControl", "cocosy_agent.conventions.communication.ActionControl", ["action_type"]]]
status_manager:
module_info: !name:aaambos.std.modules.module_status_manager.ModuleStatusManager
gui: false
\ No newline at end of file
import argparse
import asyncio
import dataclasses
import json
import random
import time
from collections import defaultdict
from datetime import datetime, timedelta
import numpy as np
from websockets import connect
from overcooked_simulator.overcooked_environment import (
ActionType,
Action,
InterActionData,
)
from overcooked_simulator.utils import custom_asdict_factory
async def agent():
parser = argparse.ArgumentParser("Random agent")
parser.add_argument("--uri", type=str)
parser.add_argument("--player_id", type=str)
parser.add_argument("--player_hash", type=str)
parser.add_argument("--step_time", type=float, default=0.5)
args = parser.parse_args()
async with connect(args.uri) as websocket:
await websocket.send(
json.dumps({"type": "ready", "player_hash": args.player_hash})
)
await websocket.recv()
ended = False
counters = None
player_info = {}
current_agent_pos = None
interaction_counter = None
last_interacting = False
last_interact_progress = None
threshold = datetime.max
task_type = None
task_args = None
started_interaction = False
still_interacting = False
current_nearest_counter_id = None
while not ended:
time.sleep(args.step_time)
await websocket.send(
json.dumps({"type": "get_state", "player_hash": args.player_hash})
)
state = json.loads(await websocket.recv())
if counters is None:
counters = defaultdict(list)
for counter in state["counters"]:
counters[counter["type"]].append(counter)
for player in state["players"]:
if player["id"] == args.player_id:
player_info = player
current_agent_pos = player["pos"]
if player["current_nearest_counter_id"]:
if (
current_nearest_counter_id
!= player["current_nearest_counter_id"]
):
for counter in state["counters"]:
if (
counter["id"]
== player["current_nearest_counter_id"]
):
interaction_counter = counter
current_nearest_counter_id = player[
"current_nearest_counter_id"
]
break
if last_interacting:
if (
not interaction_counter
or not interaction_counter["occupied_by"]
or isinstance(interaction_counter["occupied_by"], list)
or (
interaction_counter["occupied_by"][
"progress_percentage"
]
== 1.0
)
):
last_interacting = False
last_interact_progress = None
else:
if (
interaction_counter
and interaction_counter["occupied_by"]
and not isinstance(interaction_counter["occupied_by"], list)
):
if (
last_interact_progress
!= interaction_counter["occupied_by"][
"progress_percentage"
]
):
last_interact_progress = interaction_counter[
"occupied_by"
]["progress_percentage"]
last_interacting = True
break
if task_type:
if threshold < datetime.now():
print(
args.player_hash, args.player_id, "---Threshold---Too long---"
)
task_type = None
match task_type:
case "GOTO":
diff = np.array(task_args) - np.array(current_agent_pos)
dist = np.linalg.norm(diff)
if dist > 1.2:
if dist != 0:
await websocket.send(
json.dumps(
{
"type": "action",
"action": dataclasses.asdict(
Action(
args.player_id,
ActionType.MOVEMENT,
(diff / dist).tolist(),
args.step_time + 0.01,
),
dict_factory=custom_asdict_factory,
),
"player_hash": args.player_hash,
}
)
)
await websocket.recv()
else:
task_type = None
task_args = None
case "INTERACT":
if not started_interaction or (
still_interacting and interaction_counter
):
if not started_interaction:
started_interaction = True
still_interacting = True
await websocket.send(
json.dumps(
{
"type": "action",
"action": dataclasses.asdict(
Action(
args.player_id,
ActionType.INTERACT,
InterActionData.START,
),
dict_factory=custom_asdict_factory,
),
"player_hash": args.player_hash,
}
)
)
await websocket.recv()
else:
still_interacting = False
started_interaction = False
task_type = None
task_args = None
case "PUT":
await websocket.send(
json.dumps(
{
"type": "action",
"action": dataclasses.asdict(
Action(
args.player_id,
ActionType.PUT,
"pickup",
),
dict_factory=custom_asdict_factory,
),
"player_hash": args.player_hash,
}
)
)
await websocket.recv()
task_type = None
task_args = None
case None:
...
if not task_type:
task_type = random.choice(["GOTO", "PUT", "INTERACT"])
threshold = datetime.now() + timedelta(seconds=15.0)
if task_type == "GOTO":
counter_type = random.choice(list(counters.keys()))
task_args = random.choice(counters[counter_type])["pos"]
print(args.player_hash, args.player_id, task_type, counter_type)
else:
print(args.player_hash, args.player_id, task_type)
task_args = None
ended = state["ended"]
if __name__ == "__main__":
asyncio.run(agent())
general:
agent_name: cocosy_agent
instance: _dev
local_agent_directories: ~/aaambos_agents
plus:
agent_websocket: ws://localhost:8000:/ws/player/MY_CLIENT_ID
player_hash: abcdefghijklmnopqrstuvwxyz
agent_id: 1
logging:
log_level_command_line: INFO
supervisor:
run_time_manager_class: !name:aaambos.std.supervision.instruction_run_time_manager.instruction_run_time_manager.InstructionRunTimeManager
......@@ -85,7 +85,7 @@ orders:
player_config:
radius: 0.4
player_speed_units_per_seconds: 8
player_speed_units_per_seconds: 6
interaction_range: 1.6
......
______
______
______
______
______
______
_____P
\ No newline at end of file
_______
_______
_______
_______
__A____
_______
_______
______P
\ No newline at end of file
#QU#F###O#T#################N###L###B#
#____________________________________#
#____________________________________M
#____________________________________#
#____________________________________#
#____________________________________K
W____________________________________I
#____________________________________#
#____________________________________#
#__A_____A___________________________D
#____________________________________#
#____________________________________#
#____________________________________#
#____________________________________#
#____________________________________#
C____________________________________E
#____________________________________#
#____________________________________#
#____________________________________#
#____________________________________#
C____________________________________G
#____________________________________#
#P#####S+####X#####S+#################
\ No newline at end of file
......@@ -220,7 +220,9 @@ class EnvironmentHandler:
self.envs[env_id].last_step_time = time.time_ns()
self.envs[env_id].environment.reset_env_time()
def get_state(self, player_hash: str) -> str: # -> StateRepresentation as json
def get_state(
self, player_hash: str
) -> str | int: # -> StateRepresentation as json
"""Get the current state representation of the environment for a player.
Args:
......@@ -237,6 +239,10 @@ class EnvironmentHandler:
return self.envs[
self.player_data[player_hash].env_id
].environment.get_json_state()
if player_hash not in self.player_data:
return 1
if self.player_data[player_hash].env_id not in self.envs:
return 2
def pause_env(self, manager_id: str, env_id: str, reason: str):
"""Pause the specified environment.
......@@ -599,7 +605,17 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul
}
case PlayerRequestType.GET_STATE:
return environment_handler.get_state(message_dict["player_hash"])
state = environment_handler.get_state(message_dict["player_hash"])
if isinstance(state, int):
return {
"request_type": message_dict["type"],
"status": 400,
"msg": "env id of player not in running envs"
if state == 2
else "player hash unknown",
"player_hash": None,
}
return state
case PlayerRequestType.ACTION:
assert (
......
......@@ -108,6 +108,7 @@ class Visualizer:
screen: pygame.Surface,
state: dict,
grid_size: int,
controlled_player_idxs: list[int],
):
"""Draws the game state on the given surface.
......@@ -131,6 +132,14 @@ class Visualizer:
grid_size,
)
for idx, col in zip(controlled_player_idxs, [colors["blue"], colors["red"]]):
pygame.draw.circle(
screen,
col,
np.array(state["players"][idx]["pos"]) * grid_size + (grid_size // 2),
(grid_size / 2),
)
self.draw_players(
screen,
state["players"],
......@@ -148,7 +157,6 @@ class Visualizer:
height: The kitchen height.
grid_size: The gridsize to base the background shapes on.
"""
block_size = grid_size // 2 # Set the size of the grid block
surface.fill(colors[self.config["Kitchen"]["ground_tiles_color"]])
for x in range(0, width, block_size):
......@@ -230,7 +238,7 @@ class Visualizer:
if USE_PLAYER_COOK_SPRITES:
pygame.draw.circle(
screen,
self.player_colors[p_idx],
colors[self.player_colors[p_idx]],
pos - facing * grid_size * 0.25,
grid_size * 0.2,
)
......@@ -278,7 +286,7 @@ class Visualizer:
)
if player_dict["holding"] is not None:
holding_item_pos = pos + (20 * facing)
holding_item_pos = pos + (grid_size * 0.5 * facing)
self.draw_item(
pos=holding_item_pos,
grid_size=grid_size,
......
......@@ -6,7 +6,7 @@
"disabled_bg": "#25292e",
"selected_bg": "#193754",
"dark_bg": "#15191e",
"normal_text": "#c5cbd8",
"normal_text": "#000000",
"hovered_text": "#FFFFFF",
"selected_text": "#FFFFFF",
"disabled_text": "#6d736f",
......@@ -92,5 +92,70 @@
"normal_border": "#000000",
"normal_text": "#000000"
}
},
"#players": {
"colours": {
"dark_bg": "#fffacd",
"normal_border": "#fffacd"
}
},
"#players_players": {
"colours": {
"dark_bg": "#fffacd"
}
},
"#players_bots": {
"colours": {
"dark_bg": "#fffacd"
}
},
"#number_players_label": {
"colours": {
"dark_bg": "#fffacd",
"normal_text": "#000000"
},
"font": {
"size": 14,
"bold": 1
}
},
"#number_bots_label": {
"colours": {
"dark_bg": "#fffacd",
"normal_text": "#000000"
},
"font": {
"size": 14,
"bold": 1,
"colour": "#000000"
}
},
"#multiple_keysets_button": {
"font": {
"size": 12,
"bold": 1,
"colour": "#000000"
}
},
"#split_players_button": {
"font": {
"size": 12,
"bold": 1,
"colour": "#000000"
}
},
"#controller_button": {
"font": {
"size": 12,
"bold": 1,
"colour": "#000000"
}
},
"#quantity_button": {
"font": {
"size": 24,
"bold": 1,
"colour": "#000000"
}
}
}
\ No newline at end of file
This diff is collapsed.
......@@ -14,6 +14,7 @@ from typing import Literal, TypedDict, Callable, Tuple
import numpy as np
import numpy.typing as npt
import yaml
from scipy.spatial import distance_matrix
from overcooked_simulator.counter_factory import CounterFactory
from overcooked_simulator.counters import (
......@@ -55,6 +56,9 @@ from overcooked_simulator.utils import create_init_env_time, get_closest
log = logging.getLogger(__name__)
PREVENT_SQUEEZING_INTO_OTHER_PLAYERS = True
class ActionType(Enum):
"""The 3 different types of valid actions. They can be extended via the `Action.action_data` attribute."""
......@@ -222,8 +226,17 @@ class Environment:
) = self.parse_layout_file()
self.hook(LAYOUT_FILE_PARSED)
self.world_borders_x = [-0.5, self.kitchen_width - 0.5]
self.world_borders_y = [-0.5, self.kitchen_height - 0.5]
self.counter_positions = np.array([c.pos for c in self.counters])
self.world_borders = np.array(
[[-0.5, self.kitchen_width - 0.5], [-0.5, self.kitchen_height - 0.5]],
dtype=float,
)
self.player_movement_speed = self.environment_config["player_config"][
"player_speed_units_per_seconds"
]
self.player_radius = self.environment_config["player_config"]["radius"]
progress_counter_classes = list(
filter(
......@@ -261,7 +274,7 @@ class Environment:
environment_config=env_config,
layout_config=self.layout_config,
seed=seed,
env_start_time_worldtime=datetime.now()
env_start_time_worldtime=datetime.now(),
)
@property
......@@ -269,6 +282,15 @@ class Environment:
"""Whether the game is over or not based on the calculated `Environment.env_time_end`"""
return self.env_time >= self.env_time_end
def set_collision_arrays(self):
number_players = len(self.players)
self.world_borders_lower = self.world_borders[np.newaxis, :, 0].repeat(
number_players, axis=0
)
self.world_borders_upper = self.world_borders[np.newaxis, :, 1].repeat(
number_players, axis=0
)
def get_env_time(self):
"""the internal time of the environment. An environment starts always with the time from `create_init_env_time`.
......@@ -514,7 +536,7 @@ class Environment:
facing_counter = get_closest(player.facing_point, self.counters)
return facing_counter
def perform_movement(self, player: Player, duration: timedelta):
def perform_movement(self, duration: timedelta):
"""Moves a player in the direction specified in the action.action. If the player collides with a
counter or other player through this movement, then they are not moved.
(The extended code with the two ifs is for sliding movement at the counters, which feels a bit smoother.
......@@ -526,145 +548,112 @@ class Environment:
Detects collisions with other players and pushes them out of the way.
Args:
player: The player to move.
duration: The duration for how long the movement to perform.
"""
old_pos = player.pos.copy()
move_vector = player.current_movement
d_time = duration.total_seconds()
step = move_vector * (player.player_speed_units_per_seconds * d_time)
player.move(step)
if self.detect_collision(player):
collided_players = self.get_collided_players(player)
for collided_player in collided_players:
pushing_vector = collided_player.pos - player.pos
if np.linalg.norm(pushing_vector) != 0:
pushing_vector = pushing_vector / np.linalg.norm(pushing_vector)
old_pos_other = collided_player.pos.copy()
collided_player.current_movement = pushing_vector
self.perform_movement(collided_player, duration)
if self.detect_collision_counters(
collided_player
) or self.detect_collision_world_bounds(collided_player):
collided_player.move_abs(old_pos_other)
player.move_abs(old_pos)
old_pos = player.pos.copy()
step_sliding = step.copy()
step_sliding[0] = 0
player.move(step_sliding * 0.5)
player.turn(step)
if self.detect_collision(player):
player.move_abs(old_pos)
old_pos = player.pos.copy()
step_sliding = step.copy()
step_sliding[1] = 0
player.move(step_sliding * 0.5)
player.turn(step)
if self.detect_collision(player):
player.move_abs(old_pos)
if self.counters:
closest_counter = self.get_facing_counter(player)
player.current_nearest_counter = (
closest_counter if player.can_reach(closest_counter) else None
)
def detect_collision(self, player: Player):
"""Detect collisions between the player and other players or counters.
Args:
player: The player for which to check collisions.
Returns: True if the player is intersecting with any object in the environment.
"""
return (
len(self.get_collided_players(player)) != 0
or self.detect_collision_counters(player)
or self.detect_collision_world_bounds(player)
player_positions = np.array([p.pos for p in self.players.values()], dtype=float)
player_movement_vectors = np.array(
[
p.current_movement if self.env_time <= p.movement_until else [0, 0]
for p in self.players.values()
],
dtype=float,
)
number_players = len(player_positions)
def get_collided_players(self, player: Player) -> list[Player]:
"""Detects collisions between the queried player and other players. Returns the list of the collided players.
A player is modelled as a circle. Collision is detected if the distance between the players is smaller
than the sum of the radius's.
Args:
player: The player to check collisions with other players for.
Returns: The list of other players the player collides with.
"""
other_players = filter(lambda p: p.name != player.name, self.players.values())
def collide(p):
return np.linalg.norm(player.pos - p.pos) <= player.radius + p.radius
return list(filter(collide, other_players))
def detect_player_collision(self, player: Player):
"""Detects collisions between the queried player and other players.
A player is modelled as a circle. Collision is detected if the distance between the players is smaller
than the sum of the radius's.
Args:
player: The player to check collisions with other players for.
Returns: True if the player collides with other players, False if not.
targeted_positions = player_positions + (
player_movement_vectors * (self.player_movement_speed * d_time)
)
"""
other_players = filter(lambda p: p.name != player.name, self.players.values())
# Collisions player between player
distances_players_after_scipy = distance_matrix(
targeted_positions, targeted_positions
)
def collide(p):
return np.linalg.norm(player.pos - p.pos) <= (player.radius + p.radius)
player_diff_vecs = -(
player_positions[:, np.newaxis, :] - player_positions[np.newaxis, :, :]
)
collision_idxs = distances_players_after_scipy < (2 * self.player_radius)
eye_idxs = np.eye(number_players, number_players, dtype=bool)
collision_idxs[eye_idxs] = False
return any(map(collide, other_players))
# Player push players around
player_diff_vecs[collision_idxs == False] = 0
push_vectors = np.sum(player_diff_vecs, axis=0)
def detect_collision_counters(self, player: Player):
"""Checks for collisions of the queried player with each counter.
updated_movement = push_vectors + player_movement_vectors
new_positions = player_positions + (
updated_movement * (self.player_movement_speed * d_time)
)
Args:
player: The player to check collisions with counters for.
# Collisions players counters
counter_diff_vecs = (
new_positions[:, np.newaxis, :] - self.counter_positions[np.newaxis, :, :]
)
counter_distances = np.max((np.abs(counter_diff_vecs)), axis=2)
# counter_distances = np.linalg.norm(counter_diff_vecs, axis=2)
closest_counter_positions = self.counter_positions[
np.argmin(counter_distances, axis=1)
]
nearest_counter_to_player = closest_counter_positions - new_positions
collided = np.min(counter_distances, axis=1) < self.player_radius + 0.5
relevant_axes = np.abs(nearest_counter_to_player).argmax(axis=1)
for idx, player in enumerate(player_positions):
axis = relevant_axes[idx]
if collided[idx]:
# collide with counter left or top
if nearest_counter_to_player[idx][axis] < 0:
updated_movement[idx, axis] = max(updated_movement[idx, axis], 0)
# collide with counter right or bottom
if nearest_counter_to_player[idx][axis] > 0:
updated_movement[idx, axis] = min(updated_movement[idx, axis], 0)
new_positions = player_positions + (
updated_movement * (self.player_movement_speed * d_time)
)
Returns: True if the player collides with any counter, False if not.
# Check if pushed players collide with counters or second closest is to close
counter_diff_vecs = (
new_positions[:, np.newaxis, :] - self.counter_positions[np.newaxis, :, :]
)
counter_distances = np.max((np.abs(counter_diff_vecs)), axis=2)
collided2 = np.min(counter_distances, axis=1) < self.player_radius + 0.5
# player do not move if they collide after pushing/sliding
new_positions[collided2] = player_positions[collided2]
# Players that pushed the player that can not be pushed do also no movement
# in the future these players could slide around the player?
for idx, collides in enumerate(collided2):
if collides:
new_positions[collision_idxs[idx]] = player_positions[
collision_idxs[idx]
]
"""
return any(
map(
lambda counter: self.detect_collision_player_counter(player, counter),
self.counters,
# Check if two moving players collide into each other: No movement (Future: slide?)
if PREVENT_SQUEEZING_INTO_OTHER_PLAYERS:
distances_players_after_scipy = distance_matrix(
new_positions, new_positions
)
collision_idxs = distances_players_after_scipy < (2 * self.player_radius)
collision_idxs[eye_idxs] = False
collision_idxs = np.any(collision_idxs, axis=1)
new_positions[collision_idxs] = player_positions[collision_idxs]
# Collisions player world borders
new_positions = np.clip(
new_positions,
self.world_borders_lower + self.player_radius,
self.world_borders_upper - self.player_radius,
)
@staticmethod
def detect_collision_player_counter(player: Player, counter: Counter):
"""Checks if the player and counter collide (overlap).
A counter is modelled as a rectangle (square actually), a player is modelled as a circle.
The distance of the player position (circle center) and the counter rectangle is calculated, if it is
smaller than the player radius, a collision is detected.
Args:
player: The player to check the collision for.
counter: The counter to check the collision for.
Returns: True if player and counter overlap, False if not.
"""
cx, cy = player.pos
dx = max(np.abs(cx - counter.pos[0]) - 1 / 2, 0)
dy = max(np.abs(cy - counter.pos[1]) - 1 / 2, 0)
distance = np.linalg.norm([dx, dy])
# TODO: Efficiency improvement by checking only nearest counters? Quadtree...?
return distance < player.radius
for idx, p in enumerate(self.players.values()):
if not (new_positions[idx] == player_positions[idx]).all():
p.turn(player_movement_vectors[idx])
p.move_abs(new_positions[idx])
def add_player(self, player_name: str, pos: npt.NDArray = None):
"""Add a player to the environment.
......@@ -702,6 +691,7 @@ class Environment:
log.debug("No free positions left in kitchens")
player.update_facing_point()
self.set_collision_arrays()
self.hook(PLAYER_ADDED, player_name=player_name, pos=pos)
def detect_collision_world_bounds(self, player: Player):
......@@ -734,8 +724,8 @@ class Environment:
else:
for player in self.players.values():
player.progress(passed_time, self.env_time)
if self.env_time <= player.movement_until:
self.perform_movement(player, passed_time)
self.perform_movement(passed_time)
for counter in self.progressing_counters:
counter.progress(passed_time=passed_time, now=self.env_time)
......
......@@ -57,15 +57,9 @@ class Player:
self.holding: Optional[Item] = None
"""What item the player is holding."""
self.player_config = player_config
"""See `PlayerConfig`."""
self.radius: float = player_config.radius
"""See `PlayerConfig.radius`."""
self.player_speed_units_per_seconds: float | int = (
player_config.player_speed_units_per_seconds
)
"""See `PlayerConfig.move_dist`."""
self.interaction_range: float = player_config.interaction_range
"""See `PlayerConfig.interaction_range`."""
self.facing_direction: npt.NDArray[float] = np.array([0, 1])
"""Current direction the player looks."""
self.last_interacted_counter: Optional[
......@@ -127,7 +121,9 @@ class Player:
def update_facing_point(self):
"""Update facing point on the player border circle based on the radius."""
self.facing_point = self.pos + (self.facing_direction * self.radius * 0.5)
self.facing_point = self.pos + (
self.facing_direction * self.player_config.radius * 0.5
)
def can_reach(self, counter: Counter):
"""Checks whether the player can reach the counter in question. Simple check if the distance is not larger
......@@ -140,7 +136,10 @@ class Player:
True if the counter is in range of the player, False if not.
"""
return np.linalg.norm(counter.pos - self.facing_point) <= self.interaction_range
return (
np.linalg.norm(counter.pos - self.facing_point)
<= self.player_config.interaction_range
)
def put_action(self, counter: Counter):
"""Performs the pickup-action with the counter. Handles the logic of what the player is currently holding,
......
......@@ -22,6 +22,7 @@ from overcooked_simulator import ROOT_DIR
if TYPE_CHECKING:
from overcooked_simulator.counters import Counter
from overcooked_simulator.player import Player
def create_init_env_time():
......@@ -46,6 +47,18 @@ def get_closest(point: npt.NDArray[float], counters: list[Counter]):
]
def get_collided_players(
player_idx, players: list[Player], player_radius: float
) -> list[Player]:
player_positions = np.array([p.pos for p in players], dtype=float)
distances = distance_matrix(player_positions, player_positions)[player_idx]
player_radiuses = np.array([player_radius for p in players], dtype=float)
collisions = distances <= player_radiuses + player_radius
collisions[player_idx] = False
return [players[idx] for idx, val in enumerate(collisions) if val]
def get_touching_counters(target: Counter, counters: list[Counter]) -> list[Counter]:
return list(
filter(
......
......@@ -46,6 +46,7 @@ def layout_config():
with open(layout_path, "r") as file:
layout = file.read()
return layout
env.add_player("0")
@pytest.fixture
......@@ -80,7 +81,7 @@ def test_movement(env_config, layout_empty_config, item_info):
player_name = "1"
start_pos = np.array([3, 4])
env.add_player(player_name, start_pos)
env.players[player_name].player_speed_units_per_seconds = 1
env.player_movement_speed = 1
move_direction = np.array([1, 0])
move_action = Action(player_name, ActionType.MOVEMENT, move_direction, duration=0.1)
do_moves_number = 3
......@@ -89,22 +90,19 @@ def test_movement(env_config, layout_empty_config, item_info):
env.step(timedelta(seconds=0.1))
expected = start_pos + do_moves_number * (
move_direction
* env.players[player_name].player_speed_units_per_seconds
* move_action.duration
move_direction * env.player_movement_speed * move_action.duration
)
assert np.isclose(
np.linalg.norm(expected - env.players[player_name].pos), 0
), "Performed movement do not move the player as expected."
def test_player_speed_units_per_seconds(env_config, layout_empty_config, item_info):
def test_player_movement_speed(env_config, layout_empty_config, item_info):
env = Environment(env_config, layout_empty_config, item_info, as_files=False)
player_name = "1"
start_pos = np.array([3, 4])
env.add_player(player_name, start_pos)
env.players[player_name].player_speed_units_per_seconds = 2
env.player_movement_speed = 2
move_direction = np.array([1, 0])
move_action = Action(player_name, ActionType.MOVEMENT, move_direction, duration=0.1)
do_moves_number = 3
......@@ -113,9 +111,7 @@ def test_player_speed_units_per_seconds(env_config, layout_empty_config, item_in
env.step(timedelta(seconds=0.1))
expected = start_pos + do_moves_number * (
move_direction
* env.players[player_name].player_speed_units_per_seconds
* move_action.duration
move_direction * env.player_movement_speed * move_action.duration
)
assert np.isclose(
......@@ -123,36 +119,6 @@ def test_player_speed_units_per_seconds(env_config, layout_empty_config, item_in
), "Performed movement do not move the player as expected."
def test_collision_detection(env_config, layout_config, item_info):
env = Environment(env_config, layout_config, item_info, as_files=False)
counter_pos = np.array([1, 2])
counter = Counter(pos=counter_pos, hook=Hooks(env))
env.counters = [counter]
env.add_player("1", np.array([1, 1]))
env.add_player("2", np.array([1, 4]))
player1 = env.players["1"]
player2 = env.players["2"]
assert not env.detect_collision_counters(player1), "Should not collide"
assert not env.detect_player_collision(player1), "Should not collide yet."
assert not env.detect_collision(player1), "Does not collide yet."
player1.move_abs(counter_pos)
assert env.detect_collision_counters(
player1
), "Player and counter at same pos. Not detected."
player2.move_abs(counter_pos)
assert env.detect_player_collision(player1), "Players at same pos. Not detected."
player1.move_abs(np.array([-1, -1]))
assert env.detect_collision_world_bounds(
player1
), "Player collides with world bounds."
def test_player_reach(env_config, layout_empty_config, item_info):
env = Environment(env_config, layout_empty_config, item_info, as_files=False)
......@@ -160,7 +126,7 @@ def test_player_reach(env_config, layout_empty_config, item_info):
counter = Counter(pos=counter_pos, hook=Hooks(env))
env.counters = [counter]
env.add_player("1", np.array([2, 4]))
env.players["1"].player_speed_units_per_seconds = 1
env.player_movement_speed = 1
player = env.players["1"]
assert not player.can_reach(counter), "Player is too far away."
......@@ -182,7 +148,7 @@ def test_pickup(env_config, layout_config, item_info):
env.add_player("1", np.array([2, 3]))
player = env.players["1"]
player.player_speed_units_per_seconds = 1
env.player_movement_speed = 1
move_down = Action("1", ActionType.MOVEMENT, np.array([0, -1]), duration=1)
move_up = Action("1", ActionType.MOVEMENT, np.array([0, 1]), duration=1)
......@@ -244,7 +210,7 @@ def test_processing(env_config, layout_config, item_info):
tomato = Item(name="Tomato", item_info=None)
env.add_player("1", np.array([2, 3]))
player = env.players["1"]
player.player_speed_units_per_seconds = 1
env.player_movement_speed = 1
player.holding = tomato
move = Action("1", ActionType.MOVEMENT, np.array([0, -1]), duration=1)
......@@ -276,6 +242,7 @@ def test_time_passed():
layouts_folder / "empty.layout",
ROOT_DIR / "game_content" / "item_info.yaml",
)
env.add_player("0")
env.reset_env_time()
passed_time = timedelta(seconds=10)
env.step(passed_time)
......@@ -297,6 +264,8 @@ def test_time_limit():
layouts_folder / "empty.layout",
ROOT_DIR / "game_content" / "item_info.yaml",
)
env.add_player("0")
env.reset_env_time()
assert not env.game_ended, "Game has not ended yet"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment