Skip to content
Snippets Groups Projects
full_vectorization.py 8.11 KiB
Newer Older
# def setup_vectorization(self) -> VectorStateGenerationData:
#     grid_base_array = np.zeros(
#         (
#             int(self.env.kitchen_width),
#             int(self.env.kitchen_height),
#             114 + 12 + 4,  # TODO calc based on item info
#         ),
#         dtype=np.float32,
#     )
#     counter_list = [
#         "Counter",
#         "CuttingBoard",
#         "ServingWindow",
#         "Trashcan",
#         "Sink",
#         "SinkAddon",
#         "Stove",
#         "DeepFryer",
#         "Oven",
#     ]
#     grid_idxs = [
#         (x, y)
#         for x in range(int(self.env.kitchen_width))
#         for y in range(int(self.env.kitchen_height))
#     ]
#     # counters do not move
#     for counter in self.env.counters:
#         grid_idx = np.floor(counter.pos).astype(int)
#         counter_name = (
#             counter.name
#             if isinstance(counter, CookingCounter)
#             else (
#                 repr(counter)
#                 if isinstance(Counter, Dispenser)
#                 else counter.__class__.__name__
#             )
#         )
#         assert counter_name in counter_list or counter_name.endswith(
#             "Dispenser"
#         ), f"Unknown Counter {counter}"
#         oh_idx = len(counter_list)
#         if counter_name in counter_list:
#             oh_idx = counter_list.index(counter_name)
#
#         one_hot = [0] * (len(counter_list) + 2)
#         one_hot[oh_idx] = 1
#         grid_base_array[
#             grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2)
#         ] = np.array(one_hot, dtype=np.float32)
#
#         grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
#
#     for free_idx in grid_idxs:
#         one_hot = [0] * (len(counter_list) + 2)
#         one_hot[len(counter_list) + 1] = 1
#         grid_base_array[
#             free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2)
#         ] = np.array(one_hot, dtype=np.float32)
#
#     player_info_base_array = np.zeros(
#         (
#             4,
#             4 + 114,
#         ),
#         dtype=np.float32,
#     )
#     order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32)
#
#     return VectorStateGenerationData(
#         grid_base_array=grid_base_array,
#         oh_len=12,
#     )
#
#
# def get_simple_vectorized_item(self, item: Item) -> npt.NDArray[float]:
#     name = item.name
#     array = np.zeros(21, dtype=np.float32)
#     if item.name.startswith("Burnt"):
#         name = name[len("Burnt") :]
#         array[0] = 1.0
#     if name.startswith("Chopped"):
#         array[1] = 1.0
#         name = name[len("Chopped") :]
#     if name in [
#         "PizzaBase",
#         "GratedCheese",
#         "RawChips",
#         "RawPatty",
#     ]:
#         array[1] = 1.0
#         name = {
#             "PizzaBase": "Dough",
#             "GratedCheese": "Cheese",
#             "RawChips": "Potato",
#             "RawPatty": "Meat",
#         }[name]
#     if name == "CookedPatty":
#         array[2] = 1.0
#         name = "Meat"
#
#     if name in self.vector_state_generation.meals:
#         idx = 3 + self.vector_state_generation.meals.index(name)
#     elif name in self.vector_state_generation.ingredients:
#         idx = (
#             3
#             + len(self.vector_state_generation.meals)
#             + self.vector_state_generation.ingredients.index(name)
#         )
#     else:
#         raise ValueError(f"Unknown item {name} - {item}")
#     array[idx] = 1.0
#     return array
#
#
# def get_vectorized_item(self, item: Item) -> npt.NDArray[float]:
#     item_array = np.zeros(114, dtype=np.float32)
#
#     if isinstance(item, CookingEquipment) or item.item_info.type == ItemType.Tool:
#         assert (
#             item.name in self.vector_state_generation.equipments
#         ), f"unknown equipment {item}"
#         idx = self.vector_state_generation.equipments.index(item.name)
#         item_array[idx] = 1.0
#         if isinstance(item, CookingEquipment):
#             for s_idx, sub_item in enumerate(item.content_list):
#                 if s_idx > 3:
#                     print("Too much content in the content list, info dropped")
#                     break
#                 start_idx = len(self.vector_state_generation.equipments) + 21 + 2
#                 item_array[
#                     start_idx + (s_idx * (21)) : start_idx + ((s_idx + 1) * (21))
#                 ] = self.get_simple_vectorized_item(sub_item)
#
#     else:
#         item_array[
#             len(self.vector_state_generation.equipments) : len(
#                 self.vector_state_generation.equipments
#             )
#             + 21
#         ] = self.get_simple_vectorized_item(item)
#
#     item_array[
#         len(self.vector_state_generation.equipments) + 21 + 1
#     ] = item.progress_percentage
#
#     if item.active_effects:
#         item_array[
#             len(self.vector_state_generation.equipments) + 21 + 2
#         ] = 1.0  # TODO percentage of fire...
#
#     return item_array
#
#
# def get_vectorized_state_full(
#     self, player_id: str
# ) -> Tuple[
#     npt.NDArray[npt.NDArray[float]],
#     npt.NDArray[npt.NDArray[float]],
#     float,
#     npt.NDArray[float],
# ]:
#     grid_array = self.vector_state_generation.grid_base_array.copy()
#     for counter in self.env.counters:
#         grid_idx = np.floor(counter.pos).astype(int)  # store in counter?
#         if counter.occupied_by:
#             if isinstance(counter.occupied_by, deque):
#                 ...
#             else:
#                 item = counter.occupied_by
#                 grid_array[
#                     grid_idx[0],
#                     grid_idx[1],
#                     4 + self.vector_state_generation.oh_len :,
#                 ] = self.get_vectorized_item(item)
#         if counter.active_effects:
#             grid_array[
#                 grid_idx[0],
#                 grid_idx[1],
#                 4 + self.vector_state_generation.oh_len - 1,
#             ] = 1.0  # TODO percentage of fire...
#
#     assert len(self.env.players) <= 4, "To many players for vector representation"
#     player_vec = np.zeros(
#         (
#             4,
#             4 + 114,
#         ),
#         dtype=np.float32,
#     )
#     player_pos = 1
#     for player in self.env.players.values():
#         if player.name == player_id:
#             idx = 0
#             player_vec[0, :4] = np.array(
#                 [
#                     player.pos[0],
#                     player.pos[1],
#                     player.facing_point[0],
#                     player.facing_point[1],
#                 ],
#                 dtype=np.float32,
#             )
#         else:
#             idx = player_pos
#
#         if not idx:
#             player_pos += 1
#         grid_idx = np.floor(player.pos).astype(int)  # store in counter?
#         player_vec[idx, :4] = np.array(
#             [
#                 player.pos[0] - grid_idx[0],
#                 player.pos[1] - grid_idx[1],
#                 player.facing_point[0] / np.linalg.norm(player.facing_point),
#                 player.facing_point[1] / np.linalg.norm(player.facing_point),
#             ],
#             dtype=np.float32,
#         )
#         grid_array[grid_idx[0], grid_idx[1], idx] = 1.0
#
#         if player.holding:
#             player_vec[idx, 4:] = self.get_vectorized_item(player.holding)
#
#     order_array = np.zeros((10 * (8 + 1)), dtype=np.float32)
#
#     for i, order in enumerate(self.env.order_manager.open_orders):
#         if i > 9:
#             print("some orders are not represented in the vectorized state")
#             break
#         assert (
#             order.meal.name in self.vector_state_generation.meals
#         ), "unknown meal in order"
#         idx = self.vector_state_generation.meals.index(order.meal.name)
#         order_array[(i * 9) + idx] = 1.0
#         order_array[(i * 9) + 8] = (
#             self.env_time - order.start_time
#         ).total_seconds() / order.max_duration.total_seconds()
#
#     return (
#         grid_array,
#         player_vec,
#         (self.env.env_time - self.env.start_time).total_seconds()
#         / (self.env.env_time_end - self.env.start_time).total_seconds(),
#         order_array,
#     )