diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py
index 84d6312821e9cb74b1e488d850e2e13f73429c99..b42f2412f81550f45d0883e553e2145d0e415e85 100644
--- a/overcooked_simulator/overcooked_environment.py
+++ b/overcooked_simulator/overcooked_environment.py
@@ -14,7 +14,9 @@ from typing import Literal
 import numpy as np
 import numpy.typing as npt
 import yaml
+from scipy.spatial import distance_matrix
 
+from overcooked_simulator import utils
 from overcooked_simulator.counter_factory import CounterFactory
 from overcooked_simulator.counters import (
     Counter,
@@ -162,6 +164,8 @@ class Environment:
             self.free_positions,
         ) = self.parse_layout_file()
 
+        self.counter_positions = np.array([c.pos for c in self.counters])
+
         self.world_borders_x = [-0.5, self.kitchen_width - 0.5]
         self.world_borders_y = [-0.5, self.kitchen_height - 0.5]
 
@@ -460,12 +464,13 @@ class Environment:
         Returns: The list of other players the player collides with.
 
         """
-        other_players = filter(lambda p: p.name != player.name, self.players.values())
 
-        def collide(p):
-            return np.linalg.norm(player.pos - p.pos) <= player.radius + p.radius
+        players_list = list(self.players.values())
 
-        return list(filter(collide, other_players))
+        if player in players_list:
+            player_idx = players_list.index(player)
+            return utils.get_collided_players(player_idx, list(self.players.values()))
+        return []
 
     def detect_player_collision(self, player: Player):
         """Detects collisions between the queried player and other players.
@@ -478,12 +483,15 @@ class Environment:
         Returns: True if the player collides with other players, False if not.
 
         """
-        other_players = filter(lambda p: p.name != player.name, self.players.values())
 
-        def collide(p):
-            return np.linalg.norm(player.pos - p.pos) <= (player.radius + p.radius)
+        return any(self.get_collided_players(player))
 
-        return any(map(collide, other_players))
+        # other_players = filter(lambda p: p.name != player.name, self.players.values())
+        #
+        # def collide(p):
+        #     return np.linalg.norm(player.pos - p.pos) <= (player.radius + p.radius)
+        #
+        # return any(map(collide, other_players))
 
     def detect_collision_counters(self, player: Player):
         """Checks for collisions of the queried player with each counter.
@@ -494,33 +502,7 @@ class Environment:
         Returns: True if the player collides with any counter, False if not.
 
         """
-        return any(
-            map(
-                lambda counter: self.detect_collision_player_counter(player, counter),
-                self.counters,
-            )
-        )
-
-    @staticmethod
-    def detect_collision_player_counter(player: Player, counter: Counter):
-        """Checks if the player and counter collide (overlap).
-        A counter is modelled as a rectangle (square actually), a player is modelled as a circle.
-        The distance of the player position (circle center) and the counter rectangle is calculated, if it is
-        smaller than the player radius, a collision is detected.
-
-        Args:
-            player: The player to check the collision for.
-            counter: The counter to check the collision for.
-
-        Returns: True if player and counter overlap, False if not.
-
-        """
-        cx, cy = player.pos
-        dx = max(np.abs(cx - counter.pos[0]) - 1 / 2, 0)
-        dy = max(np.abs(cy - counter.pos[1]) - 1 / 2, 0)
-        distance = np.linalg.norm([dx, dy])
-        # TODO: Efficiency improvement by checking only nearest counters? Quadtree...?
-        return distance < player.radius
+        return np.any(np.max((np.abs(self.counter_positions - player.pos)-0.5), axis=1) < player.radius)
 
     def add_player(self, player_name: str, pos: npt.NDArray = None):
         """Add a player to the environment.
diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py
index edf11d72e8b4a5cbaa8672d9da727d6e4d3c312f..587feb095ae96e861a6bb8f9c7d00dd43367d606 100644
--- a/overcooked_simulator/utils.py
+++ b/overcooked_simulator/utils.py
@@ -11,6 +11,7 @@ from scipy.spatial import distance_matrix
 
 from overcooked_simulator import ROOT_DIR
 from overcooked_simulator.counters import Counter
+from overcooked_simulator.player import Player
 
 
 def create_init_env_time():
@@ -34,6 +35,15 @@ def get_closest(point: npt.NDArray[float], counters: list[Counter]):
         np.argmin(distance_matrix([point], [counter.pos for counter in counters])[0])
     ]
 
+def get_collided_players(player_idx, players: list[Player]):
+
+    player_positions = np.array([p.pos for p in players], dtype=float)
+    distances = distance_matrix(player_positions, player_positions)[player_idx]
+    player_radiuses = np.array([p.radius for p in players], dtype=float)
+    collisions = distances <= player_radiuses + players[player_idx].radius
+    collisions[player_idx] = False
+
+    return [players[idx] for idx, val in enumerate(collisions) if val]
 
 def custom_asdict_factory(data):
     def convert_value(obj):