Skip to content
Snippets Groups Projects
Commit 46699dbb authored by Fabian Heinrich's avatar Fabian Heinrich
Browse files

Experimentation with reinforcement learning

parent 34c71c41
No related branches found
No related tags found
No related merge requests found
Pipeline #48230 passed
......@@ -881,9 +881,7 @@ class Visualizer:
self.draw_gamescreen(screen, state, grid_size, [0 for _ in state["players"]])
pygame.image.save(screen, filename)
def get_state_image(
self, grid_size: int, save_folder: dict
) -> npt.NDArray[np.uint8]:
def get_state_image(self, grid_size: int, state: dict) -> npt.NDArray[np.uint8]:
width = int(np.ceil(state["kitchen"]["width"] * grid_size))
height = int(np.ceil(state["kitchen"]["height"] * grid_size))
......
......@@ -7,15 +7,10 @@ plates:
game:
time_limit_seconds: 300
undo_dispenser_pickup: true
validate_recipes: false
meals:
all: true
# if all: false -> only orders for these meals are generated
# TODO: what if this list is empty?
list:
- TomatoSoup
- OnionSoup
- Salad
layout_chars:
_: Free
......@@ -55,7 +50,15 @@ layout_chars:
orders:
order_gen_class: !!python/name:cooperative_cuisine.order.RandomOrderGeneration ''
meals:
all: true
# if all: false -> only orders for these meals are generated
# TODO: what if this list is empty?
list:
- TomatoSoup
- OnionSoup
- Salad
order_gen_class: !!python/name:cooperative_cuisine.orders.RandomOrderGeneration ''
# the class to that receives the kwargs. Should be a child class of OrderGeneration in orders.py
order_gen_kwargs:
order_duration_random_func:
......@@ -103,7 +106,7 @@ extra_setup_functions:
hooks: [ completed_order ]
callback_class: !!python/name:cooperative_cuisine.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: 1
static_score: 0.95
serve_not_ordered_meals:
func: !!python/name:cooperative_cuisine.hooks.hooks_via_callback_class ''
......@@ -111,7 +114,7 @@ extra_setup_functions:
hooks: [ serve_not_ordered_meal ]
callback_class: !!python/name:cooperative_cuisine.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: 1
static_score: 0.95
trashcan_usages:
func: !!python/name:cooperative_cuisine.hooks.hooks_via_callback_class ''
kwargs:
......@@ -125,7 +128,7 @@ extra_setup_functions:
hooks: [ cutting_board_100 ]
callback_class: !!python/name:cooperative_cuisine.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: 0.01
static_score: 0.1
stepped:
func: !!python/name:cooperative_cuisine.hooks.hooks_via_callback_class ''
kwargs:
......@@ -140,6 +143,13 @@ extra_setup_functions:
callback_class: !!python/name:cooperative_cuisine.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: 0.01
start_interact:
func: !!python/name:cooperative_cuisine.hooks.hooks_via_callback_class ''
kwargs:
hooks: [ player_start_interaction ]
callback_class: !!python/name:cooperative_cuisine.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: 0.01
# json_states:
# func: !!python/name:cooperative_cuisine.hooks.hooks_via_callback_class ''
# kwargs:
......
......@@ -108,7 +108,7 @@ def shuffle_counters(env):
if counter.__class__ != Counter:
sample_counter.append(counter)
else:
other_counters.append()
other_counters.append(counter)
new_counter_pos = [c.pos for c in sample_counter]
random.shuffle(new_counter_pos)
for counter, new_pos in zip(sample_counter, new_counter_pos):
......@@ -127,7 +127,7 @@ class EnvGymWrapper(Env):
def __init__(self):
super().__init__()
self.gridsize = 20
self.gridsize = 30
self.randomize_counter_placement = True
self.use_rgb_obs = False # if False uses simple vectorized state
......@@ -160,16 +160,17 @@ class EnvGymWrapper(Env):
self.action_space = spaces.Discrete(len(self.action_space_map))
min_obs_val = -1 if not self.use_rgb_obs else 0
max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 9
self.seen_items = []
dummy_obs = self.get_observation()
min_obs_val = -1 if not self.use_rgb_obs else 0
max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 20
self.observation_space = spaces.Box(
low=min_obs_val,
high=max_obs_val,
shape=dummy_obs.shape,
dtype=np.uint8 if self.use_rgb_obs else int,
)
print(self.observation_space)
self.last_obs = dummy_obs
......@@ -199,15 +200,29 @@ class EnvGymWrapper(Env):
item_name = "PotTwo"
elif len(item.content_list) == 3:
item_name = "PotThree"
elif item.name == "Plate":
if len(item.content_list) == 0:
item_name = "Plate"
else:
item_name = "PlateTomatoSoup"
if "Plate" in item.name:
content_list = [i.name for i in item.content_list]
match content_list:
case ["TomatoSoup"]:
item_name = "PlateTomatoSoup"
case ["ChoppedTomato"]:
item_name = "PlateChoppedTomato"
case ["ChoppedLettuce"]:
item_name = "PlateChoppedLettuce"
case []:
item_name = "Plate"
case ["ChoppedLettuce", "ChoppedTomato"]:
item_name = "PlateSalad"
case other:
assert False, f"Should not happen. {item}"
assert item_name in item_list, f"Unknown item {item_name}."
item_idx = item_list.index(item_name)
item_one_hot[item_idx] = 1
# if item_name not in self.seen_items:
# print(item, item_name)
# self.seen_items.append(item_name)
return item_one_hot, item_idx
@staticmethod
......@@ -244,6 +259,7 @@ class EnvGymWrapper(Env):
"Trashcan",
"Stove",
"CuttingBoard",
"LettuceDispenser",
]
item_list = [
......@@ -257,6 +273,11 @@ class EnvGymWrapper(Env):
"ChoppedTomato",
"Plate",
"PlateTomatoSoup",
"PlateSalad",
"Lettuce",
"PlateChoppedTomato",
"PlateChoppedLettuce",
"ChoppedLettuce",
]
grid_width, grid_height = int(self.env.kitchen_width), int(
......@@ -413,9 +434,9 @@ def main():
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 30_000_000, # hendric sagt eher so 300_000_000 schritte
"total_timesteps": 3_000_000, # hendric sagt eher so 300_000_000 schritte
"env_id": "overcooked",
"number_envs_parallel": 4,
"number_envs_parallel": 64,
}
debug = False
......@@ -424,7 +445,7 @@ def main():
number_envs_parallel = config["number_envs_parallel"]
model_classes = [A2C, DQN, PPO]
model_class = model_classes[2]
model_class = model_classes[1]
if vec_env:
env = make_vec_env(EnvGymWrapper, n_envs=number_envs_parallel)
......
##X#
T__#
T__L
U__P
#C$#
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment