Skip to content
Snippets Groups Projects
overcooked_environment.py 9.98 KiB
from __future__ import annotations
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from overcooked_simulator.player import Player
from pathlib import Path
import numpy as np
from scipy.spatial import distance_matrix
from overcooked_simulator.counters import Counter


class Action:
    """Action class, specifies player, action type and action itself.
    """
    def __init__(self, player, act_type, action):
        self.player = player
        self.act_type = act_type
        assert self.act_type in ["movement", "pickup", "interact"], "Unknown action type"
        self.action = action


class Environment:
    """Environment class which handles the game logic for the overcooked-inspired environment.

    Handles player movement, collision-detection, counters, cooking processes, recipes, incoming orders, time.
    # TODO Abstract base class for different environments
    """

    def __init__(self, layout_path):
        self.players: dict[str, Player] = {}
        self.counter_side_length: float = 40
        self.layout_path: Path = layout_path
        self.counters: list[Counter] = self.create_counters(self.layout_path)
        self.score: int = 0
        self.world_width: int = 800
        self.world_height: int = 600

    def create_counters(self, layout_file: Path):
        """Creates layout of kitchen counters in the environment based on layout file.
        Counters are arranged in a fixed size grid starting at [0,0]. The center of the first counter is at
        [counter_size/2, counter_size/2], counters are directly next to each other (of no empty space is specified
        in layout).

        Args:
            layout_file: Path to the layout file.
        """
        current_y = self.counter_side_length / 2
        counters = []

        with open(layout_file, "r") as layout_file:
            lines = layout_file.readlines()
        for line in lines:
            line = line.replace("\n", "").replace(" ", "")  # remove newline char
            current_x = self.counter_side_length / 2
            for character in line:
                character = character.capitalize()
                if character == "C":
                    counter = Counter(np.array([current_x, current_y]))
                    counters.append(counter)
                    current_x += self.counter_side_length
                elif character == "E":
                    current_x += self.counter_side_length
            current_y += self.counter_side_length
        return counters

    def perform_action(self, action: Action):
        """Performs an action of a player in the environment. Maps different types of action inputs to the
        correct execution of the players.
        Possible action types are movement, pickup and interact actions.

        Args:
            action: The action to be performed
        """
        assert action.player in self.players.keys(), "Unknown player."

        player = self.players[action.player]
        if action.act_type == "movement":
            self.perform_movement(player, action.action)
        elif action.act_type == "pickup":
            self.perform_pickup(player)
        elif action.act_type == "interact":
            self.perform_interact(player)

    def get_closest_counter(self, point: np.array):
        """Determines the closest counter for a given 2d-coordinate point in the env.

        Args:
            point: The point in the env for which to find the closest counter

        Returns: The closest counter for the given point.
        """
        counter_distances = distance_matrix([point], [counter.pos for counter in self.counters])[0]
        closest_counter_idx = np.argmin(counter_distances)
        return self.counters[closest_counter_idx]

    def get_facing_counter(self, player: Player):
        """Determines the counter which the player is looking at.
        Adds a multiple of the player facing direction onto the player position and finds the closest
        counter for that point.

        Args:
            player: The player for which to find the facing counter.

        Returns:

        """
        facing_point = player.pos + (player.facing_direction * player.interaction_range)
        facing_counter = self.get_closest_counter(facing_point)
        return facing_counter

    def perform_pickup(self, player: Player):
        """Performs the game action corresponding to picking up an item

        Args:
            player: The player which performs the pickup action.

        Returns: TODO?

        """
        pass

    def perform_interact(self, player: Player):
        """Performs the game action corresponding to interacting with a counter or other object.

        Args:
            player: The player which performs the interaction.

        Returns: TODO?
        """
        pass

    def perform_movement(self, player: Player, move_vector: np.array):
        """Moves a player in the direction specified in the action.action. If the player collides with a
        counter or other player through this movement, then they are not moved.
        (The extended code with the two ifs is for sliding movement at the counters, which feels a bit smoother.
        This happens, when the player moves diagonally against the counters or world boundary.
        This just checks if the single axis party of the movement could move the player and does so at a lower rate.)

        The movement action is a unit 2d vector.

        Args:
            player: The player to move.
            move_vector: The movement vector which is a unit-2d-vector of the movement direction
        """
        old_pos = player.pos.copy()

        step = move_vector * player.move_dist
        player.move(step)
        if self.detect_collision(player):
            player.move_abs(old_pos)

            old_pos = player.pos.copy()

            step_sliding = step.copy()
            step_sliding[0] = 0
            player.move(step_sliding * 0.5)
            player.turn(step)

            if self.detect_collision(player):
                player.move_abs(old_pos)

                old_pos = player.pos.copy()

                step_sliding = step.copy()
                step_sliding[1] = 0
                player.move(step_sliding * 0.5)
                player.turn(step)

                if self.detect_collision(player):
                    player.move_abs(old_pos)

    def detect_collision(self, player: Player):
        """Detect collisions between the player and other players or counters.

        Args:
            player: The player for which to check collisions.

        Returns: True if the player is intersecting with any object in the environment.
        """
        return (self.detect_player_collision(player) or self.detect_collision_counters(player) or
                self.detect_collision_world_bounds(player))

    def detect_player_collision(self, player: Player):
        """Detects collisions between the queried player and other players.
        A player is modelled as a circle. Collision is detected if the distance between the players is smaller
        than the sum of the radius's.

        Args:
            player: The player to check collisions with other players for.

        Returns: True if the player collides with other players, False if not.

        """
        other_players = filter(lambda p: p.name != player.name, self.players.values())

        def collide(p):
            return np.linalg.norm(player.pos - p.pos) <= (player.radius + p.radius)

        return any(map(collide, other_players))

    def detect_collision_counters(self, player: Player):
        """Checks for collisions of the queried player with each counter.

        Args:
            player:  The player to check collisions with counters for.

        Returns: True if the player collides with any counter, False if not.

        """
        return any(map(lambda counter: self.detect_collision_player_counter(player, counter), self.counters))

    def detect_collision_player_counter(self, player: Player, counter: Counter):
        """Checks if the player and counter collide (overlap).
        A counter is modelled as a rectangle (square actually), a player is modelled as a circle.
        The distance of the player position (circle center) and the counter rectangle is calculated, if it is
        smaller than the player radius, a collision is detected.
        TODO: Efficiency improvement by checking only nearest counters? Quadtree...?

        Args:
            player: The player to check the collision for.
            counter: The counter to check the collision for.

        Returns: True if player and counter overlap, False if not.

        """
        size = self.counter_side_length
        cx, cy = player.pos
        dx = max(np.abs(cx - counter.pos[0]) - size / 2, 0)
        dy = max(np.abs(cy - counter.pos[1]) - size / 2, 0)
        distance = np.linalg.norm([dx, dy])
        return distance < player.radius

    def detect_collision_world_bounds(self, player: Player):
        """Checks for detections of the player and the world bounds.

        Args:
            player: The player which to not let escape the world.

        Returns: True if the player touches the world bounds, False if not.
        """
        collisions_lower = any((player.pos - player.radius) < 0)
        collisions_upper = any((player.pos + player.radius) > [self.world_width, self.world_height])
        return collisions_lower or collisions_upper

    def step(self):
        """Performs a step of the environment. Affects time based events such as cooking or cutting things, orders
        and timelimits.
        """
        pass

    def get_state(self):
        """Get the current state of the game environment. The state here is accessible by the current python objects.

        Returns: Dict of lists of the current relevant game objects.

        """
        return {"players": self.players,
                "counters": self.counters,
                "score": self.score}

    def get_state_json(self):
        """Get the current state of the game environment as a json-like nested dictionary.

        Returns: Json-like string of the current game state.

        """
        pass