From 5987536f51978558f17d89f021d2e6d8833c8f1f Mon Sep 17 00:00:00 2001
From: fheinrich <fheinrich@techfak.uni-bielefeld.de>
Date: Tue, 6 Feb 2024 17:43:27 +0100
Subject: [PATCH] Flag for disabling squeezing into other players.

---
 .../overcooked_environment.py                 | 19 +++++++++++++------
 1 file changed, 13 insertions(+), 6 deletions(-)

diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py
index 5860fd0b..1132b52b 100644
--- a/overcooked_simulator/overcooked_environment.py
+++ b/overcooked_simulator/overcooked_environment.py
@@ -33,6 +33,9 @@ from overcooked_simulator.utils import create_init_env_time, get_closest
 log = logging.getLogger(__name__)
 
 
+PREVENT_SQUEEZING_INTO_OTHER_PLAYERS = False
+
+
 class ActionType(Enum):
     """The 3 different types of valid actions. They can be extended via the `Action.action_data` attribute."""
 
@@ -485,12 +488,16 @@ class Environment:
                 new_positions[collision_idxs[idx]] = player_positions[
                     collision_idxs[idx]
                 ]
-        # # Check if two moving players collide into each other: No movement (Future: slide?)
-        # distances_players_after_scipy = distance_matrix(new_positions, new_positions)
-        # collision_idxs = distances_players_after_scipy < (2 * self.player_radius)
-        # collision_idxs[eye_idxs] = False
-        # collision_idxs = np.any(collision_idxs, axis=1)
-        # new_positions[collision_idxs] = player_positions[collision_idxs]
+
+        # Check if two moving players collide into each other: No movement (Future: slide?)
+        if PREVENT_SQUEEZING_INTO_OTHER_PLAYERS:
+            distances_players_after_scipy = distance_matrix(
+                new_positions, new_positions
+            )
+            collision_idxs = distances_players_after_scipy < (2 * self.player_radius)
+            collision_idxs[eye_idxs] = False
+            collision_idxs = np.any(collision_idxs, axis=1)
+            new_positions[collision_idxs] = player_positions[collision_idxs]
 
         # Collisions player world borders
         new_positions = np.clip(
-- 
GitLab