from __future__ import annotations import dataclasses import inspect import json import logging import random import sys from datetime import timedelta, datetime from enum import Enum from pathlib import Path from typing import Literal import numpy as np import numpy.typing as npt import yaml from scipy.spatial import distance_matrix from overcooked_simulator import utils from overcooked_simulator.counter_factory import CounterFactory from overcooked_simulator.counters import ( Counter, PlateConfig, ) from overcooked_simulator.game_items import ( ItemInfo, ItemType, ) from overcooked_simulator.order import OrderAndScoreManager from overcooked_simulator.player import Player, PlayerConfig from overcooked_simulator.state_representation import StateRepresentation from overcooked_simulator.utils import create_init_env_time, get_closest log = logging.getLogger(__name__) class ActionType(Enum): """The 3 different types of valid actions. They can be extended via the `Action.action_data` attribute.""" MOVEMENT = "movement" """move the agent.""" PUT = "pickup" """interaction type 1, e.g., for pickup or drop off. Maybe other words: transplace?""" # TODO change value to put INTERACT = "interact" """interaction type 2, e.g., for progressing. Start and stop interaction via `keydown` and `keyup` actions.""" class InterActionData(Enum): """The data for the interaction action: `ActionType.MOVEMENT`.""" START = "keydown" "start an interaction." STOP = "keyup" "stop an interaction without moving away." @dataclasses.dataclass class Action: """Action class, specifies player, action type and action itself.""" player: str """Id of the player.""" action_type: ActionType """Type of the action to perform. Defines what action data is valid.""" action_data: npt.NDArray[float] | InterActionData | Literal["pickup"] """Data for the action, e.g., movement vector or start and stop interaction.""" duration: float | int = 0 """Duration of the action (relevant for movement)""" def __repr__(self): return f"Action({self.player},{self.action_type.value},{self.action_data},{self.duration})" def __post_init__(self): if isinstance(self.action_type, str): self.action_type = ActionType(self.action_type) if isinstance(self.action_data, str) and self.action_data != "pickup": self.action_data = InterActionData(self.action_data) # TODO Abstract base class for different environments class Environment: """Environment class which handles the game logic for the overcooked-inspired environment. Handles player movement, collision-detection, counters, cooking processes, recipes, incoming orders, time. """ PAUSED = None def __init__( self, env_config: Path | str, layout_config: Path | str, item_info: Path | str, as_files: bool = True, ): self.players: dict[str, Player] = {} """the player, keyed by their id/name.""" self.as_files = as_files """Are the configs just the path to the files.""" if self.as_files: with open(env_config, "r") as file: self.environment_config = yaml.load(file, Loader=yaml.Loader) else: self.environment_config = yaml.load(env_config, Loader=yaml.Loader) self.layout_config = layout_config """The layout config for the environment""" # self.counter_side_length = 1 # -> this changed! is 1 now self.item_info: dict[str, ItemInfo] = self.load_item_info(item_info) """The loaded item info dict. Keys are the item names.""" # self.validate_item_info() if self.environment_config["meals"]["all"]: self.allowed_meal_names = set( [ item for item, info in self.item_info.items() if info.type == ItemType.Meal ] ) else: self.allowed_meal_names = set(self.environment_config["meals"]["list"]) """The allowed meals depend on the `environment_config.yml` configured behaviour. Either all meals that are possible or only a limited subset.""" self.order_and_score = OrderAndScoreManager( order_config=self.environment_config["orders"], available_meals={ item: info for item, info in self.item_info.items() if info.type == ItemType.Meal and item in self.allowed_meal_names }, ) """The manager for the orders and score update.""" self.kitchen_height: int = 0 """The height of the kitchen, is set by the `Environment.parse_layout_file` method""" self.kitchen_width: int = 0 """The width of the kitchen, is set by the `Environment.parse_layout_file` method""" self.counter_factory = CounterFactory( layout_chars_config=self.environment_config["layout_chars"], item_info=self.item_info, serving_window_additional_kwargs={ "meals": self.allowed_meal_names, "env_time_func": self.get_env_time, }, plate_config=PlateConfig( **( self.environment_config["plates"] if "plates" in self.environment_config else {} ) ), order_and_score=self.order_and_score, ) ( self.counters, self.designated_player_positions, self.free_positions, ) = self.parse_layout_file() self.counter_positions = np.array([c.pos for c in self.counters]) self.world_borders = np.array( [[-0.5, self.kitchen_width - 0.5], [-0.5, self.kitchen_height - 0.5]], dtype=float, ) self.player_movement_speed = self.environment_config["player_config"][ "player_speed_units_per_seconds" ] self.player_radius = self.environment_config["player_config"]["radius"] progress_counter_classes = list( filter( lambda cl: hasattr(cl, "progress"), dict( inspect.getmembers( sys.modules["overcooked_simulator.counters"], inspect.isclass ) ).values(), ) ) self.progressing_counters = list( filter( lambda c: c.__class__ in progress_counter_classes, self.counters, ) ) """Counters that needs to be called in the step function via the `progress` method.""" self.env_time: datetime = create_init_env_time() """the internal time of the environment. An environment starts always with the time from `create_init_env_time`.""" self.order_and_score.create_init_orders(self.env_time) self.start_time = self.env_time """The relative env time when it started.""" self.env_time_end = self.env_time + timedelta( seconds=self.environment_config["game"]["time_limit_seconds"] ) """The relative env time when it will stop/end""" log.debug(f"End time: {self.env_time_end}") @property def game_ended(self) -> bool: """Whether the game is over or not based on the calculated `Environment.env_time_end`""" return self.env_time >= self.env_time_end def set_collision_arrays(self): number_players = len(self.players) self.world_borders_lower = self.world_borders[np.newaxis, :, 0].repeat( number_players, axis=0 ) self.world_borders_upper = self.world_borders[np.newaxis, :, 1].repeat( number_players, axis=0 ) def get_env_time(self): """the internal time of the environment. An environment starts always with the time from `create_init_env_time`. Utility method to pass a reference to the serving window.""" return self.env_time def load_item_info(self, data) -> dict[str, ItemInfo]: """Load `item_info.yml`, create ItemInfo classes and replace equipment strings with item infos.""" if self.as_files: with open(data, "r") as file: item_lookup = yaml.safe_load(file) else: item_lookup = yaml.safe_load(data) for item_name in item_lookup: item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name]) for item_name, item_info in item_lookup.items(): if item_info.equipment: item_info.equipment = item_lookup[item_info.equipment] return item_lookup def validate_item_info(self): """TODO""" raise NotImplementedError # infos = {t: [] for t in ItemType} # graph = nx.DiGraph() # for info in self.item_info.values(): # infos[info.type].append(info) # graph.add_node(info.name) # match info.type: # case ItemType.Ingredient: # if info.is_cuttable: # graph.add_edge( # info.name, info.finished_progress_name[:-1] + info.name # ) # case ItemType.Equipment: # ... # case ItemType.Meal: # if info.equipment is not None: # graph.add_edge(info.equipment.name, info.name) # for ingredient in info.needs: # graph.add_edge(ingredient, info.name) # graph = nx.DiGraph() # for item_name, item_info in self.item_info.items(): # graph.add_node(item_name, type=item_info.type.name) # if len(item_info.equipment) == 0: # for item in item_info.needs: # graph.add_edge(item, item_name) # else: # for item in item_info.needs: # for equipment in item_info.equipment: # graph.add_edge(item, equipment) # graph.add_edge(equipment, item_name) # plt.figure(figsize=(10, 10)) # pos = nx.nx_agraph.graphviz_layout(graph, prog="twopi", args="") # nx.draw(graph, pos=pos, with_labels=True, node_color="white", node_size=500) # print(nx.multipartite_layout(graph, subset_key="type", align="vertical")) # pos = { # node: ( # len(nx.ancestors(graph, node)) - len(nx.descendants(graph, node)), # y, # ) # for y, node in enumerate(graph) # } # nx.draw( # graph, # pos=pos, # with_labels=True, # node_shape="s", # node_size=500, # node_color="white", # ) # TODO add colors for ingredients, equipment and meals # plt.show() def parse_layout_file(self): """Creates layout of kitchen counters in the environment based on layout file. Counters are arranged in a fixed size grid starting at [0,0]. The center of the first counter is at [counter_size/2, counter_size/2], counters are directly next to each other (of no empty space is specified in layout). """ starting_at: float = 0.0 current_y: float = starting_at counters: list[Counter] = [] designated_player_positions: list[npt.NDArray] = [] free_positions: list[npt.NDArray] = [] if self.as_files: with open(self.layout_config, "r") as layout_file: lines = layout_file.readlines() else: lines = self.layout_config.split("\n") for line in lines: line = line.replace("\n", "").replace(" ", "") # remove newline char current_x: float = starting_at for character in line: character = character.capitalize() pos = np.array([current_x, current_y]) assert self.counter_factory.can_map( character ), f"{character=} in layout file can not be mapped" if self.counter_factory.is_counter(character): counters.append( self.counter_factory.get_counter_object(character, pos) ) else: match self.counter_factory.map_not_counter(character): case "Agent": designated_player_positions.append(pos) case "Free": free_positions.append(np.array([current_x, current_y])) current_x += 1 current_y += 1 self.kitchen_width: float = len(lines[0]) + starting_at self.kitchen_height = len(lines) + starting_at self.counter_factory.post_counter_setup(counters) return counters, designated_player_positions, free_positions def perform_action(self, action: Action): """Performs an action of a player in the environment. Maps different types of action inputs to the correct execution of the players. Possible action types are movement, pickup and interact actions. Args: action: The action to be performed """ assert action.player in self.players.keys(), "Unknown player." player = self.players[action.player] if action.action_type == ActionType.MOVEMENT: player.set_movement( action.action_data, self.env_time + timedelta(seconds=action.duration), ) else: counter = self.get_facing_counter(player) if player.can_reach(counter): if action.action_type == ActionType.PUT: player.pick_action(counter) elif action.action_type == ActionType.INTERACT: if action.action_data == InterActionData.START: player.perform_interact_hold_start(counter) player.last_interacted_counter = counter if action.action_data == InterActionData.STOP: if player.last_interacted_counter: player.perform_interact_hold_stop(player.last_interacted_counter) def get_facing_counter(self, player: Player): """Determines the counter which the player is looking at. Adds a multiple of the player facing direction onto the player position and finds the closest counter for that point. Args: player: The player for which to find the facing counter. Returns: """ facing_counter = get_closest(player.facing_point, self.counters) return facing_counter def perform_movement(self, duration: timedelta): """Moves a player in the direction specified in the action.action. If the player collides with a counter or other player through this movement, then they are not moved. (The extended code with the two ifs is for sliding movement at the counters, which feels a bit smoother. This happens, when the player moves diagonally against the counters or world boundary. This just checks if the single axis party of the movement could move the player and does so at a lower rate.) The movement action is a unit 2d vector. Detects collisions with other players and pushes them out of the way. Args: player: The player to move. duration: The duration for how long the movement to perform. """ d_time = duration.total_seconds() player_positions = np.array([p.pos for p in self.players.values()], dtype=float) player_movement_vectors = np.array( [ p.current_movement if self.env_time <= p.movement_until else [0, 0] for p in self.players.values() ], dtype=float, ) targeted_positions = player_positions + ( player_movement_vectors * (self.player_movement_speed * d_time) ) # Collisions player player distances_players_after_scipy = distance_matrix( targeted_positions, targeted_positions ) player_diff_vecs = -( player_positions[:, np.newaxis, :] - player_positions[np.newaxis, :, :] ) collision_idxs = distances_players_after_scipy < (2 * self.player_radius) eye_idxs = np.eye( distances_players_after_scipy.shape[0], distances_players_after_scipy.shape[1], dtype=bool, ) collision_idxs[eye_idxs] = False # Player push players around player_diff_vecs[collision_idxs == False] = 0 push_vectors = np.sum(player_diff_vecs, axis=0) updated_movement = push_vectors + player_movement_vectors new_positions = player_positions + ( updated_movement * (self.player_movement_speed * d_time) ) # Collisions players counters counter_diff_vecs = ( new_positions[:, np.newaxis, :] - self.counter_positions[np.newaxis, :, :] ) counter_distances = np.max((np.abs(counter_diff_vecs)), axis=2) # counter_distances = np.linalg.norm(counter_diff_vecs, axis=2) closest_counter_positions = self.counter_positions[ np.argmin(counter_distances, axis=1) ] nearest_counter_to_player = closest_counter_positions - new_positions print(nearest_counter_to_player) collided = np.min(counter_distances, axis=1) - 0.5 < self.player_radius # print(" COLLIDED", collided) # print("CLOSEST_COUNTER", closest_counter_positions) relevant_axes = nearest_counter_to_player.argmax(axis=1) relevant_values = nearest_counter_to_player.max(axis=1) new_positions = player_positions + ( updated_movement * (self.player_movement_speed * d_time) ) for idx, player in enumerate(player_positions): axis = relevant_axes[idx] if collided[idx]: # print("before", updated_movement) if relevant_values[idx] - 0.5 > 0: # print("settings more") new_positions[idx, axis] = np.min( [ player_positions[idx, axis], closest_counter_positions[idx, axis], ] ) # updated_movement[idx, axis] = np.max(updated_movement[idx, axis], 0) if relevant_values[idx] + 0.5 < 0: # print("settings less") new_positions[idx, axis] = np.max( [ player_positions[idx, axis], closest_counter_positions[idx, axis], ] ) # updated_movement[idx, axis] = np.min(updated_movement[idx, axis], 0) # print("after", updated_movement) # new_positions[collided] = player_positions[collided] # new_positions[min_counter_distances < self.player_radius] = player_positions[min_counter_distances < self.player_radius] # counter_distances_axes = np.max((np.abs(counter_diff_vecs)), axis=1) # Collisions player world borders new_positions = np.max( [new_positions, self.world_borders_lower + self.player_radius], axis=0 ) new_positions = np.min( [new_positions, self.world_borders_upper - self.player_radius], axis=0 ) for idx, p in enumerate(self.players.values()): p.turn(player_movement_vectors[idx]) p.move_abs(new_positions[idx]) def detect_collision(self, player: Player): """Detect collisions between the player and other players or counters. Args: player: The player for which to check collisions. Returns: True if the player is intersecting with any object in the environment. """ return ( len(self.get_collided_players(player)) > 0 or self.detect_collision_counters(player) or self.detect_collision_world_bounds(player) ) def get_collided_players(self, player: Player) -> list[Player]: """Detects collisions between the queried player and other players. Returns the list of the collided players. A player is modelled as a circle. Collision is detected if the distance between the players is smaller than the sum of the radius's. Args: player: The player to check collisions with other players for. Returns: The list of other players the player collides with. """ players_list = list(self.players.values()) if player in players_list: player_idx = players_list.index(player) return utils.get_collided_players(player_idx, list(self.players.values())) return [] def detect_player_collision(self, player: Player): """Detects collisions between the queried player and other players. A player is modelled as a circle. Collision is detected if the distance between the players is smaller than the sum of the radius's. Args: player: The player to check collisions with other players for. Returns: True if the player collides with other players, False if not. """ return any(self.get_collided_players(player)) def detect_collision_counters(self, player: Player): """Checks for collisions of the queried player with each counter. Args: player: The player to check collisions with counters for. Returns: True if the player collides with any counter, False if not. """ return np.any( np.max((np.abs(self.counter_positions - player.pos) - 0.5), axis=1) < player.radius ) def detect_collision_world_bounds(self, player: Player): """Checks for detections of the player and the world bounds. Args: player: The player which to not let escape the world. Returns: True if the player touches the world bounds, False if not. """ return np.any(player.pos - player.radius < self.world_borders[:, 0]) or np.any( player.pos + player.radius > self.world_borders[:, 1] ) def add_player(self, player_name: str, pos: npt.NDArray = None): """Add a player to the environment. Args: player_name: The id/name of the player to reference actions and in the state. pos: The optional init position of the player. """ # TODO check if the player name already exists in the environment and do not overwrite player. log.debug(f"Add player {player_name} to the game") player = Player( player_name, player_config=PlayerConfig( **( self.environment_config["player_config"] if "player_config" in self.environment_config else {} ) ), pos=pos, ) self.players[player.name] = player if player.pos is None: if len(self.designated_player_positions) > 0: free_idx = random.randint(0, len(self.designated_player_positions) - 1) player.move_abs(self.designated_player_positions[free_idx]) del self.designated_player_positions[free_idx] elif len(self.free_positions) > 0: free_idx = random.randint(0, len(self.free_positions) - 1) player.move_abs(self.free_positions[free_idx]) del self.free_positions[free_idx] else: log.debug("No free positions left in kitchens") player.update_facing_point() self.set_collision_arrays() def step(self, passed_time: timedelta): """Performs a step of the environment. Affects time based events such as cooking or cutting things, orders and time limits. """ self.env_time += passed_time if not self.game_ended: self.perform_movement(passed_time) for counter in self.progressing_counters: counter.progress(passed_time=passed_time, now=self.env_time) self.order_and_score.progress(passed_time=passed_time, now=self.env_time) def get_state(self): """Get the current state of the game environment. The state here is accessible by the current python objects. Returns: Dict of lists of the current relevant game objects. """ return { "players": self.players, "counters": self.counters, "score": self.order_and_score.score, "orders": self.order_and_score.open_orders, "ended": self.game_ended, "env_time": self.env_time, "remaining_time": max(self.env_time_end - self.env_time, timedelta(0)), } def get_json_state(self, player_id: str = None): state = { "players": [p.to_dict() for p in self.players.values()], "counters": [c.to_dict() for c in self.counters], "kitchen": {"width": self.kitchen_width, "height": self.kitchen_height}, "score": self.order_and_score.score, "orders": self.order_and_score.order_state(), "ended": self.game_ended, "env_time": self.env_time.isoformat(), "remaining_time": max( self.env_time_end - self.env_time, timedelta(0) ).total_seconds(), } json_data = json.dumps(state) assert StateRepresentation.model_validate_json(json_data=json_data) return json_data def reset_env_time(self): """Reset the env time to the initial time, defined by `create_init_env_time`.""" self.env_time = create_init_env_time() log.debug(f"Reset env time to {self.env_time}")