Skip to content
Snippets Groups Projects
Commit 8c290c6c authored by fheinrich's avatar fheinrich
Browse files

RL update

parent a6709e07
No related branches found
No related tags found
1 merge request!52Resolve "gym env"
...@@ -103,27 +103,27 @@ extra_setup_functions: ...@@ -103,27 +103,27 @@ extra_setup_functions:
# log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' # log_class: !!python/name:overcooked_simulator.recording.LogRecorder ''
# log_class_kwargs: # log_class_kwargs:
# log_path: USER_LOG_DIR/ENV_NAME/json_states.jsonl # log_path: USER_LOG_DIR/ENV_NAME/json_states.jsonl
actions: # actions:
func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' # func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
kwargs: # kwargs:
hooks: [ pre_perform_action ] # hooks: [ pre_perform_action ]
log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' # log_class: !!python/name:overcooked_simulator.recording.LogRecorder ''
log_class_kwargs: # log_class_kwargs:
log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl # log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl
random_env_events: # random_env_events:
func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' # func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
kwargs: # kwargs:
hooks: [ order_duration_sample, plate_out_of_kitchen_time ] # hooks: [ order_duration_sample, plate_out_of_kitchen_time ]
log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' # log_class: !!python/name:overcooked_simulator.recording.LogRecorder ''
log_class_kwargs: # log_class_kwargs:
log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl # log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl
add_hook_ref: true # add_hook_ref: true
env_configs: # env_configs:
func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' # func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
kwargs: # kwargs:
hooks: [ env_initialized, item_info_config ] # hooks: [ env_initialized, item_info_config ]
log_class: !!python/name:overcooked_simulator.recording.LogRecorder '' # log_class: !!python/name:overcooked_simulator.recording.LogRecorder ''
log_class_kwargs: # log_class_kwargs:
log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl # log_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl
add_hook_ref: true # add_hook_ref: true
#X## #X##
T__# T__#
#__#
U__P U__P
##W# ##W#
...@@ -210,29 +210,24 @@ class EnvGymWrapper(Env): ...@@ -210,29 +210,24 @@ class EnvGymWrapper(Env):
as_files=True as_files=True
) )
self.visualizer: Visualizer = Visualizer(config=visualization_config)
self.player_name = str(0) self.player_name = str(0)
self.env.add_player(self.player_name) self.env.add_player(self.player_name)
self.player_id = list(self.env.players.keys())[0] self.player_id = list(self.env.players.keys())[0]
self.visualizer.create_player_colors(1)
info = {} info = {}
return self.get_env_img(self.gridsize), info return self.get_env_img(self.gridsize), info
def render(self): def render(self):
observation = self.get_env_img(self.gridsize) observation = self.get_env_img(self.gridsize)
img = observation.transpose((1,2,0))[:,:,::-1] img = observation.transpose((1,2,0))[:,:,::-1]
img = cv2.resize(img, (img.shape[1]*5, img.shape[0]*5))
print(img.shape) print(img.shape)
img = cv2.resize(img, (img.shape[1]*5, img.shape[0]*5))
cv2.imshow("Overcooked",img) cv2.imshow("Overcooked",img)
cv2.waitKey(1) cv2.waitKey(1)
def close(self): def close(self):
pass pass
def get_env_img(self, gridsize): def get_env_img(self, gridsize):
state = self.env.get_json_state(player_id=self.player_id) state = self.env.get_json_state(player_id=self.player_id)
json_dict = json.loads(state) json_dict = json.loads(state)
...@@ -241,8 +236,6 @@ class EnvGymWrapper(Env): ...@@ -241,8 +236,6 @@ class EnvGymWrapper(Env):
).transpose((1, 0, 2)) ).transpose((1, 0, 2))
return observation.transpose((2,0,1)) return observation.transpose((2,0,1))
def sample_random_action(self): def sample_random_action(self):
act = self.action_space.sample() act = self.action_space.sample()
return act return act
...@@ -267,7 +260,7 @@ def main(): ...@@ -267,7 +260,7 @@ def main():
# # save_code=True, # optional # # save_code=True, # optional
# ) # )
env = make_vec_env(EnvGymWrapper, n_envs=32) env = make_vec_env(EnvGymWrapper, n_envs=4)
# env = EnvGymWrapper() # env = EnvGymWrapper()
model_classes = [A2C, DQN, PPO] model_classes = [A2C, DQN, PPO]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment