Skip to content
Snippets Groups Projects
Commit 75952d06 authored by Fabian Heinrich's avatar Fabian Heinrich
Browse files

Update to rl

parent ffc9571d
No related branches found
No related tags found
1 merge request!52Resolve "gym env"
Pipeline #45911 passed
plates:
clean_plates: 2
clean_plates: 1
dirty_plates: 0
plate_delay: [ 5, 10 ]
plate_delay: [ 2, 4 ]
return_dirty: False
# range of seconds until the dirty plate arrives.
game:
time_limit_seconds: 400
time_limit_seconds: 300
meals:
all: true
......@@ -93,7 +93,7 @@ extra_setup_functions:
hooks: [ completed_order ]
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: 100
static_score: 1
serve_not_ordered_meals:
func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class ''
......@@ -101,35 +101,35 @@ extra_setup_functions:
hooks: [ serve_not_ordered_meal ]
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: 100
static_score: 1
trashcan_usages:
func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class ''
kwargs:
hooks: [ trashcan_usage ]
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: -10
static_score: -0.15
item_cut:
func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class ''
kwargs:
hooks: [ cutting_board_100 ]
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: 10
static_score: 0.10
stepped:
func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class ''
kwargs:
hooks: [ post_step ]
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: -1
static_score: -0.01
combine:
func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class ''
kwargs:
hooks: [ drop_off_on_cooking_equipment ]
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: 1
static_score: 0.10
# json_states:
# func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
# kwargs:
......
#X###
T___#
#___#
U___P
#C#W#
#X##
T__W
U__P
#C##
......@@ -12,6 +12,7 @@ import pygame
import pygame_gui
import requests
import yaml
from pygame._sdl2 import get_drivers
from websockets.sync.client import connect
from overcooked_simulator import ROOT_DIR
......@@ -30,6 +31,9 @@ from overcooked_simulator.utils import (
add_list_of_manager_ids_arguments,
)
for driver in get_drivers():
print(driver)
class MenuStates(Enum):
Start = "Start"
......@@ -970,8 +974,8 @@ class PyGameGUI:
clock = pygame.time.Clock()
self.reset_window_size()
self.init_ui_elements()
self.reset_window_size()
self.manage_button_visibility()
self.update_selection_elements()
......
import json
import random
import time
from copy import deepcopy
from datetime import timedelta
from enum import Enum
from pathlib import Path
......@@ -17,6 +16,7 @@ from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecVideoRecorder
from wandb.integration.sb3 import WandbCallback
from overcooked_simulator import ROOT_DIR
......@@ -134,12 +134,18 @@ with open(environment_config_path, "r") as file:
with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file:
visualization_config = yaml.safe_load(file)
vanilla_env: Environment = Environment(
env_config=environment_config,
layout_config=layout,
item_info=item_info,
as_files=False,
)
def shuffle_counters(env):
sample_counter = []
sample_counter = []
for counter in env.counters:
if counter.__class__ != Counter:
sample_counter.append(counter)
new_counter_pos = [c.pos for c in sample_counter]
random.shuffle(new_counter_pos)
for counter, new_pos in zip(sample_counter, new_counter_pos):
counter.pos = new_pos
env.vector_state_generation = env.setup_vectorization()
class EnvGymWrapper(Env):
......@@ -147,29 +153,33 @@ class EnvGymWrapper(Env):
observation, reward, terminated, truncated, info = env.step(action)
"""
metadata = {"render_modes": ["human"], "render_fps": 30}
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}
def __init__(self):
super().__init__()
self.gridsize = 20
self.env = deepcopy(vanilla_env)
# sample_counter = []
# for counter in self.env.counters:
# if counter.__class__ != Counter:
# sample_counter.append(counter)
# new_counter_pos = [c.pos for c in sample_counter]
# random.shuffle(new_counter_pos)
# for counter, new_pos in zip(sample_counter, new_counter_pos):
# counter.pos = new_pos
# self.env.vector_state_generation = self.env.setup_vectorization()
self.randomize_counter_placement = False
self.use_rgb_obs = False
self.env: Environment = Environment(
env_config=environment_config,
layout_config=layout,
item_info=item_info,
as_files=False,
)
if self.randomize_counter_placement:
shuffle_counters(self.env)
self.visualizer: Visualizer = Visualizer(config=visualization_config)
self.player_name = str(0)
self.env.add_player(self.player_name)
self.player_id = list(self.env.players.keys())[0]
self.env.setup_vectorization()
self.visualizer.create_player_colors(1)
# self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)}
......@@ -184,9 +194,8 @@ class EnvGymWrapper(Env):
# Example for using image as input (channel-first; channel-last also works):
dummy_obs = self.get_observation()
# dummy_obs = self.get_vector_state()
self.observation_space = spaces.Box(
low=0, high=1, shape=dummy_obs.shape, dtype=float
low=-1, high=8, shape=dummy_obs.shape, dtype=int
)
self.last_obs = dummy_obs
......@@ -219,27 +228,22 @@ class EnvGymWrapper(Env):
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
# self.env: Environment = Environment(
# env_config=environment_config,
# layout_config=layout,
# item_info=item_info,
# as_files=False
# )
self.env: Environment = Environment(
env_config=environment_config,
layout_config=layout,
item_info=item_info,
as_files=False,
)
self.env = deepcopy(vanilla_env)
# sample_counter = []
# for counter in self.env.counters:
# if counter.__class__ != Counter:
# sample_counter.append(counter)
# new_counter_pos = [c.pos for c in sample_counter]
# random.shuffle(new_counter_pos)
# for counter, new_pos in zip(sample_counter, new_counter_pos):
# counter.pos = new_pos
# self.env.vector_state_generation = self.env.setup_vectorization()
if self.randomize_counter_placement:
shuffle_counters(self.env)
self.player_name = str(0)
self.env.add_player(self.player_name)
self.player_id = list(self.env.players.keys())[0]
self.env.setup_vectorization()
info = {}
obs = self.get_observation()
......@@ -248,16 +252,18 @@ class EnvGymWrapper(Env):
return obs, info
def get_observation(self):
# obs = self.get_env_img(self.gridsize)
obs = self.get_vector_state()
if self.use_rgb_obs:
obs = self.get_env_img(self.gridsize)
else:
obs = self.get_vector_state()
return obs
def render(self):
observation = self.get_env_img(self.gridsize)
img = observation.transpose((1, 2, 0))[:, :, ::-1]
# print(img.shape)
img = cv2.resize(img, (img.shape[1] * 5, img.shape[0] * 5))
cv2.imshow("Overcooked", img)
cv2.waitKey(1)
img = (observation * 255.0).astype(np.uint8)
img = img.transpose((1, 2, 0))
img = cv2.resize(img, (img.shape[1], img.shape[0]))
return img
def close(self):
pass
......@@ -268,20 +274,23 @@ class EnvGymWrapper(Env):
observation = self.visualizer.get_state_image(
grid_size=gridsize, state=json_dict
).transpose((1, 0, 2))
return observation.transpose((2, 0, 1)) / 255.0
return (observation.transpose((2, 0, 1)) / 255.0).astype(np.float32)
def get_vector_state(self):
grid, player, env_time, orders = self.env.get_vectorized_state("0")
# grid, player, env_time, orders = self.env.get_vectorized_state_full("0")
#
#
# obs = np.concatenate(
# [grid.flatten(), player.flatten()], axis=0, dtype=np.float32
# )
# return obs
obs = np.concatenate([grid.flatten(), player.flatten()], axis=0)
obs = self.env.get_vectorized_state_simple("0")
return obs
# flatten: grid + player
# concatenate all (env_time to array)
def sample_random_action(self):
act = self.action_space.sample()
return act
# return np.random.randint(len(self.action_space_map))
def main():
......@@ -290,14 +299,15 @@ def main():
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 100_000, # hendric sagt eher so 300_000_000 schritte
"total_timesteps": 30_000_000, # hendric sagt eher so 300_000_000 schritte
"env_id": "overcooked",
"number_envs_parallel": 16,
}
debug = True
debug = False
do_training = True
vec_env = True
number_envs_parallel = 8
number_envs_parallel = config["number_envs_parallel"]
model_classes = [A2C, DQN, PPO]
model_class = model_classes[2]
......@@ -307,12 +317,34 @@ def main():
else:
env = EnvGymWrapper()
model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"
env.render_mode = "rgb_array"
if not debug:
run = wandb.init(
project="overcooked",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True
# save_code=True, # optional
)
env = VecVideoRecorder(
env,
f"videos/{run.id}",
record_video_trigger=lambda x: x % 100_000 == 0,
video_length=300,
)
model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"
if do_training:
model = model_class(
config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{0}"
config["policy_type"],
env,
verbose=1,
tensorboard_log=f"runs/{0}",
# n_steps=2048,
# n_epochs=10,
)
if debug:
model.learn(
......@@ -321,16 +353,8 @@ def main():
progress_bar=True,
)
else:
run = wandb.init(
project="overcooked",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True
# save_code=True, # optional
)
checkpoint_callback = CheckpointCallback(
save_freq=1000,
save_freq=50_000,
save_path="./logs/",
name_prefix="rl_model",
save_replay_buffer=True,
......@@ -356,14 +380,17 @@ def main():
model = model_class.load(model_save_path)
env = EnvGymWrapper()
check_env(env)
obs, info = env.reset()
while True:
time.sleep(1 / 10)
time.sleep(1 / 30)
action, _states = model.predict(obs, deterministic=False)
obs, reward, terminated, truncated, info = env.step(int(action))
print(reward)
env.render()
rgb_img = env.render()
cv2.imshow("env", rgb_img)
cv2.waitKey(0)
if terminated or truncated:
obs, info = env.reset()
......
......@@ -48,7 +48,8 @@ from overcooked_simulator.hooks import (
ACTION_ON_NOT_REACHABLE_COUNTER,
ACTION_PUT,
ACTION_INTERACT_START,
ITEM_INFO_CONFIG, POST_STEP,
ITEM_INFO_CONFIG,
POST_STEP,
)
from overcooked_simulator.order import (
OrderManager,
......@@ -906,9 +907,11 @@ class Environment:
if name in self.vector_state_generation.meals:
idx = 3 + self.vector_state_generation.meals.index(name)
elif name in self.vector_state_generation.ingredients:
idx = 3 + len(
self.vector_state_generation.meals
) + self.vector_state_generation.ingredients.index(name)
idx = (
3
+ len(self.vector_state_generation.meals)
+ self.vector_state_generation.ingredients.index(name)
)
else:
raise ValueError(f"Unknown item {name} - {item}")
array[idx] = 1.0
......@@ -952,7 +955,7 @@ class Environment:
return item_array
def get_vectorized_state(
def get_vectorized_state_full(
self, player_id: str
) -> Tuple[
npt.NDArray[npt.NDArray[float]],
......@@ -1044,6 +1047,162 @@ class Environment:
order_array,
)
# def setup_vectorization_simple(self) -> VectorStateGenerationDataSimple:
# num_per_item = 114
# num_per_counter = 12
# num_players = 4
# grid_base_array = np.zeros(
# (
# int(self.kitchen_width),
# int(self.kitchen_height),
# num_per_item
# + num_per_counter
# + num_players, # TODO calc based on item info
# ),
# dtype=np.float32,
# )
# counter_list = [
# "Counter",
# "CuttingBoard",
# "ServingWindow",
# "Trashcan",
# "Sink",
# "SinkAddon",
# "Stove",
# "DeepFryer",
# "Oven",
# ]
# grid_idxs = [
# (x, y)
# for x in range(int(self.kitchen_width))
# for y in range(int(self.kitchen_height))
# ]
# # counters do not move
# for counter in self.counters:
# grid_idx = np.floor(counter.pos).astype(int)
# counter_name = (
# counter.name
# if isinstance(counter, CookingCounter)
# else (
# repr(counter)
# if isinstance(Counter, Dispenser)
# else counter.__class__.__name__
# )
# )
# assert counter_name in counter_list or counter_name.endswith(
# "Dispenser"
# ), f"Unknown Counter {counter}"
# oh_idx = len(counter_list)
# if counter_name in counter_list:
# oh_idx = counter_list.index(counter_name)
#
# one_hot = [0] * (len(counter_list) + 2)
# one_hot[oh_idx] = 1
# grid_base_array[
# grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2)
# ] = np.array(one_hot, dtype=np.float32)
#
# grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
#
# for free_idx in grid_idxs:
# one_hot = [0] * (len(counter_list) + 2)
# one_hot[len(counter_list) + 1] = 1
# grid_base_array[
# free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2)
# ] = np.array(one_hot, dtype=np.float32)
#
# player_info_base_array = np.zeros(
# (
# 4,
# 4 + 114,
# ),
# dtype=np.float32,
# )
# order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32)
#
# return VectorStateGenerationData(
# grid_base_array=grid_base_array,
# oh_len=12,
# )
def get_vectorized_state_simple(self, player):
item_list = ["Pot", "Tomato", "ChoppedTomato", "Plate"]
counter_list = [
"Counter",
"PlateDispenser",
"TomatoDispenser",
"ServingWindow",
"PlateReturn",
"Trashcan",
"Stove",
"CuttingBoard",
]
player_pos = self.players[player].pos
player_dir = self.players[player].facing_direction
grid_width, grid_height = int(self.kitchen_width), int(self.kitchen_height)
counter_one_hot_length = len(counter_list) + 1 # one for empty field
grid_base_array = np.zeros(
(
grid_width,
grid_height,
),
dtype=int,
)
grid_idxs = [(x, y) for x in range(grid_width) for y in range(grid_height)]
# counters do not move
for counter in self.counters:
grid_idx = np.floor(counter.pos).astype(int)
counter_name = (
counter.name
if isinstance(counter, CookingCounter)
else (
repr(counter)
if isinstance(Counter, Dispenser)
else counter.__class__.__name__
)
)
if counter_name == "Dispenser":
counter_name = f"{counter.occupied_by.name}Dispenser"
assert counter_name in counter_list, f"Unknown Counter {counter}"
counter_oh_idx = counter_one_hot_length
if counter_name in counter_list:
counter_oh_idx = counter_list.index(counter_name)
grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx
grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
for free_idx in grid_idxs:
grid_base_array[free_idx[0], free_idx[1]] = counter_one_hot_length - 1
counter_grid_one_hot = np.zeros(
(grid_width, grid_height, counter_one_hot_length), dtype=int
)
for x in range(grid_width):
for y in range(grid_height):
counter_type_idx = grid_base_array[x, y]
counter_grid_one_hot[x, y, counter_type_idx] = 1
player_data = np.concatenate((player_pos, player_dir), axis=0)
items_one_hot_length = len(item_list) + 1
item_one_hot = np.zeros(items_one_hot_length, dtype=int)
player_item = self.players[player].holding
player_item_idx = items_one_hot_length - 1
if player_item:
if player_item.name in item_list:
player_item_idx = item_list.index(player_item.name)
item_one_hot[player_item_idx] = 1
final = np.concatenate(
(counter_grid_one_hot.flatten(), player_data, item_one_hot), axis=0
)
return final
def reset_env_time(self):
"""Reset the env time to the initial time, defined by `create_init_env_time`."""
self.hook(PRE_RESET_ENV_TIME)
......
......@@ -65,6 +65,27 @@ class VectorStateGenerationData:
]
@dataclasses.dataclass
class VectorStateGenerationDataSimple:
grid_base_array: npt.NDArray[npt.NDArray[float]]
oh_len: int
number_normal_ingredients = 1
meals = [
"TomatoSoup",
]
equipments = [
"Pot",
"Plate",
"DirtyPlate",
"Extinguisher",
]
ingredients = [
"Tomato",
]
def create_init_env_time():
"""Init time of the environment time, because all environments should have the same internal time."""
return datetime(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment