diff --git a/overcooked_simulator/game_content/environment_config_rl.yaml b/overcooked_simulator/game_content/environment_config_rl.yaml index 6235a971e1fb53a0569e82fa63602b9a2e8427c9..a6dcf7996b8b885aa0fdae77bd73086797d580e2 100644 --- a/overcooked_simulator/game_content/environment_config_rl.yaml +++ b/overcooked_simulator/game_content/environment_config_rl.yaml @@ -103,27 +103,27 @@ extra_setup_functions: # log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' # log_class_kwargs: # log_path: USER_LOG_DIR/ENV_NAME/json_states.jsonl - actions: - func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' - kwargs: - hooks: [ pre_perform_action ] - log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' - log_class_kwargs: - log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl - random_env_events: - func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' - kwargs: - hooks: [ order_duration_sample, plate_out_of_kitchen_time ] - log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' - log_class_kwargs: - log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl - add_hook_ref: true - env_configs: - func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' - kwargs: - hooks: [ env_initialized, item_info_config ] - log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' - log_class_kwargs: - log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl - add_hook_ref: true +# actions: +# func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' +# kwargs: +# hooks: [ pre_perform_action ] +# log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' +# log_class_kwargs: +# log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl +# random_env_events: +# func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' +# kwargs: +# hooks: [ order_duration_sample, plate_out_of_kitchen_time ] +# log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' +# log_class_kwargs: +# log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl +# add_hook_ref: true +# env_configs: +# func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' +# kwargs: +# hooks: [ env_initialized, item_info_config ] +# log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' +# log_class_kwargs: +# log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl +# add_hook_ref: true diff --git a/overcooked_simulator/game_content/layouts/rl.layout b/overcooked_simulator/game_content/layouts/rl.layout index 5528ebcf1d59c8dab2d74d6590b5696ad9744d8f..4b91262e0e78525486821ca6c18edd99097138d8 100644 --- a/overcooked_simulator/game_content/layouts/rl.layout +++ b/overcooked_simulator/game_content/layouts/rl.layout @@ -1,5 +1,4 @@ #X## T__# -#__# U__P ##W# diff --git a/overcooked_simulator/gym_env.py b/overcooked_simulator/gym_env.py index a104ab1561b7f2961023dba8b3da539a62f8f145..9a0c8094e26ee236709315840b56213c1edc5bc9 100644 --- a/overcooked_simulator/gym_env.py +++ b/overcooked_simulator/gym_env.py @@ -210,29 +210,24 @@ class EnvGymWrapper(Env): as_files=True ) - self.visualizer: Visualizer = Visualizer(config=visualization_config) - self.player_name = str(0) self.env.add_player(self.player_name) self.player_id = list(self.env.players.keys())[0] - self.visualizer.create_player_colors(1) - info = {} return self.get_env_img(self.gridsize), info def render(self): observation = self.get_env_img(self.gridsize) img = observation.transpose((1,2,0))[:,:,::-1] - img = cv2.resize(img, (img.shape[1]*5, img.shape[0]*5)) print(img.shape) + img = cv2.resize(img, (img.shape[1]*5, img.shape[0]*5)) cv2.imshow("Overcooked",img) cv2.waitKey(1) def close(self): pass - def get_env_img(self, gridsize): state = self.env.get_json_state(player_id=self.player_id) json_dict = json.loads(state) @@ -241,8 +236,6 @@ class EnvGymWrapper(Env): ).transpose((1, 0, 2)) return observation.transpose((2,0,1)) - - def sample_random_action(self): act = self.action_space.sample() return act @@ -267,7 +260,7 @@ def main(): # # save_code=True, # optional # ) - env = make_vec_env(EnvGymWrapper, n_envs=32) + env = make_vec_env(EnvGymWrapper, n_envs=4) # env = EnvGymWrapper() model_classes = [A2C, DQN, PPO]