diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index b42f2412f81550f45d0883e553e2145d0e415e85..03f79c0ace56d6d34c0894ac7638f123c9567f0b 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -14,7 +14,6 @@ from typing import Literal import numpy as np import numpy.typing as npt import yaml -from scipy.spatial import distance_matrix from overcooked_simulator import utils from overcooked_simulator.counter_factory import CounterFactory @@ -166,8 +165,12 @@ class Environment: self.counter_positions = np.array([c.pos for c in self.counters]) - self.world_borders_x = [-0.5, self.kitchen_width - 0.5] - self.world_borders_y = [-0.5, self.kitchen_height - 0.5] + # self.world_borders_x = [-0.5, self.kitchen_width - 0.5] + # self.world_borders_y = [-0.5, self.kitchen_height - 0.5] + self.world_borders = np.array( + [[-0.5, self.kitchen_width - 0.5], [-0.5, self.kitchen_height - 0.5]], + dtype=float, + ) progress_counter_classes = list( filter( @@ -397,8 +400,13 @@ class Environment: step = move_vector * (player.player_speed_units_per_seconds * d_time) player.move(step) - if self.detect_collision(player): - collided_players = self.get_collided_players(player) + + world_collision = self.detect_collision_world_bounds( + player + ) or self.detect_collision_counters(player) + collided_players = self.get_collided_players(player) + + if world_collision or len(collided_players) > 0: for collided_player in collided_players: pushing_vector = collided_player.pos - player.pos if np.linalg.norm(pushing_vector) != 0: @@ -411,6 +419,7 @@ class Environment: collided_player ) or self.detect_collision_world_bounds(collided_player): collided_player.move_abs(old_pos_other) + player.move_abs(old_pos) old_pos = player.pos.copy() @@ -448,7 +457,7 @@ class Environment: Returns: True if the player is intersecting with any object in the environment. """ return ( - len(self.get_collided_players(player)) != 0 + len(self.get_collided_players(player)) > 0 or self.detect_collision_counters(player) or self.detect_collision_world_bounds(player) ) @@ -486,13 +495,6 @@ class Environment: return any(self.get_collided_players(player)) - # other_players = filter(lambda p: p.name != player.name, self.players.values()) - # - # def collide(p): - # return np.linalg.norm(player.pos - p.pos) <= (player.radius + p.radius) - # - # return any(map(collide, other_players)) - def detect_collision_counters(self, player: Player): """Checks for collisions of the queried player with each counter. @@ -502,7 +504,23 @@ class Environment: Returns: True if the player collides with any counter, False if not. """ - return np.any(np.max((np.abs(self.counter_positions - player.pos)-0.5), axis=1) < player.radius) + return np.any( + np.max((np.abs(self.counter_positions - player.pos) - 0.5), axis=1) + < player.radius + ) + + def detect_collision_world_bounds(self, player: Player): + """Checks for detections of the player and the world bounds. + + Args: + player: The player which to not let escape the world. + + Returns: True if the player touches the world bounds, False if not. + """ + + return np.any(player.pos - player.radius < self.world_borders[:, 0]) or np.any( + player.pos + player.radius > self.world_borders[:, 1] + ) def add_player(self, player_name: str, pos: npt.NDArray = None): """Add a player to the environment. @@ -538,24 +556,6 @@ class Environment: log.debug("No free positions left in kitchens") player.update_facing_point() - def detect_collision_world_bounds(self, player: Player): - """Checks for detections of the player and the world bounds. - - Args: - player: The player which to not let escape the world. - - Returns: True if the player touches the world bounds, False if not. - """ - collisions_lower = any( - (player.pos - (player.radius)) - < [self.world_borders_x[0], self.world_borders_y[0]] - ) - collisions_upper = any( - (player.pos + (player.radius)) - > [self.world_borders_x[1], self.world_borders_y[1]] - ) - return collisions_lower or collisions_upper - def step(self, passed_time: timedelta): """Performs a step of the environment. Affects time based events such as cooking or cutting things, orders and time limits.