diff --git a/overcooked_simulator/game_content/layouts/rl.layout b/overcooked_simulator/game_content/layouts/rl.layout new file mode 100644 index 0000000000000000000000000000000000000000..5c8982c7242ffd0a962d3cdf9a59dffb6a98c369 --- /dev/null +++ b/overcooked_simulator/game_content/layouts/rl.layout @@ -0,0 +1,6 @@ +#U#T## +C____# +#____+ +W____S +#____# +#X#P## diff --git a/overcooked_simulator/gui_2d_vis/drawing.py b/overcooked_simulator/gui_2d_vis/drawing.py index 62878e72a6dacb9fe63d3e3edc4c3c2eb0ff85f5..744c86a307c7bbc42c8b10cbc1ba038b299ca80f 100644 --- a/overcooked_simulator/gui_2d_vis/drawing.py +++ b/overcooked_simulator/gui_2d_vis/drawing.py @@ -80,6 +80,7 @@ class Visualizer: self.fire_state = 0 self.fire_time_steps = 8 + self.observation_screen = None def create_player_colors(self, n) -> None: """Create different colors for the players. The color hues are sampled uniformly in HSV-Space, @@ -815,13 +816,15 @@ class Visualizer: height = int(np.ceil(state["kitchen"]["height"] * grid_size)) flags = pygame.HIDDEN - screen = pygame.display.set_mode((width, height), flags=flags) - self.draw_gamescreen(screen, state, grid_size, [0 for _ in state["players"]]) + if not self.observation_screen: + self.observation_screen = pygame.display.set_mode((width, height), flags=flags) + + self.draw_gamescreen(self.observation_screen, state, grid_size, [0 for _ in state["players"]]) - red = pygame.surfarray.array_red(screen) - green = pygame.surfarray.array_green(screen) - blue = pygame.surfarray.array_blue(screen) + red = pygame.surfarray.array_red(self.observation_screen) + green = pygame.surfarray.array_green(self.observation_screen) + blue = pygame.surfarray.array_blue(self.observation_screen) res = np.stack([red, green, blue], axis=2) return res diff --git a/overcooked_simulator/gym_env.py b/overcooked_simulator/gym_env.py index ddf29efe2c51d2a36efe5e00659ea373ebe1a08b..ee2ba2ddd15208b72a8a1eb2b65846bb4950a5a2 100644 --- a/overcooked_simulator/gym_env.py +++ b/overcooked_simulator/gym_env.py @@ -1,4 +1,5 @@ import json +import time from datetime import timedelta from enum import Enum from pathlib import Path @@ -16,26 +17,114 @@ from overcooked_simulator.overcooked_environment import ( InterActionData, ) +import gymnasium as gym +import numpy as np +from gymnasium import spaces, Env + +from stable_baselines3.common.env_checker import check_env + + +SimpleActionSpace = Enum("SimpleActionSpace", ["Up", + # "Up_Left", + "Left", + # "Down_Left", + "Down", + # "Down_Right", + "Right", + # "Right_Up", + "Interact", + "Put"]) + + +def get_env_action(player_id, simple_action, duration): + match simple_action: + case SimpleActionSpace.Up: + return Action( + player_id, + ActionType.MOVEMENT, + np.array([0, -1]), + duration, + ) + # case SimpleActionSpace.Up_Left: + # return Action( + # player_id, + # ActionType.MOVEMENT, + # np.array([-1, -1]), + # duration, + # ) + case SimpleActionSpace.Left: + return Action( + player_id, + ActionType.MOVEMENT, + np.array([-1, 0]), + duration, + ) + # case SimpleActionSpace.Down_Left: + # return Action( + # player_id, + # ActionType.MOVEMENT, + # np.array([-1, 1]), + # duration, + # ) + case SimpleActionSpace.Down: + return Action( + player_id, + ActionType.MOVEMENT, + np.array([0, 1]), + duration, + ) + # case SimpleActionSpace.Down_Right: + # return Action( + # player_id, + # ActionType.MOVEMENT, + # np.array([1, 1]), + # duration, + # ) + case SimpleActionSpace.Right: + return Action( + player_id, + ActionType.MOVEMENT, + np.array([1, 0]), + duration, + ) + # case SimpleActionSpace.Right_Up: + # return Action( + # player_id, + # ActionType.MOVEMENT, + # np.array([1, -1]), + # duration, + # ) + case SimpleActionSpace.Put: + return Action( + player_id, + ActionType.PUT, + InterActionData.START, + duration, + ) + case SimpleActionSpace.Interact: + return Action( + player_id, + ActionType.INTERACT, + InterActionData.START, + duration, + ) + case other: + print("FAIL", simple_action) -class SimpleActionSpace(Enum): - Up = "Up" - Down = "Down" - Left = "Left" - Right = "Right" - Interact = "Interact" - Put = "Put" - - -class EnvGymWrapper: +class EnvGymWrapper(Env): """Should enable this: observation, reward, terminated, truncated, info = env.step(action) """ + metadata = {"render_modes": ["human"], "render_fps": 30} def __init__(self): + super().__init__() + + self.gridsize = 20 environment_config_path: Path = ( ROOT_DIR / "game_content" / "environment_config.yaml" ) - layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "basic.layout" + layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "rl.layout" item_info_path: Path = ROOT_DIR / "game_content" / "item_info.yaml" self.env: Environment = Environment( @@ -48,7 +137,6 @@ class EnvGymWrapper: visualization_config = yaml.safe_load(file) self.visualizer: Visualizer = Visualizer(config=visualization_config) - self.player_name = str(0) self.env.add_player(self.player_name) self.player_id = list(self.env.players.keys())[0] @@ -56,115 +144,60 @@ class EnvGymWrapper: self.visualizer.create_player_colors(1) # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)} - self.action_space = { - 0: SimpleActionSpace.Up, - 1: SimpleActionSpace.Down, - 2: SimpleActionSpace.Left, - 3: SimpleActionSpace.Right, - 4: SimpleActionSpace.Put, - } - - print(self.action_space) - - self.global_step_time = 0.05 - self.in_between_steps = 10 - - def get_env_action(self, simple_action, duration): - match simple_action: - case SimpleActionSpace.Up: - return Action( - self.player_id, - ActionType.MOVEMENT, - np.array([0, -1]), - duration, - ) - case SimpleActionSpace.Down: - return Action( - self.player_id, - ActionType.MOVEMENT, - np.array([0, 1]), - duration, - ) - - case SimpleActionSpace.Left: - return Action( - self.player_id, - ActionType.MOVEMENT, - np.array([-1, 0]), - duration, - ) - case SimpleActionSpace.Right: - return Action( - self.player_id, - ActionType.MOVEMENT, - np.array([1, 0]), - duration, - ) - case SimpleActionSpace.Put: - return Action( - self.player_id, - ActionType.PUT, - InterActionData.START, - duration, - ) - case SimpleActionSpace.Put: - return Action( - self.player_id, - ActionType.INTERACT, - InterActionData.START, - duration, - ) - # case SimpleActionSpace.Interact: - # pass - - def gym_env_setup(self): - self.action_space - self.observation_space - self.reward_range + self.action_space_map ={} + for idx, item in enumerate(SimpleActionSpace): + self.action_space_map[idx] = item - def render(self): - pass - def close(self): - pass + self.global_step_time = 0.5 + self.in_between_steps = 1 - def sample_random_action(self): - return np.random.randint(len(self.action_space)) + self.action_space = spaces.Discrete(len(self.action_space_map)) + # Example for using image as input (channel-first; channel-last also works): + + dummy_obs = self.get_env_img(self.gridsize) + self.observation_space = spaces.Box(low=0, high=255, + shape=dummy_obs.shape, dtype=np.uint8) - def step(self, simple_action) -> tuple: - simple_action = self.action_space[simple_action] - action = self.get_env_action(simple_action, self.global_step_time) + self.last_obs = dummy_obs - self.env.perform_action(action) + self.step_counter = 0 + self.prev_score = 0 + + def step(self, action): + simple_action = self.action_space_map[action] + env_action = get_env_action(self.player_id, simple_action, self.global_step_time) + self.env.perform_action(env_action) - print(self.env.game_ended) for i in range(self.in_between_steps): self.env.step( timedelta(seconds=self.global_step_time / self.in_between_steps) ) - state = self.env.get_json_state(player_id=self.player_id) - json_dict = json.loads(state) - observation = self.visualizer.get_state_image( - grid_size=30, state=json_dict - ).transpose((1, 0, 2)) - - print(observation.shape) - - cv2.imshow("Overcooked", observation[:, :, ::-1]) - cv2.waitKey(1) + observation = self.get_env_img(self.gridsize) reward = -1 - terminated = False - truncated = (False,) - info = "hey" + if self.env.order_and_score.score > self.prev_score and self.env.score != 0: + self.prev_score = self.env + reward = 100 + elif self.env.order_and_score.score < self.prev_score: + self.prev_score = 0 + reward = 0 + + terminated = self.env.game_ended + truncated = self.env.game_ended + info = {} + + # self.render(self.gridsize) return observation, reward, terminated, truncated, info - def reset(self): + + def reset(self, seed=None, options=None): + environment_config_path: Path = ( - ROOT_DIR / "game_content" / "environment_config.yaml" + ROOT_DIR / "game_content" / "environment_config.yaml" ) - layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "basic.layout" + layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "rl.layout" item_info_path: Path = ROOT_DIR / "game_content" / "item_info.yaml" self.env: Environment = Environment( @@ -184,14 +217,82 @@ class EnvGymWrapper: self.visualizer.create_player_colors(1) + info = {} + return self.get_env_img(self.gridsize), info + + def render(self): + observation = self.get_env_img(self.gridsize) + img = observation.transpose((1,2,0))[:,:,::-1] + img = cv2.resize(img, (img.shape[1]*5, img.shape[0]*5)) + print(img.shape) + cv2.imshow("Overcooked",img) + cv2.waitKey(1) + + def close(self): + pass + + + def get_env_img(self, gridsize): + state = self.env.get_json_state(player_id=self.player_id) + json_dict = json.loads(state) + observation = self.visualizer.get_state_image( + grid_size=gridsize, state=json_dict + ).transpose((1, 0, 2)) + return observation.transpose((2,0,1)) + + + + def sample_random_action(self): + act = self.action_space.sample() + return act + # return np.random.randint(len(self.action_space_map)) + def main(): env = EnvGymWrapper() + check_env(env) + + from stable_baselines3.common.env_util import make_vec_env + vec_env = make_vec_env(EnvGymWrapper, n_envs=4) + + # print("start") + + # start_t = time.time() + # for i in range(10000): + # if i%100==0: + # print(i) + # env.step(env.action_space.sample()) + # print("DURATION", time.time() - start_t) + # exit() + + + # from stable_baselines3 import A2C + # from stable_baselines3 import DQN + from stable_baselines3 import PPO + + + RL_CLASS = PPO + + model = RL_CLASS("CnnPolicy", vec_env, verbose=1) + + model.learn(total_timesteps=50000, log_interval=1, progress_bar=True) + model.save("oc") + del model # remove to demonstrate saving and loading + + print("LEARNING DONE.") + + model = RL_CLASS.load("oc") + + + obs, info = env.reset() while True: - action = env.sample_random_action() - print(action) - env.step(action) + time.sleep(1/30) + action, _states = model.predict(obs, deterministic=False) + obs, reward, terminated, truncated, info = env.step(int(action)) + env.render() + if terminated or truncated: + obs, info = env.reset() if __name__ == "__main__": diff --git a/overcooked_simulator/order.py b/overcooked_simulator/order.py index 6cedaaab7b1165a60c48c39a3ae54c408839d305..990be494c37bc59747e286c4cd31dca3d9e7c866 100644 --- a/overcooked_simulator/order.py +++ b/overcooked_simulator/order.py @@ -676,6 +676,9 @@ def serving_not_ordered_meals_with_zero_score(meal: Item) -> Tuple[bool, float | """Not ordered meals are accepted but do not affect the score.""" return True, 0 +def serving_not_ordered_meals_with_five_score(meal: Item) -> Tuple[bool, float | int]: + """Not ordered meals are accepted but do not affect the score.""" + return True, 5 def penalty_for_each_item(remove: Item | list[Item]) -> float: if isinstance(remove, list): diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index 5302f55a72d1270649efbac5324a94db791c9f6a..d46cec89b65462156f3a2ba61e4463ff6c77f11e 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -787,7 +787,7 @@ class Environment: self.hook(STATE_DICT, state=state, player_id=player_id) json_data = json.dumps(state) self.hook(JSON_STATE, json_data=json_data, player_id=player_id) - assert StateRepresentation.model_validate_json(json_data=json_data) + # assert StateRepresentation.model_validate_json(json_data=json_data) return json_data raise ValueError(f"No valid {player_id=}")