diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py index c5bb5d2577dd0dde98a76f54474893fd0a838b56..1db9959862cff42cad081e16bce53ea2435ed6e2 100644 --- a/overcooked_simulator/game_server.py +++ b/overcooked_simulator/game_server.py @@ -432,6 +432,13 @@ class EnvironmentHandler: / 1_000_000_000 ) ) + + ( + grid, + player, + env_time, + orders, + ) = env_data.environment.get_vectorized_state("0") env_data.last_step_time = step_start if env_data.environment.game_ended: log.info(f"Env {env_id} ended. Set env to STOPPED.") diff --git a/overcooked_simulator/gym_env.py b/overcooked_simulator/gym_env.py index 9a0c8094e26ee236709315840b56213c1edc5bc9..26d5da99b0ccdf03e1e17bbf4c278bd9e9ee12c6 100644 --- a/overcooked_simulator/gym_env.py +++ b/overcooked_simulator/gym_env.py @@ -1,5 +1,4 @@ import json -import os.path import time from datetime import timedelta from enum import Enum @@ -8,6 +7,12 @@ from pathlib import Path import cv2 import numpy as np import yaml +from gymnasium import spaces, Env +from stable_baselines3 import A2C +from stable_baselines3 import DQN +from stable_baselines3 import PPO +from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.env_util import make_vec_env from overcooked_simulator import ROOT_DIR from overcooked_simulator.gui_2d_vis.drawing import Visualizer @@ -17,29 +22,22 @@ from overcooked_simulator.overcooked_environment import ( ActionType, InterActionData, ) -import wandb -from wandb.integration.sb3 import WandbCallback - -import gymnasium as gym -import numpy as np -from gymnasium import spaces, Env - -from stable_baselines3.common.env_checker import check_env -from stable_baselines3.common.env_util import make_vec_env -from stable_baselines3 import A2C -from stable_baselines3 import DQN -from stable_baselines3 import PPO -SimpleActionSpace = Enum("SimpleActionSpace", ["Up", - # "Up_Left", - "Left", - # "Down_Left", - "Down", - # "Down_Right", - "Right", - # "Right_Up", - "Interact", - "Put"]) +SimpleActionSpace = Enum( + "SimpleActionSpace", + [ + "Up", + # "Up_Left", + "Left", + # "Down_Left", + "Down", + # "Down_Right", + "Right", + # "Right_Up", + "Interact", + "Put", + ], +) def get_env_action(player_id, simple_action, duration): @@ -118,9 +116,7 @@ def get_env_action(player_id, simple_action, duration): print("FAIL", simple_action) -environment_config_path: Path = ( - ROOT_DIR / "game_content" / "environment_config_rl.yaml" -) +environment_config_path: Path = ROOT_DIR / "game_content" / "environment_config_rl.yaml" item_info_path: Path = ROOT_DIR / "game_content" / "item_info_rl.yaml" layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "rl.layout" with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file: @@ -131,6 +127,7 @@ class EnvGymWrapper(Env): """Should enable this: observation, reward, terminated, truncated, info = env.step(action) """ + metadata = {"render_modes": ["human"], "render_fps": 30} def __init__(self): @@ -142,7 +139,7 @@ class EnvGymWrapper(Env): env_config=environment_config_path, layout_config=layout_path, item_info=item_info_path, - as_files=True + as_files=True, ) self.visualizer: Visualizer = Visualizer(config=visualization_config) @@ -153,11 +150,10 @@ class EnvGymWrapper(Env): self.visualizer.create_player_colors(1) # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)} - self.action_space_map ={} + self.action_space_map = {} for idx, item in enumerate(SimpleActionSpace): self.action_space_map[idx] = item - self.global_step_time = 1 self.in_between_steps = 1 @@ -165,8 +161,9 @@ class EnvGymWrapper(Env): # Example for using image as input (channel-first; channel-last also works): dummy_obs = self.get_env_img(self.gridsize) - self.observation_space = spaces.Box(low=0, high=255, - shape=dummy_obs.shape, dtype=np.uint8) + self.observation_space = spaces.Box( + low=0, high=255, shape=dummy_obs.shape, dtype=np.uint8 + ) self.last_obs = dummy_obs @@ -175,7 +172,9 @@ class EnvGymWrapper(Env): def step(self, action): simple_action = self.action_space_map[action] - env_action = get_env_action(self.player_id, simple_action, self.global_step_time) + env_action = get_env_action( + self.player_id, simple_action, self.global_step_time + ) self.env.perform_action(env_action) for i in range(self.in_between_steps): @@ -186,10 +185,13 @@ class EnvGymWrapper(Env): observation = self.get_env_img(self.gridsize) reward = -1 - if self.env.order_and_score.score > self.prev_score and self.env.order_and_score.score != 0: - self.prev_score = self.env.order_and_score.score + if ( + self.env.order_manager.score > self.prev_score + and self.env.order_manager.score != 0 + ): + self.prev_score = self.env.order_manager.score reward = 100 - elif self.env.order_and_score.score < self.prev_score: + elif self.env.order_manager.score < self.prev_score: self.prev_score = 0 reward = -1 @@ -200,14 +202,12 @@ class EnvGymWrapper(Env): # self.render(self.gridsize) return observation, reward, terminated, truncated, info - def reset(self, seed=None, options=None): - self.env: Environment = Environment( env_config=environment_config_path, layout_config=layout_path, item_info=item_info_path, - as_files=True + as_files=True, ) self.player_name = str(0) @@ -219,10 +219,10 @@ class EnvGymWrapper(Env): def render(self): observation = self.get_env_img(self.gridsize) - img = observation.transpose((1,2,0))[:,:,::-1] + img = observation.transpose((1, 2, 0))[:, :, ::-1] print(img.shape) - img = cv2.resize(img, (img.shape[1]*5, img.shape[0]*5)) - cv2.imshow("Overcooked",img) + img = cv2.resize(img, (img.shape[1] * 5, img.shape[0] * 5)) + cv2.imshow("Overcooked", img) cv2.waitKey(1) def close(self): @@ -234,7 +234,13 @@ class EnvGymWrapper(Env): observation = self.visualizer.get_state_image( grid_size=gridsize, state=json_dict ).transpose((1, 0, 2)) - return observation.transpose((2,0,1)) + return observation.transpose((2, 0, 1)) + + def get_vector_state(self): + grid, player, env_time, orders = self.env.get_vectorized_state("0") + + # flatten: grid + player + # concatenate all (env_time to array) def sample_random_action(self): act = self.action_space.sample() @@ -246,7 +252,6 @@ def main(): rl_agent_checkpoints = Path("./rl_agent_checkpoints") rl_agent_checkpoints.mkdir(exist_ok=True) - config = { "policy_type": "CnnPolicy", "total_timesteps": 1000000, # hendric sagt eher so 300_000_000 schritte @@ -266,8 +271,9 @@ def main(): model_class = model_classes[2] # model = model_class(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}") - model = model_class(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{0}") - + model = model_class( + config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{0}" + ) model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}" # if os.path.exists(model_save_path): @@ -280,7 +286,7 @@ def main(): # verbose=0, # ), log_interval=1, - progress_bar=True + progress_bar=True, ) # run.finish() model.save(model_save_path) @@ -293,7 +299,7 @@ def main(): check_env(env) obs, info = env.reset() while True: - time.sleep(1/30) + time.sleep(1 / 30) action, _states = model.predict(obs, deterministic=False) obs, reward, terminated, truncated, info = env.step(int(action)) env.render() diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index 6d8dd883f32c9c82da5134e119177bff926b179c..00a5794c3502ef44aa54663b245d922bb2d16164 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -5,6 +5,7 @@ import inspect import json import logging import sys +from collections import deque from datetime import timedelta, datetime from enum import Enum from pathlib import Path @@ -20,11 +21,15 @@ from overcooked_simulator.counter_factory import CounterFactory from overcooked_simulator.counters import ( Counter, PlateConfig, + CookingCounter, + Dispenser, ) from overcooked_simulator.effect_manager import EffectManager from overcooked_simulator.game_items import ( ItemInfo, ItemType, + CookingEquipment, + Item, ) from overcooked_simulator.hooks import ( ITEM_INFO_LOADED, @@ -50,7 +55,11 @@ from overcooked_simulator.order import ( OrderConfig, ) from overcooked_simulator.player import Player, PlayerConfig -from overcooked_simulator.utils import create_init_env_time, get_closest +from overcooked_simulator.utils import ( + create_init_env_time, + get_closest, + VectorStateGenerationData, +) log = logging.getLogger(__name__) @@ -194,7 +203,7 @@ class Environment: """The allowed meals depend on the `environment_config.yml` configured behaviour. Either all meals that are possible or only a limited subset.""" - self.order_and_score = OrderManager( + self.order_manager = OrderManager( order_config=self.environment_config["orders"], available_meals={ item: info @@ -225,7 +234,7 @@ class Environment: else {} ) ), - order_manager=self.order_and_score, + order_manager=self.order_manager, effect_manager_config=self.environment_config["effect_manager"], hook=self.hook, random=self.random, @@ -268,7 +277,7 @@ class Environment: ) """Counters that needs to be called in the step function via the `progress` method.""" - self.order_and_score.create_init_orders(self.env_time) + self.order_manager.create_init_orders(self.env_time) self.start_time = self.env_time """The relative env time when it started.""" self.env_time_end = self.env_time + timedelta( @@ -281,6 +290,8 @@ class Environment: str, EffectManager ] = self.counter_factory.setup_effect_manger(self.counters) + self.vector_state_generation = self.setup_vectorization() + self.hook( ENV_INITIALIZED, environment_config=env_config, @@ -665,8 +676,8 @@ class Environment: for idx, p in enumerate(self.players.values()): if not (new_positions[idx] == player_positions[idx]).all(): - p.turn(player_movement_vectors[idx]) p.move_abs(new_positions[idx]) + p.turn(player_movement_vectors[idx]) def add_player(self, player_name: str, pos: npt.NDArray = None): """Add a player to the environment. @@ -742,7 +753,7 @@ class Environment: for counter in self.progressing_counters: counter.progress(passed_time=passed_time, now=self.env_time) - self.order_and_score.progress(passed_time=passed_time, now=self.env_time) + self.order_manager.progress(passed_time=passed_time, now=self.env_time) for effect_manager in self.effect_manager.values(): effect_manager.progress(passed_time=passed_time, now=self.env_time) # self.hook(POST_STEP, passed_time=passed_time) @@ -757,7 +768,7 @@ class Environment: "players": self.players, "counters": self.counters, "score": self.score, - "orders": self.order_and_score.open_orders, + "orders": self.order_manager.open_orders, "ended": self.game_ended, "env_time": self.env_time, "remaining_time": max(self.env_time_end - self.env_time, timedelta(0)), @@ -771,7 +782,7 @@ class Environment: "counters": [c.to_dict() for c in self.counters], "kitchen": {"width": self.kitchen_width, "height": self.kitchen_height}, "score": self.score, - "orders": self.order_and_score.order_state(), + "orders": self.order_manager.order_state(), "ended": self.game_ended, "env_time": self.env_time.isoformat(), "remaining_time": max( @@ -793,6 +804,246 @@ class Environment: return json_data raise ValueError(f"No valid {player_id=}") + def setup_vectorization(self) -> VectorStateGenerationData: + grid_base_array = np.zeros( + ( + int(self.kitchen_width), + int(self.kitchen_height), + 114 + 12 + 4, # TODO calc based on item info + ), + dtype=np.float32, + ) + counter_list = [ + "Counter", + "CuttingBoard", + "ServingWindow", + "Trashcan", + "Sink", + "SinkAddon", + "Stove", + "DeepFryer", + "Oven", + ] + grid_idxs = [ + (x, y) + for x in range(int(self.kitchen_width)) + for y in range(int(self.kitchen_height)) + ] + # counters do not move + for counter in self.counters: + grid_idx = np.floor(counter.pos).astype(int) + counter_name = ( + counter.name + if isinstance(counter, CookingCounter) + else ( + repr(counter) + if isinstance(Counter, Dispenser) + else counter.__class__.__name__ + ) + ) + assert counter_name in counter_list or counter_name.endswith( + "Dispenser" + ), f"Unknown Counter {counter}" + oh_idx = len(counter_list) + if counter_name in counter_list: + oh_idx = counter_list.index(counter_name) + + one_hot = [0] * (len(counter_list) + 2) + one_hot[oh_idx] = 1 + grid_base_array[ + grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2) + ] = np.array(one_hot, dtype=np.float32) + + grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1]))) + + for free_idx in grid_idxs: + one_hot = [0] * (len(counter_list) + 2) + one_hot[len(counter_list) + 1] = 1 + grid_base_array[ + free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2) + ] = np.array(one_hot, dtype=np.float32) + + player_info_base_array = np.zeros( + ( + 4, + 4 + 114, + ), + dtype=np.float32, + ) + order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32) + + return VectorStateGenerationData( + grid_base_array=grid_base_array, + oh_len=12, + ) + + def get_simple_vectorized_item(self, item: Item) -> npt.NDArray[float]: + name = item.name + array = np.zeros(21, dtype=np.float32) + if item.name.startswith("Burnt"): + name = name[len("Burnt") :] + array[0] = 1.0 + if name.startswith("Chopped"): + array[1] = 1.0 + name = name[len("Chopped") :] + if name in [ + "PizzaBase", + "GratedCheese", + "RawChips", + "RawPatty", + ]: + array[1] = 1.0 + name = { + "PizzaBase": "Dough", + "GratedCheese": "Cheese", + "RawChips": "Potato", + "RawPatty": "Meat", + }[name] + if name == "CookedPatty": + array[2] = 1.0 + name = "Meat" + + if name in self.vector_state_generation.meals: + idx = self.vector_state_generation.meals.index(name) + elif name in self.vector_state_generation.ingredients: + idx = len( + self.vector_state_generation.meals + ) + self.vector_state_generation.ingredients.index(name) + else: + raise ValueError(f"Unknown item {name} - {item}") + array[idx] = 1.0 + return array + + def get_vectorized_item(self, item: Item) -> npt.NDArray[float]: + item_array = np.zeros(114, dtype=np.float32) + + if isinstance(item, CookingEquipment) or item.item_info.type == ItemType.Tool: + assert ( + item.name in self.vector_state_generation.equipments + ), f"unknown equipment {item}" + idx = self.vector_state_generation.equipments.index(item.name) + item_array[idx] = 1.0 + if isinstance(item, CookingEquipment): + for s_idx, sub_item in enumerate(item.content_list): + if s_idx > 3: + print("Too much content in the content list, info dropped") + break + start_idx = len(self.vector_state_generation.equipments) + 21 + 2 + item_array[ + start_idx + (s_idx * (21)) : start_idx + ((s_idx + 1) * (21)) + ] = self.get_simple_vectorized_item(sub_item) + + else: + item_array[ + len(self.vector_state_generation.equipments) : len( + self.vector_state_generation.equipments + ) + + 21 + ] = self.get_simple_vectorized_item(item) + + item_array[ + len(self.vector_state_generation.equipments) + 21 + 1 + ] = item.progress_percentage + + if item.active_effects: + item_array[ + len(self.vector_state_generation.equipments) + 21 + 2 + ] = 1.0 # TODO percentage of fire... + + return item_array + + def get_vectorized_state( + self, player_id: str + ) -> Tuple[ + npt.NDArray[npt.NDArray[float]], + npt.NDArray[npt.NDArray[float]], + float, + npt.NDArray[float], + ]: + grid_array = self.vector_state_generation.grid_base_array.copy() + for counter in self.counters: + grid_idx = np.floor(counter.pos).astype(int) # store in counter? + if counter.occupied_by: + if isinstance(counter.occupied_by, deque): + ... + else: + item = counter.occupied_by + grid_array[ + grid_idx[0], + grid_idx[1], + 4 + self.vector_state_generation.oh_len :, + ] = self.get_vectorized_item(item) + if counter.active_effects: + grid_array[ + grid_idx[0], + grid_idx[1], + 4 + self.vector_state_generation.oh_len - 1, + ] = 1.0 # TODO percentage of fire... + + assert len(self.players) <= 4, "To much players for vector representation" + player_vec = np.zeros( + ( + 4, + 4 + 114, + ), + dtype=np.float32, + ) + player_pos = 1 + for player in self.players.values(): + if player.name == player_id: + idx = 0 + player_vec[0, :4] = np.array( + [ + player.pos[0], + player.pos[1], + player.facing_point[0], + player.facing_point[1], + ], + dtype=np.float32, + ) + else: + idx = player_pos + + if not idx: + player_pos += 1 + grid_idx = np.floor(player.pos).astype(int) # store in counter? + player_vec[idx, :4] = np.array( + [ + player.pos[0] - grid_idx[0], + player.pos[1] - grid_idx[1], + player.facing_point[0] / np.linalg.norm(player.facing_point), + player.facing_point[1] / np.linalg.norm(player.facing_point), + ], + dtype=np.float32, + ) + grid_array[grid_idx[0], grid_idx[1], idx] = 1.0 + + if player.holding: + player_vec[idx, 4:] = self.get_vectorized_item(player.holding) + + order_array = np.zeros((10 * (8 + 1)), dtype=np.float32) + + for i, order in enumerate(self.order_manager.open_orders): + if i > 9: + print("some orders are not represented in the vectorized state") + break + assert ( + order.meal.name in self.vector_state_generation.meals + ), "unknown meal in order" + idx = self.vector_state_generation.meals.index(order.meal.name) + order_array[(i * 9) + idx] = 1.0 + order_array[(i * 9) + 8] = ( + self.env_time - order.start_time + ).total_seconds() / order.max_duration.total_seconds() + + return ( + grid_array, + player_vec, + (self.env_time - self.start_time).total_seconds() + / (self.env_time_end - self.start_time).total_seconds(), + order_array, + ) + def reset_env_time(self): """Reset the env time to the initial time, defined by `create_init_env_time`.""" self.hook(PRE_RESET_ENV_TIME) diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py index b78d44af869d9a53ec17f5f2cd8eda191e0b4e17..357d26795a14727eb511468f672e97a072745cd3 100644 --- a/overcooked_simulator/utils.py +++ b/overcooked_simulator/utils.py @@ -25,6 +25,46 @@ if TYPE_CHECKING: from overcooked_simulator.player import Player +@dataclasses.dataclass +class VectorStateGenerationData: + grid_base_array: npt.NDArray[npt.NDArray[float]] + oh_len: int + + number_normal_ingredients = 10 + + meals = [ + "Chips", + "FriedFish", + "Burger", + "Salad", + "TomatoSoup", + "OnionSoup", + "FishAndChips", + "Pizza", + ] + equipments = [ + "Pot", + "Pan", + "Basket", + "Peel", + "Plate", + "DirtyPlate", + "Extinguisher", + ] + ingredients = [ + "Tomato", + "Lettuce", + "Onion", + "Meat", + "Bun", + "Potato", + "Fish", + "Dough", + "Cheese", + "Sausage", + ] + + def create_init_env_time(): """Init time of the environment time, because all environments should have the same internal time.""" return datetime(