Skip to content
Snippets Groups Projects

Resolve "simple pathfinding"

Merged Fabian Heinrich requested to merge 120-simple-pathfinding into dev
Files
10
@@ -7,25 +7,53 @@ import time
from collections import defaultdict
from datetime import datetime, timedelta
import networkx
import numpy as np
import numpy.typing as npt
from websockets import connect
from cooperative_cuisine.action import ActionType, InterActionData, Action
from cooperative_cuisine.state_representation import (
create_movement_graph,
astar_heuristic,
restrict_movement_graph,
)
from cooperative_cuisine.utils import custom_asdict_factory
TIME_TO_STOP_ACTION = 3.0
ADD_RANDOM_MOVEMENTS = False
DIAGONAL_MOVEMENTS = True
AVOID_OTHER_PLAYERS = True
def get_free_neighbours(
state: dict, counter_pos: list[float] | tuple[float, float] | npt.NDArray
) -> list[tuple[float, float]]:
width, height = state["kitchen"]["width"], state["kitchen"]["height"]
free_space = np.ones((width, height), dtype=bool)
for counter in state["counters"]:
grid_idx = np.array(counter["pos"]).astype(int)
free_space[grid_idx[0], grid_idx[1]] = False
i, j = np.array(counter_pos).astype(int)
free = []
for x, y in [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]:
if 0 < x < width and 0 < y < height and free_space[x, y]:
free.append((x, y))
return free
async def agent():
parser = argparse.ArgumentParser("Random agent")
parser.add_argument("--uri", type=str)
parser.add_argument("--player_id", type=str)
parser.add_argument("--player_hash", type=str)
parser.add_argument("--step_time", type=float, default=0.5)
parser.add_argument("--step_time", type=float, default=0.1)
args = parser.parse_args()
async with connect(args.uri) as websocket:
async with (connect(args.uri) as websocket):
await websocket.send(
json.dumps({"type": "ready", "player_hash": args.player_hash})
)
@@ -34,6 +62,9 @@ async def agent():
ended = False
counters = None
all_counters = None
movement_graph = None
player_info = {}
current_agent_pos = None
@@ -60,10 +91,16 @@ async def agent():
if not state["all_players_ready"]:
continue
if movement_graph is None:
movement_graph = create_movement_graph(
state, diagonal=DIAGONAL_MOVEMENTS
)
if counters is None:
counters = defaultdict(list)
for counter in state["counters"]:
counters[counter["type"]].append(counter)
all_counters = state["counters"]
for player in state["players"]:
if player["id"] == args.player_id:
@@ -125,10 +162,77 @@ async def agent():
task_type = None
match task_type:
case "GOTO":
diff = np.array(task_args) - np.array(current_agent_pos)
dist = np.linalg.norm(diff)
if dist > 1.2:
if dist != 0:
target_diff = np.array(task_args) - np.array(current_agent_pos)
target_dist = np.linalg.norm(target_diff)
source = tuple(
np.round(np.array(current_agent_pos)).astype(int)
)
target = tuple(np.array(task_args).astype(int))
target_free_spaces = get_free_neighbours(state, target)
paths = []
for free in target_free_spaces:
try:
path = networkx.astar_path(
restrict_movement_graph(
graph=movement_graph,
player_positions=[
p["pos"]
for p in state["players"]
if p["id"] != args.player_id
],
)
if AVOID_OTHER_PLAYERS
else movement_graph,
source,
free,
heuristic=astar_heuristic,
)
paths.append(path)
except networkx.exception.NetworkXNoPath:
pass
except networkx.exception.NodeNotFound:
pass
if paths:
shortest_path = paths[np.argmin([len(p) for p in paths])]
if len(shortest_path) > 1:
node_diff = shortest_path[1] - np.array(
current_agent_pos
)
node_dist = np.linalg.norm(node_diff)
movement = node_diff / node_dist
else:
movement = target_diff / target_dist
do_movement = True
else:
# no paths available
print("NO PATHS")
# task_type = None
# task_args = None
do_movement = False
if target_dist > 1.2 and do_movement:
if target_dist != 0:
if ADD_RANDOM_MOVEMENTS:
random_small_rotation_angle = (
np.random.random() * np.pi * 0.1
)
rotation_matrix = np.array(
[
[
np.cos(random_small_rotation_angle),
-np.sin(random_small_rotation_angle),
],
[
np.sin(random_small_rotation_angle),
np.cos(random_small_rotation_angle),
],
]
)
movement = rotation_matrix @ movement
await websocket.send(
json.dumps(
{
@@ -137,7 +241,7 @@ async def agent():
Action(
args.player_id,
ActionType.MOVEMENT,
(diff / dist).tolist(),
movement.tolist(),
args.step_time + 0.01,
),
dict_factory=custom_asdict_factory,
@@ -148,6 +252,8 @@ async def agent():
)
await websocket.recv()
else:
# Target reached here.
print("TARGET REACHED")
task_type = None
task_args = None
case "INTERACT":
@@ -204,12 +310,23 @@ async def agent():
...
if not task_type:
task_type = random.choice(["GOTO", "PUT", "INTERACT"])
# task_type = random.choice(["GOTO", "PUT", "INTERACT"])
task_type = random.choice(["GOTO"])
threshold = datetime.now() + timedelta(seconds=TIME_TO_STOP_ACTION)
if task_type == "GOTO":
counter_type = random.choice(list(counters.keys()))
task_args = random.choice(counters[counter_type])["pos"]
print(args.player_hash, args.player_id, task_type, counter_type)
# counter_type = random.choice(list(counters.keys()))
# task_args = random.choice(counters[counter_type])["pos"]
random_counter = random.choice(all_counters)
counter_type = random_counter["type"]
task_args = random_counter["pos"]
print(
args.player_hash,
args.player_id,
task_type,
counter_type,
task_args,
)
else:
print(args.player_hash, args.player_id, task_type)
task_args = None
Loading