Skip to content
Snippets Groups Projects

Resolve "gym env"

Merged Fabian Heinrich requested to merge 86-gym-env into main
1 file
+ 198
0
Compare changes
  • Side-by-side
  • Inline
+ 198
0
import json
from datetime import timedelta
from enum import Enum
from pathlib import Path
import cv2
import numpy as np
import yaml
from overcooked_simulator import ROOT_DIR
from overcooked_simulator.gui_2d_vis.drawing import Visualizer
from overcooked_simulator.overcooked_environment import (
Environment,
Action,
ActionType,
InterActionData,
)
class SimpleActionSpace(Enum):
Up = "Up"
Down = "Down"
Left = "Left"
Right = "Right"
Interact = "Interact"
Put = "Put"
class EnvGymWrapper:
"""Should enable this:
observation, reward, terminated, truncated, info = env.step(action)
"""
def __init__(self):
environment_config_path: Path = (
ROOT_DIR / "game_content" / "environment_config.yaml"
)
layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "basic.layout"
item_info_path: Path = ROOT_DIR / "game_content" / "item_info.yaml"
self.env: Environment = Environment(
env_config=environment_config_path,
layout_config=layout_path,
item_info=item_info_path,
)
with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file:
visualization_config = yaml.safe_load(file)
self.visualizer: Visualizer = Visualizer(config=visualization_config)
self.player_name = str(0)
self.env.add_player(self.player_name)
self.player_id = list(self.env.players.keys())[0]
self.visualizer.create_player_colors(1)
# self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)}
self.action_space = {
0: SimpleActionSpace.Up,
1: SimpleActionSpace.Down,
2: SimpleActionSpace.Left,
3: SimpleActionSpace.Right,
4: SimpleActionSpace.Put,
}
print(self.action_space)
self.global_step_time = 0.05
self.in_between_steps = 10
def get_env_action(self, simple_action, duration):
match simple_action:
case SimpleActionSpace.Up:
return Action(
self.player_id,
ActionType.MOVEMENT,
np.array([0, -1]),
duration,
)
case SimpleActionSpace.Down:
return Action(
self.player_id,
ActionType.MOVEMENT,
np.array([0, 1]),
duration,
)
case SimpleActionSpace.Left:
return Action(
self.player_id,
ActionType.MOVEMENT,
np.array([-1, 0]),
duration,
)
case SimpleActionSpace.Right:
return Action(
self.player_id,
ActionType.MOVEMENT,
np.array([1, 0]),
duration,
)
case SimpleActionSpace.Put:
return Action(
self.player_id,
ActionType.PUT,
InterActionData.START,
duration,
)
case SimpleActionSpace.Put:
return Action(
self.player_id,
ActionType.INTERACT,
InterActionData.START,
duration,
)
# case SimpleActionSpace.Interact:
# pass
def gym_env_setup(self):
self.action_space
self.observation_space
self.reward_range
def render(self):
pass
def close(self):
pass
def sample_random_action(self):
return np.random.randint(len(self.action_space))
def step(self, simple_action) -> tuple:
simple_action = self.action_space[simple_action]
action = self.get_env_action(simple_action, self.global_step_time)
self.env.perform_action(action)
print(self.env.game_ended)
for i in range(self.in_between_steps):
self.env.step(
timedelta(seconds=self.global_step_time / self.in_between_steps)
)
state = self.env.get_json_state(player_id=self.player_id)
json_dict = json.loads(state)
observation = self.visualizer.get_state_image(
grid_size=30, state=json_dict
).transpose((1, 0, 2))
print(observation.shape)
cv2.imshow("Overcooked", observation[:, :, ::-1])
cv2.waitKey(1)
reward = -1
terminated = False
truncated = (False,)
info = "hey"
return observation, reward, terminated, truncated, info
def reset(self):
environment_config_path: Path = (
ROOT_DIR / "game_content" / "environment_config.yaml"
)
layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "basic.layout"
item_info_path: Path = ROOT_DIR / "game_content" / "item_info.yaml"
self.env: Environment = Environment(
env_config=environment_config_path,
layout_config=layout_path,
item_info=item_info_path,
)
with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file:
visualization_config = yaml.safe_load(file)
self.visualizer: Visualizer = Visualizer(config=visualization_config)
self.player_name = str(0)
self.env.add_player(self.player_name)
self.player_id = list(self.env.players.keys())[0]
self.visualizer.create_player_colors(1)
def main():
env = EnvGymWrapper()
while True:
action = env.sample_random_action()
print(action)
env.step(action)
if __name__ == "__main__":
main()
Loading