diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index 84d6312821e9cb74b1e488d850e2e13f73429c99..b42f2412f81550f45d0883e553e2145d0e415e85 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -14,7 +14,9 @@ from typing import Literal import numpy as np import numpy.typing as npt import yaml +from scipy.spatial import distance_matrix +from overcooked_simulator import utils from overcooked_simulator.counter_factory import CounterFactory from overcooked_simulator.counters import ( Counter, @@ -162,6 +164,8 @@ class Environment: self.free_positions, ) = self.parse_layout_file() + self.counter_positions = np.array([c.pos for c in self.counters]) + self.world_borders_x = [-0.5, self.kitchen_width - 0.5] self.world_borders_y = [-0.5, self.kitchen_height - 0.5] @@ -460,12 +464,13 @@ class Environment: Returns: The list of other players the player collides with. """ - other_players = filter(lambda p: p.name != player.name, self.players.values()) - def collide(p): - return np.linalg.norm(player.pos - p.pos) <= player.radius + p.radius + players_list = list(self.players.values()) - return list(filter(collide, other_players)) + if player in players_list: + player_idx = players_list.index(player) + return utils.get_collided_players(player_idx, list(self.players.values())) + return [] def detect_player_collision(self, player: Player): """Detects collisions between the queried player and other players. @@ -478,12 +483,15 @@ class Environment: Returns: True if the player collides with other players, False if not. """ - other_players = filter(lambda p: p.name != player.name, self.players.values()) - def collide(p): - return np.linalg.norm(player.pos - p.pos) <= (player.radius + p.radius) + return any(self.get_collided_players(player)) - return any(map(collide, other_players)) + # other_players = filter(lambda p: p.name != player.name, self.players.values()) + # + # def collide(p): + # return np.linalg.norm(player.pos - p.pos) <= (player.radius + p.radius) + # + # return any(map(collide, other_players)) def detect_collision_counters(self, player: Player): """Checks for collisions of the queried player with each counter. @@ -494,33 +502,7 @@ class Environment: Returns: True if the player collides with any counter, False if not. """ - return any( - map( - lambda counter: self.detect_collision_player_counter(player, counter), - self.counters, - ) - ) - - @staticmethod - def detect_collision_player_counter(player: Player, counter: Counter): - """Checks if the player and counter collide (overlap). - A counter is modelled as a rectangle (square actually), a player is modelled as a circle. - The distance of the player position (circle center) and the counter rectangle is calculated, if it is - smaller than the player radius, a collision is detected. - - Args: - player: The player to check the collision for. - counter: The counter to check the collision for. - - Returns: True if player and counter overlap, False if not. - - """ - cx, cy = player.pos - dx = max(np.abs(cx - counter.pos[0]) - 1 / 2, 0) - dy = max(np.abs(cy - counter.pos[1]) - 1 / 2, 0) - distance = np.linalg.norm([dx, dy]) - # TODO: Efficiency improvement by checking only nearest counters? Quadtree...? - return distance < player.radius + return np.any(np.max((np.abs(self.counter_positions - player.pos)-0.5), axis=1) < player.radius) def add_player(self, player_name: str, pos: npt.NDArray = None): """Add a player to the environment. diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py index edf11d72e8b4a5cbaa8672d9da727d6e4d3c312f..587feb095ae96e861a6bb8f9c7d00dd43367d606 100644 --- a/overcooked_simulator/utils.py +++ b/overcooked_simulator/utils.py @@ -11,6 +11,7 @@ from scipy.spatial import distance_matrix from overcooked_simulator import ROOT_DIR from overcooked_simulator.counters import Counter +from overcooked_simulator.player import Player def create_init_env_time(): @@ -34,6 +35,15 @@ def get_closest(point: npt.NDArray[float], counters: list[Counter]): np.argmin(distance_matrix([point], [counter.pos for counter in counters])[0]) ] +def get_collided_players(player_idx, players: list[Player]): + + player_positions = np.array([p.pos for p in players], dtype=float) + distances = distance_matrix(player_positions, player_positions)[player_idx] + player_radiuses = np.array([p.radius for p in players], dtype=float) + collisions = distances <= player_radiuses + players[player_idx].radius + collisions[player_idx] = False + + return [players[idx] for idx, val in enumerate(collisions) if val] def custom_asdict_factory(data): def convert_value(obj):