Skip to content
Snippets Groups Projects
overcooked_environment.py 25.78 KiB
from __future__ import annotations

import dataclasses
import inspect
import json
import logging
import random
import sys
from datetime import timedelta, datetime
from enum import Enum
from pathlib import Path
from typing import Literal

import numpy as np
import numpy.typing as npt
import yaml
from scipy.spatial import distance_matrix

from overcooked_simulator import utils
from overcooked_simulator.counter_factory import CounterFactory
from overcooked_simulator.counters import (
    Counter,
    PlateConfig,
)
from overcooked_simulator.game_items import (
    ItemInfo,
    ItemType,
)
from overcooked_simulator.order import OrderAndScoreManager
from overcooked_simulator.player import Player, PlayerConfig
from overcooked_simulator.state_representation import StateRepresentation
from overcooked_simulator.utils import create_init_env_time, get_closest

log = logging.getLogger(__name__)


class ActionType(Enum):
    """The 3 different types of valid actions. They can be extended via the `Action.action_data` attribute."""

    MOVEMENT = "movement"
    """move the agent."""
    PUT = "pickup"
    """interaction type 1, e.g., for pickup or drop off. Maybe other words: transplace?"""
    # TODO change value to put
    INTERACT = "interact"
    """interaction type 2, e.g., for progressing. Start and stop interaction via `keydown` and `keyup` actions."""


class InterActionData(Enum):
    """The data for the interaction action: `ActionType.MOVEMENT`."""

    START = "keydown"
    "start an interaction."
    STOP = "keyup"
    "stop an interaction without moving away."


@dataclasses.dataclass
class Action:
    """Action class, specifies player, action type and action itself."""

    player: str
    """Id of the player."""
    action_type: ActionType
    """Type of the action to perform. Defines what action data is valid."""
    action_data: npt.NDArray[float] | InterActionData | Literal["pickup"]
    """Data for the action, e.g., movement vector or start and stop interaction."""
    duration: float | int = 0
    """Duration of the action (relevant for movement)"""

    def __repr__(self):
        return f"Action({self.player},{self.action_type.value},{self.action_data},{self.duration})"

    def __post_init__(self):
        if isinstance(self.action_type, str):
            self.action_type = ActionType(self.action_type)
        if isinstance(self.action_data, str) and self.action_data != "pickup":
            self.action_data = InterActionData(self.action_data)


# TODO Abstract base class for different environments


class Environment:
    """Environment class which handles the game logic for the overcooked-inspired environment.

    Handles player movement, collision-detection, counters, cooking processes, recipes, incoming orders, time.
    """

    PAUSED = None

    def __init__(
        self,
        env_config: Path | str,
        layout_config: Path | str,
        item_info: Path | str,
        as_files: bool = True,
    ):
        self.players: dict[str, Player] = {}
        """the player, keyed by their id/name."""

        self.as_files = as_files
        """Are the configs just the path to the files."""
        if self.as_files:
            with open(env_config, "r") as file:
                self.environment_config = yaml.load(file, Loader=yaml.Loader)
        else:
            self.environment_config = yaml.load(env_config, Loader=yaml.Loader)
        self.layout_config = layout_config
        """The layout config for the environment"""
        # self.counter_side_length = 1  # -> this changed! is 1 now

        self.item_info: dict[str, ItemInfo] = self.load_item_info(item_info)
        """The loaded item info dict. Keys are the item names."""
        # self.validate_item_info()
        if self.environment_config["meals"]["all"]:
            self.allowed_meal_names = set(
                [
                    item
                    for item, info in self.item_info.items()
                    if info.type == ItemType.Meal
                ]
            )
        else:
            self.allowed_meal_names = set(self.environment_config["meals"]["list"])
            """The allowed meals depend on the `environment_config.yml` configured behaviour. Either all meals that 
            are possible or only a limited subset."""

        self.order_and_score = OrderAndScoreManager(
            order_config=self.environment_config["orders"],
            available_meals={
                item: info
                for item, info in self.item_info.items()
                if info.type == ItemType.Meal and item in self.allowed_meal_names
            },
        )
        """The manager for the orders and score update."""

        self.kitchen_height: int = 0
        """The height of the kitchen, is set by the `Environment.parse_layout_file` method"""
        self.kitchen_width: int = 0
        """The width of the kitchen, is set by the `Environment.parse_layout_file` method"""

        self.counter_factory = CounterFactory(
            layout_chars_config=self.environment_config["layout_chars"],
            item_info=self.item_info,
            serving_window_additional_kwargs={
                "meals": self.allowed_meal_names,
                "env_time_func": self.get_env_time,
            },
            plate_config=PlateConfig(
                **(
                    self.environment_config["plates"]
                    if "plates" in self.environment_config
                    else {}
                )
            ),
            order_and_score=self.order_and_score,
        )

        (
            self.counters,
            self.designated_player_positions,
            self.free_positions,
        ) = self.parse_layout_file()

        self.counter_positions = np.array([c.pos for c in self.counters])

        self.world_borders = np.array(
            [[-0.5, self.kitchen_width - 0.5], [-0.5, self.kitchen_height - 0.5]],
            dtype=float,
        )

        self.player_movement_speed = self.environment_config["player_config"][
            "player_speed_units_per_seconds"
        ]
        self.player_radius = self.environment_config["player_config"]["radius"]

        progress_counter_classes = list(
            filter(
                lambda cl: hasattr(cl, "progress"),
                dict(
                    inspect.getmembers(
                        sys.modules["overcooked_simulator.counters"], inspect.isclass
                    )
                ).values(),
            )
        )
        self.progressing_counters = list(
            filter(
                lambda c: c.__class__ in progress_counter_classes,
                self.counters,
            )
        )
        """Counters that needs to be called in the step function via the `progress` method."""

        self.env_time: datetime = create_init_env_time()
        """the internal time of the environment. An environment starts always with the time from 
        `create_init_env_time`."""
        self.order_and_score.create_init_orders(self.env_time)
        self.start_time = self.env_time
        """The relative env time when it started."""
        self.env_time_end = self.env_time + timedelta(
            seconds=self.environment_config["game"]["time_limit_seconds"]
        )
        """The relative env time when it will stop/end"""
        log.debug(f"End time: {self.env_time_end}")

    @property
    def game_ended(self) -> bool:
        """Whether the game is over or not based on the calculated `Environment.env_time_end`"""
        return self.env_time >= self.env_time_end

    def set_collision_arrays(self):
        number_players = len(self.players)
        self.world_borders_lower = self.world_borders[np.newaxis, :, 0].repeat(
            number_players, axis=0
        )
        self.world_borders_upper = self.world_borders[np.newaxis, :, 1].repeat(
            number_players, axis=0
        )

    def get_env_time(self):
        """the internal time of the environment. An environment starts always with the time from `create_init_env_time`.

        Utility method to pass a reference to the serving window."""
        return self.env_time

    def load_item_info(self, data) -> dict[str, ItemInfo]:
        """Load `item_info.yml`, create ItemInfo classes and replace equipment strings with item infos."""
        if self.as_files:
            with open(data, "r") as file:
                item_lookup = yaml.safe_load(file)
        else:
            item_lookup = yaml.safe_load(data)
        for item_name in item_lookup:
            item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name])

        for item_name, item_info in item_lookup.items():
            if item_info.equipment:
                item_info.equipment = item_lookup[item_info.equipment]
        return item_lookup

    def validate_item_info(self):
        """TODO"""
        raise NotImplementedError
        # infos = {t: [] for t in ItemType}
        # graph = nx.DiGraph()
        # for info in self.item_info.values():
        #     infos[info.type].append(info)
        #     graph.add_node(info.name)
        #     match info.type:
        #         case ItemType.Ingredient:
        #             if info.is_cuttable:
        #                 graph.add_edge(
        #                     info.name, info.finished_progress_name[:-1] + info.name
        #                 )
        #         case ItemType.Equipment:
        #             ...
        #         case ItemType.Meal:
        #             if info.equipment is not None:
        #                 graph.add_edge(info.equipment.name, info.name)
        #             for ingredient in info.needs:
        #                 graph.add_edge(ingredient, info.name)

        # graph = nx.DiGraph()
        # for item_name, item_info in self.item_info.items():
        #     graph.add_node(item_name, type=item_info.type.name)
        #     if len(item_info.equipment) == 0:
        #         for item in item_info.needs:
        #             graph.add_edge(item, item_name)
        #     else:
        #         for item in item_info.needs:
        #             for equipment in item_info.equipment:
        #                 graph.add_edge(item, equipment)
        #                 graph.add_edge(equipment, item_name)

        # plt.figure(figsize=(10, 10))
        # pos = nx.nx_agraph.graphviz_layout(graph, prog="twopi", args="")
        # nx.draw(graph, pos=pos, with_labels=True, node_color="white", node_size=500)
        # print(nx.multipartite_layout(graph, subset_key="type", align="vertical"))

        # pos = {
        #     node: (
        #         len(nx.ancestors(graph, node)) - len(nx.descendants(graph, node)),
        #         y,
        #     )
        #     for y, node in enumerate(graph)
        # }
        # nx.draw(
        #     graph,
        #     pos=pos,
        #     with_labels=True,
        #     node_shape="s",
        #     node_size=500,
        #     node_color="white",
        # )
        # TODO add colors for ingredients, equipment and meals
        # plt.show()

    def parse_layout_file(self):
        """Creates layout of kitchen counters in the environment based on layout file.
        Counters are arranged in a fixed size grid starting at [0,0]. The center of the first counter is at
        [counter_size/2, counter_size/2], counters are directly next to each other (of no empty space is specified
        in layout).
        """

        starting_at: float = 0.0
        current_y: float = starting_at
        counters: list[Counter] = []
        designated_player_positions: list[npt.NDArray] = []
        free_positions: list[npt.NDArray] = []

        if self.as_files:
            with open(self.layout_config, "r") as layout_file:
                lines = layout_file.readlines()
        else:
            lines = self.layout_config.split("\n")

        for line in lines:
            line = line.replace("\n", "").replace(" ", "")  # remove newline char
            current_x: float = starting_at
            for character in line:
                character = character.capitalize()
                pos = np.array([current_x, current_y])
                assert self.counter_factory.can_map(
                    character
                ), f"{character=} in layout file can not be mapped"
                if self.counter_factory.is_counter(character):
                    counters.append(
                        self.counter_factory.get_counter_object(character, pos)
                    )
                else:
                    match self.counter_factory.map_not_counter(character):
                        case "Agent":
                            designated_player_positions.append(pos)
                        case "Free":
                            free_positions.append(np.array([current_x, current_y]))

                current_x += 1
            current_y += 1

        self.kitchen_width: float = len(lines[0]) + starting_at
        self.kitchen_height = len(lines) + starting_at
        self.counter_factory.post_counter_setup(counters)

        return counters, designated_player_positions, free_positions

    def perform_action(self, action: Action):
        """Performs an action of a player in the environment. Maps different types of action inputs to the
        correct execution of the players.
        Possible action types are movement, pickup and interact actions.

        Args:
            action: The action to be performed
        """
        assert action.player in self.players.keys(), "Unknown player."
        player = self.players[action.player]

        if action.action_type == ActionType.MOVEMENT:
            player.set_movement(
                action.action_data,
                self.env_time + timedelta(seconds=action.duration),
            )
        else:
            counter = self.get_facing_counter(player)
            if player.can_reach(counter):
                if action.action_type == ActionType.PUT:
                    player.pick_action(counter)

                elif action.action_type == ActionType.INTERACT:
                    if action.action_data == InterActionData.START:
                        player.perform_interact_hold_start(counter)
                        player.last_interacted_counter = counter
            if action.action_data == InterActionData.STOP:
                if player.last_interacted_counter:
                    player.perform_interact_hold_stop(player.last_interacted_counter)

    def get_facing_counter(self, player: Player):
        """Determines the counter which the player is looking at.
        Adds a multiple of the player facing direction onto the player position and finds the closest
        counter for that point.

        Args:
            player: The player for which to find the facing counter.

        Returns:

        """
        facing_counter = get_closest(player.facing_point, self.counters)
        return facing_counter

    def perform_movement(self, duration: timedelta):
        """Moves a player in the direction specified in the action.action. If the player collides with a
        counter or other player through this movement, then they are not moved.
        (The extended code with the two ifs is for sliding movement at the counters, which feels a bit smoother.
        This happens, when the player moves diagonally against the counters or world boundary.
        This just checks if the single axis party of the movement could move the player and does so at a lower rate.)

        The movement action is a unit 2d vector.

        Detects collisions with other players and pushes them out of the way.

        Args:
            player: The player to move.
            duration: The duration for how long the movement to perform.
        """
        d_time = duration.total_seconds()

        player_positions = np.array([p.pos for p in self.players.values()], dtype=float)
        player_movement_vectors = np.array(
            [
                p.current_movement if self.env_time <= p.movement_until else [0, 0]
                for p in self.players.values()
            ],
            dtype=float,
        )

        targeted_positions = player_positions + (
            player_movement_vectors * (self.player_movement_speed * d_time)
        )

        # Collisions player player
        distances_players_after_scipy = distance_matrix(
            targeted_positions, targeted_positions
        )

        player_diff_vecs = -(
            player_positions[:, np.newaxis, :] - player_positions[np.newaxis, :, :]
        )
        collision_idxs = distances_players_after_scipy < (2 * self.player_radius)
        eye_idxs = np.eye(
            distances_players_after_scipy.shape[0],
            distances_players_after_scipy.shape[1],
            dtype=bool,
        )
        collision_idxs[eye_idxs] = False

        # Player push players around
        player_diff_vecs[collision_idxs == False] = 0
        push_vectors = np.sum(player_diff_vecs, axis=0)

        updated_movement = push_vectors + player_movement_vectors
        new_positions = player_positions + (
            updated_movement * (self.player_movement_speed * d_time)
        )

        # Collisions players counters
        counter_diff_vecs = (
            new_positions[:, np.newaxis, :] - self.counter_positions[np.newaxis, :, :]
        )
        counter_distances = np.max((np.abs(counter_diff_vecs)), axis=2)
        # counter_distances = np.linalg.norm(counter_diff_vecs, axis=2)
        closest_counter_positions = self.counter_positions[
            np.argmin(counter_distances, axis=1)
        ]

        nearest_counter_to_player = closest_counter_positions - new_positions
        print(nearest_counter_to_player)

        collided = np.min(counter_distances, axis=1) - 0.5 < self.player_radius
        # print("                     COLLIDED", collided)

        # print("CLOSEST_COUNTER", closest_counter_positions)
        relevant_axes = nearest_counter_to_player.argmax(axis=1)
        relevant_values = nearest_counter_to_player.max(axis=1)

        new_positions = player_positions + (
            updated_movement * (self.player_movement_speed * d_time)
        )

        for idx, player in enumerate(player_positions):
            axis = relevant_axes[idx]

            if collided[idx]:
                # print("before", updated_movement)
                if relevant_values[idx] - 0.5 > 0:
                    # print("settings more")
                    new_positions[idx, axis] = np.min(
                        [
                            player_positions[idx, axis],
                            closest_counter_positions[idx, axis],
                        ]
                    )

                    # updated_movement[idx, axis] = np.max(updated_movement[idx, axis], 0)
                if relevant_values[idx] + 0.5 < 0:
                    # print("settings less")
                    new_positions[idx, axis] = np.max(
                        [
                            player_positions[idx, axis],
                            closest_counter_positions[idx, axis],
                        ]
                    )

                    # updated_movement[idx, axis] = np.min(updated_movement[idx, axis], 0)
                # print("after", updated_movement)

        # new_positions[collided] = player_positions[collided]

        # new_positions[min_counter_distances < self.player_radius] = player_positions[min_counter_distances < self.player_radius]

        # counter_distances_axes = np.max((np.abs(counter_diff_vecs)), axis=1)

        # Collisions player world borders
        new_positions = np.max(
            [new_positions, self.world_borders_lower + self.player_radius], axis=0
        )
        new_positions = np.min(
            [new_positions, self.world_borders_upper - self.player_radius], axis=0
        )

        for idx, p in enumerate(self.players.values()):
            p.turn(player_movement_vectors[idx])
            p.move_abs(new_positions[idx])

    def detect_collision(self, player: Player):
        """Detect collisions between the player and other players or counters.

        Args:
            player: The player for which to check collisions.

        Returns: True if the player is intersecting with any object in the environment.
        """
        return (
            len(self.get_collided_players(player)) > 0
            or self.detect_collision_counters(player)
            or self.detect_collision_world_bounds(player)
        )

    def get_collided_players(self, player: Player) -> list[Player]:
        """Detects collisions between the queried player and other players. Returns the list of the collided players.
        A player is modelled as a circle. Collision is detected if the distance between the players is smaller
        than the sum of the radius's.

        Args:
            player: The player to check collisions with other players for.

        Returns: The list of other players the player collides with.

        """

        players_list = list(self.players.values())

        if player in players_list:
            player_idx = players_list.index(player)
            return utils.get_collided_players(player_idx, list(self.players.values()))
        return []

    def detect_player_collision(self, player: Player):
        """Detects collisions between the queried player and other players.
        A player is modelled as a circle. Collision is detected if the distance between the players is smaller
        than the sum of the radius's.

        Args:
            player: The player to check collisions with other players for.

        Returns: True if the player collides with other players, False if not.

        """

        return any(self.get_collided_players(player))

    def detect_collision_counters(self, player: Player):
        """Checks for collisions of the queried player with each counter.

        Args:
            player:  The player to check collisions with counters for.

        Returns: True if the player collides with any counter, False if not.

        """
        return np.any(
            np.max((np.abs(self.counter_positions - player.pos) - 0.5), axis=1)
            < player.radius
        )

    def detect_collision_world_bounds(self, player: Player):
        """Checks for detections of the player and the world bounds.

        Args:
            player: The player which to not let escape the world.

        Returns: True if the player touches the world bounds, False if not.
        """

        return np.any(player.pos - player.radius < self.world_borders[:, 0]) or np.any(
            player.pos + player.radius > self.world_borders[:, 1]
        )

    def add_player(self, player_name: str, pos: npt.NDArray = None):
        """Add a player to the environment.

        Args:
            player_name: The id/name of the player to reference actions and in the state.
            pos: The optional init position of the player.
        """
        # TODO check if the player name already exists in the environment and do not overwrite player.
        log.debug(f"Add player {player_name} to the game")
        player = Player(
            player_name,
            player_config=PlayerConfig(
                **(
                    self.environment_config["player_config"]
                    if "player_config" in self.environment_config
                    else {}
                )
            ),
            pos=pos,
        )
        self.players[player.name] = player
        if player.pos is None:
            if len(self.designated_player_positions) > 0:
                free_idx = random.randint(0, len(self.designated_player_positions) - 1)
                player.move_abs(self.designated_player_positions[free_idx])
                del self.designated_player_positions[free_idx]
            elif len(self.free_positions) > 0:
                free_idx = random.randint(0, len(self.free_positions) - 1)
                player.move_abs(self.free_positions[free_idx])
                del self.free_positions[free_idx]
            else:
                log.debug("No free positions left in kitchens")
            player.update_facing_point()

        self.set_collision_arrays()

    def step(self, passed_time: timedelta):
        """Performs a step of the environment. Affects time based events such as cooking or cutting things, orders
        and time limits.
        """
        self.env_time += passed_time

        if not self.game_ended:
            self.perform_movement(passed_time)

            for counter in self.progressing_counters:
                counter.progress(passed_time=passed_time, now=self.env_time)
            self.order_and_score.progress(passed_time=passed_time, now=self.env_time)

    def get_state(self):
        """Get the current state of the game environment. The state here is accessible by the current python objects.

        Returns: Dict of lists of the current relevant game objects.

        """
        return {
            "players": self.players,
            "counters": self.counters,
            "score": self.order_and_score.score,
            "orders": self.order_and_score.open_orders,
            "ended": self.game_ended,
            "env_time": self.env_time,
            "remaining_time": max(self.env_time_end - self.env_time, timedelta(0)),
        }

    def get_json_state(self, player_id: str = None):
        state = {
            "players": [p.to_dict() for p in self.players.values()],
            "counters": [c.to_dict() for c in self.counters],
            "kitchen": {"width": self.kitchen_width, "height": self.kitchen_height},
            "score": self.order_and_score.score,
            "orders": self.order_and_score.order_state(),
            "ended": self.game_ended,
            "env_time": self.env_time.isoformat(),
            "remaining_time": max(
                self.env_time_end - self.env_time, timedelta(0)
            ).total_seconds(),
        }
        json_data = json.dumps(state)
        assert StateRepresentation.model_validate_json(json_data=json_data)
        return json_data

    def reset_env_time(self):
        """Reset the env time to the initial time, defined by `create_init_env_time`."""
        self.env_time = create_init_env_time()
        log.debug(f"Reset env time to {self.env_time}")