Skip to content
Snippets Groups Projects
Commit 4d7c2edc authored by Christoph Kowalski's avatar Christoph Kowalski
Browse files

Implemented two basic StateConverters and integrated them into the hydra config management

parent 3f1cc9e8
No related branches found
No related tags found
3 merge requests!110V1.2.0 changes,!102Fixed caching of recipe layouts. Ids were used in hash, which are generated...,!98Resolve "Restructure Reinforcement Learning files"
Pipeline #57050 passed
defaults:
- order_generator: random_order_generator
# Here the filename of the converter should be given. The converter class needs to be called StateConverter and implement the abstract StateToObservationConverter class
state_converter: "base_converter_onehot"
\ No newline at end of file
...@@ -2,4 +2,4 @@ defaults: ...@@ -2,4 +2,4 @@ defaults:
- environment: environment_config_rl - environment: environment_config_rl
- item_info: item_info_rl - item_info: item_info_rl
- model: PPO - model: PPO
- order_generator: random_order_generator - additional_configs: additional_config_base
\ No newline at end of file \ No newline at end of file
import importlib
import json import json
import random import random
import time from abc import abstractmethod
from collections import deque
from copy import deepcopy from copy import deepcopy
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
...@@ -9,29 +9,16 @@ from pathlib import Path ...@@ -9,29 +9,16 @@ from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
import wandb
import yaml import yaml
from abc import abstractmethod
from gymnasium import spaces, Env from gymnasium import spaces, Env
import hydra from omegaconf import OmegaConf
from omegaconf import DictConfig, OmegaConf
from stable_baselines3 import A2C
from stable_baselines3 import DQN
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecVideoRecorder
from wandb.integration.sb3 import WandbCallback
from cooperative_cuisine import ROOT_DIR from cooperative_cuisine import ROOT_DIR
from cooperative_cuisine.action import ActionType, InterActionData, Action from cooperative_cuisine.action import ActionType, InterActionData, Action
from cooperative_cuisine.counters import Counter, CookingCounter, Dispenser from cooperative_cuisine.counters import Counter
from cooperative_cuisine.environment import ( from cooperative_cuisine.environment import (
Environment, Environment,
) )
from cooperative_cuisine.items import CookingEquipment
from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer
...@@ -117,11 +104,8 @@ def shuffle_counters(env): ...@@ -117,11 +104,8 @@ def shuffle_counters(env):
class StateToObservationConverter: class StateToObservationConverter:
def __init__(self, config):
self.config = config
@abstractmethod @abstractmethod
def setup(self): def setup(self, env):
... ...
@abstractmethod @abstractmethod
...@@ -129,6 +113,13 @@ class StateToObservationConverter: ...@@ -129,6 +113,13 @@ class StateToObservationConverter:
... ...
def get_converter(converter_name):
module_path = f"cooperative_cuisine.reinforcement_learning.obs_converter.{converter_name}"
module = importlib.import_module(module_path)
converter_class = getattr(module, "StateConverter")
return converter_class()
class EnvGymWrapper(Env): class EnvGymWrapper(Env):
"""Should enable this: """Should enable this:
observation, reward, terminated, truncated, info = env.step(action) observation, reward, terminated, truncated, info = env.step(action)
...@@ -142,10 +133,9 @@ class EnvGymWrapper(Env): ...@@ -142,10 +133,9 @@ class EnvGymWrapper(Env):
self.randomize_counter_placement = False self.randomize_counter_placement = False
self.use_rgb_obs = False # if False uses simple vectorized state self.use_rgb_obs = False # if False uses simple vectorized state
self.full_vector_state = True self.full_vector_state = True
self.onehot_state = True
config_env = OmegaConf.to_container(config.environment, resolve=True) config_env = OmegaConf.to_container(config.environment, resolve=True)
config_item_info = OmegaConf.to_container(config.item_info, resolve=True) config_item_info = OmegaConf.to_container(config.item_info, resolve=True)
order_generator = config.order_generator.order_generator_type order_generator = config.additional_configs.order_generator.order_generator_type
order_file = order_generator + "_orders.yaml" order_file = order_generator + "_orders.yaml"
custom_config_path = ROOT_DIR / "reinforcement_learning" / "config" / order_file custom_config_path = ROOT_DIR / "reinforcement_learning" / "config" / order_file
with open(custom_config_path, "r") as file: with open(custom_config_path, "r") as file:
...@@ -184,7 +174,16 @@ class EnvGymWrapper(Env): ...@@ -184,7 +174,16 @@ class EnvGymWrapper(Env):
self.action_space = spaces.Discrete(len(self.action_space_map)) self.action_space = spaces.Discrete(len(self.action_space_map))
self.seen_items = [] self.seen_items = []
self.converter = get_converter(config.additional_configs.state_converter)
self.converter.setup(self.env)
try:
self.onehot_state = self.converter.onehot
except AttributeError:
if 'onehot' in config.additional_configs.state_converter.lower():
self.onehot_state = True
else:
self.onehot_state = False
dummy_obs = self.get_observation() dummy_obs = self.get_observation()
min_obs_val = -1 if not self.use_rgb_obs else 0 min_obs_val = -1 if not self.use_rgb_obs else 0
max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 20 max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 20
...@@ -196,179 +195,9 @@ class EnvGymWrapper(Env): ...@@ -196,179 +195,9 @@ class EnvGymWrapper(Env):
) )
self.last_obs = dummy_obs self.last_obs = dummy_obs
self.step_counter = 0 self.step_counter = 0
self.prev_score = 0 self.prev_score = 0
def vectorize_item(self, item, item_list):
item_one_hot = np.zeros(len(item_list))
if item is None:
item_name = "None"
elif isinstance(item, deque):
if len(item) > 0:
item_name = item[0].name
else:
item_name = "None"
else:
item_name = item.name
if isinstance(item, CookingEquipment):
if item.name == "Pot":
if len(item.content_list) > 0:
if item.content_list[0].name == "TomatoSoup":
item_name = "PotDone"
elif len(item.content_list) == 1:
item_name = "PotOne"
elif len(item.content_list) == 2:
item_name = "PotTwo"
elif len(item.content_list) == 3:
item_name = "PotThree"
if "Plate" in item.name:
content_list = [i.name for i in item.content_list]
match content_list:
case ["TomatoSoup"]:
item_name = "PlateTomatoSoup"
case ["ChoppedTomato"]:
item_name = "PlateChoppedTomato"
case ["ChoppedLettuce"]:
item_name = "PlateChoppedLettuce"
case []:
item_name = "Plate"
case ["ChoppedLettuce", "ChoppedTomato"]:
item_name = "PlateSalad"
case other:
assert False, f"Should not happen. {item}"
assert item_name in item_list, f"Unknown item {item_name}."
item_idx = item_list.index(item_name)
item_one_hot[item_idx] = 1
# if item_name not in self.seen_items:
# print(item, item_name)
# self.seen_items.append(item_name)
return item_one_hot, item_idx
@staticmethod
def vectorize_counter(counter, counter_list):
counter_name = (
counter.name
if isinstance(counter, CookingCounter)
else (
repr(counter)
if isinstance(Counter, Dispenser)
else counter.__class__.__name__
)
)
if counter_name == "Dispenser":
counter_name = f"{counter.occupied_by.name}Dispenser"
assert counter_name in counter_list, f"Unknown Counter {counter}"
counter_oh_idx = counter_list.index("Empty")
if counter_name in counter_list:
counter_oh_idx = counter_list.index(counter_name)
counter_one_hot = np.zeros(len(counter_list), dtype=int)
counter_one_hot[counter_oh_idx] = 1
return counter_one_hot, counter_oh_idx
def get_vectorized_state_simple(self, player, onehot=True):
counter_list = [
"Empty",
"Counter",
"PlateDispenser",
"TomatoDispenser",
"ServingWindow",
"PlateReturn",
"Trashcan",
"Stove",
"CuttingBoard",
"LettuceDispenser",
]
item_list = [
"None",
"Pot",
"PotOne",
"PotTwo",
"PotThree",
"PotDone",
"Tomato",
"ChoppedTomato",
"Plate",
"PlateTomatoSoup",
"PlateSalad",
"Lettuce",
"PlateChoppedTomato",
"PlateChoppedLettuce",
"ChoppedLettuce",
]
grid_width, grid_height = int(self.env.kitchen_width), int(
self.env.kitchen_height
)
grid_base_array = np.zeros(
(
grid_width,
grid_height,
),
dtype=int,
)
grid_idxs = [(x, y) for x in range(grid_width) for y in range(grid_height)]
if onehot:
item_one_hot_length = len(item_list)
counter_items = np.zeros(
(grid_width, grid_height, item_one_hot_length), dtype=int
)
counter_one_hot_length = len(counter_list)
counters = np.zeros(
(grid_width, grid_height, counter_one_hot_length), dtype=int
)
else:
counter_items = np.zeros((grid_width, grid_height), dtype=int)
counters = np.zeros((grid_width, grid_height), dtype=int)
for counter in self.env.counters:
grid_idx = np.floor(counter.pos).astype(int)
counter_one_hot, counter_oh_idx = self.vectorize_counter(
counter, counter_list
)
grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx
grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
counter_item_one_hot, counter_item_oh_idx = self.vectorize_item(
counter.occupied_by, item_list
)
counter_items[grid_idx] = (
counter_item_one_hot if onehot else counter_item_oh_idx
)
counters[grid_idx] = counter_one_hot if onehot else counter_oh_idx
for free_idx in grid_idxs:
grid_base_array[free_idx[0], free_idx[1]] = counter_list.index("Empty")
player_pos = self.env.players[player].pos.astype(int)
player_dir = self.env.players[player].facing_direction.astype(int)
player_data = np.concatenate((player_pos, player_dir), axis=0)
player_item_one_hot, player_item_idx = self.vectorize_item(
self.env.players[player].holding, item_list
)
player_item = player_item_one_hot if onehot else [player_item_idx]
final = np.concatenate(
(
counters.flatten(),
counter_items.flatten(),
player_data.flatten(),
player_item,
),
axis=0,
)
return final
def step(self, action): def step(self, action):
simple_action = self.action_space_map[action] simple_action = self.action_space_map[action]
env_action = get_env_action( env_action = get_env_action(
...@@ -397,7 +226,7 @@ class EnvGymWrapper(Env): ...@@ -397,7 +226,7 @@ class EnvGymWrapper(Env):
def reset(self, seed=None, options=None): def reset(self, seed=None, options=None):
self.env: Environment = Environment( self.env: Environment = Environment(
env_config= deepcopy(self.config_env), env_config=deepcopy(self.config_env),
layout_config=layout, layout_config=layout,
item_info=deepcopy(self.config_item_info), item_info=deepcopy(self.config_item_info),
as_files=False, as_files=False,
...@@ -444,7 +273,7 @@ class EnvGymWrapper(Env): ...@@ -444,7 +273,7 @@ class EnvGymWrapper(Env):
return (observation.transpose((2, 0, 1))).astype(np.uint8) return (observation.transpose((2, 0, 1))).astype(np.uint8)
def get_vector_state(self): def get_vector_state(self):
obs = self.get_vectorized_state_simple("0", self.onehot_state) obs = self.converter.convert_state_to_observation(self.env)
return obs return obs
def sample_random_action(self): def sample_random_action(self):
......
from collections import deque
import numpy as np
from cooperative_cuisine.counters import CookingCounter, Counter, Dispenser
from cooperative_cuisine.items import CookingEquipment
from cooperative_cuisine.reinforcement_learning.gym_env import StateToObservationConverter
class StateConverter(StateToObservationConverter):
def __init__(self):
self.onehot = False
self.counter_list = [
"Empty",
"Counter",
"PlateDispenser",
"TomatoDispenser",
"ServingWindow",
"PlateReturn",
"Trashcan",
"Stove",
"CuttingBoard",
"LettuceDispenser",
]
self.item_list = [
"None",
"Pot",
"PotOne",
"PotTwo",
"PotThree",
"PotDone",
"Tomato",
"ChoppedTomato",
"Plate",
"PlateTomatoSoup",
"PlateSalad",
"Lettuce",
"PlateChoppedTomato",
"PlateChoppedLettuce",
"ChoppedLettuce",
]
self.player = "0"
def setup(self, env):
self.grid_width, self.grid_height = int(env.kitchen_width), int(
env.kitchen_height)
def convert_state_to_observation(self, env) -> np.ndarray:
grid_base_array = np.zeros(
(
self.grid_width,
self.grid_height,
),
dtype=int,
)
grid_idxs = [(x, y) for x in range(self.grid_width) for y in range(self.grid_height)]
counter_items = np.zeros((self.grid_width, self.grid_height), dtype=int)
counters = np.zeros((self.grid_width, self.grid_height), dtype=int)
for counter in env.counters:
grid_idx = np.floor(counter.pos).astype(int)
counter_oh_idx = self.vectorize_counter(
counter, self.counter_list
)
grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx
grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
counter_item_oh_idx = self.vectorize_item(
counter.occupied_by, self.item_list
)
counter_items[grid_idx] = (
counter_item_oh_idx
)
counters[grid_idx] = counter_oh_idx
for free_idx in grid_idxs:
grid_base_array[free_idx[0], free_idx[1]] = self.counter_list.index("Empty")
player_pos = env.players[self.player].pos.astype(int)
player_dir = env.players[self.player].facing_direction.astype(int)
player_data = np.concatenate((player_pos, player_dir), axis=0)
player_item_idx = self.vectorize_item(
env.players[self.player].holding, self.item_list
)
player_item = [player_item_idx]
final = np.concatenate(
(
counters.flatten(),
counter_items.flatten(),
player_data.flatten(),
player_item,
),
axis=0,
)
return final
def vectorize_item(self, item, item_list):
if item is None:
item_name = "None"
elif isinstance(item, deque):
if len(item) > 0:
item_name = item[0].name
else:
item_name = "None"
else:
item_name = item.name
if isinstance(item, CookingEquipment):
if item.name == "Pot":
if len(item.content_list) > 0:
if item.content_list[0].name == "TomatoSoup":
item_name = "PotDone"
elif len(item.content_list) == 1:
item_name = "PotOne"
elif len(item.content_list) == 2:
item_name = "PotTwo"
elif len(item.content_list) == 3:
item_name = "PotThree"
if "Plate" in item.name:
content_list = [i.name for i in item.content_list]
match content_list:
case ["TomatoSoup"]:
item_name = "PlateTomatoSoup"
case ["ChoppedTomato"]:
item_name = "PlateChoppedTomato"
case ["ChoppedLettuce"]:
item_name = "PlateChoppedLettuce"
case []:
item_name = "Plate"
case ["ChoppedLettuce", "ChoppedTomato"]:
item_name = "PlateSalad"
case other:
assert False, f"Should not happen. {item}"
assert item_name in item_list, f"Unknown item {item_name}."
item_idx = item_list.index(item_name)
# if item_name not in self.seen_items:
# print(item, item_name)
# self.seen_items.append(item_name)
return item_idx
@staticmethod
def vectorize_counter(counter, counter_list):
counter_name = (
counter.name
if isinstance(counter, CookingCounter)
else (
repr(counter)
if isinstance(Counter, Dispenser)
else counter.__class__.__name__
)
)
if counter_name == "Dispenser":
counter_name = f"{counter.occupied_by.name}Dispenser"
assert counter_name in counter_list, f"Unknown Counter {counter}"
counter_oh_idx = counter_list.index("Empty")
if counter_name in counter_list:
counter_oh_idx = counter_list.index(counter_name)
return counter_oh_idx
from collections import deque
import numpy as np
from cooperative_cuisine.counters import CookingCounter, Counter, Dispenser
from cooperative_cuisine.items import CookingEquipment
from cooperative_cuisine.reinforcement_learning.gym_env import StateToObservationConverter
class StateConverter(StateToObservationConverter):
def __init__(self):
self.onehot = True
self.grid_height = None
self.grid_width = None
self.counter_list = [
"Empty",
"Counter",
"PlateDispenser",
"TomatoDispenser",
"ServingWindow",
"PlateReturn",
"Trashcan",
"Stove",
"CuttingBoard",
"LettuceDispenser",
]
self.item_list = [
"None",
"Pot",
"PotOne",
"PotTwo",
"PotThree",
"PotDone",
"Tomato",
"ChoppedTomato",
"Plate",
"PlateTomatoSoup",
"PlateSalad",
"Lettuce",
"PlateChoppedTomato",
"PlateChoppedLettuce",
"ChoppedLettuce",
]
self.player = "0"
def setup(self, env):
self.grid_width, self.grid_height = int(env.kitchen_width), int(
env.kitchen_height)
def convert_state_to_observation(self, env) -> np.ndarray:
grid_base_array = np.zeros(
(
self.grid_width,
self.grid_height,
),
dtype=int,
)
grid_idxs = [(x, y) for x in range(self.grid_width) for y in range(self.grid_height)]
item_one_hot_length = len(self.item_list)
counter_items = np.zeros(
(self.grid_width, self.grid_height, item_one_hot_length), dtype=int
)
counter_one_hot_length = len(self.counter_list)
counters = np.zeros(
(self.grid_width, self.grid_height, counter_one_hot_length), dtype=int
)
for counter in env.counters:
grid_idx = np.floor(counter.pos).astype(int)
counter_one_hot, counter_oh_idx = self.vectorize_counter(
counter, self.counter_list
)
grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx
grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
counter_item_one_hot = self.vectorize_item(
counter.occupied_by, self.item_list
)
counter_items[grid_idx] = (
counter_item_one_hot
)
counters[grid_idx] = counter_one_hot
for free_idx in grid_idxs:
grid_base_array[free_idx[0], free_idx[1]] = self.counter_list.index("Empty")
player_pos = env.players[self.player].pos.astype(int)
player_dir = env.players[self.player].facing_direction.astype(int)
player_data = np.concatenate((player_pos, player_dir), axis=0)
player_item_one_hot = self.vectorize_item(
env.players[self.player].holding, self.item_list
)
final = np.concatenate(
(
counters.flatten(),
counter_items.flatten(),
player_data.flatten(),
player_item_one_hot,
),
axis=0,
)
return final
def vectorize_item(self, item, item_list):
item_one_hot = np.zeros(len(item_list))
if item is None:
item_name = "None"
elif isinstance(item, deque):
if len(item) > 0:
item_name = item[0].name
else:
item_name = "None"
else:
item_name = item.name
if isinstance(item, CookingEquipment):
if item.name == "Pot":
if len(item.content_list) > 0:
if item.content_list[0].name == "TomatoSoup":
item_name = "PotDone"
elif len(item.content_list) == 1:
item_name = "PotOne"
elif len(item.content_list) == 2:
item_name = "PotTwo"
elif len(item.content_list) == 3:
item_name = "PotThree"
if "Plate" in item.name:
content_list = [i.name for i in item.content_list]
match content_list:
case ["TomatoSoup"]:
item_name = "PlateTomatoSoup"
case ["ChoppedTomato"]:
item_name = "PlateChoppedTomato"
case ["ChoppedLettuce"]:
item_name = "PlateChoppedLettuce"
case []:
item_name = "Plate"
case ["ChoppedLettuce", "ChoppedTomato"]:
item_name = "PlateSalad"
case other:
assert False, f"Should not happen. {item}"
assert item_name in item_list, f"Unknown item {item_name}."
item_idx = item_list.index(item_name)
item_one_hot[item_idx] = 1
# if item_name not in self.seen_items:
# print(item, item_name)
# self.seen_items.append(item_name)
return item_one_hot
@staticmethod
def vectorize_counter(counter, counter_list):
counter_name = (
counter.name
if isinstance(counter, CookingCounter)
else (
repr(counter)
if isinstance(Counter, Dispenser)
else counter.__class__.__name__
)
)
if counter_name == "Dispenser":
counter_name = f"{counter.occupied_by.name}Dispenser"
assert counter_name in counter_list, f"Unknown Counter {counter}"
counter_oh_idx = counter_list.index("Empty")
if counter_name in counter_list:
counter_oh_idx = counter_list.index(counter_name)
counter_one_hot = np.zeros(len(counter_list), dtype=int)
counter_one_hot[counter_oh_idx] = 1
return counter_one_hot, counter_oh_idx
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment