Skip to content
Snippets Groups Projects

Resolve "simple pathfinding"

Merged Fabian Heinrich requested to merge 120-simple-pathfinding into dev
Files
4
@@ -7,21 +7,45 @@ import time
from collections import defaultdict
from datetime import datetime, timedelta
import networkx
import numpy as np
import numpy.typing as npt
from websockets import connect
from cooperative_cuisine.action import ActionType, InterActionData, Action
from cooperative_cuisine.state_representation import (
create_movement_graph,
astar_heuristic,
restrict_movement_graph,
)
from cooperative_cuisine.utils import custom_asdict_factory
TIME_TO_STOP_ACTION = 3.0
def get_free_neighbours(
state: dict, counter_pos: list[float] | tuple[float, float] | npt.NDArray
) -> list[tuple[float, float]]:
width, height = state["kitchen"]["width"], state["kitchen"]["height"]
free_space = np.ones((width, height), dtype=bool)
for counter in state["counters"]:
grid_idx = np.array(counter["pos"]).astype(int)
free_space[*grid_idx] = False
i, j = np.array(counter_pos).astype(int)
free = []
for x, y in [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]:
if 0 < x < width and 0 < y < height and free_space[x, y]:
free.append((x, y))
return free
async def agent():
parser = argparse.ArgumentParser("Random agent")
parser.add_argument("--uri", type=str)
parser.add_argument("--player_id", type=str)
parser.add_argument("--player_hash", type=str)
parser.add_argument("--step_time", type=float, default=0.5)
parser.add_argument("--step_time", type=float, default=0.1)
args = parser.parse_args()
@@ -35,6 +59,8 @@ async def agent():
counters = None
movement_graph = None
player_info = {}
current_agent_pos = None
interaction_counter = None
@@ -60,6 +86,9 @@ async def agent():
if not state["all_players_ready"]:
continue
if movement_graph is None:
movement_graph = create_movement_graph(state, diagonal=True)
if counters is None:
counters = defaultdict(list)
for counter in state["counters"]:
@@ -125,10 +154,71 @@ async def agent():
task_type = None
match task_type:
case "GOTO":
diff = np.array(task_args) - np.array(current_agent_pos)
dist = np.linalg.norm(diff)
if dist > 1.2:
if dist != 0:
target_diff = np.array(task_args) - np.array(current_agent_pos)
target_dist = np.linalg.norm(target_diff)
source = tuple(
np.round(np.array(current_agent_pos)).astype(int)
)
target = tuple(np.array(task_args).astype(int))
target_free_spaces = get_free_neighbours(state, target)
paths = []
for free in target_free_spaces:
try:
modified_graph = restrict_movement_graph(
graph=movement_graph,
player_positions=[
p["pos"]
for p in state["players"]
if p["id"] != args.player_id
],
)
path = networkx.astar_path(
modified_graph,
source,
free,
heuristic=astar_heuristic,
)
paths.append(path)
except networkx.exception.NetworkXNoPath:
pass
except networkx.exception.NodeNotFound:
pass
if paths:
shortest_path = paths[np.argmin([len(p) for p in paths])]
if len(shortest_path) > 1:
node_diff = shortest_path[1] - np.array(
current_agent_pos
)
node_dist = np.linalg.norm(node_diff)
movement = node_diff / node_dist
else:
movement = target_diff / target_dist
else:
movement = np.array([0, 0])
task_type = None
task_args = None
if target_dist > 1.2:
if target_dist != 0:
# random_small_rotation_angle = (
# np.random.random() * np.pi * 0.1
# )
# rotation_matrix = np.array(
# [
# [
# np.cos(random_small_rotation_angle),
# -np.sin(random_small_rotation_angle),
# ],
# [
# np.sin(random_small_rotation_angle),
# np.cos(random_small_rotation_angle),
# ],
# ]
# )
# movement = rotation_matrix @ movement
await websocket.send(
json.dumps(
{
@@ -137,7 +227,7 @@ async def agent():
Action(
args.player_id,
ActionType.MOVEMENT,
(diff / dist).tolist(),
movement.tolist(),
args.step_time + 0.01,
),
dict_factory=custom_asdict_factory,
@@ -204,7 +294,9 @@ async def agent():
...
if not task_type:
task_type = random.choice(["GOTO", "PUT", "INTERACT"])
# task_type = random.choice(["GOTO"])
task_type = random.choice(["GOTO", "PUT"])
# task_type = random.choice(["GOTO", "PUT", "INTERACT"])
threshold = datetime.now() + timedelta(seconds=TIME_TO_STOP_ACTION)
if task_type == "GOTO":
counter_type = random.choice(list(counters.keys()))
Loading