Newer
Older
import uuid
from collections import deque
from datetime import timedelta
from enum import Enum
from pathlib import Path
from uuid import uuid4
import cv2
import numpy as np
import wandb
import yaml
from gymnasium import spaces, Env
from stable_baselines3 import A2C
from stable_baselines3 import DQN
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecVideoRecorder
from wandb.integration.sb3 import WandbCallback
from cooperative_cuisine import ROOT_DIR
from cooperative_cuisine.action import ActionType, InterActionData, Action
from cooperative_cuisine.counters import Counter, CookingCounter, Dispenser
from cooperative_cuisine.environment import (
from cooperative_cuisine.items import CookingEquipment
from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer, CacheFlags
"""Enumeration of actions for simple action spaces for an RL agent."""
Up = "Up"
Down = "Down"
Left = "Left"
Right = "Right"
Interact = "Interact"
Put = "Put"
def get_env_action(player_id, simple_action, duration):
match simple_action:
case SimpleActionSpace.Up:
return Action(
player_id,
ActionType.MOVEMENT,
np.array([0, -1]),
duration,
)
case SimpleActionSpace.Left:
return Action(
player_id,
ActionType.MOVEMENT,
np.array([-1, 0]),
duration,
)
case SimpleActionSpace.Down:
return Action(
player_id,
ActionType.MOVEMENT,
np.array([0, 1]),
duration,
)
case SimpleActionSpace.Right:
return Action(
player_id,
ActionType.MOVEMENT,
np.array([1, 0]),
duration,
)
case SimpleActionSpace.Put:
return Action(
player_id,
InterActionData.START,
duration,
)
case SimpleActionSpace.Interact:
return Action(
player_id,
ActionType.INTERACT,
InterActionData.START,
duration,
)
environment_config_path = (
ROOT_DIR / "reinforcement_learning" / "environment_config_rl.yaml"
)
layout_path: Path = ROOT_DIR / "reinforcement_learning" / "rl_small.layout"
item_info_path = ROOT_DIR / "reinforcement_learning" / "item_info_rl.yaml"
with open(item_info_path, "r") as file:
item_info = file.read()
with open(layout_path, "r") as file:
layout = file.read()
with open(environment_config_path, "r") as file:
environment_config = file.read()
with open(ROOT_DIR / "pygame_2d_vis" / "visualization.yaml", "r") as file:
visualization_config = yaml.safe_load(file)
visualizer: Visualizer = Visualizer(config=visualization_config)
visualizer.create_player_colors(1)
visualizer.set_grid_size(40)
def shuffle_counters(env):
sample_counter = []
other_counters = []
for counter in env.counters:
if counter.__class__ != Counter:
sample_counter.append(counter)
else:
other_counters.append(counter)
new_counter_pos = [c.pos for c in sample_counter]
random.shuffle(new_counter_pos)
for counter, new_pos in zip(sample_counter, new_counter_pos):
counter.pos = new_pos
sample_counter.extend(other_counters)
env.overwrite_counters(sample_counter)
class EnvGymWrapper(Env):
"""Should enable this:
observation, reward, terminated, truncated, info = env.step(action)
"""
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10}
def __init__(self):
super().__init__()
self.randomize_counter_placement = False
self.use_rgb_obs = True # if False uses simple vectorized state
self.onehot_state = True
self.env: Environment = Environment(
env_config=environment_config,
layout_config=layout,
item_info=item_info,
as_files=False,
env_name=uuid.uuid4().hex,
)
if self.randomize_counter_placement:
shuffle_counters(self.env)
self.player_name = str(0)
self.env.add_player(self.player_name)
self.player_id = list(self.env.players.keys())[0]
self.visualizer = visualizer
# self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)}
self.action_space_map = {}
for idx, item in enumerate(SimpleActionSpace):
self.action_space_map[idx] = item
self.global_step_time = 1
self.in_between_steps = 1
self.action_space = spaces.Discrete(len(self.action_space_map))
min_obs_val = -1 if not self.use_rgb_obs else 0
max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 20
self.observation_space = spaces.Box(
low=min_obs_val,
high=max_obs_val,
shape=dummy_obs.shape,
dtype=np.uint8 if self.use_rgb_obs else int,
)
self.last_obs = dummy_obs
self.step_counter = 0
self.prev_score = 0
def vectorize_item(self, item, item_list):
item_one_hot = np.zeros(len(item_list))
if item is None:
item_name = "None"
elif isinstance(item, deque):
if len(item) > 0:
item_name = item[0].name
else:
item_name = "None"
else:
item_name = item.name
if isinstance(item, CookingEquipment):
if item.name == "Pot":
if len(item.content_list) > 0:
if item.content_list[0].name == "TomatoSoup":
item_name = "PotDone"
elif len(item.content_list) == 1:
item_name = "PotOne"
elif len(item.content_list) == 2:
item_name = "PotTwo"
elif len(item.content_list) == 3:
item_name = "PotThree"
if "Plate" in item.name:
content_list = [i.name for i in item.content_list]
match content_list:
case ["TomatoSoup"]:
item_name = "PlateTomatoSoup"
case ["ChoppedTomato"]:
item_name = "PlateChoppedTomato"
case ["ChoppedLettuce"]:
item_name = "PlateChoppedLettuce"
case []:
item_name = "Plate"
case ["ChoppedLettuce", "ChoppedTomato"]:
item_name = "PlateSalad"
case other:
assert False, f"Should not happen. {item}"
assert item_name in item_list, f"Unknown item {item_name}."
item_idx = item_list.index(item_name)
item_one_hot[item_idx] = 1
# if item_name not in self.seen_items:
# print(item, item_name)
# self.seen_items.append(item_name)
return item_one_hot, item_idx
@staticmethod
def vectorize_counter(counter, counter_list):
counter_name = (
counter.name
if isinstance(counter, CookingCounter)
else (
repr(counter)
if isinstance(Counter, Dispenser)
else counter.__class__.__name__
)
)
if counter_name == "Dispenser":
counter_name = f"{counter.occupied_by.name}Dispenser"
assert counter_name in counter_list, f"Unknown Counter {counter}"
counter_oh_idx = counter_list.index("Empty")
if counter_name in counter_list:
counter_oh_idx = counter_list.index(counter_name)
counter_one_hot = np.zeros(len(counter_list), dtype=int)
counter_one_hot[counter_oh_idx] = 1
return counter_one_hot, counter_oh_idx
def get_vectorized_state_simple(self, player, onehot=True):
counter_list = [
"Counter",
"PlateDispenser",
"TomatoDispenser",
"ServingWindow",
"PlateReturn",
"Trashcan",
"Stove",
"CuttingBoard",
item_list = [
"None",
"Pot",
"PotOne",
"PotTwo",
"PotThree",
"PotDone",
"Tomato",
"ChoppedTomato",
"Plate",
"PlateTomatoSoup",
"PlateSalad",
"Lettuce",
"PlateChoppedTomato",
"PlateChoppedLettuce",
"ChoppedLettuce",
grid_width, grid_height = int(self.env.kitchen_width), int(
self.env.kitchen_height
)
grid_base_array = np.zeros(
(
grid_width,
grid_height,
),
dtype=int,
)
grid_idxs = [(x, y) for x in range(grid_width) for y in range(grid_height)]
if onehot:
item_one_hot_length = len(item_list)
counter_items = np.zeros(
(grid_width, grid_height, item_one_hot_length), dtype=int
)
counter_one_hot_length = len(counter_list)
counters = np.zeros(
(grid_width, grid_height, counter_one_hot_length), dtype=int
)
counter_items = np.zeros((grid_width, grid_height), dtype=int)
counters = np.zeros((grid_width, grid_height), dtype=int)
for counter in self.env.counters:
grid_idx = np.floor(counter.pos).astype(int)
counter_one_hot, counter_oh_idx = self.vectorize_counter(
counter, counter_list
)
grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx
grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
counter_item_one_hot, counter_item_oh_idx = self.vectorize_item(
counter.occupied_by, item_list
)
counter_items[grid_idx] = (
counter_item_one_hot if onehot else counter_item_oh_idx
)
counters[grid_idx] = counter_one_hot if onehot else counter_oh_idx
grid_base_array[free_idx[0], free_idx[1]] = counter_list.index("Empty")
player_pos = self.env.players[player].pos.astype(int)
player_dir = self.env.players[player].facing_direction.astype(int)
player_data = np.concatenate((player_pos, player_dir), axis=0)
player_item_one_hot, player_item_idx = self.vectorize_item(
self.env.players[player].holding, item_list
player_item = player_item_one_hot if onehot else [player_item_idx]
counters.flatten(),
counter_items.flatten(),
player_data.flatten(),
player_item,
def step(self, action):
simple_action = self.action_space_map[action]
env_action = get_env_action(
self.player_id, simple_action, self.global_step_time
)
self.env.perform_action(env_action)
for i in range(self.in_between_steps):
self.env.step(
timedelta(seconds=self.global_step_time / self.in_between_steps)
)
observation = self.get_observation()
reward = self.env.score - self.prev_score
self.prev_score = self.env.score
if reward > 0.6:
print("- - - - - - - - - - - - - - - - SCORED", reward)
terminated = self.env.game_ended
truncated = self.env.game_ended
info = {}
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
del visualizer.surface_cache_dict[self.env.env_name]
self.env: Environment = Environment(
env_config=environment_config,
layout_config=layout,
item_info=item_info,
as_files=False,
env_name=uuid.uuid4().hex,
)
if self.randomize_counter_placement:
shuffle_counters(self.env)
self.player_name = str(0)
self.env.add_player(self.player_name)
self.player_id = list(self.env.players.keys())[0]
info = {}
obs = self.get_observation()
self.prev_score = 0
return obs, info
def get_observation(self):
if self.use_rgb_obs:
obs = self.get_env_img()
else:
obs = self.get_vector_state()
return obs
def render(self):
return self.get_env_img()
def get_env_img(self):
state = self.env.get_json_state(player_id=self.player_id)
json_dict = json.loads(state)
observation = self.visualizer.get_state_image(state=json_dict, env_id_ref=self.env.env_name).astype(np.uint8)
return observation
obs = self.get_vectorized_state_simple("0", self.onehot_state)
return obs
def sample_random_action(self):
act = self.action_space.sample()
return act
def main():
rl_agent_checkpoints = Path("rl_agent_checkpoints")
rl_agent_checkpoints.mkdir(exist_ok=True)
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 3_000_000, # hendric sagt eher so 300_000_000 schritte
do_training = True
vec_env = True
number_envs_parallel = config["number_envs_parallel"]
model_classes = [A2C, DQN, PPO]
model_class = model_classes[1]
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
if vec_env:
env = make_vec_env(EnvGymWrapper, n_envs=number_envs_parallel)
else:
env = EnvGymWrapper()
env.render_mode = "rgb_array"
if not debug:
run = wandb.init(
project="overcooked",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True,
# save_code=True, # optional
)
env = VecVideoRecorder(
env,
f"videos/{run.id}",
record_video_trigger=lambda x: x % 200_000 == 0,
video_length=300,
)
model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"
if do_training:
model = model_class(
config["policy_type"],
env,
verbose=1,
tensorboard_log=f"runs/{0}",
device="cpu"
# n_steps=2048,
# n_epochs=10,
)
if debug:
model.learn(
total_timesteps=config["total_timesteps"],
log_interval=1,
progress_bar=True,
)
else:
checkpoint_callback = CheckpointCallback(
save_freq=50_000,
save_path="logs",
name_prefix="rl_model",
save_replay_buffer=True,
save_vecnormalize=True,
)
wandb_callback = WandbCallback(
model_save_path=f"models/{run.id}",
verbose=0,
)
callback = CallbackList([checkpoint_callback, wandb_callback])
model.learn(
total_timesteps=config["total_timesteps"],
callback=callback,
log_interval=1,
progress_bar=True,
)
run.finish()
model.save(model_save_path)
del model
print("LEARNING DONE.")
model = model_class.load(model_save_path)
env = EnvGymWrapper()
check_env(env)
obs, info = env.reset()
while True:
action, _states = model.predict(obs, deterministic=False)
obs, reward, terminated, truncated, info = env.step(int(action))
print(reward)
rgb_img = env.render()
cv2.imshow("env", rgb_img)
cv2.waitKey(0)
if terminated or truncated:
obs, info = env.reset()
time.sleep(1 / env.metadata["render_fps"])
if __name__ == "__main__":
main()