-
Fabian Heinrich authoredFabian Heinrich authored
overcooked_environment.py 31.09 KiB
from __future__ import annotations
import dataclasses
import inspect
import json
import logging
import sys
from datetime import timedelta, datetime
from enum import Enum
from pathlib import Path
from random import Random
from typing import Literal, TypedDict, Callable, Tuple
import numpy as np
import numpy.typing as npt
import yaml
from scipy.spatial import distance_matrix
from overcooked_simulator.counter_factory import CounterFactory
from overcooked_simulator.counters import (
Counter,
PlateConfig,
)
from overcooked_simulator.effect_manager import EffectManager
from overcooked_simulator.game_items import (
ItemInfo,
ItemType,
)
from overcooked_simulator.hooks import (
ITEM_INFO_LOADED,
LAYOUT_FILE_PARSED,
ENV_INITIALIZED,
PRE_PERFORM_ACTION,
POST_PERFORM_ACTION,
PLAYER_ADDED,
GAME_ENDED_STEP,
PRE_STATE,
STATE_DICT,
JSON_STATE,
PRE_RESET_ENV_TIME,
POST_RESET_ENV_TIME,
Hooks,
ACTION_ON_NOT_REACHABLE_COUNTER,
ACTION_PUT,
ACTION_INTERACT_START,
ITEM_INFO_CONFIG,
)
from overcooked_simulator.order import (
OrderAndScoreManager,
OrderConfig,
)
from overcooked_simulator.player import Player, PlayerConfig
from overcooked_simulator.state_representation import StateRepresentation
from overcooked_simulator.utils import create_init_env_time, get_closest
log = logging.getLogger(__name__)
FOG_OF_WAR = True
PREVENT_SQUEEZING_INTO_OTHER_PLAYERS = False
class ActionType(Enum):
"""The 3 different types of valid actions. They can be extended via the `Action.action_data` attribute."""
MOVEMENT = "movement"
"""move the agent."""
PUT = "pickup"
"""interaction type 1, e.g., for pickup or drop off. Maybe other words: transplace?"""
# TODO change value to put
INTERACT = "interact"
"""interaction type 2, e.g., for progressing. Start and stop interaction via `keydown` and `keyup` actions."""
class InterActionData(Enum):
"""The data for the interaction action: `ActionType.MOVEMENT`."""
START = "keydown"
"start an interaction."
STOP = "keyup"
"stop an interaction without moving away."
@dataclasses.dataclass
class Action:
"""Action class, specifies player, action type and action itself."""
player: str
"""Id of the player."""
action_type: ActionType
"""Type of the action to perform. Defines what action data is valid."""
action_data: npt.NDArray[float] | InterActionData | Literal["pickup"]
"""Data for the action, e.g., movement vector or start and stop interaction."""
duration: float | int = 0
"""Duration of the action (relevant for movement)"""
def __repr__(self):
return f"Action({self.player},{self.action_type.value},{self.action_data},{self.duration})"
def __post_init__(self):
if isinstance(self.action_type, str):
self.action_type = ActionType(self.action_type)
if isinstance(self.action_data, str) and self.action_data != "pickup":
self.action_data = InterActionData(self.action_data)
# TODO Abstract base class for different environments
class EnvironmentConfig(TypedDict):
plates: PlateConfig
game: dict[Literal["time_limit_seconds"], int]
meals: dict[Literal["all"] | Literal["list"], bool | list[str]]
orders: OrderConfig
player_config: PlayerConfig
layout_chars: dict[str, str]
extra_setup_functions: dict[str, dict]
effect_manager: dict
class Environment:
"""Environment class which handles the game logic for the overcooked-inspired environment.
Handles player movement, collision-detection, counters, cooking processes, recipes, incoming orders, time.
"""
PAUSED = None
def __init__(
self,
env_config: Path | str,
layout_config: Path | str,
item_info: Path | str,
as_files: bool = True,
env_name: str = "overcooked_sim",
seed: int = 56789223842348,
):
self.env_name = env_name
"""Reference to the run. E.g, the env id."""
self.env_time: datetime = create_init_env_time()
"""the internal time of the environment. An environment starts always with the time from
`create_init_env_time`."""
self.random: Random = Random(seed)
"""Random instance."""
self.hook: Hooks = Hooks(self)
"""Hook manager. Register callbacks and create hook points with additional kwargs."""
self.players: dict[str, Player] = {}
"""the player, keyed by their id/name."""
self.as_files = as_files
"""Are the configs just the path to the files."""
if self.as_files:
with open(env_config, "r") as file:
env_config = file.read()
self.environment_config: EnvironmentConfig = yaml.load(
env_config, Loader=yaml.Loader
)
"""The config of the environment. All environment specific attributes is configured here."""
self.player_view_restricted = self.environment_config["player_config"][
"restricted_view"
]
if self.player_view_restricted:
self.player_view_angle = self.environment_config["player_config"][
"view_angle"
]
self.player_view_range = self.environment_config["player_config"][
"view_range"
]
self.extra_setup_functions()
self.layout_config = layout_config
"""The layout config for the environment"""
# self.counter_side_length = 1 # -> this changed! is 1 now
self.item_info: dict[str, ItemInfo] = self.load_item_info(item_info)
"""The loaded item info dict. Keys are the item names."""
self.hook(ITEM_INFO_LOADED, item_info=item_info, as_files=as_files)
# self.validate_item_info()
if self.environment_config["meals"]["all"]:
self.allowed_meal_names = set(
[
item
for item, info in self.item_info.items()
if info.type == ItemType.Meal
]
)
else:
self.allowed_meal_names = set(self.environment_config["meals"]["list"])
"""The allowed meals depend on the `environment_config.yml` configured behaviour. Either all meals that
are possible or only a limited subset."""
self.order_and_score = OrderAndScoreManager(
order_config=self.environment_config["orders"],
available_meals={
item: info
for item, info in self.item_info.items()
if info.type == ItemType.Meal and item in self.allowed_meal_names
},
hook=self.hook,
random=self.random,
)
"""The manager for the orders and score update."""
self.kitchen_height: int = 0
"""The height of the kitchen, is set by the `Environment.parse_layout_file` method"""
self.kitchen_width: int = 0
"""The width of the kitchen, is set by the `Environment.parse_layout_file` method"""
self.counter_factory = CounterFactory(
layout_chars_config=self.environment_config["layout_chars"],
item_info=self.item_info,
serving_window_additional_kwargs={
"meals": self.allowed_meal_names,
"env_time_func": self.get_env_time,
},
plate_config=PlateConfig(
**(
self.environment_config["plates"]
if "plates" in self.environment_config
else {}
)
),
order_and_score=self.order_and_score,
effect_manager_config=self.environment_config["effect_manager"],
hook=self.hook,
random=self.random,
)
(
self.counters,
self.designated_player_positions,
self.free_positions,
) = self.parse_layout_file()
self.hook(LAYOUT_FILE_PARSED)
self.counter_positions = np.array([c.pos for c in self.counters])
self.world_borders = np.array(
[[-0.5, self.kitchen_width - 0.5], [-0.5, self.kitchen_height - 0.5]],
dtype=float,
)
self.player_movement_speed = self.environment_config["player_config"][
"player_speed_units_per_seconds"
]
self.player_radius = self.environment_config["player_config"]["radius"]
progress_counter_classes = list(
filter(
lambda cl: hasattr(cl, "progress"),
dict(
inspect.getmembers(
sys.modules["overcooked_simulator.counters"], inspect.isclass
)
).values(),
)
)
self.progressing_counters = list(
filter(
lambda c: c.__class__ in progress_counter_classes,
self.counters,
)
)
"""Counters that needs to be called in the step function via the `progress` method."""
self.order_and_score.create_init_orders(self.env_time)
self.start_time = self.env_time
"""The relative env time when it started."""
self.env_time_end = self.env_time + timedelta(
seconds=self.environment_config["game"]["time_limit_seconds"]
)
"""The relative env time when it will stop/end"""
log.debug(f"End time: {self.env_time_end}")
self.effect_manager: dict[
str, EffectManager
] = self.counter_factory.setup_effect_manger(self.counters)
self.hook(
ENV_INITIALIZED,
environment_config=env_config,
layout_config=self.layout_config,
seed=seed,
env_start_time_worldtime=datetime.now(),
)
@property
def game_ended(self) -> bool:
"""Whether the game is over or not based on the calculated `Environment.env_time_end`"""
return self.env_time >= self.env_time_end
def set_collision_arrays(self):
number_players = len(self.players)
self.world_borders_lower = self.world_borders[np.newaxis, :, 0].repeat(
number_players, axis=0
)
self.world_borders_upper = self.world_borders[np.newaxis, :, 1].repeat(
number_players, axis=0
)
def get_env_time(self):
"""the internal time of the environment. An environment starts always with the time from `create_init_env_time`.
Utility method to pass a reference to the serving window."""
return self.env_time
def load_item_info(self, data) -> dict[str, ItemInfo]:
"""Load `item_info.yml`, create ItemInfo classes and replace equipment strings with item infos."""
if self.as_files:
with open(data, "r") as file:
data = file.read()
self.hook(ITEM_INFO_CONFIG, item_info_config=data)
item_lookup = yaml.safe_load(data)
for item_name in item_lookup:
item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name])
for item_name, item_info in item_lookup.items():
if item_info.equipment:
item_info.equipment = item_lookup[item_info.equipment]
return item_lookup
def validate_item_info(self):
"""TODO"""
raise NotImplementedError
# infos = {t: [] for t in ItemType}
# graph = nx.DiGraph()
# for info in self.item_info.values():
# infos[info.type].append(info)
# graph.add_node(info.name)
# match info.type:
# case ItemType.Ingredient:
# if info.is_cuttable:
# graph.add_edge(
# info.name, info.finished_progress_name[:-1] + info.name
# )
# case ItemType.Equipment:
# ...
# case ItemType.Meal:
# if info.equipment is not None:
# graph.add_edge(info.equipment.name, info.name)
# for ingredient in info.needs:
# graph.add_edge(ingredient, info.name)
# graph = nx.DiGraph()
# for item_name, item_info in self.item_info.items():
# graph.add_node(item_name, type=item_info.type.name)
# if len(item_info.equipment) == 0:
# for item in item_info.needs:
# graph.add_edge(item, item_name)
# else:
# for item in item_info.needs:
# for equipment in item_info.equipment:
# graph.add_edge(item, equipment)
# graph.add_edge(equipment, item_name)
# plt.figure(figsize=(10, 10))
# pos = nx.nx_agraph.graphviz_layout(graph, prog="twopi", args="")
# nx.draw(graph, pos=pos, with_labels=True, node_color="white", node_size=500)
# print(nx.multipartite_layout(graph, subset_key="type", align="vertical"))
# pos = {
# node: (
# len(nx.ancestors(graph, node)) - len(nx.descendants(graph, node)),
# y,
# )
# for y, node in enumerate(graph)
# }
# nx.draw(
# graph,
# pos=pos,
# with_labels=True,
# node_shape="s",
# node_size=500,
# node_color="white",
# )
# TODO add colors for ingredients, equipment and meals
# plt.show()
def parse_layout_file(
self,
) -> Tuple[list[Counter], list[npt.NDArray], list[npt.NDArray]]:
"""Creates layout of kitchen counters in the environment based on layout file.
Counters are arranged in a fixed size grid starting at [0,0]. The center of the first counter is at
[counter_size/2, counter_size/2], counters are directly next to each other (of no empty space is specified
in layout).
"""
starting_at: float = 0.0
current_y: float = starting_at
counters: list[Counter] = []
designated_player_positions: list[npt.NDArray] = []
free_positions: list[npt.NDArray] = []
if self.as_files:
with open(self.layout_config, "r") as layout_file:
self.layout_config = layout_file.read()
lines = self.layout_config.split("\n")
grid = []
lines = list(filter(lambda l: l != "", lines))
for line in lines:
line = line.replace("\n", "").replace(" ", "") # remove newline char
current_x: float = starting_at
grid_line = []
for character in line:
character = character.capitalize()
pos = np.array([current_x, current_y])
assert self.counter_factory.can_map(
character
), f"{character=} in layout file can not be mapped"
if self.counter_factory.is_counter(character):
counters.append(
self.counter_factory.get_counter_object(character, pos)
)
grid_line.append(1)
else:
grid_line.append(0)
match self.counter_factory.map_not_counter(character):
case "Agent":
designated_player_positions.append(pos)
case "Free":
free_positions.append(np.array([current_x, current_y]))
current_x += 1
grid.append(grid_line)
current_y += 1
self.kitchen_width: float = len(lines[0]) + starting_at
self.kitchen_height = len(lines) + starting_at
self.determine_counter_orientations(
counters, grid, np.array([self.kitchen_width / 2, self.kitchen_height / 2])
)
self.counter_factory.post_counter_setup(counters)
return counters, designated_player_positions, free_positions
def determine_counter_orientations(self, counters, grid, kitchen_center):
grid = np.array(grid).T
grid_width = grid.shape[0]
grid_height = grid.shape[1]
last_counter = None
fst_counter_in_row = None
for c in counters:
grid_idx = np.floor(c.pos).astype(int)
neighbour_offsets = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]], dtype=int)
neighbours_free = []
for offset in neighbour_offsets:
neighbour_pos = grid_idx + offset
if (
neighbour_pos[0] > (grid_width - 1)
or neighbour_pos[0] < 0
or neighbour_pos[1] > (grid_height - 1)
or neighbour_pos[1] < 0
):
pass
else:
if grid[neighbour_pos[0]][neighbour_pos[1]] == 0:
neighbours_free.append(offset)
if len(neighbours_free) > 0:
vector_to_center = c.pos - kitchen_center
vector_to_center /= np.linalg.norm(vector_to_center)
n_idx = np.argmin(
np.linalg.norm(vector_to_center - n) for n in neighbours_free
)
nearest_vec = neighbours_free[n_idx]
# print(nearest_vec, type(nearest_vec))
c.set_orientation(nearest_vec)
elif grid_idx[0] == 0:
if grid_idx[1] == 0:
# counter top left
c.set_orientation(np.array([1, 0]))
else:
c.set_orientation(fst_counter_in_row.orientation)
fst_counter_in_row = c
else:
c.set_orientation(last_counter.orientation)
last_counter = c
# for c in counters:
# near_counters = [
# other
# for other in counters
# if np.isclose(np.linalg.norm(c.pos - other.pos), 1)
# ]
# # print(c.pos, len(near_counters))
def perform_action(self, action: Action):
"""Performs an action of a player in the environment. Maps different types of action inputs to the
correct execution of the players.
Possible action types are movement, pickup and interact actions.
Args:
action: The action to be performed
"""
assert action.player in self.players.keys(), "Unknown player."
self.hook(PRE_PERFORM_ACTION, action=action)
player = self.players[action.player]
if action.action_type == ActionType.MOVEMENT:
player.set_movement(
action.action_data,
self.env_time + timedelta(seconds=action.duration),
)
else:
counter = self.get_facing_counter(player)
if player.can_reach(counter):
if action.action_type == ActionType.PUT:
player.put_action(counter)
self.hook(ACTION_PUT, action=action, counter=counter)
elif action.action_type == ActionType.INTERACT:
if action.action_data == InterActionData.START:
player.perform_interact_start(counter)
self.hook(ACTION_INTERACT_START, action=action, counter=counter)
else:
self.hook(
ACTION_ON_NOT_REACHABLE_COUNTER, action=action, counter=counter
)
if action.action_data == InterActionData.STOP:
player.perform_interact_stop()
self.hook(POST_PERFORM_ACTION, action=action)
def get_facing_counter(self, player: Player):
"""Determines the counter which the player is looking at.
Adds a multiple of the player facing direction onto the player position and finds the closest
counter for that point.
Args:
player: The player for which to find the facing counter.
Returns:
"""
facing_counter = get_closest(player.facing_point, self.counters)
return facing_counter
def perform_movement(self, duration: timedelta):
"""Moves a player in the direction specified in the action.action. If the player collides with a
counter or other player through this movement, then they are not moved.
(The extended code with the two ifs is for sliding movement at the counters, which feels a bit smoother.
This happens, when the player moves diagonally against the counters or world boundary.
This just checks if the single axis party of the movement could move the player and does so at a lower rate.)
The movement action is a unit 2d vector.
Detects collisions with other players and pushes them out of the way.
Args:
duration: The duration for how long the movement to perform.
"""
d_time = duration.total_seconds()
player_positions = np.array([p.pos for p in self.players.values()], dtype=float)
player_movement_vectors = np.array(
[
p.current_movement if self.env_time <= p.movement_until else [0, 0]
for p in self.players.values()
],
dtype=float,
)
number_players = len(player_positions)
targeted_positions = player_positions + (
player_movement_vectors * (self.player_movement_speed * d_time)
)
# Collisions player between player
distances_players_after_scipy = distance_matrix(
targeted_positions, targeted_positions
)
player_diff_vecs = -(
player_positions[:, np.newaxis, :] - player_positions[np.newaxis, :, :]
)
collision_idxs = distances_players_after_scipy < (2 * self.player_radius)
eye_idxs = np.eye(number_players, number_players, dtype=bool)
collision_idxs[eye_idxs] = False
# Player push players around
player_diff_vecs[collision_idxs == False] = 0
push_vectors = np.sum(player_diff_vecs, axis=0)
updated_movement = push_vectors + player_movement_vectors
new_positions = player_positions + (
updated_movement * (self.player_movement_speed * d_time)
)
# Collisions players counters
counter_diff_vecs = (
new_positions[:, np.newaxis, :] - self.counter_positions[np.newaxis, :, :]
)
counter_distances = np.max((np.abs(counter_diff_vecs)), axis=2)
# counter_distances = np.linalg.norm(counter_diff_vecs, axis=2)
closest_counter_positions = self.counter_positions[
np.argmin(counter_distances, axis=1)
]
nearest_counter_to_player = closest_counter_positions - new_positions
collided = np.min(counter_distances, axis=1) < self.player_radius + 0.5
relevant_axes = np.abs(nearest_counter_to_player).argmax(axis=1)
for idx, player in enumerate(player_positions):
axis = relevant_axes[idx]
if collided[idx]:
# collide with counter left or top
if nearest_counter_to_player[idx][axis] < 0:
updated_movement[idx, axis] = max(updated_movement[idx, axis], 0)
# collide with counter right or bottom
if nearest_counter_to_player[idx][axis] > 0:
updated_movement[idx, axis] = min(updated_movement[idx, axis], 0)
new_positions = player_positions + (
updated_movement * (self.player_movement_speed * d_time)
)
# Check if pushed players collide with counters or second closest is to close
counter_diff_vecs = (
new_positions[:, np.newaxis, :] - self.counter_positions[np.newaxis, :, :]
)
counter_distances = np.max((np.abs(counter_diff_vecs)), axis=2)
collided2 = np.min(counter_distances, axis=1) < self.player_radius + 0.5
# player do not move if they collide after pushing/sliding
new_positions[collided2] = player_positions[collided2]
# Players that pushed the player that can not be pushed do also no movement
# in the future these players could slide around the player?
for idx, collides in enumerate(collided2):
if collides:
new_positions[collision_idxs[idx]] = player_positions[
collision_idxs[idx]
]
# Check if two moving players collide into each other: No movement (Future: slide?)
if PREVENT_SQUEEZING_INTO_OTHER_PLAYERS:
distances_players_after_scipy = distance_matrix(
new_positions, new_positions
)
collision_idxs = distances_players_after_scipy < (2 * self.player_radius)
collision_idxs[eye_idxs] = False
collision_idxs = np.any(collision_idxs, axis=1)
new_positions[collision_idxs] = player_positions[collision_idxs]
# Collisions player world borders
new_positions = np.clip(
new_positions,
self.world_borders_lower + self.player_radius,
self.world_borders_upper - self.player_radius,
)
for idx, p in enumerate(self.players.values()):
if not (new_positions[idx] == player_positions[idx]).all():
p.move_abs(new_positions[idx])
p.turn(player_movement_vectors[idx])
def add_player(self, player_name: str, pos: npt.NDArray = None):
"""Add a player to the environment.
Args:
player_name: The id/name of the player to reference actions and in the state.
pos: The optional init position of the player.
"""
# TODO check if the player name already exists in the environment and do not overwrite player.
log.debug(f"Add player {player_name} to the game")
player = Player(
player_name,
player_config=PlayerConfig(
**(
self.environment_config["player_config"]
if "player_config" in self.environment_config
else {}
)
),
pos=pos,
)
self.players[player.name] = player
if player.pos is None:
if len(self.designated_player_positions) > 0:
free_idx = self.random.randint(
0, len(self.designated_player_positions) - 1
)
player.move_abs(self.designated_player_positions[free_idx])
del self.designated_player_positions[free_idx]
elif len(self.free_positions) > 0:
free_idx = self.random.randint(0, len(self.free_positions) - 1)
player.move_abs(self.free_positions[free_idx])
del self.free_positions[free_idx]
else:
log.debug("No free positions left in kitchens")
player.update_facing_point()
self.set_collision_arrays()
self.hook(PLAYER_ADDED, player_name=player_name, pos=pos)
def detect_collision_world_bounds(self, player: Player):
"""Checks for detections of the player and the world bounds.
Args:
player: The player which to not let escape the world.
Returns: True if the player touches the world bounds, False if not.
"""
collisions_lower = any(
(player.pos - (player.radius))
< [self.world_borders_x[0], self.world_borders_y[0]]
)
collisions_upper = any(
(player.pos + (player.radius))
> [self.world_borders_x[1], self.world_borders_y[1]]
)
return collisions_lower or collisions_upper
def step(self, passed_time: timedelta):
"""Performs a step of the environment. Affects time based events such as cooking or cutting things, orders
and time limits.
"""
# self.hook(PRE_STEP, passed_time=passed_time)
self.env_time += passed_time
if self.game_ended:
self.hook(GAME_ENDED_STEP)
else:
for player in self.players.values():
player.progress(passed_time, self.env_time)
self.perform_movement(passed_time)
for counter in self.progressing_counters:
counter.progress(passed_time=passed_time, now=self.env_time)
self.order_and_score.progress(passed_time=passed_time, now=self.env_time)
for effect_manager in self.effect_manager.values():
effect_manager.progress(passed_time=passed_time, now=self.env_time)
# self.hook(POST_STEP, passed_time=passed_time)
def get_state(self):
"""Get the current state of the game environment. The state here is accessible by the current python objects.
Returns: Dict of lists of the current relevant game objects.
"""
return {
"players": self.players,
"counters": self.counters,
"score": self.order_and_score.score,
"orders": self.order_and_score.open_orders,
"ended": self.game_ended,
"env_time": self.env_time,
"remaining_time": max(self.env_time_end - self.env_time, timedelta(0)),
}
def get_json_state(self, player_id: str = None) -> str:
if player_id in self.players:
self.hook(PRE_STATE, player_id=player_id)
state = {
"players": [p.to_dict() for p in self.players.values()],
"counters": [c.to_dict() for c in self.counters],
"kitchen": {"width": self.kitchen_width, "height": self.kitchen_height},
"score": self.order_and_score.score,
"orders": self.order_and_score.order_state(),
"ended": self.game_ended,
"env_time": self.env_time.isoformat(),
"remaining_time": max(
self.env_time_end - self.env_time, timedelta(0)
).total_seconds(),
"view_restriction": {
"direction": self.players[player_id].facing_direction.tolist(),
"position": self.players[player_id].pos.tolist(),
"angle": self.player_view_angle,
"counter_mask": None,
"range": self.player_view_range,
}
if self.player_view_restricted
else None,
}
self.hook(STATE_DICT, state=state, player_id=player_id)
json_data = json.dumps(state)
self.hook(JSON_STATE, json_data=json_data, player_id=player_id)
assert StateRepresentation.model_validate_json(json_data=json_data)
return json_data
raise ValueError(f"No valid {player_id=}")
def reset_env_time(self):
"""Reset the env time to the initial time, defined by `create_init_env_time`."""
self.hook(PRE_RESET_ENV_TIME)
self.env_time = create_init_env_time()
self.hook(POST_RESET_ENV_TIME)
log.debug(f"Reset env time to {self.env_time}")
def register_callback_for_hook(self, hook_ref: str | list[str], callback: Callable):
self.hook.register_callback(hook_ref, callback)
def extra_setup_functions(self):
if self.environment_config["extra_setup_functions"]:
for function_name, function_def in self.environment_config[
"extra_setup_functions"
].items():
log.info(f"Setup function {function_name}")
function_def["func"](
name=function_name, env=self, **function_def["kwargs"]
)