Skip to content
Snippets Groups Projects
Commit 4614074f authored by Fabian Heinrich's avatar Fabian Heinrich
Browse files

Moved rl into subfolder

parent 75952d06
No related branches found
No related tags found
1 merge request!52Resolve "gym env"
Pipeline #46075 passed
#X##
T__W
U__P
#C##
......@@ -5,7 +5,6 @@ import inspect
import json
import logging
import sys
from collections import deque
from datetime import timedelta, datetime
from enum import Enum
from pathlib import Path
......@@ -21,15 +20,11 @@ from overcooked_simulator.counter_factory import CounterFactory
from overcooked_simulator.counters import (
Counter,
PlateConfig,
CookingCounter,
Dispenser,
)
from overcooked_simulator.effect_manager import EffectManager
from overcooked_simulator.game_items import (
ItemInfo,
ItemType,
CookingEquipment,
Item,
)
from overcooked_simulator.hooks import (
ITEM_INFO_LOADED,
......@@ -59,7 +54,6 @@ from overcooked_simulator.player import Player, PlayerConfig
from overcooked_simulator.utils import (
create_init_env_time,
get_closest,
VectorStateGenerationData,
)
log = logging.getLogger(__name__)
......@@ -291,8 +285,6 @@ class Environment:
str, EffectManager
] = self.counter_factory.setup_effect_manger(self.counters)
self.vector_state_generation = self.setup_vectorization()
self.hook(
ENV_INITIALIZED,
environment_config=env_config,
......@@ -805,404 +797,6 @@ class Environment:
return json_data
raise ValueError(f"No valid {player_id=}")
def setup_vectorization(self) -> VectorStateGenerationData:
grid_base_array = np.zeros(
(
int(self.kitchen_width),
int(self.kitchen_height),
114 + 12 + 4, # TODO calc based on item info
),
dtype=np.float32,
)
counter_list = [
"Counter",
"CuttingBoard",
"ServingWindow",
"Trashcan",
"Sink",
"SinkAddon",
"Stove",
"DeepFryer",
"Oven",
]
grid_idxs = [
(x, y)
for x in range(int(self.kitchen_width))
for y in range(int(self.kitchen_height))
]
# counters do not move
for counter in self.counters:
grid_idx = np.floor(counter.pos).astype(int)
counter_name = (
counter.name
if isinstance(counter, CookingCounter)
else (
repr(counter)
if isinstance(Counter, Dispenser)
else counter.__class__.__name__
)
)
assert counter_name in counter_list or counter_name.endswith(
"Dispenser"
), f"Unknown Counter {counter}"
oh_idx = len(counter_list)
if counter_name in counter_list:
oh_idx = counter_list.index(counter_name)
one_hot = [0] * (len(counter_list) + 2)
one_hot[oh_idx] = 1
grid_base_array[
grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2)
] = np.array(one_hot, dtype=np.float32)
grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
for free_idx in grid_idxs:
one_hot = [0] * (len(counter_list) + 2)
one_hot[len(counter_list) + 1] = 1
grid_base_array[
free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2)
] = np.array(one_hot, dtype=np.float32)
player_info_base_array = np.zeros(
(
4,
4 + 114,
),
dtype=np.float32,
)
order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32)
return VectorStateGenerationData(
grid_base_array=grid_base_array,
oh_len=12,
)
def get_simple_vectorized_item(self, item: Item) -> npt.NDArray[float]:
name = item.name
array = np.zeros(21, dtype=np.float32)
if item.name.startswith("Burnt"):
name = name[len("Burnt") :]
array[0] = 1.0
if name.startswith("Chopped"):
array[1] = 1.0
name = name[len("Chopped") :]
if name in [
"PizzaBase",
"GratedCheese",
"RawChips",
"RawPatty",
]:
array[1] = 1.0
name = {
"PizzaBase": "Dough",
"GratedCheese": "Cheese",
"RawChips": "Potato",
"RawPatty": "Meat",
}[name]
if name == "CookedPatty":
array[2] = 1.0
name = "Meat"
if name in self.vector_state_generation.meals:
idx = 3 + self.vector_state_generation.meals.index(name)
elif name in self.vector_state_generation.ingredients:
idx = (
3
+ len(self.vector_state_generation.meals)
+ self.vector_state_generation.ingredients.index(name)
)
else:
raise ValueError(f"Unknown item {name} - {item}")
array[idx] = 1.0
return array
def get_vectorized_item(self, item: Item) -> npt.NDArray[float]:
item_array = np.zeros(114, dtype=np.float32)
if isinstance(item, CookingEquipment) or item.item_info.type == ItemType.Tool:
assert (
item.name in self.vector_state_generation.equipments
), f"unknown equipment {item}"
idx = self.vector_state_generation.equipments.index(item.name)
item_array[idx] = 1.0
if isinstance(item, CookingEquipment):
for s_idx, sub_item in enumerate(item.content_list):
if s_idx > 3:
print("Too much content in the content list, info dropped")
break
start_idx = len(self.vector_state_generation.equipments) + 21 + 2
item_array[
start_idx + (s_idx * (21)) : start_idx + ((s_idx + 1) * (21))
] = self.get_simple_vectorized_item(sub_item)
else:
item_array[
len(self.vector_state_generation.equipments) : len(
self.vector_state_generation.equipments
)
+ 21
] = self.get_simple_vectorized_item(item)
item_array[
len(self.vector_state_generation.equipments) + 21 + 1
] = item.progress_percentage
if item.active_effects:
item_array[
len(self.vector_state_generation.equipments) + 21 + 2
] = 1.0 # TODO percentage of fire...
return item_array
def get_vectorized_state_full(
self, player_id: str
) -> Tuple[
npt.NDArray[npt.NDArray[float]],
npt.NDArray[npt.NDArray[float]],
float,
npt.NDArray[float],
]:
grid_array = self.vector_state_generation.grid_base_array.copy()
for counter in self.counters:
grid_idx = np.floor(counter.pos).astype(int) # store in counter?
if counter.occupied_by:
if isinstance(counter.occupied_by, deque):
...
else:
item = counter.occupied_by
grid_array[
grid_idx[0],
grid_idx[1],
4 + self.vector_state_generation.oh_len :,
] = self.get_vectorized_item(item)
if counter.active_effects:
grid_array[
grid_idx[0],
grid_idx[1],
4 + self.vector_state_generation.oh_len - 1,
] = 1.0 # TODO percentage of fire...
assert len(self.players) <= 4, "To many players for vector representation"
player_vec = np.zeros(
(
4,
4 + 114,
),
dtype=np.float32,
)
player_pos = 1
for player in self.players.values():
if player.name == player_id:
idx = 0
player_vec[0, :4] = np.array(
[
player.pos[0],
player.pos[1],
player.facing_point[0],
player.facing_point[1],
],
dtype=np.float32,
)
else:
idx = player_pos
if not idx:
player_pos += 1
grid_idx = np.floor(player.pos).astype(int) # store in counter?
player_vec[idx, :4] = np.array(
[
player.pos[0] - grid_idx[0],
player.pos[1] - grid_idx[1],
player.facing_point[0] / np.linalg.norm(player.facing_point),
player.facing_point[1] / np.linalg.norm(player.facing_point),
],
dtype=np.float32,
)
grid_array[grid_idx[0], grid_idx[1], idx] = 1.0
if player.holding:
player_vec[idx, 4:] = self.get_vectorized_item(player.holding)
order_array = np.zeros((10 * (8 + 1)), dtype=np.float32)
for i, order in enumerate(self.order_manager.open_orders):
if i > 9:
print("some orders are not represented in the vectorized state")
break
assert (
order.meal.name in self.vector_state_generation.meals
), "unknown meal in order"
idx = self.vector_state_generation.meals.index(order.meal.name)
order_array[(i * 9) + idx] = 1.0
order_array[(i * 9) + 8] = (
self.env_time - order.start_time
).total_seconds() / order.max_duration.total_seconds()
return (
grid_array,
player_vec,
(self.env_time - self.start_time).total_seconds()
/ (self.env_time_end - self.start_time).total_seconds(),
order_array,
)
# def setup_vectorization_simple(self) -> VectorStateGenerationDataSimple:
# num_per_item = 114
# num_per_counter = 12
# num_players = 4
# grid_base_array = np.zeros(
# (
# int(self.kitchen_width),
# int(self.kitchen_height),
# num_per_item
# + num_per_counter
# + num_players, # TODO calc based on item info
# ),
# dtype=np.float32,
# )
# counter_list = [
# "Counter",
# "CuttingBoard",
# "ServingWindow",
# "Trashcan",
# "Sink",
# "SinkAddon",
# "Stove",
# "DeepFryer",
# "Oven",
# ]
# grid_idxs = [
# (x, y)
# for x in range(int(self.kitchen_width))
# for y in range(int(self.kitchen_height))
# ]
# # counters do not move
# for counter in self.counters:
# grid_idx = np.floor(counter.pos).astype(int)
# counter_name = (
# counter.name
# if isinstance(counter, CookingCounter)
# else (
# repr(counter)
# if isinstance(Counter, Dispenser)
# else counter.__class__.__name__
# )
# )
# assert counter_name in counter_list or counter_name.endswith(
# "Dispenser"
# ), f"Unknown Counter {counter}"
# oh_idx = len(counter_list)
# if counter_name in counter_list:
# oh_idx = counter_list.index(counter_name)
#
# one_hot = [0] * (len(counter_list) + 2)
# one_hot[oh_idx] = 1
# grid_base_array[
# grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2)
# ] = np.array(one_hot, dtype=np.float32)
#
# grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
#
# for free_idx in grid_idxs:
# one_hot = [0] * (len(counter_list) + 2)
# one_hot[len(counter_list) + 1] = 1
# grid_base_array[
# free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2)
# ] = np.array(one_hot, dtype=np.float32)
#
# player_info_base_array = np.zeros(
# (
# 4,
# 4 + 114,
# ),
# dtype=np.float32,
# )
# order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32)
#
# return VectorStateGenerationData(
# grid_base_array=grid_base_array,
# oh_len=12,
# )
def get_vectorized_state_simple(self, player):
item_list = ["Pot", "Tomato", "ChoppedTomato", "Plate"]
counter_list = [
"Counter",
"PlateDispenser",
"TomatoDispenser",
"ServingWindow",
"PlateReturn",
"Trashcan",
"Stove",
"CuttingBoard",
]
player_pos = self.players[player].pos
player_dir = self.players[player].facing_direction
grid_width, grid_height = int(self.kitchen_width), int(self.kitchen_height)
counter_one_hot_length = len(counter_list) + 1 # one for empty field
grid_base_array = np.zeros(
(
grid_width,
grid_height,
),
dtype=int,
)
grid_idxs = [(x, y) for x in range(grid_width) for y in range(grid_height)]
# counters do not move
for counter in self.counters:
grid_idx = np.floor(counter.pos).astype(int)
counter_name = (
counter.name
if isinstance(counter, CookingCounter)
else (
repr(counter)
if isinstance(Counter, Dispenser)
else counter.__class__.__name__
)
)
if counter_name == "Dispenser":
counter_name = f"{counter.occupied_by.name}Dispenser"
assert counter_name in counter_list, f"Unknown Counter {counter}"
counter_oh_idx = counter_one_hot_length
if counter_name in counter_list:
counter_oh_idx = counter_list.index(counter_name)
grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx
grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
for free_idx in grid_idxs:
grid_base_array[free_idx[0], free_idx[1]] = counter_one_hot_length - 1
counter_grid_one_hot = np.zeros(
(grid_width, grid_height, counter_one_hot_length), dtype=int
)
for x in range(grid_width):
for y in range(grid_height):
counter_type_idx = grid_base_array[x, y]
counter_grid_one_hot[x, y, counter_type_idx] = 1
player_data = np.concatenate((player_pos, player_dir), axis=0)
items_one_hot_length = len(item_list) + 1
item_one_hot = np.zeros(items_one_hot_length, dtype=int)
player_item = self.players[player].holding
player_item_idx = items_one_hot_length - 1
if player_item:
if player_item.name in item_list:
player_item_idx = item_list.index(player_item.name)
item_one_hot[player_item_idx] = 1
final = np.concatenate(
(counter_grid_one_hot.flatten(), player_data, item_one_hot), axis=0
)
return final
def reset_env_time(self):
"""Reset the env time to the initial time, defined by `create_init_env_time`."""
self.hook(PRE_RESET_ENV_TIME)
......
plates:
clean_plates: 1
clean_plates: 2
dirty_plates: 0
plate_delay: [ 2, 4 ]
return_dirty: False
......@@ -108,7 +108,7 @@ extra_setup_functions:
hooks: [ trashcan_usage ]
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: -0.15
static_score: -0.5
item_cut:
func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class ''
kwargs:
......@@ -122,14 +122,14 @@ extra_setup_functions:
hooks: [ post_step ]
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: -0.01
static_score: -0.05
combine:
func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class ''
kwargs:
hooks: [ drop_off_on_cooking_equipment ]
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: 0.10
static_score: 0.15
# json_states:
# func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
# kwargs:
......
##X##
T___P
#___#
U___#
#C#W#
#X##
T__P
U__#
#CW#
File deleted
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment