Skip to content
Snippets Groups Projects
Commit 343242fd authored by fheinrich's avatar fheinrich
Browse files

Using state vector in env

parent 1829348f
No related branches found
No related tags found
1 merge request!52Resolve "gym env"
Pipeline #45780 passed
...@@ -70,6 +70,7 @@ Cheese: ...@@ -70,6 +70,7 @@ Cheese:
Sausage: Sausage:
type: Ingredient type: Ingredient
# Chopped things
ChoppedTomato: ChoppedTomato:
type: Ingredient type: Ingredient
needs: [ Tomato ] needs: [ Tomato ]
......
#X## #X###
T__# T___#
U__P #___#
##W# U___P
##W##
import json import json
import time import time
from copy import copy, deepcopy
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
...@@ -116,12 +117,25 @@ def get_env_action(player_id, simple_action, duration): ...@@ -116,12 +117,25 @@ def get_env_action(player_id, simple_action, duration):
print("FAIL", simple_action) print("FAIL", simple_action)
environment_config_path: Path = ROOT_DIR / "game_content" / "environment_config_rl.yaml"
item_info_path: Path = ROOT_DIR / "game_content" / "item_info_rl.yaml" environment_config_path = ROOT_DIR / "game_content" / "environment_config_rl.yaml"
layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "rl.layout" layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "rl.layout"
item_info_path = ROOT_DIR / "game_content" / "item_info_rl.yaml"
with open(item_info_path, "r") as file:
item_info = file.read()
with open(layout_path, "r") as file:
layout = file.read()
with open(environment_config_path, "r") as file:
environment_config = file.read()
with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file: with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file:
visualization_config = yaml.safe_load(file) visualization_config = yaml.safe_load(file)
vanilla_env: Environment = Environment(
env_config=environment_config,
layout_config=layout,
item_info=item_info,
as_files=False
)
class EnvGymWrapper(Env): class EnvGymWrapper(Env):
"""Should enable this: """Should enable this:
...@@ -135,12 +149,7 @@ class EnvGymWrapper(Env): ...@@ -135,12 +149,7 @@ class EnvGymWrapper(Env):
self.gridsize = 20 self.gridsize = 20
self.env: Environment = Environment( self.env = deepcopy(vanilla_env)
env_config=environment_config_path,
layout_config=layout_path,
item_info=item_info_path,
as_files=True,
)
self.visualizer: Visualizer = Visualizer(config=visualization_config) self.visualizer: Visualizer = Visualizer(config=visualization_config)
self.player_name = str(0) self.player_name = str(0)
...@@ -160,9 +169,10 @@ class EnvGymWrapper(Env): ...@@ -160,9 +169,10 @@ class EnvGymWrapper(Env):
self.action_space = spaces.Discrete(len(self.action_space_map)) self.action_space = spaces.Discrete(len(self.action_space_map))
# Example for using image as input (channel-first; channel-last also works): # Example for using image as input (channel-first; channel-last also works):
dummy_obs = self.get_env_img(self.gridsize) # dummy_obs = self.get_env_img(self.gridsize)
dummy_obs = self.get_vector_state()
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=0, high=255, shape=dummy_obs.shape, dtype=np.uint8 low=0, high=1, shape=dummy_obs.shape, dtype=float
) )
self.last_obs = dummy_obs self.last_obs = dummy_obs
...@@ -182,16 +192,18 @@ class EnvGymWrapper(Env): ...@@ -182,16 +192,18 @@ class EnvGymWrapper(Env):
timedelta(seconds=self.global_step_time / self.in_between_steps) timedelta(seconds=self.global_step_time / self.in_between_steps)
) )
observation = self.get_env_img(self.gridsize) # observation = self.get_env_img(self.gridsize)
observation = self.get_vector_state()
reward = -1 reward = -1
if ( if (
self.env.order_manager.score > self.prev_score self.env.score > self.prev_score
and self.env.order_manager.score != 0 and self.env.score != 0
): ):
self.prev_score = self.env.order_manager.score self.prev_score = self.env.score
reward = 100 reward = 100
elif self.env.order_manager.score < self.prev_score: elif self.env.score < self.prev_score:
self.prev_score = 0 self.prev_score = 0
reward = -1 reward = -1
...@@ -203,24 +215,27 @@ class EnvGymWrapper(Env): ...@@ -203,24 +215,27 @@ class EnvGymWrapper(Env):
return observation, reward, terminated, truncated, info return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None): def reset(self, seed=None, options=None):
self.env: Environment = Environment( # self.env: Environment = Environment(
env_config=environment_config_path, # env_config=environment_config,
layout_config=layout_path, # layout_config=layout,
item_info=item_info_path, # item_info=item_info,
as_files=True, # as_files=False
) # )
self.env = deepcopy(vanilla_env)
self.player_name = str(0) self.player_name = str(0)
self.env.add_player(self.player_name) self.env.add_player(self.player_name)
self.player_id = list(self.env.players.keys())[0] self.player_id = list(self.env.players.keys())[0]
info = {} info = {}
return self.get_env_img(self.gridsize), info # obs = self.get_env_img(self.gridsize)
obs = self.get_vector_state()
return obs, info
def render(self): def render(self):
observation = self.get_env_img(self.gridsize) observation = self.get_env_img(self.gridsize)
img = observation.transpose((1, 2, 0))[:, :, ::-1] img = observation.transpose((1, 2, 0))[:, :, ::-1]
print(img.shape) # print(img.shape)
img = cv2.resize(img, (img.shape[1] * 5, img.shape[0] * 5)) img = cv2.resize(img, (img.shape[1] * 5, img.shape[0] * 5))
cv2.imshow("Overcooked", img) cv2.imshow("Overcooked", img)
cv2.waitKey(1) cv2.waitKey(1)
...@@ -239,6 +254,8 @@ class EnvGymWrapper(Env): ...@@ -239,6 +254,8 @@ class EnvGymWrapper(Env):
def get_vector_state(self): def get_vector_state(self):
grid, player, env_time, orders = self.env.get_vectorized_state("0") grid, player, env_time, orders = self.env.get_vectorized_state("0")
obs = np.concatenate([grid.flatten(), player.flatten()], axis=0)
return obs
# flatten: grid + player # flatten: grid + player
# concatenate all (env_time to array) # concatenate all (env_time to array)
...@@ -248,12 +265,13 @@ class EnvGymWrapper(Env): ...@@ -248,12 +265,13 @@ class EnvGymWrapper(Env):
# return np.random.randint(len(self.action_space_map)) # return np.random.randint(len(self.action_space_map))
def main(): def main():
rl_agent_checkpoints = Path("./rl_agent_checkpoints") rl_agent_checkpoints = Path("./rl_agent_checkpoints")
rl_agent_checkpoints.mkdir(exist_ok=True) rl_agent_checkpoints.mkdir(exist_ok=True)
config = { config = {
"policy_type": "CnnPolicy", "policy_type": "MlpPolicy",
"total_timesteps": 1000000, # hendric sagt eher so 300_000_000 schritte "total_timesteps": 1000000, # hendric sagt eher so 300_000_000 schritte
"env_id": "overcooked", "env_id": "overcooked",
} }
...@@ -265,7 +283,7 @@ def main(): ...@@ -265,7 +283,7 @@ def main():
# # save_code=True, # optional # # save_code=True, # optional
# ) # )
env = make_vec_env(EnvGymWrapper, n_envs=4) env = make_vec_env(EnvGymWrapper, n_envs=64)
# env = EnvGymWrapper() # env = EnvGymWrapper()
model_classes = [A2C, DQN, PPO] model_classes = [A2C, DQN, PPO]
......
...@@ -37,6 +37,9 @@ GAME_ENDED_STEP = "game_ended_step" ...@@ -37,6 +37,9 @@ GAME_ENDED_STEP = "game_ended_step"
PRE_STATE = "pre_state" PRE_STATE = "pre_state"
PRE_STEP = "pre_step"
POST_STEP = "post_step"
STATE_DICT = "state_dict" STATE_DICT = "state_dict"
JSON_STATE = "json_state" JSON_STATE = "json_state"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment