diff --git a/overcooked_simulator/gym_env.py b/overcooked_simulator/gym_env.py new file mode 100644 index 0000000000000000000000000000000000000000..ddf29efe2c51d2a36efe5e00659ea373ebe1a08b --- /dev/null +++ b/overcooked_simulator/gym_env.py @@ -0,0 +1,198 @@ +import json +from datetime import timedelta +from enum import Enum +from pathlib import Path + +import cv2 +import numpy as np +import yaml + +from overcooked_simulator import ROOT_DIR +from overcooked_simulator.gui_2d_vis.drawing import Visualizer +from overcooked_simulator.overcooked_environment import ( + Environment, + Action, + ActionType, + InterActionData, +) + + +class SimpleActionSpace(Enum): + Up = "Up" + Down = "Down" + Left = "Left" + Right = "Right" + Interact = "Interact" + Put = "Put" + + +class EnvGymWrapper: + """Should enable this: + observation, reward, terminated, truncated, info = env.step(action) + """ + + def __init__(self): + environment_config_path: Path = ( + ROOT_DIR / "game_content" / "environment_config.yaml" + ) + layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "basic.layout" + item_info_path: Path = ROOT_DIR / "game_content" / "item_info.yaml" + + self.env: Environment = Environment( + env_config=environment_config_path, + layout_config=layout_path, + item_info=item_info_path, + ) + + with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file: + visualization_config = yaml.safe_load(file) + + self.visualizer: Visualizer = Visualizer(config=visualization_config) + + self.player_name = str(0) + self.env.add_player(self.player_name) + self.player_id = list(self.env.players.keys())[0] + + self.visualizer.create_player_colors(1) + + # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)} + self.action_space = { + 0: SimpleActionSpace.Up, + 1: SimpleActionSpace.Down, + 2: SimpleActionSpace.Left, + 3: SimpleActionSpace.Right, + 4: SimpleActionSpace.Put, + } + + print(self.action_space) + + self.global_step_time = 0.05 + self.in_between_steps = 10 + + def get_env_action(self, simple_action, duration): + match simple_action: + case SimpleActionSpace.Up: + return Action( + self.player_id, + ActionType.MOVEMENT, + np.array([0, -1]), + duration, + ) + case SimpleActionSpace.Down: + return Action( + self.player_id, + ActionType.MOVEMENT, + np.array([0, 1]), + duration, + ) + + case SimpleActionSpace.Left: + return Action( + self.player_id, + ActionType.MOVEMENT, + np.array([-1, 0]), + duration, + ) + case SimpleActionSpace.Right: + return Action( + self.player_id, + ActionType.MOVEMENT, + np.array([1, 0]), + duration, + ) + case SimpleActionSpace.Put: + return Action( + self.player_id, + ActionType.PUT, + InterActionData.START, + duration, + ) + case SimpleActionSpace.Put: + return Action( + self.player_id, + ActionType.INTERACT, + InterActionData.START, + duration, + ) + # case SimpleActionSpace.Interact: + # pass + + def gym_env_setup(self): + self.action_space + self.observation_space + self.reward_range + + def render(self): + pass + + def close(self): + pass + + def sample_random_action(self): + return np.random.randint(len(self.action_space)) + + def step(self, simple_action) -> tuple: + simple_action = self.action_space[simple_action] + action = self.get_env_action(simple_action, self.global_step_time) + + self.env.perform_action(action) + + print(self.env.game_ended) + for i in range(self.in_between_steps): + self.env.step( + timedelta(seconds=self.global_step_time / self.in_between_steps) + ) + + state = self.env.get_json_state(player_id=self.player_id) + json_dict = json.loads(state) + observation = self.visualizer.get_state_image( + grid_size=30, state=json_dict + ).transpose((1, 0, 2)) + + print(observation.shape) + + cv2.imshow("Overcooked", observation[:, :, ::-1]) + cv2.waitKey(1) + + reward = -1 + terminated = False + truncated = (False,) + info = "hey" + return observation, reward, terminated, truncated, info + + def reset(self): + environment_config_path: Path = ( + ROOT_DIR / "game_content" / "environment_config.yaml" + ) + layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "basic.layout" + item_info_path: Path = ROOT_DIR / "game_content" / "item_info.yaml" + + self.env: Environment = Environment( + env_config=environment_config_path, + layout_config=layout_path, + item_info=item_info_path, + ) + + with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file: + visualization_config = yaml.safe_load(file) + + self.visualizer: Visualizer = Visualizer(config=visualization_config) + + self.player_name = str(0) + self.env.add_player(self.player_name) + self.player_id = list(self.env.players.keys())[0] + + self.visualizer.create_player_colors(1) + + +def main(): + env = EnvGymWrapper() + + while True: + action = env.sample_random_action() + print(action) + env.step(action) + + +if __name__ == "__main__": + main()