from __future__ import annotations import dataclasses import inspect import json import logging import sys from collections import defaultdict from datetime import timedelta, datetime from enum import Enum from pathlib import Path from random import Random from typing import Literal, TypedDict, Callable, Tuple import numpy as np import numpy.typing as npt import yaml from scipy.spatial import distance_matrix from overcooked_simulator.counter_factory import CounterFactory from overcooked_simulator.counters import ( Counter, PlateConfig, ) from overcooked_simulator.effect_manager import EffectManager from overcooked_simulator.game_items import ( ItemInfo, ItemType, ) from overcooked_simulator.hooks import ( ITEM_INFO_LOADED, LAYOUT_FILE_PARSED, ENV_INITIALIZED, PRE_PERFORM_ACTION, POST_PERFORM_ACTION, PLAYER_ADDED, GAME_ENDED_STEP, PRE_STATE, STATE_DICT, JSON_STATE, PRE_RESET_ENV_TIME, POST_RESET_ENV_TIME, Hooks, ACTION_ON_NOT_REACHABLE_COUNTER, ACTION_PUT, ACTION_INTERACT_START, ITEM_INFO_CONFIG, ) from overcooked_simulator.order import ( OrderAndScoreManager, OrderConfig, ) from overcooked_simulator.player import Player, PlayerConfig from overcooked_simulator.state_representation import StateRepresentation, InfoMsg from overcooked_simulator.utils import create_init_env_time, get_closest log = logging.getLogger(__name__) FOG_OF_WAR = True PREVENT_SQUEEZING_INTO_OTHER_PLAYERS = False class ActionType(Enum): """The 3 different types of valid actions. They can be extended via the `Action.action_data` attribute.""" MOVEMENT = "movement" """move the agent.""" PUT = "pickup" """interaction type 1, e.g., for pickup or drop off. Maybe other words: transplace?""" # TODO change value to put INTERACT = "interact" """interaction type 2, e.g., for progressing. Start and stop interaction via `keydown` and `keyup` actions.""" class InterActionData(Enum): """The data for the interaction action: `ActionType.MOVEMENT`.""" START = "keydown" "start an interaction." STOP = "keyup" "stop an interaction without moving away." @dataclasses.dataclass class Action: """Action class, specifies player, action type and action itself.""" player: str """Id of the player.""" action_type: ActionType """Type of the action to perform. Defines what action data is valid.""" action_data: npt.NDArray[float] | InterActionData | Literal["pickup"] """Data for the action, e.g., movement vector or start and stop interaction.""" duration: float | int = 0 """Duration of the action (relevant for movement)""" def __repr__(self): return f"Action({self.player},{self.action_type.value},{self.action_data},{self.duration})" def __post_init__(self): if isinstance(self.action_type, str): self.action_type = ActionType(self.action_type) if isinstance(self.action_data, str) and self.action_data != "pickup": self.action_data = InterActionData(self.action_data) # TODO Abstract base class for different environments class EnvironmentConfig(TypedDict): plates: PlateConfig game: dict[Literal["time_limit_seconds"], int] meals: dict[Literal["all"] | Literal["list"], bool | list[str]] orders: OrderConfig player_config: PlayerConfig layout_chars: dict[str, str] extra_setup_functions: dict[str, dict] effect_manager: dict class Environment: """Environment class which handles the game logic for the overcooked-inspired environment. Handles player movement, collision-detection, counters, cooking processes, recipes, incoming orders, time. """ PAUSED = None def __init__( self, env_config: Path | str, layout_config: Path | str, item_info: Path | str, as_files: bool = True, env_name: str = "overcooked_sim", seed: int = 56789223842348, ): self.env_name = env_name """Reference to the run. E.g, the env id.""" self.env_time: datetime = create_init_env_time() """the internal time of the environment. An environment starts always with the time from `create_init_env_time`.""" self.random: Random = Random(seed) """Random instance.""" self.hook: Hooks = Hooks(self) """Hook manager. Register callbacks and create hook points with additional kwargs.""" self.players: dict[str, Player] = {} """the player, keyed by their id/name.""" self.as_files = as_files """Are the configs just the path to the files.""" if self.as_files: with open(env_config, "r") as file: env_config = file.read() self.environment_config: EnvironmentConfig = yaml.load( env_config, Loader=yaml.Loader ) """The config of the environment. All environment specific attributes is configured here.""" self.player_view_restricted = self.environment_config["player_config"][ "restricted_view" ] if self.player_view_restricted: self.player_view_angle = self.environment_config["player_config"][ "view_angle" ] self.extra_setup_functions() self.layout_config = layout_config """The layout config for the environment""" # self.counter_side_length = 1 # -> this changed! is 1 now self.item_info: dict[str, ItemInfo] = self.load_item_info(item_info) """The loaded item info dict. Keys are the item names.""" self.hook(ITEM_INFO_LOADED, item_info=item_info, as_files=as_files) # self.validate_item_info() if self.environment_config["meals"]["all"]: self.allowed_meal_names = set( [ item for item, info in self.item_info.items() if info.type == ItemType.Meal ] ) else: self.allowed_meal_names = set(self.environment_config["meals"]["list"]) """The allowed meals depend on the `environment_config.yml` configured behaviour. Either all meals that are possible or only a limited subset.""" self.order_and_score = OrderAndScoreManager( order_config=self.environment_config["orders"], available_meals={ item: info for item, info in self.item_info.items() if info.type == ItemType.Meal and item in self.allowed_meal_names }, hook=self.hook, random=self.random, ) """The manager for the orders and score update.""" self.kitchen_height: int = 0 """The height of the kitchen, is set by the `Environment.parse_layout_file` method""" self.kitchen_width: int = 0 """The width of the kitchen, is set by the `Environment.parse_layout_file` method""" self.counter_factory = CounterFactory( layout_chars_config=self.environment_config["layout_chars"], item_info=self.item_info, serving_window_additional_kwargs={ "meals": self.allowed_meal_names, "env_time_func": self.get_env_time, }, plate_config=PlateConfig( **( self.environment_config["plates"] if "plates" in self.environment_config else {} ) ), order_and_score=self.order_and_score, effect_manager_config=self.environment_config["effect_manager"], hook=self.hook, random=self.random, ) ( self.counters, self.designated_player_positions, self.free_positions, ) = self.parse_layout_file() self.hook(LAYOUT_FILE_PARSED) self.counter_positions = np.array([c.pos for c in self.counters]) self.world_borders = np.array( [[-0.5, self.kitchen_width - 0.5], [-0.5, self.kitchen_height - 0.5]], dtype=float, ) self.player_movement_speed = self.environment_config["player_config"][ "player_speed_units_per_seconds" ] self.player_radius = self.environment_config["player_config"]["radius"] progress_counter_classes = list( filter( lambda cl: hasattr(cl, "progress"), dict( inspect.getmembers( sys.modules["overcooked_simulator.counters"], inspect.isclass ) ).values(), ) ) self.progressing_counters = list( filter( lambda c: c.__class__ in progress_counter_classes, self.counters, ) ) """Counters that needs to be called in the step function via the `progress` method.""" self.order_and_score.create_init_orders(self.env_time) self.start_time = self.env_time """The relative env time when it started.""" self.env_time_end = self.env_time + timedelta( seconds=self.environment_config["game"]["time_limit_seconds"] ) """The relative env time when it will stop/end""" log.debug(f"End time: {self.env_time_end}") self.effect_manager: dict[ str, EffectManager ] = self.counter_factory.setup_effect_manger(self.counters) self.info_msgs_per_player: dict[str, list[InfoMsg]] = defaultdict(list) self.hook( ENV_INITIALIZED, environment_config=env_config, layout_config=self.layout_config, seed=seed, env_start_time_worldtime=datetime.now(), ) @property def game_ended(self) -> bool: """Whether the game is over or not based on the calculated `Environment.env_time_end`""" return self.env_time >= self.env_time_end def set_collision_arrays(self): number_players = len(self.players) self.world_borders_lower = self.world_borders[np.newaxis, :, 0].repeat( number_players, axis=0 ) self.world_borders_upper = self.world_borders[np.newaxis, :, 1].repeat( number_players, axis=0 ) def get_env_time(self): """the internal time of the environment. An environment starts always with the time from `create_init_env_time`. Utility method to pass a reference to the serving window.""" return self.env_time def load_item_info(self, data) -> dict[str, ItemInfo]: """Load `item_info.yml`, create ItemInfo classes and replace equipment strings with item infos.""" if self.as_files: with open(data, "r") as file: data = file.read() self.hook(ITEM_INFO_CONFIG, item_info_config=data) item_lookup = yaml.safe_load(data) for item_name in item_lookup: item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name]) for item_name, item_info in item_lookup.items(): if item_info.equipment: item_info.equipment = item_lookup[item_info.equipment] return item_lookup def validate_item_info(self): """TODO""" raise NotImplementedError # infos = {t: [] for t in ItemType} # graph = nx.DiGraph() # for info in self.item_info.values(): # infos[info.type].append(info) # graph.add_node(info.name) # match info.type: # case ItemType.Ingredient: # if info.is_cuttable: # graph.add_edge( # info.name, info.finished_progress_name[:-1] + info.name # ) # case ItemType.Equipment: # ... # case ItemType.Meal: # if info.equipment is not None: # graph.add_edge(info.equipment.name, info.name) # for ingredient in info.needs: # graph.add_edge(ingredient, info.name) # graph = nx.DiGraph() # for item_name, item_info in self.item_info.items(): # graph.add_node(item_name, type=item_info.type.name) # if len(item_info.equipment) == 0: # for item in item_info.needs: # graph.add_edge(item, item_name) # else: # for item in item_info.needs: # for equipment in item_info.equipment: # graph.add_edge(item, equipment) # graph.add_edge(equipment, item_name) # plt.figure(figsize=(10, 10)) # pos = nx.nx_agraph.graphviz_layout(graph, prog="twopi", args="") # nx.draw(graph, pos=pos, with_labels=True, node_color="white", node_size=500) # print(nx.multipartite_layout(graph, subset_key="type", align="vertical")) # pos = { # node: ( # len(nx.ancestors(graph, node)) - len(nx.descendants(graph, node)), # y, # ) # for y, node in enumerate(graph) # } # nx.draw( # graph, # pos=pos, # with_labels=True, # node_shape="s", # node_size=500, # node_color="white", # ) # TODO add colors for ingredients, equipment and meals # plt.show() def parse_layout_file( self, ) -> Tuple[list[Counter], list[npt.NDArray], list[npt.NDArray]]: """Creates layout of kitchen counters in the environment based on layout file. Counters are arranged in a fixed size grid starting at [0,0]. The center of the first counter is at [counter_size/2, counter_size/2], counters are directly next to each other (of no empty space is specified in layout). """ starting_at: float = 0.0 current_y: float = starting_at counters: list[Counter] = [] designated_player_positions: list[npt.NDArray] = [] free_positions: list[npt.NDArray] = [] if self.as_files: with open(self.layout_config, "r") as layout_file: self.layout_config = layout_file.read() lines = self.layout_config.split("\n") grid = [] lines = list(filter(lambda l: l != "", lines)) for line in lines: line = line.replace("\n", "").replace(" ", "") # remove newline char current_x: float = starting_at grid_line = [] for character in line: character = character.capitalize() pos = np.array([current_x, current_y]) assert self.counter_factory.can_map( character ), f"{character=} in layout file can not be mapped" if self.counter_factory.is_counter(character): counters.append( self.counter_factory.get_counter_object(character, pos) ) grid_line.append(1) else: grid_line.append(0) match self.counter_factory.map_not_counter(character): case "Agent": designated_player_positions.append(pos) case "Free": free_positions.append(np.array([current_x, current_y])) current_x += 1 grid.append(grid_line) current_y += 1 self.kitchen_width: float = len(lines[0]) + starting_at self.kitchen_height = len(lines) + starting_at self.determine_counter_orientations( counters, grid, np.array([self.kitchen_width / 2, self.kitchen_height / 2]) ) self.counter_factory.post_counter_setup(counters) return counters, designated_player_positions, free_positions def determine_counter_orientations(self, counters, grid, kitchen_center): grid = np.array(grid).T grid_width = grid.shape[0] grid_height = grid.shape[1] last_counter = None fst_counter_in_row = None for c in counters: grid_idx = np.floor(c.pos).astype(int) neighbour_offsets = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]], dtype=int) neighbours_free = [] for offset in neighbour_offsets: neighbour_pos = grid_idx + offset if ( neighbour_pos[0] > (grid_width - 1) or neighbour_pos[0] < 0 or neighbour_pos[1] > (grid_height - 1) or neighbour_pos[1] < 0 ): pass else: if grid[neighbour_pos[0]][neighbour_pos[1]] == 0: neighbours_free.append(offset) if len(neighbours_free) > 0: vector_to_center = c.pos - kitchen_center vector_to_center /= np.linalg.norm(vector_to_center) n_idx = np.argmin( np.linalg.norm(vector_to_center - n) for n in neighbours_free ) nearest_vec = neighbours_free[n_idx] # print(nearest_vec, type(nearest_vec)) c.set_orientation(nearest_vec) elif grid_idx[0] == 0: if grid_idx[1] == 0: # counter top left c.set_orientation(np.array([1, 0])) else: c.set_orientation(fst_counter_in_row.orientation) fst_counter_in_row = c else: c.set_orientation(last_counter.orientation) last_counter = c # for c in counters: # near_counters = [ # other # for other in counters # if np.isclose(np.linalg.norm(c.pos - other.pos), 1) # ] # # print(c.pos, len(near_counters)) def perform_action(self, action: Action): """Performs an action of a player in the environment. Maps different types of action inputs to the correct execution of the players. Possible action types are movement, pickup and interact actions. Args: action: The action to be performed """ assert action.player in self.players.keys(), "Unknown player." self.hook(PRE_PERFORM_ACTION, action=action) player = self.players[action.player] if action.action_type == ActionType.MOVEMENT: player.set_movement( action.action_data, self.env_time + timedelta(seconds=action.duration), ) else: counter = self.get_facing_counter(player) if player.can_reach(counter): if action.action_type == ActionType.PUT: player.put_action(counter) self.hook(ACTION_PUT, action=action, counter=counter) elif action.action_type == ActionType.INTERACT: if action.action_data == InterActionData.START: player.perform_interact_start(counter) self.hook(ACTION_INTERACT_START, action=action, counter=counter) else: self.hook( ACTION_ON_NOT_REACHABLE_COUNTER, action=action, counter=counter ) if action.action_data == InterActionData.STOP: player.perform_interact_stop() self.hook(POST_PERFORM_ACTION, action=action) def get_facing_counter(self, player: Player): """Determines the counter which the player is looking at. Adds a multiple of the player facing direction onto the player position and finds the closest counter for that point. Args: player: The player for which to find the facing counter. Returns: """ facing_counter = get_closest(player.facing_point, self.counters) return facing_counter def perform_movement(self, duration: timedelta): """Moves a player in the direction specified in the action.action. If the player collides with a counter or other player through this movement, then they are not moved. (The extended code with the two ifs is for sliding movement at the counters, which feels a bit smoother. This happens, when the player moves diagonally against the counters or world boundary. This just checks if the single axis party of the movement could move the player and does so at a lower rate.) The movement action is a unit 2d vector. Detects collisions with other players and pushes them out of the way. Args: duration: The duration for how long the movement to perform. """ d_time = duration.total_seconds() player_positions = np.array([p.pos for p in self.players.values()], dtype=float) player_movement_vectors = np.array( [ p.current_movement if self.env_time <= p.movement_until else [0, 0] for p in self.players.values() ], dtype=float, ) number_players = len(player_positions) targeted_positions = player_positions + ( player_movement_vectors * (self.player_movement_speed * d_time) ) # Collisions player between player distances_players_after_scipy = distance_matrix( targeted_positions, targeted_positions ) player_diff_vecs = -( player_positions[:, np.newaxis, :] - player_positions[np.newaxis, :, :] ) collision_idxs = distances_players_after_scipy < (2 * self.player_radius) eye_idxs = np.eye(number_players, number_players, dtype=bool) collision_idxs[eye_idxs] = False # Player push players around player_diff_vecs[collision_idxs == False] = 0 push_vectors = np.sum(player_diff_vecs, axis=0) updated_movement = push_vectors + player_movement_vectors new_positions = player_positions + ( updated_movement * (self.player_movement_speed * d_time) ) # Collisions players counters counter_diff_vecs = ( new_positions[:, np.newaxis, :] - self.counter_positions[np.newaxis, :, :] ) counter_distances = np.max((np.abs(counter_diff_vecs)), axis=2) # counter_distances = np.linalg.norm(counter_diff_vecs, axis=2) closest_counter_positions = self.counter_positions[ np.argmin(counter_distances, axis=1) ] nearest_counter_to_player = closest_counter_positions - new_positions collided = np.min(counter_distances, axis=1) < self.player_radius + 0.5 relevant_axes = np.abs(nearest_counter_to_player).argmax(axis=1) for idx, player in enumerate(player_positions): axis = relevant_axes[idx] if collided[idx]: # collide with counter left or top if nearest_counter_to_player[idx][axis] < 0: updated_movement[idx, axis] = max(updated_movement[idx, axis], 0) # collide with counter right or bottom if nearest_counter_to_player[idx][axis] > 0: updated_movement[idx, axis] = min(updated_movement[idx, axis], 0) new_positions = player_positions + ( updated_movement * (self.player_movement_speed * d_time) ) # Check if pushed players collide with counters or second closest is to close counter_diff_vecs = ( new_positions[:, np.newaxis, :] - self.counter_positions[np.newaxis, :, :] ) counter_distances = np.max((np.abs(counter_diff_vecs)), axis=2) collided2 = np.min(counter_distances, axis=1) < self.player_radius + 0.5 # player do not move if they collide after pushing/sliding new_positions[collided2] = player_positions[collided2] # Players that pushed the player that can not be pushed do also no movement # in the future these players could slide around the player? for idx, collides in enumerate(collided2): if collides: new_positions[collision_idxs[idx]] = player_positions[ collision_idxs[idx] ] # Check if two moving players collide into each other: No movement (Future: slide?) if PREVENT_SQUEEZING_INTO_OTHER_PLAYERS: distances_players_after_scipy = distance_matrix( new_positions, new_positions ) collision_idxs = distances_players_after_scipy < (2 * self.player_radius) collision_idxs[eye_idxs] = False collision_idxs = np.any(collision_idxs, axis=1) new_positions[collision_idxs] = player_positions[collision_idxs] # Collisions player world borders new_positions = np.clip( new_positions, self.world_borders_lower + self.player_radius, self.world_borders_upper - self.player_radius, ) for idx, p in enumerate(self.players.values()): if not (new_positions[idx] == player_positions[idx]).all(): p.turn(player_movement_vectors[idx]) p.move_abs(new_positions[idx]) def add_player(self, player_name: str, pos: npt.NDArray = None): """Add a player to the environment. Args: player_name: The id/name of the player to reference actions and in the state. pos: The optional init position of the player. """ # TODO check if the player name already exists in the environment and do not overwrite player. log.debug(f"Add player {player_name} to the game") player = Player( player_name, player_config=PlayerConfig( **( self.environment_config["player_config"] if "player_config" in self.environment_config else {} ) ), pos=pos, ) self.players[player.name] = player if player.pos is None: if len(self.designated_player_positions) > 0: free_idx = self.random.randint( 0, len(self.designated_player_positions) - 1 ) player.move_abs(self.designated_player_positions[free_idx]) del self.designated_player_positions[free_idx] elif len(self.free_positions) > 0: free_idx = self.random.randint(0, len(self.free_positions) - 1) player.move_abs(self.free_positions[free_idx]) del self.free_positions[free_idx] else: log.debug("No free positions left in kitchens") player.update_facing_point() self.set_collision_arrays() self.hook(PLAYER_ADDED, player_name=player_name, pos=pos) def detect_collision_world_bounds(self, player: Player): """Checks for detections of the player and the world bounds. Args: player: The player which to not let escape the world. Returns: True if the player touches the world bounds, False if not. """ collisions_lower = any( (player.pos - (player.radius)) < [self.world_borders_x[0], self.world_borders_y[0]] ) collisions_upper = any( (player.pos + (player.radius)) > [self.world_borders_x[1], self.world_borders_y[1]] ) return collisions_lower or collisions_upper def step(self, passed_time: timedelta): """Performs a step of the environment. Affects time based events such as cooking or cutting things, orders and time limits. """ # self.hook(PRE_STEP, passed_time=passed_time) self.env_time += passed_time if self.game_ended: self.hook(GAME_ENDED_STEP) else: for player in self.players.values(): player.progress(passed_time, self.env_time) self.perform_movement(passed_time) for counter in self.progressing_counters: counter.progress(passed_time=passed_time, now=self.env_time) self.order_and_score.progress(passed_time=passed_time, now=self.env_time) for effect_manager in self.effect_manager.values(): effect_manager.progress(passed_time=passed_time, now=self.env_time) # self.hook(POST_STEP, passed_time=passed_time) def get_state(self): """Get the current state of the game environment. The state here is accessible by the current python objects. Returns: Dict of lists of the current relevant game objects. """ return { "players": self.players, "counters": self.counters, "score": self.order_and_score.score, "orders": self.order_and_score.open_orders, "ended": self.game_ended, "env_time": self.env_time, "remaining_time": max(self.env_time_end - self.env_time, timedelta(0)), } def get_json_state(self, player_id: str = None) -> str: if player_id in self.players: self.hook(PRE_STATE, player_id=player_id) state = { "players": [p.to_dict() for p in self.players.values()], "counters": [c.to_dict() for c in self.counters], "kitchen": {"width": self.kitchen_width, "height": self.kitchen_height}, "score": self.order_and_score.score, "orders": self.order_and_score.order_state(), "ended": self.game_ended, "env_time": self.env_time.isoformat(), "remaining_time": max( self.env_time_end - self.env_time, timedelta(0) ).total_seconds(), "view_restriction": { "direction": self.players[player_id].facing_direction.tolist(), "position": self.players[player_id].pos.tolist(), "angle": self.player_view_angle, "counter_mask": None, } if self.player_view_restricted else None, "info_msg": [ (msg["msg"], msg["level"]) for msg in self.info_msgs_per_player[player_id] if msg["start_time"] < self.env_time and msg["end_time"] > self.env_time ], } self.hook(STATE_DICT, state=state, player_id=player_id) json_data = json.dumps(state) self.hook(JSON_STATE, json_data=json_data, player_id=player_id) assert StateRepresentation.model_validate_json(json_data=json_data) return json_data raise ValueError(f"No valid {player_id=}") def reset_env_time(self): """Reset the env time to the initial time, defined by `create_init_env_time`.""" self.hook(PRE_RESET_ENV_TIME) self.env_time = create_init_env_time() self.hook(POST_RESET_ENV_TIME) log.debug(f"Reset env time to {self.env_time}") def register_callback_for_hook(self, hook_ref: str | list[str], callback: Callable): self.hook.register_callback(hook_ref, callback) def extra_setup_functions(self): if self.environment_config["extra_setup_functions"]: for function_name, function_def in self.environment_config[ "extra_setup_functions" ].items(): log.info(f"Setup function {function_name}") function_def["func"]( name=function_name, env=self, **function_def["kwargs"] )