Skip to content
Snippets Groups Projects
Commit 343242fd authored by fheinrich's avatar fheinrich
Browse files

Using state vector in env

parent 1829348f
No related branches found
No related tags found
1 merge request!52Resolve "gym env"
Pipeline #45780 passed
......@@ -70,6 +70,7 @@ Cheese:
Sausage:
type: Ingredient
# Chopped things
ChoppedTomato:
type: Ingredient
needs: [ Tomato ]
......
#X##
T__#
U__P
##W#
#X###
T___#
#___#
U___P
##W##
import json
import time
from copy import copy, deepcopy
from datetime import timedelta
from enum import Enum
from pathlib import Path
......@@ -116,12 +117,25 @@ def get_env_action(player_id, simple_action, duration):
print("FAIL", simple_action)
environment_config_path: Path = ROOT_DIR / "game_content" / "environment_config_rl.yaml"
item_info_path: Path = ROOT_DIR / "game_content" / "item_info_rl.yaml"
environment_config_path = ROOT_DIR / "game_content" / "environment_config_rl.yaml"
layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "rl.layout"
item_info_path = ROOT_DIR / "game_content" / "item_info_rl.yaml"
with open(item_info_path, "r") as file:
item_info = file.read()
with open(layout_path, "r") as file:
layout = file.read()
with open(environment_config_path, "r") as file:
environment_config = file.read()
with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file:
visualization_config = yaml.safe_load(file)
vanilla_env: Environment = Environment(
env_config=environment_config,
layout_config=layout,
item_info=item_info,
as_files=False
)
class EnvGymWrapper(Env):
"""Should enable this:
......@@ -135,12 +149,7 @@ class EnvGymWrapper(Env):
self.gridsize = 20
self.env: Environment = Environment(
env_config=environment_config_path,
layout_config=layout_path,
item_info=item_info_path,
as_files=True,
)
self.env = deepcopy(vanilla_env)
self.visualizer: Visualizer = Visualizer(config=visualization_config)
self.player_name = str(0)
......@@ -160,9 +169,10 @@ class EnvGymWrapper(Env):
self.action_space = spaces.Discrete(len(self.action_space_map))
# Example for using image as input (channel-first; channel-last also works):
dummy_obs = self.get_env_img(self.gridsize)
# dummy_obs = self.get_env_img(self.gridsize)
dummy_obs = self.get_vector_state()
self.observation_space = spaces.Box(
low=0, high=255, shape=dummy_obs.shape, dtype=np.uint8
low=0, high=1, shape=dummy_obs.shape, dtype=float
)
self.last_obs = dummy_obs
......@@ -182,16 +192,18 @@ class EnvGymWrapper(Env):
timedelta(seconds=self.global_step_time / self.in_between_steps)
)
observation = self.get_env_img(self.gridsize)
# observation = self.get_env_img(self.gridsize)
observation = self.get_vector_state()
reward = -1
if (
self.env.order_manager.score > self.prev_score
and self.env.order_manager.score != 0
self.env.score > self.prev_score
and self.env.score != 0
):
self.prev_score = self.env.order_manager.score
self.prev_score = self.env.score
reward = 100
elif self.env.order_manager.score < self.prev_score:
elif self.env.score < self.prev_score:
self.prev_score = 0
reward = -1
......@@ -203,24 +215,27 @@ class EnvGymWrapper(Env):
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
self.env: Environment = Environment(
env_config=environment_config_path,
layout_config=layout_path,
item_info=item_info_path,
as_files=True,
)
# self.env: Environment = Environment(
# env_config=environment_config,
# layout_config=layout,
# item_info=item_info,
# as_files=False
# )
self.env = deepcopy(vanilla_env)
self.player_name = str(0)
self.env.add_player(self.player_name)
self.player_id = list(self.env.players.keys())[0]
info = {}
return self.get_env_img(self.gridsize), info
# obs = self.get_env_img(self.gridsize)
obs = self.get_vector_state()
return obs, info
def render(self):
observation = self.get_env_img(self.gridsize)
img = observation.transpose((1, 2, 0))[:, :, ::-1]
print(img.shape)
# print(img.shape)
img = cv2.resize(img, (img.shape[1] * 5, img.shape[0] * 5))
cv2.imshow("Overcooked", img)
cv2.waitKey(1)
......@@ -239,6 +254,8 @@ class EnvGymWrapper(Env):
def get_vector_state(self):
grid, player, env_time, orders = self.env.get_vectorized_state("0")
obs = np.concatenate([grid.flatten(), player.flatten()], axis=0)
return obs
# flatten: grid + player
# concatenate all (env_time to array)
......@@ -248,12 +265,13 @@ class EnvGymWrapper(Env):
# return np.random.randint(len(self.action_space_map))
def main():
rl_agent_checkpoints = Path("./rl_agent_checkpoints")
rl_agent_checkpoints.mkdir(exist_ok=True)
config = {
"policy_type": "CnnPolicy",
"policy_type": "MlpPolicy",
"total_timesteps": 1000000, # hendric sagt eher so 300_000_000 schritte
"env_id": "overcooked",
}
......@@ -265,7 +283,7 @@ def main():
# # save_code=True, # optional
# )
env = make_vec_env(EnvGymWrapper, n_envs=4)
env = make_vec_env(EnvGymWrapper, n_envs=64)
# env = EnvGymWrapper()
model_classes = [A2C, DQN, PPO]
......
......@@ -37,6 +37,9 @@ GAME_ENDED_STEP = "game_ended_step"
PRE_STATE = "pre_state"
PRE_STEP = "pre_step"
POST_STEP = "post_step"
STATE_DICT = "state_dict"
JSON_STATE = "json_state"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment