diff --git a/cooperative_cuisine/configs/agents/random_agent.py b/cooperative_cuisine/configs/agents/random_agent.py index 99f8a2384625318f6293b4ab21a0c706615b45a0..e1db06bfb20f925815b81016a2afb49a8a0377e6 100644 --- a/cooperative_cuisine/configs/agents/random_agent.py +++ b/cooperative_cuisine/configs/agents/random_agent.py @@ -175,7 +175,7 @@ async def agent(): try: path = networkx.astar_path( restrict_movement_graph( - graph=movement_graph.copy(), + graph=movement_graph, player_positions=[ p["pos"] for p in state["players"] @@ -207,6 +207,8 @@ async def agent(): do_movement = True else: # no paths available + print("NO PATHS") + # task_type = None # task_args = None do_movement = False @@ -251,6 +253,7 @@ async def agent(): await websocket.recv() else: # Target reached here. + print("TARGET REACHED") task_type = None task_args = None case "INTERACT": @@ -307,7 +310,8 @@ async def agent(): ... if not task_type: - task_type = random.choice(["GOTO", "PUT", "INTERACT"]) + # task_type = random.choice(["GOTO", "PUT", "INTERACT"]) + task_type = random.choice(["GOTO"]) threshold = datetime.now() + timedelta(seconds=TIME_TO_STOP_ACTION) if task_type == "GOTO": # counter_type = random.choice(list(counters.keys())) diff --git a/cooperative_cuisine/state_representation.py b/cooperative_cuisine/state_representation.py index 7887d0f8a7d542bded6fb91ddaf1c7632843c8ee..209bf7d98ad3157f093071b23a479fd89563a637 100644 --- a/cooperative_cuisine/state_representation.py +++ b/cooperative_cuisine/state_representation.py @@ -222,7 +222,7 @@ def create_movement_graph(state: StateRepresentation, diagonal=True) -> Graph: for dj in range(-1, 2): x, y = i + di, j + dj if ( - 0 < x < width + 0 <= x < width and 0 < y < height and free_space[x, y] and (di, dj) != (0, 0) @@ -244,7 +244,7 @@ def create_movement_graph(state: StateRepresentation, diagonal=True) -> Graph: ) else: for x, y in [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]: - if 0 < x < width and 0 < y < height and free_space[x, y]: + if 0 <= x < width and 0 <= y < height and free_space[x, y]: graph.add_edge( (i, j), (x, y), @@ -266,11 +266,12 @@ def restrict_movement_graph( Returns: The modified graph without nodes where players stand. """ + copied = graph.copy() for pos in player_positions: tup = tuple(np.array(pos).round().astype(int)) - if tup in graph.nodes.keys(): - graph.remove_node(tup) - return graph + if tup in copied.nodes.keys(): + copied.remove_node(tup) + return copied def create_json_schema() -> dict[str, Any]: