diff --git a/overcooked_simulator/game_content/item_info.yaml b/overcooked_simulator/game_content/item_info.yaml index 1266f61ebd611cd5c2a9097b1be9dd7eff65b7f7..b06b6296f391e37bff755207de3d9029d398f1a0 100644 --- a/overcooked_simulator/game_content/item_info.yaml +++ b/overcooked_simulator/game_content/item_info.yaml @@ -70,6 +70,7 @@ Cheese: Sausage: type: Ingredient +# Chopped things ChoppedTomato: type: Ingredient needs: [ Tomato ] diff --git a/overcooked_simulator/game_content/layouts/rl.layout b/overcooked_simulator/game_content/layouts/rl.layout index 4b91262e0e78525486821ca6c18edd99097138d8..624af90b4e174a801bfd027e0c9c864450489eef 100644 --- a/overcooked_simulator/game_content/layouts/rl.layout +++ b/overcooked_simulator/game_content/layouts/rl.layout @@ -1,4 +1,5 @@ -#X## -T__# -U__P -##W# +#X### +T___# +#___# +U___P +##W## diff --git a/overcooked_simulator/gym_env.py b/overcooked_simulator/gym_env.py index 26d5da99b0ccdf03e1e17bbf4c278bd9e9ee12c6..2e3f55d7c8cf504d108824beda2e02ac18313310 100644 --- a/overcooked_simulator/gym_env.py +++ b/overcooked_simulator/gym_env.py @@ -1,5 +1,6 @@ import json import time +from copy import copy, deepcopy from datetime import timedelta from enum import Enum from pathlib import Path @@ -116,12 +117,25 @@ def get_env_action(player_id, simple_action, duration): print("FAIL", simple_action) -environment_config_path: Path = ROOT_DIR / "game_content" / "environment_config_rl.yaml" -item_info_path: Path = ROOT_DIR / "game_content" / "item_info_rl.yaml" + +environment_config_path = ROOT_DIR / "game_content" / "environment_config_rl.yaml" layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "rl.layout" +item_info_path = ROOT_DIR / "game_content" / "item_info_rl.yaml" +with open(item_info_path, "r") as file: + item_info = file.read() +with open(layout_path, "r") as file: + layout = file.read() +with open(environment_config_path, "r") as file: + environment_config = file.read() with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file: visualization_config = yaml.safe_load(file) +vanilla_env: Environment = Environment( + env_config=environment_config, + layout_config=layout, + item_info=item_info, + as_files=False + ) class EnvGymWrapper(Env): """Should enable this: @@ -135,12 +149,7 @@ class EnvGymWrapper(Env): self.gridsize = 20 - self.env: Environment = Environment( - env_config=environment_config_path, - layout_config=layout_path, - item_info=item_info_path, - as_files=True, - ) + self.env = deepcopy(vanilla_env) self.visualizer: Visualizer = Visualizer(config=visualization_config) self.player_name = str(0) @@ -160,9 +169,10 @@ class EnvGymWrapper(Env): self.action_space = spaces.Discrete(len(self.action_space_map)) # Example for using image as input (channel-first; channel-last also works): - dummy_obs = self.get_env_img(self.gridsize) + # dummy_obs = self.get_env_img(self.gridsize) + dummy_obs = self.get_vector_state() self.observation_space = spaces.Box( - low=0, high=255, shape=dummy_obs.shape, dtype=np.uint8 + low=0, high=1, shape=dummy_obs.shape, dtype=float ) self.last_obs = dummy_obs @@ -182,16 +192,18 @@ class EnvGymWrapper(Env): timedelta(seconds=self.global_step_time / self.in_between_steps) ) - observation = self.get_env_img(self.gridsize) + # observation = self.get_env_img(self.gridsize) + observation = self.get_vector_state() + reward = -1 if ( - self.env.order_manager.score > self.prev_score - and self.env.order_manager.score != 0 + self.env.score > self.prev_score + and self.env.score != 0 ): - self.prev_score = self.env.order_manager.score + self.prev_score = self.env.score reward = 100 - elif self.env.order_manager.score < self.prev_score: + elif self.env.score < self.prev_score: self.prev_score = 0 reward = -1 @@ -203,24 +215,27 @@ class EnvGymWrapper(Env): return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): - self.env: Environment = Environment( - env_config=environment_config_path, - layout_config=layout_path, - item_info=item_info_path, - as_files=True, - ) + # self.env: Environment = Environment( + # env_config=environment_config, + # layout_config=layout, + # item_info=item_info, + # as_files=False + # ) + + self.env = deepcopy(vanilla_env) self.player_name = str(0) self.env.add_player(self.player_name) self.player_id = list(self.env.players.keys())[0] - info = {} - return self.get_env_img(self.gridsize), info + # obs = self.get_env_img(self.gridsize) + obs = self.get_vector_state() + return obs, info def render(self): observation = self.get_env_img(self.gridsize) img = observation.transpose((1, 2, 0))[:, :, ::-1] - print(img.shape) + # print(img.shape) img = cv2.resize(img, (img.shape[1] * 5, img.shape[0] * 5)) cv2.imshow("Overcooked", img) cv2.waitKey(1) @@ -239,6 +254,8 @@ class EnvGymWrapper(Env): def get_vector_state(self): grid, player, env_time, orders = self.env.get_vectorized_state("0") + obs = np.concatenate([grid.flatten(), player.flatten()], axis=0) + return obs # flatten: grid + player # concatenate all (env_time to array) @@ -248,12 +265,13 @@ class EnvGymWrapper(Env): # return np.random.randint(len(self.action_space_map)) + def main(): rl_agent_checkpoints = Path("./rl_agent_checkpoints") rl_agent_checkpoints.mkdir(exist_ok=True) config = { - "policy_type": "CnnPolicy", + "policy_type": "MlpPolicy", "total_timesteps": 1000000, # hendric sagt eher so 300_000_000 schritte "env_id": "overcooked", } @@ -265,7 +283,7 @@ def main(): # # save_code=True, # optional # ) - env = make_vec_env(EnvGymWrapper, n_envs=4) + env = make_vec_env(EnvGymWrapper, n_envs=64) # env = EnvGymWrapper() model_classes = [A2C, DQN, PPO] diff --git a/overcooked_simulator/hooks.py b/overcooked_simulator/hooks.py index 9d907fcd93f2c4ba6d08cbdbd034c118496ac6c1..fc79f8e31ab90ae2d15f6b2a96bc599fb3c2fae7 100644 --- a/overcooked_simulator/hooks.py +++ b/overcooked_simulator/hooks.py @@ -37,6 +37,9 @@ GAME_ENDED_STEP = "game_ended_step" PRE_STATE = "pre_state" +PRE_STEP = "pre_step" +POST_STEP = "post_step" + STATE_DICT = "state_dict" JSON_STATE = "json_state"