From 1ced66ccefe7105581cae9bc0485ec523db7f963 Mon Sep 17 00:00:00 2001
From: fheinrich <fheinrich@techfak.uni-bielefeld.de>
Date: Mon, 18 Mar 2024 11:10:43 +0100
Subject: [PATCH] Adjusted copying of graph in restriction

---
 cooperative_cuisine/configs/agents/random_agent.py |  8 ++++++--
 cooperative_cuisine/state_representation.py        | 11 ++++++-----
 2 files changed, 12 insertions(+), 7 deletions(-)

diff --git a/cooperative_cuisine/configs/agents/random_agent.py b/cooperative_cuisine/configs/agents/random_agent.py
index 99f8a238..e1db06bf 100644
--- a/cooperative_cuisine/configs/agents/random_agent.py
+++ b/cooperative_cuisine/configs/agents/random_agent.py
@@ -175,7 +175,7 @@ async def agent():
                             try:
                                 path = networkx.astar_path(
                                     restrict_movement_graph(
-                                        graph=movement_graph.copy(),
+                                        graph=movement_graph,
                                         player_positions=[
                                             p["pos"]
                                             for p in state["players"]
@@ -207,6 +207,8 @@ async def agent():
                             do_movement = True
                         else:
                             # no paths available
+                            print("NO PATHS")
+
                             # task_type = None
                             # task_args = None
                             do_movement = False
@@ -251,6 +253,7 @@ async def agent():
                                 await websocket.recv()
                         else:
                             # Target reached here.
+                            print("TARGET REACHED")
                             task_type = None
                             task_args = None
                     case "INTERACT":
@@ -307,7 +310,8 @@ async def agent():
                         ...
 
             if not task_type:
-                task_type = random.choice(["GOTO", "PUT", "INTERACT"])
+                # task_type = random.choice(["GOTO", "PUT", "INTERACT"])
+                task_type = random.choice(["GOTO"])
                 threshold = datetime.now() + timedelta(seconds=TIME_TO_STOP_ACTION)
                 if task_type == "GOTO":
                     # counter_type = random.choice(list(counters.keys()))
diff --git a/cooperative_cuisine/state_representation.py b/cooperative_cuisine/state_representation.py
index 7887d0f8..209bf7d9 100644
--- a/cooperative_cuisine/state_representation.py
+++ b/cooperative_cuisine/state_representation.py
@@ -222,7 +222,7 @@ def create_movement_graph(state: StateRepresentation, diagonal=True) -> Graph:
                         for dj in range(-1, 2):
                             x, y = i + di, j + dj
                             if (
-                                0 < x < width
+                                0 <= x < width
                                 and 0 < y < height
                                 and free_space[x, y]
                                 and (di, dj) != (0, 0)
@@ -244,7 +244,7 @@ def create_movement_graph(state: StateRepresentation, diagonal=True) -> Graph:
                                     )
                 else:
                     for x, y in [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]:
-                        if 0 < x < width and 0 < y < height and free_space[x, y]:
+                        if 0 <= x < width and 0 <= y < height and free_space[x, y]:
                             graph.add_edge(
                                 (i, j),
                                 (x, y),
@@ -266,11 +266,12 @@ def restrict_movement_graph(
     Returns: The modified graph without nodes where players stand.
 
     """
+    copied = graph.copy()
     for pos in player_positions:
         tup = tuple(np.array(pos).round().astype(int))
-        if tup in graph.nodes.keys():
-            graph.remove_node(tup)
-    return graph
+        if tup in copied.nodes.keys():
+            copied.remove_node(tup)
+    return copied
 
 
 def create_json_schema() -> dict[str, Any]:
-- 
GitLab