-
Fabian Heinrich authoredFabian Heinrich authored
overcooked_environment.py 25.07 KiB
from __future__ import annotations
import datetime
import json
import logging
import random
from datetime import timedelta
from pathlib import Path
from threading import Lock
import numpy as np
import numpy.typing as npt
import yaml
from scipy.spatial import distance_matrix
from overcooked_simulator.counters import (
Counter,
CuttingBoard,
Trash,
Dispenser,
ServingWindow,
Stove,
Sink,
PlateDispenser,
SinkAddon,
)
from overcooked_simulator.game_items import (
ItemInfo,
ItemType,
CookingEquipment,
)
from overcooked_simulator.order import OrderAndScoreManager
from overcooked_simulator.player import Player
from overcooked_simulator.utils import create_init_env_time
log = logging.getLogger(__name__)
class Action:
"""Action class, specifies player, action type and action itself."""
def __init__(self, player, act_type, action):
self.player = player
self.act_type = act_type
assert self.act_type in [
"movement",
"pickup",
"interact",
], "Unknown action type"
self.action = action
def __repr__(self):
return f"Action({self.player},{self.act_type},{self.action})"
class Environment:
"""Environment class which handles the game logic for the overcooked-inspired environment.
Handles player movement, collision-detection, counters, cooking processes, recipes, incoming orders, time.
# TODO Abstract base class for different environments
"""
def __init__(self, env_config_path: Path, layout_path, item_info_path: Path):
self.lock = Lock()
self.players: dict[str, Player] = {}
with open(env_config_path, "r") as file:
self.environment_config = yaml.load(file, Loader=yaml.Loader)
self.layout_path: Path = layout_path
# self.counter_side_length = 1 # -> this changed! is 1 now
self.item_info_path: Path = item_info_path
self.item_info = self.load_item_info()
self.validate_item_info()
if self.environment_config["meals"]["all"]:
self.allowed_meal_names = set(
[
item
for item, info in self.item_info.items()
if info.type == ItemType.Meal
]
)
else:
self.allowed_meal_names = set(self.environment_config["meals"]["list"])
self.order_and_score = OrderAndScoreManager(
order_config=self.environment_config["orders"],
available_meals={
item: info
for item, info in self.item_info.items()
if info.type == ItemType.Meal and item in self.allowed_meal_names
},
)
plate_transitions = {
item: {
"seconds": info.seconds,
"needs": info.needs,
"info": info,
}
for item, info in self.item_info.items()
if info.type == ItemType.Meal
}
self.SYMBOL_TO_CHARACTER_MAP = {
"#": Counter,
"C": lambda pos: CuttingBoard(
pos,
{
info.needs[0]: {"seconds": info.seconds, "result": item}
for item, info in self.item_info.items()
if info.equipment is not None
and info.equipment.name == "CuttingBoard"
},
),
"X": Trash,
"W": lambda pos: ServingWindow(
pos,
self.order_and_score,
meals=self.allowed_meal_names,
env_time_func=self.get_env_time,
),
"T": lambda pos: Dispenser(pos, self.item_info["Tomato"]),
"L": lambda pos: Dispenser(pos, self.item_info["Lettuce"]),
"P": lambda pos: PlateDispenser(
plate_transitions=plate_transitions,
pos=pos,
dispensing=self.item_info["Plate"],
plate_config=self.environment_config["plates"]
if "plates" in self.environment_config
else {},
),
"N": lambda pos: Dispenser(pos, self.item_info["Onion"]), # N for oNioN
"_": "Free",
"A": "Agent",
"U": lambda pos: Stove(
pos,
occupied_by=CookingEquipment(
name="Pot",
item_info=self.item_info["Pot"],
transitions={
item: {
"seconds": info.seconds,
"needs": info.needs,
"info": info,
}
for item, info in self.item_info.items()
if info.equipment is not None and info.equipment.name == "Pot"
},
),
), # Stove with pot: U because it looks like a pot
"Q": lambda pos: Stove(
pos,
occupied_by=CookingEquipment(
name="Pan",
item_info=self.item_info["Pan"],
transitions={
item: {
"seconds": info.seconds,
"needs": info.needs,
"info": info,
}
for item, info in self.item_info.items()
if info.equipment is not None and info.equipment.name == "Pan"
},
),
), # Stove with pan: Q because it looks like a pan
"B": lambda pos: Dispenser(pos, self.item_info["Bun"]),
"M": lambda pos: Dispenser(pos, self.item_info["Meat"]),
"S": lambda pos: Sink(
pos,
transitions={
info.needs[0]: {"seconds": info.seconds, "result": item}
for item, info in self.item_info.items()
if info.equipment is not None and info.equipment.name == "Sink"
},
),
"+": SinkAddon,
}
self.kitchen_height: int = 0
self.kitchen_width: int = 0
(
self.counters,
self.designated_player_positions,
self.free_positions,
) = self.parse_layout_file(self.layout_path)
self.init_counters()
self.env_time: datetime.datetime = create_init_env_time()
self.order_and_score.create_init_orders(self.env_time)
self.beginning_time = self.env_time
self.env_time_end = self.env_time + timedelta(
seconds=self.environment_config["game"]["time_limit_seconds"]
)
log.debug(f"End time: {self.env_time_end}")
def get_env_time(self):
return self.env_time
@property
def game_ended(self) -> bool:
return self.env_time >= self.env_time_end
def load_item_info(self) -> dict[str, ItemInfo]:
with open(self.item_info_path, "r") as file:
item_lookup = yaml.safe_load(file)
for item_name in item_lookup:
item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name])
for item_name, item_info in item_lookup.items():
if item_info.equipment:
item_info.equipment = item_lookup[item_info.equipment]
item_info.equipment.add_start_meal_to_equipment(item_info)
for item_name, item_info in item_lookup.items():
if item_info.type == ItemType.Equipment:
# first select meals with smaller needs / ingredients
item_info.sort_start_meals()
return item_lookup
def validate_item_info(self):
pass
# infos = {t: [] for t in ItemType}
# graph = nx.DiGraph()
# for info in self.item_info.values():
# infos[info.type].append(info)
# graph.add_node(info.name)
# match info.type:
# case ItemType.Ingredient:
# if info.is_cuttable:
# graph.add_edge(
# info.name, info.finished_progress_name[:-1] + info.name
# )
# case ItemType.Equipment:
# ...
# case ItemType.Meal:
# if info.equipment is not None:
# graph.add_edge(info.equipment.name, info.name)
# for ingredient in info.needs:
# graph.add_edge(ingredient, info.name)
# graph = nx.DiGraph()
# for item_name, item_info in self.item_info.items():
# graph.add_node(item_name, type=item_info.type.name)
# if len(item_info.equipment) == 0:
# for item in item_info.needs:
# graph.add_edge(item, item_name)
# else:
# for item in item_info.needs:
# for equipment in item_info.equipment:
# graph.add_edge(item, equipment)
# graph.add_edge(equipment, item_name)
# plt.figure(figsize=(10, 10))
# pos = nx.nx_agraph.graphviz_layout(graph, prog="twopi", args="")
# nx.draw(graph, pos=pos, with_labels=True, node_color="white", node_size=500)
# print(nx.multipartite_layout(graph, subset_key="type", align="vertical"))
# pos = {
# node: (
# len(nx.ancestors(graph, node)) - len(nx.descendants(graph, node)),
# y,
# )
# for y, node in enumerate(graph)
# }
# nx.draw(
# graph,
# pos=pos,
# with_labels=True,
# node_shape="s",
# node_size=500,
# node_color="white",
# )
# TODO add colors for ingredients, equipment and meals
# plt.show()
def parse_layout_file(self, layout_file: Path):
"""Creates layout of kitchen counters in the environment based on layout file.
Counters are arranged in a fixed size grid starting at [0,0]. The center of the first counter is at
[counter_size/2, counter_size/2], counters are directly next to each other (of no empty space is specified
in layout).
Args:
layout_file: Path to the layout file.
"""
current_y: float = 0.5
counters: list[Counter] = []
designated_player_positions: list[npt.NDArray] = []
free_positions: list[npt.NDArray] = []
self.kitchen_width = 0
with open(layout_file, "r") as layout_file:
lines = layout_file.readlines()
self.kitchen_height = len(lines)
for line in lines:
line = line.replace("\n", "").replace(" ", "") # remove newline char
current_x = 0.5
for character in line:
character = character.capitalize()
pos = np.array([current_x, current_y])
print(pos)
counter_class = self.SYMBOL_TO_CHARACTER_MAP[character]
if not isinstance(counter_class, str):
counter = counter_class(pos)
counters.append(counter)
else:
if counter_class == "Agent":
designated_player_positions.append(
np.array([current_x, current_y])
)
elif counter_class == "Free":
free_positions.append(np.array([current_x, current_y]))
current_x += 1
if current_x > self.kitchen_width:
self.kitchen_width = current_x
current_y += 1
self.kitchen_width -= 0.5
return counters, designated_player_positions, free_positions
def perform_action(self, action: Action):
"""Performs an action of a player in the environment. Maps different types of action inputs to the
correct execution of the players.
Possible action types are movement, pickup and interact actions.
Args:
action: The action to be performed
"""
assert action.player in self.players.keys(), "Unknown player."
player = self.players[action.player]
if action.act_type == "movement":
with self.lock:
self.perform_movement(player, action.action)
else:
counter = self.get_facing_counter(player)
if player.can_reach(counter):
if action.act_type == "pickup":
with self.lock:
player.pick_action(counter)
elif action.act_type == "interact":
if action.action == "keydown":
player.perform_interact_hold_start(counter)
player.last_interacted_counter = counter
if action.action == "keyup":
if player.last_interacted_counter:
with self.lock:
player.perform_interact_hold_stop(
player.last_interacted_counter
)
def get_closest_counter(self, point: np.ndarray):
"""Determines the closest counter for a given 2d-coordinate point in the env.
Args:
point: The point in the env for which to find the closest counter
Returns: The closest counter for the given point.
"""
counter_distances = distance_matrix(
[point], [counter.pos for counter in self.counters]
)[0]
closest_counter_idx = np.argmin(counter_distances)
return self.counters[closest_counter_idx]
def get_facing_counter(self, player: Player):
"""Determines the counter which the player is looking at.
Adds a multiple of the player facing direction onto the player position and finds the closest
counter for that point.
Args:
player: The player for which to find the facing counter.
Returns:
"""
facing_counter = self.get_closest_counter(player.facing_point)
return facing_counter
def perform_movement(self, player: Player, move_vector: np.array):
"""Moves a player in the direction specified in the action.action. If the player collides with a
counter or other player through this movement, then they are not moved.
(The extended code with the two ifs is for sliding movement at the counters, which feels a bit smoother.
This happens, when the player moves diagonally against the counters or world boundary.
This just checks if the single axis party of the movement could move the player and does so at a lower rate.)
The movement action is a unit 2d vector.
Detects collisions with other players and pushes them out of the way.
Args:
player: The player to move.
move_vector: The movement vector which is a unit-2d-vector of the movement direction
"""
old_pos = player.pos.copy()
step = move_vector * player.move_dist
player.move(step)
if self.detect_collision(player):
collided_players = self.get_collided_players(player)
for collided_player in collided_players:
pushing_vector = collided_player.pos - player.pos
if np.linalg.norm(pushing_vector) != 0:
pushing_vector = pushing_vector / np.linalg.norm(pushing_vector)
old_pos_other = collided_player.pos.copy()
self.perform_movement(collided_player, pushing_vector)
if self.detect_collision_counters(
collided_player
) or self.detect_collision_world_bounds(collided_player):
collided_player.move_abs(old_pos_other)
player.move_abs(old_pos)
old_pos = player.pos.copy()
step_sliding = step.copy()
step_sliding[0] = 0
player.move(step_sliding * 0.5)
player.turn(step)
if self.detect_collision(player):
player.move_abs(old_pos)
old_pos = player.pos.copy()
step_sliding = step.copy()
step_sliding[1] = 0
player.move(step_sliding * 0.5)
player.turn(step)
if self.detect_collision(player):
player.move_abs(old_pos)
if self.counters:
closest_counter = self.get_facing_counter(player)
player.current_nearest_counter = (
closest_counter if player.can_reach(closest_counter) else None
)
def detect_collision(self, player: Player):
"""Detect collisions between the player and other players or counters.
Args:
player: The player for which to check collisions.
Returns: True if the player is intersecting with any object in the environment.
"""
return (
len(self.get_collided_players(player)) != 0
or self.detect_collision_counters(player)
or self.detect_collision_world_bounds(player)
)
def get_collided_players(self, player: Player) -> list[Player]:
"""Detects collisions between the queried player and other players. Returns the list of the collided players.
A player is modelled as a circle. Collision is detected if the distance between the players is smaller
than the sum of the radius's.
Args:
player: The player to check collisions with other players for.
Returns: The list of other players the player collides with.
"""
other_players = filter(lambda p: p.name != player.name, self.players.values())
def collide(p):
return np.linalg.norm(player.pos - p.pos) <= (player.radius) + (p.radius)
return list(filter(collide, other_players))
def detect_player_collision(self, player: Player):
"""Detects collisions between the queried player and other players.
A player is modelled as a circle. Collision is detected if the distance between the players is smaller
than the sum of the radius's.
Args:
player: The player to check collisions with other players for.
Returns: True if the player collides with other players, False if not.
"""
other_players = filter(lambda p: p.name != player.name, self.players.values())
def collide(p):
return np.linalg.norm(player.pos - p.pos) <= ((player.radius) + (p.radius))
return any(map(collide, other_players))
def detect_collision_counters(self, player: Player):
"""Checks for collisions of the queried player with each counter.
Args:
player: The player to check collisions with counters for.
Returns: True if the player collides with any counter, False if not.
"""
return any(
map(
lambda counter: self.detect_collision_player_counter(player, counter),
self.counters,
)
)
def detect_collision_player_counter(self, player: Player, counter: Counter):
"""Checks if the player and counter collide (overlap).
A counter is modelled as a rectangle (square actually), a player is modelled as a circle.
The distance of the player position (circle center) and the counter rectangle is calculated, if it is
smaller than the player radius, a collision is detected.
TODO: Efficiency improvement by checking only nearest counters? Quadtree...?
Args:
player: The player to check the collision for.
counter: The counter to check the collision for.
Returns: True if player and counter overlap, False if not.
"""
cx, cy = player.pos
dx = max(np.abs(cx - counter.pos[0]) - 1 / 2, 0)
dy = max(np.abs(cy - counter.pos[1]) - 1 / 2, 0)
distance = np.linalg.norm([dx, dy])
return distance < (player.radius)
def add_player(self, player_name: str, pos: npt.NDArray = None):
log.debug(f"Add player {player_name} to the game")
player = Player(
player_name, player_config=self.environment_config["player_config"], pos=pos
)
self.players[player.name] = player
if player.pos is None:
if len(self.designated_player_positions) > 0:
free_idx = random.randint(0, len(self.designated_player_positions) - 1)
player.move_abs(self.designated_player_positions[free_idx])
del self.designated_player_positions[free_idx]
elif len(self.free_positions) > 0:
free_idx = random.randint(0, len(self.free_positions) - 1)
player.move_abs(self.free_positions[free_idx])
del self.free_positions[free_idx]
else:
log.debug("No free positions left in kitchens")
player.update_facing_point()
def detect_collision_world_bounds(self, player: Player):
"""Checks for detections of the player and the world bounds.
Args:
player: The player which to not let escape the world.
Returns: True if the player touches the world bounds, False if not.
"""
collisions_lower = any((player.pos - (player.radius)) < 0)
collisions_upper = any(
(player.pos + (player.radius)) > [self.kitchen_width, self.kitchen_height]
)
return collisions_lower or collisions_upper
def step(self, passed_time: timedelta):
"""Performs a step of the environment. Affects time based events such as cooking or cutting things, orders
and time limits.
"""
self.env_time += passed_time
with self.lock:
for counter in self.counters:
if isinstance(counter, (CuttingBoard, Stove, Sink, PlateDispenser)):
counter.progress(passed_time=passed_time, now=self.env_time)
self.order_and_score.progress(passed_time=passed_time, now=self.env_time)
def get_state(self):
"""Get the current state of the game environment. The state here is accessible by the current python objects.
Returns: Dict of lists of the current relevant game objects.
"""
return {
"players": self.players,
"counters": self.counters,
"score": self.order_and_score.score,
"orders": self.order_and_score.open_orders,
"ended": self.game_ended,
"env_time": self.env_time,
"remaining_time": max(self.env_time_end - self.env_time, timedelta(0)),
}
def get_state_simple_json(self):
"""Get the current state of the game environment as a json-like nested dictionary.
Returns: Json-like string of the current game state.
"""
players = [
{
"pos": [float(p.pos[0]), float(p.pos[1])],
"facing": [float(p.facing_direction[0]), float(p.facing_direction[1])],
"holding": p.holding,
}
for p in self.players.values()
]
counters = []
for counter in self.counters:
if isinstance(counter, Dispenser):
counter_type = f"{counter.dispensing.name}{counter.__class__.__name__}"
else:
counter_type = counter.__class__.__name__
counter_dict = {
"pos": [float(counter.pos[0]), float(counter.pos[1])],
"type": counter_type,
}
counters.append(counter_dict)
gamestate_dict = {
"players": players,
"counters": counters,
"score": self.order_and_score.score,
# "orders": self.order_and_score.open_orders,
"ended": self.game_ended,
"env_time": self.env_time.second,
"remaining_time": max(
self.env_time_end - self.env_time, timedelta(0)
).seconds,
}
answer = json.dumps(gamestate_dict)
return answer
def init_counters(self):
plate_dispenser = self.get_counter_of_type(PlateDispenser)
assert len(plate_dispenser) > 0, "No Plate Return in the environment"
sink_addons = self.get_counter_of_type(SinkAddon)
for counter in self.counters:
match counter:
case ServingWindow():
counter.add_plate_dispenser(plate_dispenser[0])
case Sink(pos=pos):
assert len(sink_addons) > 0, "No SinkAddon but normal Sink"
closest_addon = self.get_closest(pos, sink_addons)
assert 1 - (1 * 0.05) <= np.linalg.norm(
closest_addon.pos - pos
), f"No SinkAddon connected to Sink at pos {pos}"
counter.set_addon(closest_addon)
pass
@staticmethod
def get_closest(pos: npt.NDArray[float], counter: list[Counter]):
return min(counter, key=lambda c: np.linalg.norm(c.pos - pos))
def get_counter_of_type(self, counter_type) -> list[Counter]:
return list(
filter(lambda counter: isinstance(counter, counter_type), self.counters)
)
def reset_env_time(self):
self.env_time = create_init_env_time()
log.debug(f"Reset env time to {self.env_time}")