diff --git a/overcooked_simulator/game_content/layouts/rl.layout b/overcooked_simulator/game_content/layouts/rl.layout deleted file mode 100644 index 5115af280b02654e287bb9365ec0097a676c1986..0000000000000000000000000000000000000000 --- a/overcooked_simulator/game_content/layouts/rl.layout +++ /dev/null @@ -1,4 +0,0 @@ -#X## -T__W -U__P -#C## diff --git a/overcooked_simulator/gym_env.py b/overcooked_simulator/gym_env.py deleted file mode 100644 index 47dbc108da7edc607936d3cfb175bf97e5a9cb48..0000000000000000000000000000000000000000 --- a/overcooked_simulator/gym_env.py +++ /dev/null @@ -1,399 +0,0 @@ -import json -import random -import time -from datetime import timedelta -from enum import Enum -from pathlib import Path - -import cv2 -import numpy as np -import wandb -import yaml -from gymnasium import spaces, Env -from stable_baselines3 import A2C -from stable_baselines3 import DQN -from stable_baselines3 import PPO -from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback -from stable_baselines3.common.env_checker import check_env -from stable_baselines3.common.env_util import make_vec_env -from stable_baselines3.common.vec_env import VecVideoRecorder -from wandb.integration.sb3 import WandbCallback - -from overcooked_simulator import ROOT_DIR -from overcooked_simulator.counters import Counter -from overcooked_simulator.gui_2d_vis.drawing import Visualizer -from overcooked_simulator.overcooked_environment import ( - Environment, - Action, - ActionType, - InterActionData, -) - -SimpleActionSpace = Enum( - "SimpleActionSpace", - [ - "Up", - # "Up_Left", - "Left", - # "Down_Left", - "Down", - # "Down_Right", - "Right", - # "Right_Up", - "Interact", - "Put", - ], -) - - -def get_env_action(player_id, simple_action, duration): - match simple_action: - case SimpleActionSpace.Up: - return Action( - player_id, - ActionType.MOVEMENT, - np.array([0, -1]), - duration, - ) - # case SimpleActionSpace.Up_Left: - # return Action( - # player_id, - # ActionType.MOVEMENT, - # np.array([-1, -1]), - # duration, - # ) - case SimpleActionSpace.Left: - return Action( - player_id, - ActionType.MOVEMENT, - np.array([-1, 0]), - duration, - ) - # case SimpleActionSpace.Down_Left: - # return Action( - # player_id, - # ActionType.MOVEMENT, - # np.array([-1, 1]), - # duration, - # ) - case SimpleActionSpace.Down: - return Action( - player_id, - ActionType.MOVEMENT, - np.array([0, 1]), - duration, - ) - # case SimpleActionSpace.Down_Right: - # return Action( - # player_id, - # ActionType.MOVEMENT, - # np.array([1, 1]), - # duration, - # ) - case SimpleActionSpace.Right: - return Action( - player_id, - ActionType.MOVEMENT, - np.array([1, 0]), - duration, - ) - # case SimpleActionSpace.Right_Up: - # return Action( - # player_id, - # ActionType.MOVEMENT, - # np.array([1, -1]), - # duration, - # ) - case SimpleActionSpace.Put: - return Action( - player_id, - ActionType.PUT, - InterActionData.START, - duration, - ) - case SimpleActionSpace.Interact: - return Action( - player_id, - ActionType.INTERACT, - InterActionData.START, - duration, - ) - case other: - print("FAIL", simple_action) - - -environment_config_path = ROOT_DIR / "game_content" / "environment_config_rl.yaml" -layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "rl.layout" -item_info_path = ROOT_DIR / "game_content" / "item_info_rl.yaml" -with open(item_info_path, "r") as file: - item_info = file.read() -with open(layout_path, "r") as file: - layout = file.read() -with open(environment_config_path, "r") as file: - environment_config = file.read() -with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file: - visualization_config = yaml.safe_load(file) - - -def shuffle_counters(env): - sample_counter = [] - sample_counter = [] - for counter in env.counters: - if counter.__class__ != Counter: - sample_counter.append(counter) - new_counter_pos = [c.pos for c in sample_counter] - random.shuffle(new_counter_pos) - for counter, new_pos in zip(sample_counter, new_counter_pos): - counter.pos = new_pos - env.vector_state_generation = env.setup_vectorization() - - -class EnvGymWrapper(Env): - """Should enable this: - observation, reward, terminated, truncated, info = env.step(action) - """ - - metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} - - def __init__(self): - super().__init__() - - self.gridsize = 20 - - self.randomize_counter_placement = False - self.use_rgb_obs = False - - self.env: Environment = Environment( - env_config=environment_config, - layout_config=layout, - item_info=item_info, - as_files=False, - ) - - if self.randomize_counter_placement: - shuffle_counters(self.env) - - self.visualizer: Visualizer = Visualizer(config=visualization_config) - self.player_name = str(0) - self.env.add_player(self.player_name) - self.player_id = list(self.env.players.keys())[0] - - self.env.setup_vectorization() - - self.visualizer.create_player_colors(1) - - # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)} - self.action_space_map = {} - for idx, item in enumerate(SimpleActionSpace): - self.action_space_map[idx] = item - - self.global_step_time = 1 - self.in_between_steps = 1 - - self.action_space = spaces.Discrete(len(self.action_space_map)) - # Example for using image as input (channel-first; channel-last also works): - - dummy_obs = self.get_observation() - self.observation_space = spaces.Box( - low=-1, high=8, shape=dummy_obs.shape, dtype=int - ) - - self.last_obs = dummy_obs - - self.step_counter = 0 - self.prev_score = 0 - - def step(self, action): - simple_action = self.action_space_map[action] - env_action = get_env_action( - self.player_id, simple_action, self.global_step_time - ) - self.env.perform_action(env_action) - - for i in range(self.in_between_steps): - self.env.step( - timedelta(seconds=self.global_step_time / self.in_between_steps) - ) - - observation = self.get_observation() - - reward = self.env.score - self.prev_score - self.prev_score = self.env.score - - terminated = self.env.game_ended - truncated = self.env.game_ended - info = {} - - # self.render(self.gridsize) - return observation, reward, terminated, truncated, info - - def reset(self, seed=None, options=None): - self.env: Environment = Environment( - env_config=environment_config, - layout_config=layout, - item_info=item_info, - as_files=False, - ) - - if self.randomize_counter_placement: - shuffle_counters(self.env) - - self.player_name = str(0) - self.env.add_player(self.player_name) - self.player_id = list(self.env.players.keys())[0] - - self.env.setup_vectorization() - - info = {} - obs = self.get_observation() - - self.prev_score = 0 - - return obs, info - - def get_observation(self): - if self.use_rgb_obs: - obs = self.get_env_img(self.gridsize) - else: - obs = self.get_vector_state() - return obs - - def render(self): - observation = self.get_env_img(self.gridsize) - img = (observation * 255.0).astype(np.uint8) - img = img.transpose((1, 2, 0)) - img = cv2.resize(img, (img.shape[1], img.shape[0])) - return img - - def close(self): - pass - - def get_env_img(self, gridsize): - state = self.env.get_json_state(player_id=self.player_id) - json_dict = json.loads(state) - observation = self.visualizer.get_state_image( - grid_size=gridsize, state=json_dict - ).transpose((1, 0, 2)) - return (observation.transpose((2, 0, 1)) / 255.0).astype(np.float32) - - def get_vector_state(self): - # grid, player, env_time, orders = self.env.get_vectorized_state_full("0") - # - # - # obs = np.concatenate( - # [grid.flatten(), player.flatten()], axis=0, dtype=np.float32 - # ) - # return obs - - obs = self.env.get_vectorized_state_simple("0") - return obs - - def sample_random_action(self): - act = self.action_space.sample() - return act - - -def main(): - rl_agent_checkpoints = Path("./rl_agent_checkpoints") - rl_agent_checkpoints.mkdir(exist_ok=True) - - config = { - "policy_type": "MlpPolicy", - "total_timesteps": 30_000_000, # hendric sagt eher so 300_000_000 schritte - "env_id": "overcooked", - "number_envs_parallel": 16, - } - - debug = False - do_training = True - vec_env = True - number_envs_parallel = config["number_envs_parallel"] - - model_classes = [A2C, DQN, PPO] - model_class = model_classes[2] - - if vec_env: - env = make_vec_env(EnvGymWrapper, n_envs=number_envs_parallel) - else: - env = EnvGymWrapper() - - env.render_mode = "rgb_array" - - if not debug: - run = wandb.init( - project="overcooked", - config=config, - sync_tensorboard=True, # auto-upload sb3's tensorboard metrics - monitor_gym=True - # save_code=True, # optional - ) - - env = VecVideoRecorder( - env, - f"videos/{run.id}", - record_video_trigger=lambda x: x % 100_000 == 0, - video_length=300, - ) - - model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}" - - if do_training: - model = model_class( - config["policy_type"], - env, - verbose=1, - tensorboard_log=f"runs/{0}", - # n_steps=2048, - # n_epochs=10, - ) - if debug: - model.learn( - total_timesteps=config["total_timesteps"], - log_interval=1, - progress_bar=True, - ) - else: - checkpoint_callback = CheckpointCallback( - save_freq=50_000, - save_path="./logs/", - name_prefix="rl_model", - save_replay_buffer=True, - save_vecnormalize=True, - ) - wandb_callback = WandbCallback( - model_save_path=f"models/{run.id}", - verbose=0, - ) - - callback = CallbackList([checkpoint_callback, wandb_callback]) - model.learn( - total_timesteps=config["total_timesteps"], - callback=callback, - log_interval=1, - progress_bar=True, - ) - run.finish() - model.save(model_save_path) - - del model - print("LEARNING DONE.") - - model = model_class.load(model_save_path) - env = EnvGymWrapper() - - check_env(env) - obs, info = env.reset() - while True: - time.sleep(1 / 30) - action, _states = model.predict(obs, deterministic=False) - obs, reward, terminated, truncated, info = env.step(int(action)) - print(reward) - rgb_img = env.render() - cv2.imshow("env", rgb_img) - cv2.waitKey(0) - if terminated or truncated: - obs, info = env.reset() - - -if __name__ == "__main__": - main() diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index c394624a9a75a9ba1f969de43ffb4ee2ca5de4f3..1a2e238db7fd4e37609c55d82c480337940ba528 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -5,7 +5,6 @@ import inspect import json import logging import sys -from collections import deque from datetime import timedelta, datetime from enum import Enum from pathlib import Path @@ -21,15 +20,11 @@ from overcooked_simulator.counter_factory import CounterFactory from overcooked_simulator.counters import ( Counter, PlateConfig, - CookingCounter, - Dispenser, ) from overcooked_simulator.effect_manager import EffectManager from overcooked_simulator.game_items import ( ItemInfo, ItemType, - CookingEquipment, - Item, ) from overcooked_simulator.hooks import ( ITEM_INFO_LOADED, @@ -59,7 +54,6 @@ from overcooked_simulator.player import Player, PlayerConfig from overcooked_simulator.utils import ( create_init_env_time, get_closest, - VectorStateGenerationData, ) log = logging.getLogger(__name__) @@ -291,8 +285,6 @@ class Environment: str, EffectManager ] = self.counter_factory.setup_effect_manger(self.counters) - self.vector_state_generation = self.setup_vectorization() - self.hook( ENV_INITIALIZED, environment_config=env_config, @@ -805,404 +797,6 @@ class Environment: return json_data raise ValueError(f"No valid {player_id=}") - def setup_vectorization(self) -> VectorStateGenerationData: - grid_base_array = np.zeros( - ( - int(self.kitchen_width), - int(self.kitchen_height), - 114 + 12 + 4, # TODO calc based on item info - ), - dtype=np.float32, - ) - counter_list = [ - "Counter", - "CuttingBoard", - "ServingWindow", - "Trashcan", - "Sink", - "SinkAddon", - "Stove", - "DeepFryer", - "Oven", - ] - grid_idxs = [ - (x, y) - for x in range(int(self.kitchen_width)) - for y in range(int(self.kitchen_height)) - ] - # counters do not move - for counter in self.counters: - grid_idx = np.floor(counter.pos).astype(int) - counter_name = ( - counter.name - if isinstance(counter, CookingCounter) - else ( - repr(counter) - if isinstance(Counter, Dispenser) - else counter.__class__.__name__ - ) - ) - assert counter_name in counter_list or counter_name.endswith( - "Dispenser" - ), f"Unknown Counter {counter}" - oh_idx = len(counter_list) - if counter_name in counter_list: - oh_idx = counter_list.index(counter_name) - - one_hot = [0] * (len(counter_list) + 2) - one_hot[oh_idx] = 1 - grid_base_array[ - grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2) - ] = np.array(one_hot, dtype=np.float32) - - grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1]))) - - for free_idx in grid_idxs: - one_hot = [0] * (len(counter_list) + 2) - one_hot[len(counter_list) + 1] = 1 - grid_base_array[ - free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2) - ] = np.array(one_hot, dtype=np.float32) - - player_info_base_array = np.zeros( - ( - 4, - 4 + 114, - ), - dtype=np.float32, - ) - order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32) - - return VectorStateGenerationData( - grid_base_array=grid_base_array, - oh_len=12, - ) - - def get_simple_vectorized_item(self, item: Item) -> npt.NDArray[float]: - name = item.name - array = np.zeros(21, dtype=np.float32) - if item.name.startswith("Burnt"): - name = name[len("Burnt") :] - array[0] = 1.0 - if name.startswith("Chopped"): - array[1] = 1.0 - name = name[len("Chopped") :] - if name in [ - "PizzaBase", - "GratedCheese", - "RawChips", - "RawPatty", - ]: - array[1] = 1.0 - name = { - "PizzaBase": "Dough", - "GratedCheese": "Cheese", - "RawChips": "Potato", - "RawPatty": "Meat", - }[name] - if name == "CookedPatty": - array[2] = 1.0 - name = "Meat" - - if name in self.vector_state_generation.meals: - idx = 3 + self.vector_state_generation.meals.index(name) - elif name in self.vector_state_generation.ingredients: - idx = ( - 3 - + len(self.vector_state_generation.meals) - + self.vector_state_generation.ingredients.index(name) - ) - else: - raise ValueError(f"Unknown item {name} - {item}") - array[idx] = 1.0 - return array - - def get_vectorized_item(self, item: Item) -> npt.NDArray[float]: - item_array = np.zeros(114, dtype=np.float32) - - if isinstance(item, CookingEquipment) or item.item_info.type == ItemType.Tool: - assert ( - item.name in self.vector_state_generation.equipments - ), f"unknown equipment {item}" - idx = self.vector_state_generation.equipments.index(item.name) - item_array[idx] = 1.0 - if isinstance(item, CookingEquipment): - for s_idx, sub_item in enumerate(item.content_list): - if s_idx > 3: - print("Too much content in the content list, info dropped") - break - start_idx = len(self.vector_state_generation.equipments) + 21 + 2 - item_array[ - start_idx + (s_idx * (21)) : start_idx + ((s_idx + 1) * (21)) - ] = self.get_simple_vectorized_item(sub_item) - - else: - item_array[ - len(self.vector_state_generation.equipments) : len( - self.vector_state_generation.equipments - ) - + 21 - ] = self.get_simple_vectorized_item(item) - - item_array[ - len(self.vector_state_generation.equipments) + 21 + 1 - ] = item.progress_percentage - - if item.active_effects: - item_array[ - len(self.vector_state_generation.equipments) + 21 + 2 - ] = 1.0 # TODO percentage of fire... - - return item_array - - def get_vectorized_state_full( - self, player_id: str - ) -> Tuple[ - npt.NDArray[npt.NDArray[float]], - npt.NDArray[npt.NDArray[float]], - float, - npt.NDArray[float], - ]: - grid_array = self.vector_state_generation.grid_base_array.copy() - for counter in self.counters: - grid_idx = np.floor(counter.pos).astype(int) # store in counter? - if counter.occupied_by: - if isinstance(counter.occupied_by, deque): - ... - else: - item = counter.occupied_by - grid_array[ - grid_idx[0], - grid_idx[1], - 4 + self.vector_state_generation.oh_len :, - ] = self.get_vectorized_item(item) - if counter.active_effects: - grid_array[ - grid_idx[0], - grid_idx[1], - 4 + self.vector_state_generation.oh_len - 1, - ] = 1.0 # TODO percentage of fire... - - assert len(self.players) <= 4, "To many players for vector representation" - player_vec = np.zeros( - ( - 4, - 4 + 114, - ), - dtype=np.float32, - ) - player_pos = 1 - for player in self.players.values(): - if player.name == player_id: - idx = 0 - player_vec[0, :4] = np.array( - [ - player.pos[0], - player.pos[1], - player.facing_point[0], - player.facing_point[1], - ], - dtype=np.float32, - ) - else: - idx = player_pos - - if not idx: - player_pos += 1 - grid_idx = np.floor(player.pos).astype(int) # store in counter? - player_vec[idx, :4] = np.array( - [ - player.pos[0] - grid_idx[0], - player.pos[1] - grid_idx[1], - player.facing_point[0] / np.linalg.norm(player.facing_point), - player.facing_point[1] / np.linalg.norm(player.facing_point), - ], - dtype=np.float32, - ) - grid_array[grid_idx[0], grid_idx[1], idx] = 1.0 - - if player.holding: - player_vec[idx, 4:] = self.get_vectorized_item(player.holding) - - order_array = np.zeros((10 * (8 + 1)), dtype=np.float32) - - for i, order in enumerate(self.order_manager.open_orders): - if i > 9: - print("some orders are not represented in the vectorized state") - break - assert ( - order.meal.name in self.vector_state_generation.meals - ), "unknown meal in order" - idx = self.vector_state_generation.meals.index(order.meal.name) - order_array[(i * 9) + idx] = 1.0 - order_array[(i * 9) + 8] = ( - self.env_time - order.start_time - ).total_seconds() / order.max_duration.total_seconds() - - return ( - grid_array, - player_vec, - (self.env_time - self.start_time).total_seconds() - / (self.env_time_end - self.start_time).total_seconds(), - order_array, - ) - - # def setup_vectorization_simple(self) -> VectorStateGenerationDataSimple: - # num_per_item = 114 - # num_per_counter = 12 - # num_players = 4 - # grid_base_array = np.zeros( - # ( - # int(self.kitchen_width), - # int(self.kitchen_height), - # num_per_item - # + num_per_counter - # + num_players, # TODO calc based on item info - # ), - # dtype=np.float32, - # ) - # counter_list = [ - # "Counter", - # "CuttingBoard", - # "ServingWindow", - # "Trashcan", - # "Sink", - # "SinkAddon", - # "Stove", - # "DeepFryer", - # "Oven", - # ] - # grid_idxs = [ - # (x, y) - # for x in range(int(self.kitchen_width)) - # for y in range(int(self.kitchen_height)) - # ] - # # counters do not move - # for counter in self.counters: - # grid_idx = np.floor(counter.pos).astype(int) - # counter_name = ( - # counter.name - # if isinstance(counter, CookingCounter) - # else ( - # repr(counter) - # if isinstance(Counter, Dispenser) - # else counter.__class__.__name__ - # ) - # ) - # assert counter_name in counter_list or counter_name.endswith( - # "Dispenser" - # ), f"Unknown Counter {counter}" - # oh_idx = len(counter_list) - # if counter_name in counter_list: - # oh_idx = counter_list.index(counter_name) - # - # one_hot = [0] * (len(counter_list) + 2) - # one_hot[oh_idx] = 1 - # grid_base_array[ - # grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2) - # ] = np.array(one_hot, dtype=np.float32) - # - # grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1]))) - # - # for free_idx in grid_idxs: - # one_hot = [0] * (len(counter_list) + 2) - # one_hot[len(counter_list) + 1] = 1 - # grid_base_array[ - # free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2) - # ] = np.array(one_hot, dtype=np.float32) - # - # player_info_base_array = np.zeros( - # ( - # 4, - # 4 + 114, - # ), - # dtype=np.float32, - # ) - # order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32) - # - # return VectorStateGenerationData( - # grid_base_array=grid_base_array, - # oh_len=12, - # ) - - def get_vectorized_state_simple(self, player): - item_list = ["Pot", "Tomato", "ChoppedTomato", "Plate"] - counter_list = [ - "Counter", - "PlateDispenser", - "TomatoDispenser", - "ServingWindow", - "PlateReturn", - "Trashcan", - "Stove", - "CuttingBoard", - ] - player_pos = self.players[player].pos - player_dir = self.players[player].facing_direction - - grid_width, grid_height = int(self.kitchen_width), int(self.kitchen_height) - - counter_one_hot_length = len(counter_list) + 1 # one for empty field - grid_base_array = np.zeros( - ( - grid_width, - grid_height, - ), - dtype=int, - ) - - grid_idxs = [(x, y) for x in range(grid_width) for y in range(grid_height)] - - # counters do not move - for counter in self.counters: - grid_idx = np.floor(counter.pos).astype(int) - counter_name = ( - counter.name - if isinstance(counter, CookingCounter) - else ( - repr(counter) - if isinstance(Counter, Dispenser) - else counter.__class__.__name__ - ) - ) - if counter_name == "Dispenser": - counter_name = f"{counter.occupied_by.name}Dispenser" - assert counter_name in counter_list, f"Unknown Counter {counter}" - - counter_oh_idx = counter_one_hot_length - if counter_name in counter_list: - counter_oh_idx = counter_list.index(counter_name) - - grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx - grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1]))) - - for free_idx in grid_idxs: - grid_base_array[free_idx[0], free_idx[1]] = counter_one_hot_length - 1 - - counter_grid_one_hot = np.zeros( - (grid_width, grid_height, counter_one_hot_length), dtype=int - ) - for x in range(grid_width): - for y in range(grid_height): - counter_type_idx = grid_base_array[x, y] - counter_grid_one_hot[x, y, counter_type_idx] = 1 - - player_data = np.concatenate((player_pos, player_dir), axis=0) - - items_one_hot_length = len(item_list) + 1 - item_one_hot = np.zeros(items_one_hot_length, dtype=int) - player_item = self.players[player].holding - player_item_idx = items_one_hot_length - 1 - if player_item: - if player_item.name in item_list: - player_item_idx = item_list.index(player_item.name) - item_one_hot[player_item_idx] = 1 - - final = np.concatenate( - (counter_grid_one_hot.flatten(), player_data, item_one_hot), axis=0 - ) - return final - def reset_env_time(self): """Reset the env time to the initial time, defined by `create_init_env_time`.""" self.hook(PRE_RESET_ENV_TIME) diff --git a/overcooked_simulator/reinforcement_learning/__init__.py b/overcooked_simulator/reinforcement_learning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/overcooked_simulator/game_content/environment_config_rl.yaml b/overcooked_simulator/reinforcement_learning/environment_config_rl.yaml similarity index 98% rename from overcooked_simulator/game_content/environment_config_rl.yaml rename to overcooked_simulator/reinforcement_learning/environment_config_rl.yaml index 0b3b4570f92b3ab8554a6183ae56d87b11889f51..b6c9d3bc212547eaffdcd197c29ebc793b2f5d5c 100644 --- a/overcooked_simulator/game_content/environment_config_rl.yaml +++ b/overcooked_simulator/reinforcement_learning/environment_config_rl.yaml @@ -1,5 +1,5 @@ plates: - clean_plates: 1 + clean_plates: 2 dirty_plates: 0 plate_delay: [ 2, 4 ] return_dirty: False @@ -108,7 +108,7 @@ extra_setup_functions: hooks: [ trashcan_usage ] callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' callback_class_kwargs: - static_score: -0.15 + static_score: -0.5 item_cut: func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' kwargs: @@ -122,14 +122,14 @@ extra_setup_functions: hooks: [ post_step ] callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' callback_class_kwargs: - static_score: -0.01 + static_score: -0.05 combine: func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' kwargs: hooks: [ drop_off_on_cooking_equipment ] callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' callback_class_kwargs: - static_score: 0.10 + static_score: 0.15 # json_states: # func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' # kwargs: diff --git a/overcooked_simulator/reinforcement_learning/gym_env.py b/overcooked_simulator/reinforcement_learning/gym_env.py new file mode 100644 index 0000000000000000000000000000000000000000..8b87ad085c3bb1938586378a97a4751f0c914441 --- /dev/null +++ b/overcooked_simulator/reinforcement_learning/gym_env.py @@ -0,0 +1,703 @@ +import json +import random +import time +from collections import deque +from datetime import timedelta +from enum import Enum +from pathlib import Path +from typing import Tuple + +import cv2 +import numpy as np +import numpy.typing as npt +import wandb +import yaml +from gymnasium import spaces, Env +from stable_baselines3 import A2C +from stable_baselines3 import DQN +from stable_baselines3 import PPO +from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback +from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env import VecVideoRecorder +from wandb.integration.sb3 import WandbCallback + +from overcooked_simulator import ROOT_DIR +from overcooked_simulator.counters import Counter, CookingCounter, Dispenser +from overcooked_simulator.game_items import Item, CookingEquipment, ItemType +from overcooked_simulator.gui_2d_vis.drawing import Visualizer +from overcooked_simulator.overcooked_environment import ( + Environment, + Action, + ActionType, + InterActionData, +) +from overcooked_simulator.utils import VectorStateGenerationData + +SimpleActionSpace = Enum( + "SimpleActionSpace", + [ + "Up", + "Left", + "Down", + "Right", + "Interact", + "Put", + ], +) + + +def get_env_action(player_id, simple_action, duration): + match simple_action: + case SimpleActionSpace.Up: + return Action( + player_id, + ActionType.MOVEMENT, + np.array([0, -1]), + duration, + ) + case SimpleActionSpace.Left: + return Action( + player_id, + ActionType.MOVEMENT, + np.array([-1, 0]), + duration, + ) + case SimpleActionSpace.Down: + return Action( + player_id, + ActionType.MOVEMENT, + np.array([0, 1]), + duration, + ) + case SimpleActionSpace.Right: + return Action( + player_id, + ActionType.MOVEMENT, + np.array([1, 0]), + duration, + ) + case SimpleActionSpace.Put: + return Action( + player_id, + ActionType.PUT, + InterActionData.START, + duration, + ) + case SimpleActionSpace.Interact: + return Action( + player_id, + ActionType.INTERACT, + InterActionData.START, + duration, + ) + + +environment_config_path = ( + ROOT_DIR / "reinforcement_learning" / "environment_config_rl.yaml" +) +layout_path: Path = ROOT_DIR / "reinforcement_learning" / "rl_small.layout" +item_info_path = ROOT_DIR / "reinforcement_learning" / "item_info_rl.yaml" +with open(item_info_path, "r") as file: + item_info = file.read() +with open(layout_path, "r") as file: + layout = file.read() +with open(environment_config_path, "r") as file: + environment_config = file.read() +with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file: + visualization_config = yaml.safe_load(file) + + +def shuffle_counters(env): + sample_counter = [] + for counter in env.counters: + if counter.__class__ != Counter: + sample_counter.append(counter) + new_counter_pos = [c.pos for c in sample_counter] + random.shuffle(new_counter_pos) + for counter, new_pos in zip(sample_counter, new_counter_pos): + counter.pos = new_pos + + +class EnvGymWrapper(Env): + """Should enable this: + observation, reward, terminated, truncated, info = env.step(action) + """ + + metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 20} + + def __init__(self): + super().__init__() + + self.gridsize = 20 + + self.randomize_counter_placement = True + self.use_rgb_obs = False # if False uses simple vectorized state + self.use_onehot = False + self.full_vector_state = True + + self.env: Environment = Environment( + env_config=environment_config, + layout_config=layout, + item_info=item_info, + as_files=False, + ) + + if self.randomize_counter_placement: + shuffle_counters(self.env) + + if self.full_vector_state: + self.vector_state_generation = self.setup_vectorization() + + self.visualizer: Visualizer = Visualizer(config=visualization_config) + self.visualizer.create_player_colors(1) + + self.player_name = str(0) + self.env.add_player(self.player_name) + self.player_id = list(self.env.players.keys())[0] + + # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)} + self.action_space_map = {} + for idx, item in enumerate(SimpleActionSpace): + self.action_space_map[idx] = item + self.global_step_time = 1 + self.in_between_steps = 1 + + self.action_space = spaces.Discrete(len(self.action_space_map)) + # Example for using image as input (channel-first; channel-last also works): + + min_obs_val = -1 if not self.use_rgb_obs else 0 + max_obs_val = 1 if self.use_onehot else 255 if self.use_rgb_obs else 8 + dummy_obs = self.get_observation() + self.observation_space = spaces.Box( + low=min_obs_val, + high=max_obs_val, + shape=dummy_obs.shape, + dtype=np.uint8 if self.use_rgb_obs else np.float32, + ) + print(self.observation_space) + + self.last_obs = dummy_obs + + self.step_counter = 0 + self.prev_score = 0 + + def get_vectorized_state_simple(self, player, onehot=True): + item_list = ["Pot", "Tomato", "ChoppedTomato", "Plate"] + counter_list = [ + "Counter", + "PlateDispenser", + "TomatoDispenser", + "ServingWindow", + "PlateReturn", + "Trashcan", + "Stove", + "CuttingBoard", + ] + + grid_width, grid_height = int(self.env.kitchen_width), int( + self.env.kitchen_height + ) + + counter_one_hot_length = len(counter_list) + 1 # one for empty field + grid_base_array = np.zeros( + ( + grid_width, + grid_height, + ), + dtype=int, + ) + + grid_idxs = [(x, y) for x in range(grid_width) for y in range(grid_height)] + + # counters do not move + for counter in self.env.counters: + grid_idx = np.floor(counter.pos).astype(int) + counter_name = ( + counter.name + if isinstance(counter, CookingCounter) + else ( + repr(counter) + if isinstance(Counter, Dispenser) + else counter.__class__.__name__ + ) + ) + if counter_name == "Dispenser": + counter_name = f"{counter.occupied_by.name}Dispenser" + assert counter_name in counter_list, f"Unknown Counter {counter}" + + counter_oh_idx = counter_one_hot_length + if counter_name in counter_list: + counter_oh_idx = counter_list.index(counter_name) + + grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx + grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1]))) + + for free_idx in grid_idxs: + grid_base_array[free_idx[0], free_idx[1]] = counter_one_hot_length - 1 + + counter_grid_one_hot = np.zeros( + (grid_width, grid_height, counter_one_hot_length), dtype=int + ) + for x in range(grid_width): + for y in range(grid_height): + counter_type_idx = grid_base_array[x, y] + counter_grid_one_hot[x, y, counter_type_idx] = 1 + + player_pos = self.env.players[player].pos + if onehot: + player_pos[0] /= self.env.kitchen_width + player_pos[1] /= self.env.kitchen_height + else: + player_pos = player_pos.astype(int) + + player_dir = self.env.players[player].facing_direction + player_data = np.concatenate((player_pos, player_dir), axis=0) + + items_one_hot_length = len(item_list) + 1 + item_one_hot = np.zeros(items_one_hot_length, dtype=int) + player_item = self.env.players[player].holding + player_item_idx = items_one_hot_length - 1 + if player_item: + if player_item.name in item_list: + player_item_idx = item_list.index(player_item.name) + item_one_hot[player_item_idx] = 1 + + final_idxs = np.concatenate( + (grid_base_array.flatten(), player_data, item_one_hot), axis=0 + ) + final_one_hot = np.concatenate( + (counter_grid_one_hot.flatten(), player_data, item_one_hot), axis=0 + ) + + return final_one_hot if onehot else final_idxs + + def setup_vectorization(self) -> VectorStateGenerationData: + grid_base_array = np.zeros( + ( + int(self.env.kitchen_width), + int(self.env.kitchen_height), + 114 + 12 + 4, # TODO calc based on item info + ), + dtype=np.float32, + ) + counter_list = [ + "Counter", + "CuttingBoard", + "ServingWindow", + "Trashcan", + "Sink", + "SinkAddon", + "Stove", + "DeepFryer", + "Oven", + ] + grid_idxs = [ + (x, y) + for x in range(int(self.env.kitchen_width)) + for y in range(int(self.env.kitchen_height)) + ] + # counters do not move + for counter in self.env.counters: + grid_idx = np.floor(counter.pos).astype(int) + counter_name = ( + counter.name + if isinstance(counter, CookingCounter) + else ( + repr(counter) + if isinstance(Counter, Dispenser) + else counter.__class__.__name__ + ) + ) + assert counter_name in counter_list or counter_name.endswith( + "Dispenser" + ), f"Unknown Counter {counter}" + oh_idx = len(counter_list) + if counter_name in counter_list: + oh_idx = counter_list.index(counter_name) + + one_hot = [0] * (len(counter_list) + 2) + one_hot[oh_idx] = 1 + grid_base_array[ + grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2) + ] = np.array(one_hot, dtype=np.float32) + + grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1]))) + + for free_idx in grid_idxs: + one_hot = [0] * (len(counter_list) + 2) + one_hot[len(counter_list) + 1] = 1 + grid_base_array[ + free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2) + ] = np.array(one_hot, dtype=np.float32) + + player_info_base_array = np.zeros( + ( + 4, + 4 + 114, + ), + dtype=np.float32, + ) + order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32) + + return VectorStateGenerationData( + grid_base_array=grid_base_array, + oh_len=12, + ) + + def get_simple_vectorized_item(self, item: Item) -> npt.NDArray[float]: + name = item.name + array = np.zeros(21, dtype=np.float32) + if item.name.startswith("Burnt"): + name = name[len("Burnt") :] + array[0] = 1.0 + if name.startswith("Chopped"): + array[1] = 1.0 + name = name[len("Chopped") :] + if name in [ + "PizzaBase", + "GratedCheese", + "RawChips", + "RawPatty", + ]: + array[1] = 1.0 + name = { + "PizzaBase": "Dough", + "GratedCheese": "Cheese", + "RawChips": "Potato", + "RawPatty": "Meat", + }[name] + if name == "CookedPatty": + array[2] = 1.0 + name = "Meat" + + if name in self.vector_state_generation.meals: + idx = 3 + self.vector_state_generation.meals.index(name) + elif name in self.vector_state_generation.ingredients: + idx = ( + 3 + + len(self.vector_state_generation.meals) + + self.vector_state_generation.ingredients.index(name) + ) + else: + raise ValueError(f"Unknown item {name} - {item}") + array[idx] = 1.0 + return array + + def get_vectorized_item(self, item: Item) -> npt.NDArray[float]: + item_array = np.zeros(114, dtype=np.float32) + + if isinstance(item, CookingEquipment) or item.item_info.type == ItemType.Tool: + assert ( + item.name in self.vector_state_generation.equipments + ), f"unknown equipment {item}" + idx = self.vector_state_generation.equipments.index(item.name) + item_array[idx] = 1.0 + if isinstance(item, CookingEquipment): + for s_idx, sub_item in enumerate(item.content_list): + if s_idx > 3: + print("Too much content in the content list, info dropped") + break + start_idx = len(self.vector_state_generation.equipments) + 21 + 2 + item_array[ + start_idx + (s_idx * (21)) : start_idx + ((s_idx + 1) * (21)) + ] = self.get_simple_vectorized_item(sub_item) + + else: + item_array[ + len(self.vector_state_generation.equipments) : len( + self.vector_state_generation.equipments + ) + + 21 + ] = self.get_simple_vectorized_item(item) + + item_array[ + len(self.vector_state_generation.equipments) + 21 + 1 + ] = item.progress_percentage + + if item.active_effects: + item_array[ + len(self.vector_state_generation.equipments) + 21 + 2 + ] = 1.0 # TODO percentage of fire... + + return item_array + + def get_vectorized_state_full( + self, player_id: str + ) -> Tuple[ + npt.NDArray[npt.NDArray[float]], + npt.NDArray[npt.NDArray[float]], + float, + npt.NDArray[float], + ]: + grid_array = self.vector_state_generation.grid_base_array.copy() + for counter in self.env.counters: + grid_idx = np.floor(counter.pos).astype(int) # store in counter? + if counter.occupied_by: + if isinstance(counter.occupied_by, deque): + ... + else: + item = counter.occupied_by + grid_array[ + grid_idx[0], + grid_idx[1], + 4 + self.vector_state_generation.oh_len :, + ] = self.get_vectorized_item(item) + if counter.active_effects: + grid_array[ + grid_idx[0], + grid_idx[1], + 4 + self.vector_state_generation.oh_len - 1, + ] = 1.0 # TODO percentage of fire... + + assert len(self.env.players) <= 4, "To many players for vector representation" + player_vec = np.zeros( + ( + 4, + 4 + 114, + ), + dtype=np.float32, + ) + player_pos = 1 + for player in self.env.players.values(): + if player.name == player_id: + idx = 0 + player_vec[0, :4] = np.array( + [ + player.pos[0], + player.pos[1], + player.facing_point[0], + player.facing_point[1], + ], + dtype=np.float32, + ) + else: + idx = player_pos + + if not idx: + player_pos += 1 + grid_idx = np.floor(player.pos).astype(int) # store in counter? + player_vec[idx, :4] = np.array( + [ + player.pos[0] - grid_idx[0], + player.pos[1] - grid_idx[1], + player.facing_point[0] / np.linalg.norm(player.facing_point), + player.facing_point[1] / np.linalg.norm(player.facing_point), + ], + dtype=np.float32, + ) + grid_array[grid_idx[0], grid_idx[1], idx] = 1.0 + + if player.holding: + player_vec[idx, 4:] = self.get_vectorized_item(player.holding) + + order_array = np.zeros((10 * (8 + 1)), dtype=np.float32) + + for i, order in enumerate(self.env.order_manager.open_orders): + if i > 9: + print("some orders are not represented in the vectorized state") + break + assert ( + order.meal.name in self.vector_state_generation.meals + ), "unknown meal in order" + idx = self.vector_state_generation.meals.index(order.meal.name) + order_array[(i * 9) + idx] = 1.0 + order_array[(i * 9) + 8] = ( + self.env_time - order.start_time + ).total_seconds() / order.max_duration.total_seconds() + + return ( + grid_array, + player_vec, + (self.env.env_time - self.env.start_time).total_seconds() + / (self.env.env_time_end - self.env.start_time).total_seconds(), + order_array, + ) + + def step(self, action): + simple_action = self.action_space_map[action] + env_action = get_env_action( + self.player_id, simple_action, self.global_step_time + ) + self.env.perform_action(env_action) + + for i in range(self.in_between_steps): + self.env.step( + timedelta(seconds=self.global_step_time / self.in_between_steps) + ) + + observation = self.get_observation() + + reward = self.env.score - self.prev_score + self.prev_score = self.env.score + + terminated = self.env.game_ended + truncated = self.env.game_ended + info = {} + + return observation, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + self.env: Environment = Environment( + env_config=environment_config, + layout_config=layout, + item_info=item_info, + as_files=False, + ) + + if self.randomize_counter_placement: + shuffle_counters(self.env) + + self.player_name = str(0) + self.env.add_player(self.player_name) + self.player_id = list(self.env.players.keys())[0] + + if self.full_vector_state: + self.vector_state_generation = self.setup_vectorization() + + info = {} + obs = self.get_observation() + + self.prev_score = 0 + + return obs, info + + def get_observation(self): + if self.use_rgb_obs: + obs = self.get_env_img(self.gridsize) + else: + obs = self.get_vector_state() + return obs + + def render(self): + observation = self.get_env_img(self.gridsize) + img = observation.astype(np.uint8) + img = img.transpose((1, 2, 0)) + img = cv2.resize(img, (img.shape[1], img.shape[0])) + return img + + def close(self): + pass + + def get_env_img(self, gridsize): + state = self.env.get_json_state(player_id=self.player_id) + json_dict = json.loads(state) + observation = self.visualizer.get_state_image( + grid_size=gridsize, state=json_dict + ).transpose((1, 0, 2)) + return (observation.transpose((2, 0, 1))).astype(np.uint8) + + def get_vector_state(self): + obs = self.get_vectorized_state_simple("0", self.use_onehot) + return obs + + def sample_random_action(self): + act = self.action_space.sample() + return act + + +def main(): + rl_agent_checkpoints = Path("rl_agent_checkpoints") + rl_agent_checkpoints.mkdir(exist_ok=True) + + config = { + "policy_type": "MlpPolicy", + "total_timesteps": 30_000_000, # hendric sagt eher so 300_000_000 schritte + "env_id": "overcooked", + "number_envs_parallel": 4, + } + + debug = False + do_training = True + vec_env = True + number_envs_parallel = config["number_envs_parallel"] + + model_classes = [A2C, DQN, PPO] + model_class = model_classes[2] + + if vec_env: + env = make_vec_env(EnvGymWrapper, n_envs=number_envs_parallel) + else: + env = EnvGymWrapper() + + env.render_mode = "rgb_array" + + if not debug: + run = wandb.init( + project="overcooked", + config=config, + sync_tensorboard=True, # auto-upload sb3's tensorboard metrics + monitor_gym=True, + # save_code=True, # optional + ) + + env = VecVideoRecorder( + env, + f"videos/{run.id}", + record_video_trigger=lambda x: x % 200_000 == 0, + video_length=300, + ) + + model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}" + + if do_training: + model = model_class( + config["policy_type"], + env, + verbose=1, + tensorboard_log=f"runs/{0}", + device="cpu" + # n_steps=2048, + # n_epochs=10, + ) + if debug: + model.learn( + total_timesteps=config["total_timesteps"], + log_interval=1, + progress_bar=True, + ) + else: + checkpoint_callback = CheckpointCallback( + save_freq=50_000, + save_path="logs", + name_prefix="rl_model", + save_replay_buffer=True, + save_vecnormalize=True, + ) + wandb_callback = WandbCallback( + model_save_path=f"models/{run.id}", + verbose=0, + ) + + callback = CallbackList([checkpoint_callback, wandb_callback]) + model.learn( + total_timesteps=config["total_timesteps"], + callback=callback, + log_interval=1, + progress_bar=True, + ) + run.finish() + model.save(model_save_path) + + del model + print("LEARNING DONE.") + + model = model_class.load(model_save_path) + env = EnvGymWrapper() + + check_env(env) + obs, info = env.reset() + while True: + action, _states = model.predict(obs, deterministic=False) + obs, reward, terminated, truncated, info = env.step(int(action)) + print(reward) + rgb_img = env.render() + cv2.imshow("env", rgb_img) + cv2.waitKey(0) + if terminated or truncated: + obs, info = env.reset() + time.sleep(1 / env.metadata["render_fps"]) + + +if __name__ == "__main__": + main() diff --git a/overcooked_simulator/game_content/item_info_rl.yaml b/overcooked_simulator/reinforcement_learning/item_info_rl.yaml similarity index 100% rename from overcooked_simulator/game_content/item_info_rl.yaml rename to overcooked_simulator/reinforcement_learning/item_info_rl.yaml diff --git a/overcooked_simulator/reinforcement_learning/rl.layout b/overcooked_simulator/reinforcement_learning/rl.layout new file mode 100644 index 0000000000000000000000000000000000000000..e1e8c0757badf2379f3a29728f99f0088ea76bca --- /dev/null +++ b/overcooked_simulator/reinforcement_learning/rl.layout @@ -0,0 +1,5 @@ +##X## +T___P +#___# +U___# +#C#W# diff --git a/overcooked_simulator/reinforcement_learning/rl_small.layout b/overcooked_simulator/reinforcement_learning/rl_small.layout new file mode 100644 index 0000000000000000000000000000000000000000..a9eda0c9e9680e3c781c2f54579c0c6309a6bc47 --- /dev/null +++ b/overcooked_simulator/reinforcement_learning/rl_small.layout @@ -0,0 +1,4 @@ +#X## +T__P +U__# +#CW# diff --git a/overcooked_simulator/rl_agent_checkpoints/overcooked_PPO.zip b/overcooked_simulator/rl_agent_checkpoints/overcooked_PPO.zip deleted file mode 100644 index b79af5605fa8773501888d028bca62b6845cb2d4..0000000000000000000000000000000000000000 Binary files a/overcooked_simulator/rl_agent_checkpoints/overcooked_PPO.zip and /dev/null differ