Skip to content
Snippets Groups Projects
Commit a610ce6d authored by Fabian Heinrich's avatar Fabian Heinrich
Browse files

Merge branch '120-simple-pathfinding' into 'dev'

Resolve "simple pathfinding"

Closes #120

See merge request scs/cocosy/overcooked-simulator!89
parents d76a8d0a 8e81a464
No related branches found
No related tags found
1 merge request!89Resolve "simple pathfinding"
Pipeline #49050 passed
...@@ -7,25 +7,53 @@ import time ...@@ -7,25 +7,53 @@ import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
import networkx
import numpy as np import numpy as np
import numpy.typing as npt
from websockets import connect from websockets import connect
from cooperative_cuisine.action import ActionType, InterActionData, Action from cooperative_cuisine.action import ActionType, InterActionData, Action
from cooperative_cuisine.state_representation import (
create_movement_graph,
astar_heuristic,
restrict_movement_graph,
)
from cooperative_cuisine.utils import custom_asdict_factory from cooperative_cuisine.utils import custom_asdict_factory
TIME_TO_STOP_ACTION = 3.0 TIME_TO_STOP_ACTION = 3.0
ADD_RANDOM_MOVEMENTS = False
DIAGONAL_MOVEMENTS = True
AVOID_OTHER_PLAYERS = True
def get_free_neighbours(
state: dict, counter_pos: list[float] | tuple[float, float] | npt.NDArray
) -> list[tuple[float, float]]:
width, height = state["kitchen"]["width"], state["kitchen"]["height"]
free_space = np.ones((width, height), dtype=bool)
for counter in state["counters"]:
grid_idx = np.array(counter["pos"]).astype(int)
free_space[grid_idx[0], grid_idx[1]] = False
i, j = np.array(counter_pos).astype(int)
free = []
for x, y in [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]:
if 0 < x < width and 0 < y < height and free_space[x, y]:
free.append((x, y))
return free
async def agent(): async def agent():
parser = argparse.ArgumentParser("Random agent") parser = argparse.ArgumentParser("Random agent")
parser.add_argument("--uri", type=str) parser.add_argument("--uri", type=str)
parser.add_argument("--player_id", type=str) parser.add_argument("--player_id", type=str)
parser.add_argument("--player_hash", type=str) parser.add_argument("--player_hash", type=str)
parser.add_argument("--step_time", type=float, default=0.5) parser.add_argument("--step_time", type=float, default=0.1)
args = parser.parse_args() args = parser.parse_args()
async with connect(args.uri) as websocket: async with (connect(args.uri) as websocket):
await websocket.send( await websocket.send(
json.dumps({"type": "ready", "player_hash": args.player_hash}) json.dumps({"type": "ready", "player_hash": args.player_hash})
) )
...@@ -34,6 +62,9 @@ async def agent(): ...@@ -34,6 +62,9 @@ async def agent():
ended = False ended = False
counters = None counters = None
all_counters = None
movement_graph = None
player_info = {} player_info = {}
current_agent_pos = None current_agent_pos = None
...@@ -60,10 +91,16 @@ async def agent(): ...@@ -60,10 +91,16 @@ async def agent():
if not state["all_players_ready"]: if not state["all_players_ready"]:
continue continue
if movement_graph is None:
movement_graph = create_movement_graph(
state, diagonal=DIAGONAL_MOVEMENTS
)
if counters is None: if counters is None:
counters = defaultdict(list) counters = defaultdict(list)
for counter in state["counters"]: for counter in state["counters"]:
counters[counter["type"]].append(counter) counters[counter["type"]].append(counter)
all_counters = state["counters"]
for player in state["players"]: for player in state["players"]:
if player["id"] == args.player_id: if player["id"] == args.player_id:
...@@ -125,10 +162,77 @@ async def agent(): ...@@ -125,10 +162,77 @@ async def agent():
task_type = None task_type = None
match task_type: match task_type:
case "GOTO": case "GOTO":
diff = np.array(task_args) - np.array(current_agent_pos) target_diff = np.array(task_args) - np.array(current_agent_pos)
dist = np.linalg.norm(diff) target_dist = np.linalg.norm(target_diff)
if dist > 1.2:
if dist != 0: source = tuple(
np.round(np.array(current_agent_pos)).astype(int)
)
target = tuple(np.array(task_args).astype(int))
target_free_spaces = get_free_neighbours(state, target)
paths = []
for free in target_free_spaces:
try:
path = networkx.astar_path(
restrict_movement_graph(
graph=movement_graph,
player_positions=[
p["pos"]
for p in state["players"]
if p["id"] != args.player_id
],
)
if AVOID_OTHER_PLAYERS
else movement_graph,
source,
free,
heuristic=astar_heuristic,
)
paths.append(path)
except networkx.exception.NetworkXNoPath:
pass
except networkx.exception.NodeNotFound:
pass
if paths:
shortest_path = paths[np.argmin([len(p) for p in paths])]
if len(shortest_path) > 1:
node_diff = shortest_path[1] - np.array(
current_agent_pos
)
node_dist = np.linalg.norm(node_diff)
movement = node_diff / node_dist
else:
movement = target_diff / target_dist
do_movement = True
else:
# no paths available
print("NO PATHS")
# task_type = None
# task_args = None
do_movement = False
if target_dist > 1.2 and do_movement:
if target_dist != 0:
if ADD_RANDOM_MOVEMENTS:
random_small_rotation_angle = (
np.random.random() * np.pi * 0.1
)
rotation_matrix = np.array(
[
[
np.cos(random_small_rotation_angle),
-np.sin(random_small_rotation_angle),
],
[
np.sin(random_small_rotation_angle),
np.cos(random_small_rotation_angle),
],
]
)
movement = rotation_matrix @ movement
await websocket.send( await websocket.send(
json.dumps( json.dumps(
{ {
...@@ -137,7 +241,7 @@ async def agent(): ...@@ -137,7 +241,7 @@ async def agent():
Action( Action(
args.player_id, args.player_id,
ActionType.MOVEMENT, ActionType.MOVEMENT,
(diff / dist).tolist(), movement.tolist(),
args.step_time + 0.01, args.step_time + 0.01,
), ),
dict_factory=custom_asdict_factory, dict_factory=custom_asdict_factory,
...@@ -148,6 +252,8 @@ async def agent(): ...@@ -148,6 +252,8 @@ async def agent():
) )
await websocket.recv() await websocket.recv()
else: else:
# Target reached here.
print("TARGET REACHED")
task_type = None task_type = None
task_args = None task_args = None
case "INTERACT": case "INTERACT":
...@@ -204,12 +310,23 @@ async def agent(): ...@@ -204,12 +310,23 @@ async def agent():
... ...
if not task_type: if not task_type:
task_type = random.choice(["GOTO", "PUT", "INTERACT"]) # task_type = random.choice(["GOTO", "PUT", "INTERACT"])
task_type = random.choice(["GOTO"])
threshold = datetime.now() + timedelta(seconds=TIME_TO_STOP_ACTION) threshold = datetime.now() + timedelta(seconds=TIME_TO_STOP_ACTION)
if task_type == "GOTO": if task_type == "GOTO":
counter_type = random.choice(list(counters.keys())) # counter_type = random.choice(list(counters.keys()))
task_args = random.choice(counters[counter_type])["pos"] # task_args = random.choice(counters[counter_type])["pos"]
print(args.player_hash, args.player_id, task_type, counter_type)
random_counter = random.choice(all_counters)
counter_type = random_counter["type"]
task_args = random_counter["pos"]
print(
args.player_hash,
args.player_id,
task_type,
counter_type,
task_args,
)
else: else:
print(args.player_hash, args.player_id, task_type) print(args.player_hash, args.player_id, task_type)
task_args = None task_args = None
......
plates:
clean_plates: 2
dirty_plates: 1
plate_delay: [ 5, 10 ]
# range of seconds until the dirty plate arrives.
game:
time_limit_seconds: 300
undo_dispenser_pickup: true
validate_recipes: false
layout_chars:
_: Free
hash: Counter # #
A: Agent
pipe: Extinguisher
P: PlateDispenser
C: CuttingBoard
X: Trashcan
$: ServingWindow
S: Sink
+: SinkAddon
at: Plate # @ just a clean plate on a counter
U: Pot # with Stove
Q: Pan # with Stove
O: Peel # with Oven
F: Basket # with DeepFryer
T: Tomato
N: Onion # oNioN
L: Lettuce
K: Potato # Kartoffel
I: Fish # fIIIsh
D: Dough
E: Cheese # chEEEse
G: Sausage # sausaGe
B: Bun
M: Meat
question: Counter # ? mushroom
: Counter
^: Counter
right: Counter
left: Counter
wave: Free # ~ Water
minus: Free # - Ice
dquote: Counter # " wall/truck
p: Counter # second plate return ??
orders:
meals:
all: true
# if all: false -> only orders for these meals are generated
# TODO: what if this list is empty?
list:
# - TomatoSoup
# - OnionSoup
# - Salad
- FriedFish
order_gen_class: !!python/name:cooperative_cuisine.orders.RandomOrderGeneration ''
# the class to that receives the kwargs. Should be a child class of OrderGeneration in orders.py
order_gen_kwargs:
order_duration_random_func:
# how long should the orders be alive
# 'random' library call with getattr, kwargs are passed to the function
func: uniform
kwargs:
a: 40
b: 60
max_orders: 6
# maximum number of active orders at the same time
num_start_meals: 2
# number of orders generated at the start of the environment
sample_on_dur_random_func:
# 'random' library call with getattr, kwargs are passed to the function
func: uniform
kwargs:
a: 10
b: 20
sample_on_serving: false
# Sample the delay for the next order only after a meal was served.
serving_not_ordered_meals: true
# can meals that are not ordered be served / dropped on the serving window
player_config:
radius: 0.4
speed_units_per_seconds: 6
interaction_range: 1.6
restricted_view: False
view_angle: 70
view_range: 4 # in grid units, can be "null"
effect_manager:
FireManager:
class: !!python/name:cooperative_cuisine.effects.FireEffectManager ''
kwargs:
spreading_duration: [ 5, 10 ]
fire_burns_ingredients_and_meals: true
hook_callbacks:
# # --------------- Scoring ---------------
orders:
hooks: [ completed_order ]
callback_class: !!python/name:cooperative_cuisine.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: 20
score_on_specific_kwarg: meal_name
score_map:
Burger: 15
OnionSoup: 10
Salad: 5
TomatoSoup: 10
not_ordered_meals:
hooks: [ serve_not_ordered_meal ]
callback_class: !!python/name:cooperative_cuisine.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: 2
trashcan_usages:
hooks: [ trashcan_usage ]
callback_class: !!python/name:cooperative_cuisine.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: -5
expired_orders:
hooks: [ order_expired ]
callback_class: !!python/name:cooperative_cuisine.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: -10
# --------------- Recording ---------------
# json_states:
# hooks: [ json_state ]
# callback_class: !!python/name:cooperative_cuisine.recording.FileRecorder ''
# callback_class_kwargs:
# record_path: USER_LOG_DIR/ENV_NAME/json_states.jsonl
actions:
hooks: [ pre_perform_action ]
callback_class: !!python/name:cooperative_cuisine.recording.FileRecorder ''
callback_class_kwargs:
record_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl
random_env_events:
hooks: [ order_duration_sample, plate_out_of_kitchen_time ]
callback_class: !!python/name:cooperative_cuisine.recording.FileRecorder ''
callback_class_kwargs:
record_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl
add_hook_ref: true
env_configs:
hooks: [ env_initialized, item_info_config ]
callback_class: !!python/name:cooperative_cuisine.recording.FileRecorder ''
callback_class_kwargs:
record_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl
add_hook_ref: true
# Game event recording
game_events:
hooks:
- post_counter_pick_up
- post_counter_drop_off
- post_dispenser_pick_up
- cutting_board_100
- player_start_interaction
- player_end_interact
- post_serving
- no_serving
- dirty_plate_arrives
- trashcan_usage
- plate_cleaned
- added_plate_to_sink
- drop_on_sink_addon
- pick_up_from_sink_addon
- serve_not_ordered_meal
- serve_without_plate
- completed_order
- new_orders
- order_expired
- action_on_not_reachable_counter
- new_fire
- fire_spreading
- drop_off_on_cooking_equipment
- players_collide
- post_plate_dispenser_pick_up
- post_plate_dispenser_drop_off
- on_item_transition
- progress_started
- progress_finished
- content_ready
- dispenser_item_returned
callback_class: !!python/name:cooperative_cuisine.recording.FileRecorder ''
callback_class_kwargs:
record_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl
add_hook_ref: true
# info_msg:
# func: !!python/name:cooperative_cuisine.hooks.hooks_via_callback_class ''
# kwargs:
# hooks: [ cutting_board_100 ]
# callback_class: !!python/name:cooperative_cuisine.info_msg.InfoMsgManager ''
# callback_class_kwargs:
# msg: Glückwunsch du hast was geschnitten!
# fire_msg:
# func: !!python/name:cooperative_cuisine.hooks.hooks_via_callback_class ''
# kwargs:
# hooks: [ new_fire ]
# callback_class: !!python/name:cooperative_cuisine.info_msg.InfoMsgManager ''
# callback_class_kwargs:
# msg: Feuer, Feuer, Feuer
# level: Warning
...@@ -193,32 +193,37 @@ class Movement: ...@@ -193,32 +193,37 @@ class Movement:
updated_movement * (self.player_movement_speed * d_time) updated_movement * (self.player_movement_speed * d_time)
) )
# Check collisions with counters # check if players collided with counters through movement or through being pushed
( (
collided, collided,
relevant_axes, relevant_axes,
nearest_counter_to_player, nearest_counter_to_player,
) = self.get_counter_collisions(new_targeted_positions) ) = self.get_counter_collisions(new_targeted_positions)
# Check if sliding against counters is possible # If collided, check if the players could still move along the axis, starting with x
for idx, player in enumerate(player_positions): # This leads to players beeing able to slide along counters, which feels alot nicer.
axis = relevant_axes[idx] projected_x = updated_movement.copy()
if collided[idx]: projected_x[collided, 1] = 0
# collide with counter left or top new_targeted_positions[collided] = player_positions[collided] + (
if nearest_counter_to_player[idx][axis] > 0: projected_x[collided] * (self.player_movement_speed * d_time)
updated_movement[idx, axis] = np.max( )
[updated_movement[idx, axis], 0] # checking collisions again
) (
# collide with counter right or bottom collided,
if nearest_counter_to_player[idx][axis] < 0: relevant_axes,
updated_movement[idx, axis] = np.min( nearest_counter_to_player,
[updated_movement[idx, axis], 0] ) = self.get_counter_collisions(new_targeted_positions)
) new_targeted_positions[collided] = player_positions[collided]
new_positions = player_positions + ( # and now y axis collisions
updated_movement * (self.player_movement_speed * d_time) projected_y = updated_movement.copy()
projected_y[collided, 0] = 0
new_targeted_positions[collided] = player_positions[collided] + (
projected_y[collided] * (self.player_movement_speed * d_time)
) )
new_positions = new_targeted_positions
# Check collisions with counters again, now absolute with no sliding possible # Check collisions with counters a final time, now absolute with no sliding possible.
# Players should never be able to enter counters this way.
( (
collided, collided,
relevant_axes, relevant_axes,
...@@ -226,7 +231,7 @@ class Movement: ...@@ -226,7 +231,7 @@ class Movement:
) = self.get_counter_collisions(new_positions) ) = self.get_counter_collisions(new_positions)
new_positions[collided] = player_positions[collided] new_positions[collided] = player_positions[collided]
# Collisions player world borders # Collisions of players with world borders
new_positions = np.clip( new_positions = np.clip(
new_positions, new_positions,
self.world_borders_lower + self.player_radius, self.world_borders_lower + self.player_radius,
......
...@@ -134,7 +134,9 @@ class Player: ...@@ -134,7 +134,9 @@ class Player:
def update_facing_point(self): def update_facing_point(self):
"""Update facing point on the player border circle based on the radius.""" """Update facing point on the player border circle based on the radius."""
self.facing_point = self.pos + ( self.facing_point = self.pos + (
self.facing_direction * self.player_config.radius * 0.5 self.facing_direction
* self.player_config.radius
* self.player_config.interaction_range
) )
def can_reach(self, counter: Counter) -> bool: def can_reach(self, counter: Counter) -> bool:
......
...@@ -149,7 +149,7 @@ class Visualizer: ...@@ -149,7 +149,7 @@ class Visualizer:
grid_size, grid_size,
) )
for idx, col in zip(controlled_player_idxs, [colors["blue"], colors["red"]]): for idx, col in zip(controlled_player_idxs, [ colors["red"], colors["blue"]]):
pygame.draw.circle( pygame.draw.circle(
screen, screen,
col, col,
......
...@@ -6,6 +6,10 @@ from datetime import datetime ...@@ -6,6 +6,10 @@ from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any from typing import Any
import networkx
import numpy as np
import numpy.typing as npt
from networkx import Graph
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Literal, TypedDict from typing_extensions import Literal, TypedDict
...@@ -186,6 +190,90 @@ class StateRepresentation(BaseModel): ...@@ -186,6 +190,90 @@ class StateRepresentation(BaseModel):
"""Added by the game server, indicate if all players are ready and actions are passed to the environment.""" """Added by the game server, indicate if all players are ready and actions are passed to the environment."""
def astar_heuristic(x, y):
"""Heuristic distance function used in astar algorithm."""
return np.linalg.norm(np.array(x) - np.array(y))
def create_movement_graph(state: StateRepresentation, diagonal=True) -> Graph:
"""
Creates a graph which represents the connections of empty kitchen tiles and such
possible coarse movements of an agent.
Args:
state: State representation to determine the graph to.
diagonal: if True use 8 way connection, i.e. diagonal connections between the spaces.
Returns: Graph representing the connections between empty kitchen tiles.
"""
width, height = state["kitchen"]["width"], state["kitchen"]["height"]
free_space = np.ones((width, height), dtype=bool)
for counter in state["counters"]:
grid_idx = np.array(counter["pos"]).round().astype(int)
free_space[grid_idx[0], grid_idx[1]] = False
graph = networkx.Graph()
for i in range(width):
for j in range(height):
if free_space[i, j]:
graph.add_node((i, j))
if diagonal:
for di in range(-1, 2):
for dj in range(-1, 2):
x, y = i + di, j + dj
if (
0 <= x < width
and 0 < y < height
and free_space[x, y]
and (di, dj) != (0, 0)
):
if np.sum(np.abs(np.array([di, dj]))) == 2:
if free_space[i + di, j] and free_space[i, j + dj]:
graph.add_edge(
(i, j),
(x, y),
weight=np.linalg.norm(
np.array([i - x, j - y])
),
)
else:
graph.add_edge(
(i, j),
(x, y),
weight=np.linalg.norm(np.array([i - x, j - y])),
)
else:
for x, y in [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]:
if 0 <= x < width and 0 <= y < height and free_space[x, y]:
graph.add_edge(
(i, j),
(x, y),
weight=1,
)
return graph
def restrict_movement_graph(
graph: Graph,
player_positions: list[tuple[float, float] | list[float]] | npt.NDArray[float],
) -> Graph:
"""Modifies a given movement graph. Removed the nodes of spaces on which players stand.
Args:
graph: The graph to modify.
player_positions: Positions of players.
Returns: The modified graph without nodes where players stand.
"""
copied = graph.copy()
for pos in player_positions:
tup = tuple(np.array(pos).round().astype(int))
if tup in copied.nodes.keys():
copied.remove_node(tup)
return copied
def create_json_schema() -> dict[str, Any]: def create_json_schema() -> dict[str, Any]:
"""Create a json scheme of the state representation of an environment.""" """Create a json scheme of the state representation of an environment."""
return StateRepresentation.model_json_schema() return StateRepresentation.model_json_schema()
......
...@@ -43,8 +43,11 @@ from cooperative_cuisine.utils import create_init_env_time, get_touching_counter ...@@ -43,8 +43,11 @@ from cooperative_cuisine.utils import create_init_env_time, get_touching_counter
layouts_folder = ROOT_DIR / "configs" / "layouts" layouts_folder = ROOT_DIR / "configs" / "layouts"
environment_config_path = ROOT_DIR / "configs" / "environment_config.yaml" environment_config_path = ROOT_DIR / "configs" / "environment_config.yaml"
environment_config_no_validation_path = (
ROOT_DIR / "configs" / "environment_config_no_validation.yaml"
)
layout_path = ROOT_DIR / "configs" / "layouts" / "basic.layout" layout_path = ROOT_DIR / "configs" / "layouts" / "basic.layout"
layout_empty_path = ROOT_DIR / "configs" / "layouts" / "basic.layout" layout_empty_path = ROOT_DIR / "configs" / "layouts" / "empty.layout"
item_info_path = ROOT_DIR / "configs" / "item_info.yaml" item_info_path = ROOT_DIR / "configs" / "item_info.yaml"
# TODO: TESTs are in absolute pixel coordinates still. # TODO: TESTs are in absolute pixel coordinates still.
...@@ -54,6 +57,9 @@ item_info_path = ROOT_DIR / "configs" / "item_info.yaml" ...@@ -54,6 +57,9 @@ item_info_path = ROOT_DIR / "configs" / "item_info.yaml"
def test_file_availability(): def test_file_availability():
assert layouts_folder.is_dir(), "layouts folder does not exists" assert layouts_folder.is_dir(), "layouts folder does not exists"
assert environment_config_path.is_file(), "environment config file does not exists" assert environment_config_path.is_file(), "environment config file does not exists"
assert (
environment_config_no_validation_path.is_file()
), "environment config file does not exists"
assert layout_path.is_file(), "layout config file does not exists" assert layout_path.is_file(), "layout config file does not exists"
assert layout_empty_path.is_file(), "layout empty config file does not exists" assert layout_empty_path.is_file(), "layout empty config file does not exists"
assert item_info_path.is_file(), "item info config file does not exists" assert item_info_path.is_file(), "item info config file does not exists"
...@@ -67,6 +73,13 @@ def env_config(): ...@@ -67,6 +73,13 @@ def env_config():
return env_config return env_config
@pytest.fixture
def env_config_no_validation():
with open(environment_config_no_validation_path, "r") as file:
env_config = file.read()
return env_config
@pytest.fixture @pytest.fixture
def layout_config(): def layout_config():
with open(layout_path, "r") as file: with open(layout_path, "r") as file:
...@@ -76,7 +89,7 @@ def layout_config(): ...@@ -76,7 +89,7 @@ def layout_config():
@pytest.fixture @pytest.fixture
def layout_empty_config(): def layout_empty_config():
with open(layout_path, "r") as file: with open(layout_empty_path, "r") as file:
layout = file.read() layout = file.read()
return layout return layout
...@@ -101,8 +114,10 @@ def test_player_registration(env_config, layout_config, item_info): ...@@ -101,8 +114,10 @@ def test_player_registration(env_config, layout_config, item_info):
env.add_player("2") env.add_player("2")
def test_movement(env_config, layout_empty_config, item_info): def test_movement(env_config_no_validation, layout_empty_config, item_info):
env = Environment(env_config, layout_empty_config, item_info, as_files=False) env = Environment(
env_config_no_validation, layout_empty_config, item_info, as_files=False
)
player_name = "1" player_name = "1"
start_pos = np.array([3, 4]) start_pos = np.array([3, 4])
env.add_player(player_name, start_pos) env.add_player(player_name, start_pos)
...@@ -122,8 +137,12 @@ def test_movement(env_config, layout_empty_config, item_info): ...@@ -122,8 +137,12 @@ def test_movement(env_config, layout_empty_config, item_info):
), "Performed movement do not move the player as expected." ), "Performed movement do not move the player as expected."
def test_player_movement_speed(env_config, layout_empty_config, item_info): def test_player_movement_speed(
env = Environment(env_config, layout_empty_config, item_info, as_files=False) env_config_no_validation, layout_empty_config, item_info
):
env = Environment(
env_config_no_validation, layout_empty_config, item_info, as_files=False
)
player_name = "1" player_name = "1"
start_pos = np.array([3, 4]) start_pos = np.array([3, 4])
env.add_player(player_name, start_pos) env.add_player(player_name, start_pos)
...@@ -148,8 +167,10 @@ def test_player_movement_speed(env_config, layout_empty_config, item_info): ...@@ -148,8 +167,10 @@ def test_player_movement_speed(env_config, layout_empty_config, item_info):
), "json state does not match expected StateRepresentation." ), "json state does not match expected StateRepresentation."
def test_player_reach(env_config, layout_empty_config, item_info): def test_player_reach(env_config_no_validation, layout_empty_config, item_info):
env = Environment(env_config, layout_empty_config, item_info, as_files=False) env = Environment(
env_config_no_validation, layout_empty_config, item_info, as_files=False
)
counter_pos = np.array([2, 2]) counter_pos = np.array([2, 2])
counter = Counter(pos=counter_pos, hook=Hooks(env)) counter = Counter(pos=counter_pos, hook=Hooks(env))
......
import json
from argparse import ArgumentParser from argparse import ArgumentParser
import networkx
import pytest
from cooperative_cuisine.environment import Environment
from cooperative_cuisine.state_representation import (
create_movement_graph,
restrict_movement_graph,
astar_heuristic,
)
from cooperative_cuisine.utils import ( from cooperative_cuisine.utils import (
url_and_port_arguments, url_and_port_arguments,
add_list_of_manager_ids_arguments, add_list_of_manager_ids_arguments,
...@@ -9,6 +19,8 @@ from cooperative_cuisine.utils import ( ...@@ -9,6 +19,8 @@ from cooperative_cuisine.utils import (
create_layout_with_counters, create_layout_with_counters,
setup_logging, setup_logging,
) )
from tests.test_start import env_config_no_validation
from tests.test_start import layout_empty_config, item_info
def test_parser_gen(): def test_parser_gen():
...@@ -44,3 +56,92 @@ def test_layout_creation(): ...@@ -44,3 +56,92 @@ def test_layout_creation():
def test_setup_logging(): def test_setup_logging():
setup_logging() setup_logging()
def test_movement_graph(env_config_no_validation, layout_empty_config, item_info):
env = Environment(
env_config_no_validation, layout_empty_config, item_info, as_files=False
)
player_name = "0"
env.add_player(player_name)
state_string = env.get_json_state(player_id=player_name)
state = json.loads(state_string)
graph_diag = create_movement_graph(state, diagonal=True)
graph = create_movement_graph(
json.loads(env.get_json_state(player_id=player_name)), diagonal=False
)
path = networkx.astar_path(
graph,
source=(0, 0),
target=(3, 3),
heuristic=astar_heuristic,
)
assert len(path) != 0, "No path found, but should have."
graph_restricted = restrict_movement_graph(graph_diag, [(1, 0), (0, 1), (1, 1)])
with pytest.raises(networkx.exception.NetworkXNoPath) as e_info:
path = networkx.astar_path(
graph_restricted,
source=(0, 0),
target=(3, 3),
heuristic=astar_heuristic,
)
with pytest.raises(networkx.exception.NodeNotFound) as e_info:
path = networkx.astar_path(
graph_restricted,
source=(20, 20),
target=(40, 40),
heuristic=astar_heuristic,
)
path = networkx.astar_path(
restrict_movement_graph(
graph=graph_diag,
player_positions=[],
),
source=(0, 0),
target=(5, 5),
heuristic=astar_heuristic,
)
assert len(path) != 0, "No path found, but should have."
# now with diagonal movement
graph = create_movement_graph(
json.loads(env.get_json_state(player_id=player_name)), diagonal=True
)
path = networkx.astar_path(
graph,
source=(0, 0),
target=(3, 3),
heuristic=astar_heuristic,
)
assert len(path) != 0, "No path found, but should have."
graph_restricted = restrict_movement_graph(graph_diag, [(1, 0), (0, 1), (1, 1)])
with pytest.raises(networkx.exception.NetworkXNoPath) as e_info:
path = networkx.astar_path(
graph_restricted,
source=(0, 0),
target=(3, 3),
heuristic=astar_heuristic,
)
with pytest.raises(networkx.exception.NodeNotFound) as e_info:
path = networkx.astar_path(
graph_restricted,
source=(20, 20),
target=(40, 40),
heuristic=astar_heuristic,
)
path = networkx.astar_path(
restrict_movement_graph(
graph=graph_diag,
player_positions=[],
),
source=(0, 0),
target=(5, 5),
heuristic=astar_heuristic,
)
assert len(path) != 0, "No path found, but should have."
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment