diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index c38ad3871c7793e39b9edc5307d011e7f0672b12..624121caeef30e0426c1546b290a08ae86eba162 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -239,8 +239,6 @@ class Environment: ) = self.parse_layout_file() self.hook(LAYOUT_FILE_PARSED) - self.counter_positions = np.array([c.pos for c in self.counters]) - self.world_borders = np.array( [[-0.5, self.kitchen_width - 0.5], [-0.5, self.kitchen_height - 0.5]], dtype=float, @@ -250,6 +248,9 @@ class Environment: "player_speed_units_per_seconds" ] self.player_radius = self.environment_config["player_config"]["radius"] + self.player_interaction_range = self.environment_config["player_config"][ + "interaction_range" + ] progress_counter_classes = list( filter( @@ -269,6 +270,8 @@ class Environment: ) """Counters that needs to be called in the step function via the `progress` method.""" + self.counter_positions = np.array([c.pos for c in self.counters]) + self.order_and_score.create_init_orders(self.env_time) self.start_time = self.env_time """The relative env time when it started.""" @@ -290,6 +293,27 @@ class Environment: env_start_time_worldtime=datetime.now(), ) + def overwrite_counters(self, counters): + self.counters = counters + self.counter_positions = np.array([c.pos for c in self.counters]) + + progress_counter_classes = list( + filter( + lambda cl: hasattr(cl, "progress"), + dict( + inspect.getmembers( + sys.modules["overcooked_simulator.counters"], inspect.isclass + ) + ).values(), + ) + ) + self.progressing_counters = list( + filter( + lambda c: c.__class__ in progress_counter_classes, + self.counters, + ) + ) + @property def game_ended(self) -> bool: """Whether the game is over or not based on the calculated `Environment.env_time_end`""" @@ -666,9 +690,21 @@ class Environment: for idx, p in enumerate(self.players.values()): if not (new_positions[idx] == player_positions[idx]).all(): - p.move_abs(new_positions[idx]) + p.pos = new_positions[idx] + p.perform_interact_stop() + p.turn(player_movement_vectors[idx]) + facing_distances = np.linalg.norm( + p.facing_point - self.counter_positions, axis=1 + ) + closest_counter = self.counters[facing_distances.argmin()] + p.current_nearest_counter = ( + closest_counter + if facing_distances.min() <= self.player_interaction_range + else None + ) + def add_player(self, player_name: str, pos: npt.NDArray = None): """Add a player to the environment. diff --git a/overcooked_simulator/player.py b/overcooked_simulator/player.py index 82c2c8cdcb570fe37a609e50715f63387a60eed3..638bc4cd5c0d537aa5823af96543a008b43c3300 100644 --- a/overcooked_simulator/player.py +++ b/overcooked_simulator/player.py @@ -91,19 +91,7 @@ class Player: function of the environment""" self.current_movement = move_vector self.movement_until = move_until - - def move(self, movement: npt.NDArray[float]): - """Moves the player position by the given movement vector. - A unit direction vector multiplied by move_dist is added to the player position. - - Args: - movement: 2D-Vector of length 1 - """ - if self.interacting and np.any(movement): - self.perform_interact_stop() - self.pos += movement - if np.linalg.norm(movement) != 0: - self.turn(movement) + self.perform_interact_stop() def move_abs(self, new_pos: npt.NDArray[float]): """Overwrites the player location by the new_pos 2d-vector. Absolute movement. diff --git a/tests/test_start.py b/tests/test_start.py index 8efe8da9ecf1c67974f88a18b1bcdb0b1cd3786a..8531f9d81dc52cb4327c98e641406ad2d45bdec2 100644 --- a/tests/test_start.py +++ b/tests/test_start.py @@ -124,7 +124,7 @@ def test_player_reach(env_config, layout_empty_config, item_info): counter_pos = np.array([2, 2]) counter = Counter(pos=counter_pos, hook=Hooks(env)) - env.counters = [counter] + env.overwrite_counters([counter]) env.add_player("1", np.array([2, 4])) env.player_movement_speed = 1 player = env.players["1"] @@ -144,7 +144,7 @@ def test_pickup(env_config, layout_config, item_info): counter_pos = np.array([2, 2]) counter = Counter(pos=counter_pos, hook=Hooks(env)) counter.occupied_by = Item(name="Tomato", item_info=None) - env.counters = [counter] + env.overwrite_counters([counter]) env.add_player("1", np.array([2, 3])) player = env.players["1"]