-
Fabian Heinrich authoredFabian Heinrich authored
environment.py 18.78 KiB
from __future__ import annotations
import inspect
import json
import logging
import sys
from collections import defaultdict
from datetime import timedelta, datetime
from pathlib import Path
from random import Random
from typing import Literal, TypedDict, Callable
import numpy as np
import numpy.typing as npt
import yaml
from cooperative_cuisine.action import ActionType, InterActionData, Action
from cooperative_cuisine.counter_factory import (
CounterFactory,
)
from cooperative_cuisine.counters import (
PlateConfig,
)
from cooperative_cuisine.effects import EffectManager
from cooperative_cuisine.hooks import (
ITEM_INFO_LOADED,
LAYOUT_FILE_PARSED,
ENV_INITIALIZED,
PRE_PERFORM_ACTION,
POST_PERFORM_ACTION,
PLAYER_ADDED,
GAME_ENDED_STEP,
PRE_STATE,
STATE_DICT,
JSON_STATE,
PRE_RESET_ENV_TIME,
POST_RESET_ENV_TIME,
Hooks,
ACTION_ON_NOT_REACHABLE_COUNTER,
ACTION_PUT,
ACTION_INTERACT_START,
ITEM_INFO_CONFIG,
POST_STEP,
)
from cooperative_cuisine.items import (
ItemInfo,
ItemType,
)
from cooperative_cuisine.movement import Movement
from cooperative_cuisine.orders import (
OrderManager,
OrderConfig,
)
from cooperative_cuisine.player import Player, PlayerConfig
from cooperative_cuisine.state_representation import InfoMsg
from cooperative_cuisine.utils import (
create_init_env_time,
get_closest,
)
from cooperative_cuisine.validation import Validation
log = logging.getLogger(__name__)
PREVENT_SQUEEZING_INTO_OTHER_PLAYERS = True
class EnvironmentConfig(TypedDict):
plates: PlateConfig
game: dict[
Literal["time_limit_seconds"] | Literal["undo_dispenser_pickup"],
int | bool,
bool,
]
meals: dict[Literal["all"] | Literal["list"], bool | list[str]]
orders: OrderConfig
player_config: PlayerConfig
layout_chars: dict[str, str]
extra_setup_functions: dict[str, dict]
effect_manager: dict
class Environment:
"""Environment class which handles the game logic for the overcooked-inspired environment.
Handles player movement, collision-detection, counters, cooking processes, recipes, incoming orders, time.
"""
PAUSED = None
def __init__(
self,
env_config: Path | str,
layout_config: Path | str,
item_info: Path | str,
as_files: bool = True,
env_name: str = "overcooked_sim",
seed: int = 56789223842348,
):
self.env_name = env_name
"""Reference to the run. E.g, the env id."""
self.env_time: datetime = create_init_env_time()
"""the internal time of the environment. An environment starts always with the time from
`create_init_env_time`."""
self.random: Random = Random(seed)
"""Random instance."""
self.hook: Hooks = Hooks(self)
"""Hook manager. Register callbacks and create hook points with additional kwargs."""
self.score: float = 0.0
"""The current score of the environment."""
self.players: dict[str, Player] = {}
"""the player, keyed by their id/name."""
self.as_files = as_files
"""Are the configs just the path to the files."""
if self.as_files:
with open(env_config, "r") as file:
env_config = file.read()
with open(layout_config, "r") as layout_file:
layout_config = layout_file.read()
with open(item_info, "r") as file:
item_info = file.read()
self.environment_config: EnvironmentConfig = yaml.load(
env_config, Loader=yaml.Loader
)
"""The config of the environment. All environment specific attributes is configured here."""
self.player_view_restricted = self.environment_config["player_config"][
"restricted_view"
]
if self.player_view_restricted:
self.player_view_angle = self.environment_config["player_config"][
"view_angle"
]
self.player_view_range = self.environment_config["player_config"][
"view_range"
]
self.extra_setup_functions()
self.layout_config = layout_config
"""The layout config for the environment"""
# self.counter_side_length = 1 # -> this changed! is 1 now
self.item_info: dict[str, ItemInfo] = self.load_item_info(item_info)
"""The loaded item info dict. Keys are the item names."""
self.hook(ITEM_INFO_LOADED, item_info=item_info)
if self.environment_config["meals"]["all"]:
self.allowed_meal_names = set(
[
item
for item, info in self.item_info.items()
if info.type == ItemType.Meal
]
)
else:
self.allowed_meal_names = set(self.environment_config["meals"]["list"])
"""The allowed meals depend on the `environment_config.yml` configured behaviour. Either all meals that
are possible or only a limited subset."""
self.order_manager = OrderManager(
order_config=self.environment_config["orders"],
hook=self.hook,
random=self.random,
)
"""The manager for the orders and score update."""
self.counter_factory = CounterFactory(
layout_chars_config=self.environment_config["layout_chars"],
item_info=self.item_info,
serving_window_additional_kwargs={
"meals": self.allowed_meal_names,
"env_time_func": self.get_env_time,
},
plate_config=PlateConfig(
**(
self.environment_config["plates"]
if "plates" in self.environment_config
else {}
)
),
order_manager=self.order_manager,
effect_manager_config=self.environment_config["effect_manager"],
undo_dispenser_pickup=self.environment_config["game"][
"undo_dispenser_pickup"
]
if "game" in self.environment_config
and "undo_dispenser_pickup" in self.environment_config["game"]
else False,
hook=self.hook,
random=self.random,
)
(
self.counters,
self.designated_player_positions,
self.free_positions,
self.kitchen_width,
self.kitchen_height,
) = self.counter_factory.parse_layout_file(self.layout_config)
self.hook(LAYOUT_FILE_PARSED)
self.movement = Movement(
counter_positions=np.array([c.pos for c in self.counters]),
player_config=self.environment_config["player_config"],
world_borders=np.array(
[[-0.5, self.kitchen_width - 0.5], [-0.5, self.kitchen_height - 0.5]],
dtype=float,
),
)
self.progressing_counters = []
"""Counters that needs to be called in the step function via the `progress` method."""
self.overwrite_counters(self.counters)
do_validation = (
self.environment_config["game"]["validate_recipes"]
if "validate_recipes" in self.environment_config["game"].keys()
else True
)
self.recipe_validation = Validation(
meals=[m for m in self.item_info.values() if m.type == ItemType.Meal]
if self.environment_config["meals"]["all"]
else [
self.item_info[m]
for m in self.environment_config["meals"]["list"]
if self.item_info[m].type == ItemType.Meal
],
item_info=self.item_info,
order_manager=self.order_manager,
do_validation=do_validation,
)
meals_to_be_ordered = self.recipe_validation.validate_environment(self.counters)
assert meals_to_be_ordered, "Need possible meals for order generation."
available_meals = {meal: self.item_info[meal] for meal in meals_to_be_ordered}
self.order_manager.set_available_meals(available_meals)
self.order_manager.create_init_orders(self.env_time)
self.start_time = self.env_time
"""The relative env time when it started."""
self.env_time_end = self.env_time + timedelta(
seconds=self.environment_config["game"]["time_limit_seconds"]
)
"""The relative env time when it will stop/end"""
log.debug(f"End time: {self.env_time_end}")
self.effect_manager: dict[
str, EffectManager
] = self.counter_factory.setup_effect_manger(self.counters)
self.info_msgs_per_player: dict[str, list[InfoMsg]] = defaultdict(list)
self.hook(
ENV_INITIALIZED,
environment_config=env_config,
layout_config=self.layout_config,
seed=seed,
env_start_time_worldtime=datetime.now(),
)
def overwrite_counters(self, counters):
self.counters = counters
self.movement.counter_positions = np.array([c.pos for c in self.counters])
progress_counter_classes = list(
filter(
lambda cl: hasattr(cl, "progress"),
dict(
inspect.getmembers(
sys.modules["cooperative_cuisine.counters"], inspect.isclass
)
).values(),
)
)
self.progressing_counters = list(
filter(
lambda c: c.__class__ in progress_counter_classes,
self.counters,
)
)
@property
def game_ended(self) -> bool:
"""Whether the game is over or not based on the calculated `Environment.env_time_end`"""
return self.env_time >= self.env_time_end
def get_env_time(self):
"""the internal time of the environment. An environment starts always with the time from `create_init_env_time`.
Utility method to pass a reference to the serving window."""
return self.env_time
def load_item_info(self, item_info) -> dict[str, ItemInfo]:
"""Load `item_info.yml`, create ItemInfo classes and replace equipment strings with item infos."""
self.hook(ITEM_INFO_CONFIG, item_info_config=item_info)
item_lookup = yaml.safe_load(item_info)
for item_name in item_lookup:
item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name])
for item_name, item_info in item_lookup.items():
if item_info.equipment:
item_info.equipment = item_lookup[item_info.equipment]
return item_lookup
def perform_action(self, action: Action):
"""Performs an action of a player in the environment. Maps different types of action inputs to the
correct execution of the players.
Possible action types are movement, pickup and interact actions.
Args:
action: The action to be performed
"""
assert action.player in self.players.keys(), "Unknown player."
self.hook(PRE_PERFORM_ACTION, action=action)
player = self.players[action.player]
if action.action_type == ActionType.MOVEMENT:
player.set_movement(
action.action_data,
self.env_time + timedelta(seconds=action.duration),
)
else:
counter = get_closest(player.facing_point, self.counters)
if player.can_reach(counter):
if action.action_type == ActionType.PUT:
player.put_action(counter)
self.hook(ACTION_PUT, action=action, counter=counter)
elif action.action_type == ActionType.INTERACT:
if action.action_data == InterActionData.START:
player.perform_interact_start(counter)
self.hook(ACTION_INTERACT_START, action=action, counter=counter)
else:
self.hook(
ACTION_ON_NOT_REACHABLE_COUNTER, action=action, counter=counter
)
if action.action_data == InterActionData.STOP:
player.perform_interact_stop()
self.hook(POST_PERFORM_ACTION, action=action)
def add_player(self, player_name: str, pos: npt.NDArray = None):
"""Add a player to the environment.
Args:
player_name: The id/name of the player to reference actions and in the state.
pos: The optional init position of the player.
"""
if player_name in self.players:
raise ValueError(f"Player {player_name} already exists.")
log.debug(f"Add player {player_name} to the game")
player = Player(
player_name,
player_config=PlayerConfig(
**(
self.environment_config["player_config"]
if "player_config" in self.environment_config
else {}
)
),
pos=pos,
)
self.players[player.name] = player
if player.pos is None:
if len(self.designated_player_positions) > 0:
free_idx = self.random.randint(
0, len(self.designated_player_positions) - 1
)
player.move_abs(self.designated_player_positions[free_idx])
del self.designated_player_positions[free_idx]
elif len(self.free_positions) > 0:
free_idx = self.random.randint(0, len(self.free_positions) - 1)
player.move_abs(self.free_positions[free_idx])
del self.free_positions[free_idx]
else:
log.debug("No free positions left in kitchens")
player.update_facing_point()
self.movement.set_collision_arrays(len(self.players))
self.hook(PLAYER_ADDED, player_name=player_name, pos=pos)
def step(self, passed_time: timedelta):
"""Performs a step of the environment. Affects time based events such as cooking or cutting things, orders
and time limits.
"""
# self.hook(PRE_STEP, passed_time=passed_time)
self.env_time += passed_time
if self.game_ended:
self.hook(GAME_ENDED_STEP)
else:
for player in self.players.values():
player.progress(passed_time, self.env_time)
self.movement.perform_movement(
passed_time, self.env_time, self.players, self.counters
)
for counter in self.progressing_counters:
counter.progress(passed_time=passed_time, now=self.env_time)
self.order_manager.progress(passed_time=passed_time, now=self.env_time)
for effect_manager in self.effect_manager.values():
effect_manager.progress(passed_time=passed_time, now=self.env_time)
self.hook(POST_STEP, passed_time=passed_time)
def get_state(self, player_id: str = None, additional_key_values: dict = None):
"""Get the current state of the game environment. The state here is accessible by the current python objects.
Args:
player_id: The player for which to get the state.
additional_key_values: Additional dict that is added to the state
Returns: The state of the game as a dict.
"""
if player_id in self.players:
self.hook(PRE_STATE, player_id=player_id)
state = {
"players": [p.to_dict() for p in self.players.values()],
"counters": [c.to_dict() for c in self.counters],
"kitchen": {"width": self.kitchen_width, "height": self.kitchen_height},
"score": self.score,
"orders": self.order_manager.order_state(),
"ended": self.game_ended,
"env_time": self.env_time.isoformat(),
"remaining_time": max(
self.env_time_end - self.env_time, timedelta(0)
).total_seconds(),
"view_restrictions": [
{
"direction": player.facing_direction.tolist(),
"position": player.pos.tolist(),
"angle": self.player_view_angle,
"counter_mask": None,
"range": self.player_view_range,
}
for player in self.players.values()
]
if self.player_view_restricted
else None,
"served_meals": [
(player, str(meal))
for (meal, time, player) in self.order_manager.served_meals
],
"info_msg": [
(msg["msg"], msg["level"])
for msg in self.info_msgs_per_player[player_id]
if msg["start_time"] < self.env_time
and msg["end_time"] > self.env_time
],
**(additional_key_values if additional_key_values else {}),
}
self.hook(STATE_DICT, state=state, player_id=player_id)
return state
raise ValueError(f"No valid {player_id=}")
def get_json_state(
self, player_id: str = None, additional_key_values: dict = None
) -> str:
"""Return the current state of the game formatted in json dict.
Args:
player_id: The player for which to get the state.
additional_key_values: Additional dict that is added to the state
Returns: The state of the game formatted as a json-string
"""
state = self.get_state(player_id, additional_key_values)
json_data = json.dumps(state)
self.hook(JSON_STATE, json_data=json_data, player_id=player_id)
# assert additional_key_values is None or StateRepresentation.model_validate_json(json_data=json_data)
return json_data
def reset_env_time(self):
"""Reset the env time to the initial time, defined by `create_init_env_time`."""
self.hook(PRE_RESET_ENV_TIME)
self.env_time = create_init_env_time()
self.hook(POST_RESET_ENV_TIME)
log.debug(f"Reset env time to {self.env_time}")
def register_callback_for_hook(self, hook_ref: str | list[str], callback: Callable):
self.hook.register_callback(hook_ref, callback)
def extra_setup_functions(self):
if self.environment_config["extra_setup_functions"]:
for function_name, function_def in self.environment_config[
"extra_setup_functions"
].items():
log.info(f"Setup function {function_name}")
function_def["func"](
name=function_name, env=self, **function_def["kwargs"]
)
def increment_score(self, score: int | float, info: str = ""):
"""Add a value to the current score and log it."""
self.score += score
log.debug(f"Score: {self.score} ({score}) - {info}")