Newer
Older
import argparse
import asyncio
import dataclasses
import json
import random
import time
from collections import defaultdict
from datetime import datetime, timedelta
import numpy as np
import numpy.typing as npt
from websockets import connect
from cooperative_cuisine.action import ActionType, InterActionData, Action
from cooperative_cuisine.state_representation import (
create_movement_graph,
astar_heuristic,
restrict_movement_graph,
)
from cooperative_cuisine.utils import custom_asdict_factory
ADD_RANDOM_MOVEMENTS = False
AVOID_OTHER_PLAYERS = True
def get_free_neighbours(
state: dict, counter_pos: list[float] | tuple[float, float] | npt.NDArray
) -> list[tuple[float, float]]:
width, height = state["kitchen"]["width"], state["kitchen"]["height"]
free_space = np.ones((width, height), dtype=bool)
for counter in state["counters"]:
grid_idx = np.array(counter["pos"]).astype(int)
i, j = np.array(counter_pos).astype(int)
free = []
for x, y in [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]:
if 0 < x < width and 0 < y < height and free_space[x, y]:
free.append((x, y))
return free
async def agent():
parser = argparse.ArgumentParser("Random agent")
parser.add_argument("--uri", type=str)
parser.add_argument("--player_id", type=str)
parser.add_argument("--player_hash", type=str)
parser.add_argument("--step_time", type=float, default=0.1)
args = parser.parse_args()
async with (connect(args.uri) as websocket):
await websocket.send(
json.dumps({"type": "ready", "player_hash": args.player_hash})
)
await websocket.recv()
ended = False
counters = None
player_info = {}
current_agent_pos = None
interaction_counter = None
last_interacting = False
last_interact_progress = None
threshold = datetime.max
task_type = None
task_args = None
started_interaction = False
still_interacting = False
current_nearest_counter_id = None
while not ended:
time.sleep(args.step_time)
await websocket.send(
json.dumps({"type": "get_state", "player_hash": args.player_hash})
)
state = json.loads(await websocket.recv())
if not state["all_players_ready"]:
continue
if movement_graph is None:
movement_graph = create_movement_graph(
state, diagonal=DIAGONAL_MOVEMENTS
)
if counters is None:
counters = defaultdict(list)
for counter in state["counters"]:
counters[counter["type"]].append(counter)
all_counters = state["counters"]
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
for player in state["players"]:
if player["id"] == args.player_id:
player_info = player
current_agent_pos = player["pos"]
if player["current_nearest_counter_id"]:
if (
current_nearest_counter_id
!= player["current_nearest_counter_id"]
):
for counter in state["counters"]:
if (
counter["id"]
== player["current_nearest_counter_id"]
):
interaction_counter = counter
current_nearest_counter_id = player[
"current_nearest_counter_id"
]
break
if last_interacting:
if (
not interaction_counter
or not interaction_counter["occupied_by"]
or isinstance(interaction_counter["occupied_by"], list)
or (
interaction_counter["occupied_by"][
"progress_percentage"
]
== 1.0
)
):
last_interacting = False
last_interact_progress = None
else:
interaction_counter
and interaction_counter["occupied_by"]
and not isinstance(interaction_counter["occupied_by"], list)
if (
last_interact_progress
!= interaction_counter["occupied_by"][
"progress_percentage"
]
):
last_interact_progress = interaction_counter[
"occupied_by"
]["progress_percentage"]
last_interacting = True
break
if task_type:
if threshold < datetime.now():
print(
args.player_hash, args.player_id, "---Threshold---Too long---"
)
task_type = None
match task_type:
case "GOTO":
target_diff = np.array(task_args) - np.array(current_agent_pos)
target_dist = np.linalg.norm(target_diff)
source = tuple(
np.round(np.array(current_agent_pos)).astype(int)
)
target = tuple(np.array(task_args).astype(int))
target_free_spaces = get_free_neighbours(state, target)
paths = []
for free in target_free_spaces:
try:
path = networkx.astar_path(
restrict_movement_graph(
player_positions=[
p["pos"]
for p in state["players"]
if p["id"] != args.player_id
],
)
if AVOID_OTHER_PLAYERS
else movement_graph,
source,
free,
heuristic=astar_heuristic,
)
paths.append(path)
except networkx.exception.NetworkXNoPath:
pass
except networkx.exception.NodeNotFound:
pass
if paths:
shortest_path = paths[np.argmin([len(p) for p in paths])]
if len(shortest_path) > 1:
node_diff = shortest_path[1] - np.array(
current_agent_pos
)
node_dist = np.linalg.norm(node_diff)
movement = node_diff / node_dist
else:
movement = target_diff / target_dist
# task_type = None
# task_args = None
do_movement = False
if target_dist > 1.2 and do_movement:
if ADD_RANDOM_MOVEMENTS:
random_small_rotation_angle = (
np.random.random() * np.pi * 0.1
)
rotation_matrix = np.array(
[
[
np.cos(random_small_rotation_angle),
-np.sin(random_small_rotation_angle),
],
[
np.sin(random_small_rotation_angle),
np.cos(random_small_rotation_angle),
],
]
)
movement = rotation_matrix @ movement
await websocket.send(
json.dumps(
{
"type": "action",
"action": dataclasses.asdict(
Action(
args.player_id,
ActionType.MOVEMENT,
args.step_time + 0.01,
),
dict_factory=custom_asdict_factory,
),
"player_hash": args.player_hash,
}
)
)
await websocket.recv()
else:
task_type = None
task_args = None
case "INTERACT":
if not started_interaction or (
still_interacting and interaction_counter
):
if not started_interaction:
started_interaction = True
still_interacting = True
await websocket.send(
json.dumps(
{
"type": "action",
"action": dataclasses.asdict(
Action(
args.player_id,
ActionType.INTERACT,
InterActionData.START,
),
dict_factory=custom_asdict_factory,
),
"player_hash": args.player_hash,
}
)
)
await websocket.recv()
else:
still_interacting = False
started_interaction = False
task_type = None
task_args = None
case "PUT":
await websocket.send(
json.dumps(
{
"type": "action",
"action": dataclasses.asdict(
Action(
args.player_id,
),
dict_factory=custom_asdict_factory,
),
"player_hash": args.player_hash,
}
)
)
await websocket.recv()
task_type = None
task_args = None
case None:
...
if not task_type:
# task_type = random.choice(["GOTO", "PUT", "INTERACT"])
task_type = random.choice(["GOTO"])
threshold = datetime.now() + timedelta(seconds=TIME_TO_STOP_ACTION)
if task_type == "GOTO":
# counter_type = random.choice(list(counters.keys()))
# task_args = random.choice(counters[counter_type])["pos"]
random_counter = random.choice(all_counters)
counter_type = random_counter["type"]
task_args = random_counter["pos"]
print(
args.player_hash,
args.player_id,
task_type,
counter_type,
task_args,
)
else:
print(args.player_hash, args.player_id, task_type)
task_args = None
ended = state["ended"]
if __name__ == "__main__":
asyncio.run(agent())