Skip to content
Snippets Groups Projects
Commit f10136c6 authored by fheinrich's avatar fheinrich
Browse files

Fix reference only usage of graph, needed copy.

parent 7e2b9783
No related branches found
No related tags found
1 merge request!89Resolve "simple pathfinding"
...@@ -22,6 +22,8 @@ from cooperative_cuisine.utils import custom_asdict_factory ...@@ -22,6 +22,8 @@ from cooperative_cuisine.utils import custom_asdict_factory
TIME_TO_STOP_ACTION = 3.0 TIME_TO_STOP_ACTION = 3.0
ADD_RANDOM_MOVEMENTS = True
def get_free_neighbours( def get_free_neighbours(
state: dict, counter_pos: list[float] | tuple[float, float] | npt.NDArray state: dict, counter_pos: list[float] | tuple[float, float] | npt.NDArray
...@@ -49,7 +51,7 @@ async def agent(): ...@@ -49,7 +51,7 @@ async def agent():
args = parser.parse_args() args = parser.parse_args()
async with connect(args.uri) as websocket: async with (connect(args.uri) as websocket):
await websocket.send( await websocket.send(
json.dumps({"type": "ready", "player_hash": args.player_hash}) json.dumps({"type": "ready", "player_hash": args.player_hash})
) )
...@@ -58,6 +60,7 @@ async def agent(): ...@@ -58,6 +60,7 @@ async def agent():
ended = False ended = False
counters = None counters = None
all_counters = None
movement_graph = None movement_graph = None
...@@ -89,10 +92,12 @@ async def agent(): ...@@ -89,10 +92,12 @@ async def agent():
if movement_graph is None: if movement_graph is None:
movement_graph = create_movement_graph(state, diagonal=True) movement_graph = create_movement_graph(state, diagonal=True)
if counters is None: if counters is None:
counters = defaultdict(list) counters = defaultdict(list)
for counter in state["counters"]: for counter in state["counters"]:
counters[counter["type"]].append(counter) counters[counter["type"]].append(counter)
all_counters = state["counters"]
for player in state["players"]: for player in state["players"]:
if player["id"] == args.player_id: if player["id"] == args.player_id:
...@@ -166,7 +171,7 @@ async def agent(): ...@@ -166,7 +171,7 @@ async def agent():
for free in target_free_spaces: for free in target_free_spaces:
try: try:
modified_graph = restrict_movement_graph( modified_graph = restrict_movement_graph(
graph=movement_graph, graph=movement_graph.copy(),
player_positions=[ player_positions=[
p["pos"] p["pos"]
for p in state["players"] for p in state["players"]
...@@ -195,29 +200,32 @@ async def agent(): ...@@ -195,29 +200,32 @@ async def agent():
movement = node_diff / node_dist movement = node_diff / node_dist
else: else:
movement = target_diff / target_dist movement = target_diff / target_dist
do_movement = True
else: else:
movement = np.array([0, 0]) # no paths available
task_type = None # task_type = None
task_args = None # task_args = None
do_movement = False
if target_dist > 1.2: if target_dist > 1.2 and do_movement:
if target_dist != 0: if target_dist != 0:
# random_small_rotation_angle = ( if ADD_RANDOM_MOVEMENTS:
# np.random.random() * np.pi * 0.1 random_small_rotation_angle = (
# ) np.random.random() * np.pi * 0.1
# rotation_matrix = np.array( )
# [ rotation_matrix = np.array(
# [ [
# np.cos(random_small_rotation_angle), [
# -np.sin(random_small_rotation_angle), np.cos(random_small_rotation_angle),
# ], -np.sin(random_small_rotation_angle),
# [ ],
# np.sin(random_small_rotation_angle), [
# np.cos(random_small_rotation_angle), np.sin(random_small_rotation_angle),
# ], np.cos(random_small_rotation_angle),
# ] ],
# ) ]
# movement = rotation_matrix @ movement )
movement = rotation_matrix @ movement
await websocket.send( await websocket.send(
json.dumps( json.dumps(
...@@ -238,6 +246,7 @@ async def agent(): ...@@ -238,6 +246,7 @@ async def agent():
) )
await websocket.recv() await websocket.recv()
else: else:
# Target reached here.
task_type = None task_type = None
task_args = None task_args = None
case "INTERACT": case "INTERACT":
...@@ -299,9 +308,13 @@ async def agent(): ...@@ -299,9 +308,13 @@ async def agent():
# task_type = random.choice(["GOTO", "PUT", "INTERACT"]) # task_type = random.choice(["GOTO", "PUT", "INTERACT"])
threshold = datetime.now() + timedelta(seconds=TIME_TO_STOP_ACTION) threshold = datetime.now() + timedelta(seconds=TIME_TO_STOP_ACTION)
if task_type == "GOTO": if task_type == "GOTO":
counter_type = random.choice(list(counters.keys())) # counter_type = random.choice(list(counters.keys()))
task_args = random.choice(counters[counter_type])["pos"] # task_args = random.choice(counters[counter_type])["pos"]
print(args.player_hash, args.player_id, task_type, counter_type)
random_counter = random.choice(all_counters)
counter_type = random_counter["type"]
task_args = random_counter["pos"]
print(args.player_hash, args.player_id, task_type, counter_type, task_args)
else: else:
print(args.player_hash, args.player_id, task_type) print(args.player_hash, args.player_id, task_type)
task_args = None task_args = None
......
...@@ -193,7 +193,8 @@ class StateRepresentation(BaseModel): ...@@ -193,7 +193,8 @@ class StateRepresentation(BaseModel):
def astar_heuristic(x, y): def astar_heuristic(x, y):
return np.linalg.norm(np.array(list(x)) - np.array((y))) """Heuristic distance function used in astart algorithm."""
return np.linalg.norm(np.array(x) - np.array(y))
def create_movement_graph(state: StateRepresentation, diagonal=True) -> Graph: def create_movement_graph(state: StateRepresentation, diagonal=True) -> Graph:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment