Skip to content
Snippets Groups Projects
Commit 2834b4e7 authored by Florian Schröder's avatar Florian Schröder
Browse files

Refactor random_agent to use asyncio and correct counter handling

The `random_agent.py` has been refactored to use asyncio. This modification allows for more efficient handling of asynchronous operations. Additionally, a fix has been implemented to correctly identify and handle the "current nearest counter ID", making sure that the agent continues performing its tasks in a correct manner.
parent 450fcefb
No related branches found
No related tags found
1 merge request!44Resolve "GUI Player Management"
Pipeline #45537 failed
import argparse
import asyncio
import dataclasses
import json
import random
......@@ -7,7 +8,7 @@ from collections import defaultdict
from datetime import datetime, timedelta
import numpy as np
from websockets.sync.client import connect
from websockets import connect
from overcooked_simulator.overcooked_environment import (
ActionType,
......@@ -16,109 +17,156 @@ from overcooked_simulator.overcooked_environment import (
)
from overcooked_simulator.utils import custom_asdict_factory
parser = argparse.ArgumentParser("Random agent")
parser.add_argument("--uri", type=str)
parser.add_argument("--player_id", type=str)
parser.add_argument("--player_hash", type=str)
parser.add_argument("--step_time", type=float, default=0.1)
args = parser.parse_args()
async def agent():
parser = argparse.ArgumentParser("Random agent")
parser.add_argument("--uri", type=str)
parser.add_argument("--player_id", type=str)
parser.add_argument("--player_hash", type=str)
parser.add_argument("--step_time", type=float, default=0.1)
websocket = connect(args.uri)
try:
websocket.send(json.dumps({"type": "ready", "player_hash": args.player_hash}))
websocket.recv()
args = parser.parse_args()
ended = False
async with connect(args.uri) as websocket:
await websocket.send(
json.dumps({"type": "ready", "player_hash": args.player_hash})
)
await websocket.recv()
counters = None
ended = False
player_info = {}
current_agent_pos = None
interaction_counter = None
counters = None
last_interacting = False
last_interact_progress = None
player_info = {}
current_agent_pos = None
interaction_counter = None
threshold = datetime.max
last_interacting = False
last_interact_progress = None
task_type = None
task_args = None
threshold = datetime.max
started_interaction = False
still_interacting = False
task_type = None
task_args = None
while not ended:
time.sleep(args.step_time)
websocket.send(
json.dumps({"type": "get_state", "player_hash": args.player_hash})
)
state = json.loads(websocket.recv())
started_interaction = False
still_interacting = False
current_nearest_counter_id = None
if counters is None:
counters = defaultdict(list)
for counter in state["counters"]:
counters[counter["type"]].append(counter)
while not ended:
time.sleep(args.step_time)
await websocket.send(
json.dumps({"type": "get_state", "player_hash": args.player_hash})
)
state = json.loads(await websocket.recv())
interaction_counter = None
for player in state["players"]:
if player["id"] == args.player_id:
player_info = player
current_agent_pos = player["pos"]
if player["current_nearest_counter_id"]:
for counter in state["counters"]:
if counter["id"] == player["current_nearest_counter_id"]:
interaction_counter = counter
break
if last_interacting:
if (
not interaction_counter
or not interaction_counter["occupied_by"]
or isinstance(interaction_counter["occupied_by"], list)
or (
interaction_counter["occupied_by"]["progress_percentage"]
== 1.0
)
):
last_interacting = False
last_interact_progress = None
else:
if (
interaction_counter
and interaction_counter["occupied_by"]
and not isinstance(interaction_counter["occupied_by"], list)
):
if counters is None:
counters = defaultdict(list)
for counter in state["counters"]:
counters[counter["type"]].append(counter)
for player in state["players"]:
if player["id"] == args.player_id:
player_info = player
current_agent_pos = player["pos"]
if player["current_nearest_counter_id"]:
if (
current_nearest_counter_id
!= player["current_nearest_counter_id"]
):
for counter in state["counters"]:
if (
counter["id"]
== player["current_nearest_counter_id"]
):
interaction_counter = counter
current_nearest_counter_id = player[
"current_nearest_counter_id"
]
break
if last_interacting:
if (
not interaction_counter
or not interaction_counter["occupied_by"]
or isinstance(interaction_counter["occupied_by"], list)
or (
interaction_counter["occupied_by"][
"progress_percentage"
]
== 1.0
)
):
last_interacting = False
last_interact_progress = None
else:
if (
last_interact_progress
!= interaction_counter["occupied_by"]["progress_percentage"]
interaction_counter
and interaction_counter["occupied_by"]
and not isinstance(interaction_counter["occupied_by"], list)
):
last_interact_progress = interaction_counter["occupied_by"][
"progress_percentage"
]
last_interacting = True
break
if task_type:
if threshold < datetime.now():
print(args.player_hash, args.player_id, "---Threshold---Too long---")
task_type = None
match task_type:
case "GOTO":
diff = np.array(task_args) - np.array(current_agent_pos)
dist = np.linalg.norm(diff)
if dist > 1.2:
if dist != 0:
websocket.send(
if (
last_interact_progress
!= interaction_counter["occupied_by"][
"progress_percentage"
]
):
last_interact_progress = interaction_counter[
"occupied_by"
]["progress_percentage"]
last_interacting = True
break
if task_type:
if threshold < datetime.now():
print(
args.player_hash, args.player_id, "---Threshold---Too long---"
)
task_type = None
match task_type:
case "GOTO":
diff = np.array(task_args) - np.array(current_agent_pos)
dist = np.linalg.norm(diff)
if dist > 1.2:
if dist != 0:
await websocket.send(
json.dumps(
{
"type": "action",
"action": dataclasses.asdict(
Action(
args.player_id,
ActionType.MOVEMENT,
(diff / dist).tolist(),
args.step_time + 0.01,
),
dict_factory=custom_asdict_factory,
),
"player_hash": args.player_hash,
}
)
)
await websocket.recv()
else:
task_type = None
task_args = None
case "INTERACT":
if not started_interaction or (
still_interacting and interaction_counter
):
if not started_interaction:
started_interaction = True
still_interacting = True
await websocket.send(
json.dumps(
{
"type": "action",
"action": dataclasses.asdict(
Action(
args.player_id,
ActionType.MOVEMENT,
(diff / dist).tolist(),
args.step_time + 0.01,
ActionType.INTERACT,
InterActionData.START,
),
dict_factory=custom_asdict_factory,
),
......@@ -126,27 +174,22 @@ try:
}
)
)
websocket.recv()
else:
task_type = None
task_args = None
case "INTERACT":
if not started_interaction or (
still_interacting and interaction_counter
):
if not started_interaction:
started_interaction = True
still_interacting = True
websocket.send(
await websocket.recv()
else:
still_interacting = False
started_interaction = False
task_type = None
task_args = None
case "PUT":
await websocket.send(
json.dumps(
{
"type": "action",
"action": dataclasses.asdict(
Action(
args.player_id,
ActionType.INTERACT,
InterActionData.START,
ActionType.PUT,
"pickup",
),
dict_factory=custom_asdict_factory,
),
......@@ -154,46 +197,25 @@ try:
}
)
)
websocket.recv()
else:
still_interacting = False
started_interaction = False
await websocket.recv()
task_type = None
task_args = None
case "PUT":
websocket.send(
json.dumps(
{
"type": "action",
"action": dataclasses.asdict(
Action(
args.player_id,
ActionType.PUT,
"pickup",
),
dict_factory=custom_asdict_factory,
),
"player_hash": args.player_hash,
}
)
)
websocket.recv()
task_type = None
case None:
...
if not task_type:
task_type = random.choice(["GOTO", "PUT", "INTERACT"])
threshold = datetime.now() + timedelta(seconds=15.0)
if task_type == "GOTO":
counter_type = random.choice(list(counters.keys()))
task_args = random.choice(counters[counter_type])["pos"]
print(args.player_hash, args.player_id, task_type, counter_type)
else:
print(args.player_hash, args.player_id, task_type)
task_args = None
case None:
...
if not task_type:
task_type = random.choice(["GOTO", "PUT", "INTERACT"])
threshold = datetime.now() + timedelta(seconds=15.0)
if task_type == "GOTO":
counter_type = random.choice(list(counters.keys()))
task_args = random.choice(counters[counter_type])["pos"]
print(args.player_hash, args.player_id, task_type, counter_type)
else:
print(args.player_hash, args.player_id, task_type)
task_args = None
ended = state["ended"]
finally:
websocket.close()
ended = state["ended"]
if __name__ == "__main__":
asyncio.run(agent())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment