From afc984aeb34b3c8f433ac7ee3e34b96e114e4e08 Mon Sep 17 00:00:00 2001
From: fheinrich <fheinrich@techfak.de>
Date: Fri, 2 Feb 2024 15:03:56 +0100
Subject: [PATCH] Collision optimization

---
 .../overcooked_environment.py                 | 91 +++++++++----------
 1 file changed, 41 insertions(+), 50 deletions(-)

diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py
index 03f79c0a..8fdb399d 100644
--- a/overcooked_simulator/overcooked_environment.py
+++ b/overcooked_simulator/overcooked_environment.py
@@ -14,6 +14,7 @@ from typing import Literal
 import numpy as np
 import numpy.typing as npt
 import yaml
+from scipy.spatial import distance_matrix
 
 from overcooked_simulator import utils
 from overcooked_simulator.counter_factory import CounterFactory
@@ -165,13 +166,15 @@ class Environment:
 
         self.counter_positions = np.array([c.pos for c in self.counters])
 
-        # self.world_borders_x = [-0.5, self.kitchen_width - 0.5]
-        # self.world_borders_y = [-0.5, self.kitchen_height - 0.5]
         self.world_borders = np.array(
             [[-0.5, self.kitchen_width - 0.5], [-0.5, self.kitchen_height - 0.5]],
             dtype=float,
         )
 
+
+        self.player_movement_speed = self.environment_config["player_config"]["player_speed_units_per_seconds"]
+        self.player_radius = self.environment_config["player_config"]["radius"]
+
         progress_counter_classes = list(
             filter(
                 lambda cl: hasattr(cl, "progress"),
@@ -207,6 +210,11 @@ class Environment:
         """Whether the game is over or not based on the calculated `Environment.env_time_end`"""
         return self.env_time >= self.env_time_end
 
+    def set_collision_arrays(self):
+        number_players = len(self.players)
+        self.world_borders_lower = self.world_borders[np.newaxis, :, 0].repeat(number_players, axis=0)
+        self.world_borders_upper = self.world_borders[np.newaxis, :, 1].repeat(number_players, axis=0)
+
     def get_env_time(self):
         """the internal time of the environment. An environment starts always with the time from `create_init_env_time`.
 
@@ -377,7 +385,7 @@ class Environment:
         facing_counter = get_closest(player.facing_point, self.counters)
         return facing_counter
 
-    def perform_movement(self, player: Player, duration: timedelta):
+    def perform_movement(self, duration: timedelta):
         """Moves a player in the direction specified in the action.action. If the player collides with a
         counter or other player through this movement, then they are not moved.
         (The extended code with the two ifs is for sliding movement at the counters, which feels a bit smoother.
@@ -392,61 +400,44 @@ class Environment:
             player: The player to move.
             duration: The duration for how long the movement to perform.
         """
-        old_pos = player.pos.copy()
-
-        move_vector = player.current_movement
-
         d_time = duration.total_seconds()
-        step = move_vector * (player.player_speed_units_per_seconds * d_time)
 
-        player.move(step)
+        player_positions = np.array([p.pos for p in self.players.values()], dtype=float)
+        player_movement_vectors = np.array([p.current_movement if self.env_time <= p.movement_until else [0, 0] for p in self.players.values()], dtype=float)
 
-        world_collision = self.detect_collision_world_bounds(
-            player
-        ) or self.detect_collision_counters(player)
-        collided_players = self.get_collided_players(player)
+        new_positions = player_positions + (player_movement_vectors * (self.player_movement_speed * d_time))
 
-        if world_collision or len(collided_players) > 0:
-            for collided_player in collided_players:
-                pushing_vector = collided_player.pos - player.pos
-                if np.linalg.norm(pushing_vector) != 0:
-                    pushing_vector = pushing_vector / np.linalg.norm(pushing_vector)
+        # Collisions player player
+        distances_players_after_scipy = distance_matrix(new_positions, new_positions)
 
-                old_pos_other = collided_player.pos.copy()
-                collided_player.current_movement = pushing_vector
-                self.perform_movement(collided_player, duration)
-                if self.detect_collision_counters(
-                    collided_player
-                ) or self.detect_collision_world_bounds(collided_player):
-                    collided_player.move_abs(old_pos_other)
+        player_diff_vecs = -(player_positions[:, np.newaxis, :] - player_positions[np.newaxis, :, :])
 
-            player.move_abs(old_pos)
+        collision_idxs = distances_players_after_scipy < (2*self.player_radius)
+        eye_idxs = np.eye(distances_players_after_scipy.shape[0], distances_players_after_scipy.shape[1], dtype=bool)
+        collision_idxs[eye_idxs] = False
+        # collisions_any = np.any(collision_idxs, axis=1)
 
-            old_pos = player.pos.copy()
+        # Player push players around
+        player_diff_vecs[collision_idxs==False] = 0
+        push_vectors = np.sum(player_diff_vecs, axis=0)
+        # new_positions[collisions_any] = player_positions[collisions_any]
+        new_positions += push_vectors * (self.player_movement_speed * d_time)
 
-            step_sliding = step.copy()
-            step_sliding[0] = 0
-            player.move(step_sliding * 0.5)
-            player.turn(step)
+        # Collisions player world borders
+        new_positions = np.max([new_positions, self.world_borders_lower+self.player_radius], axis=0)
+        new_positions = np.min([new_positions, self.world_borders_upper-self.player_radius], axis=0)
 
-            if self.detect_collision(player):
-                player.move_abs(old_pos)
+        # Collisions players counters
+        counter_diff_vecs = (new_positions[:, np.newaxis, :] - self.counter_positions[np.newaxis, :, :])
+        counter_distances = np.max((np.abs(counter_diff_vecs)), axis=2)
+        min_counter_distances = np.min(counter_distances, axis=1) - 0.5
+        new_positions[min_counter_distances < self.player_radius] = player_positions[min_counter_distances < self.player_radius]
 
-                old_pos = player.pos.copy()
+        counter_distances_axes = np.max((np.abs(counter_diff_vecs)), axis=1)
 
-                step_sliding = step.copy()
-                step_sliding[1] = 0
-                player.move(step_sliding * 0.5)
-                player.turn(step)
-
-                if self.detect_collision(player):
-                    player.move_abs(old_pos)
-
-        if self.counters:
-            closest_counter = self.get_facing_counter(player)
-            player.current_nearest_counter = (
-                closest_counter if player.can_reach(closest_counter) else None
-            )
+        for idx, p in enumerate(self.players.values()):
+            p.turn(player_movement_vectors[idx])
+            p.move_abs(new_positions[idx])
 
     def detect_collision(self, player: Player):
         """Detect collisions between the player and other players or counters.
@@ -556,6 +547,8 @@ class Environment:
                 log.debug("No free positions left in kitchens")
             player.update_facing_point()
 
+        self.set_collision_arrays()
+
     def step(self, passed_time: timedelta):
         """Performs a step of the environment. Affects time based events such as cooking or cutting things, orders
         and time limits.
@@ -563,9 +556,7 @@ class Environment:
         self.env_time += passed_time
 
         if not self.game_ended:
-            for player in self.players.values():
-                if self.env_time <= player.movement_until:
-                    self.perform_movement(player, passed_time)
+            self.perform_movement(passed_time)
 
             for counter in self.progressing_counters:
                 counter.progress(passed_time=passed_time, now=self.env_time)
-- 
GitLab