From 47cc026f5f513f9aff8a3dc97cab15ff6762b709 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Thu, 8 Feb 2024 19:05:35 +0100
Subject: [PATCH] Implement vector state representation in game environment

The game environment now supports vectorized states for reinforcement learning agents. This includes updates on various classes such as OrderManager, Counter, and Player. Several new utility functions are added to facilitate the conversion of the game state to vector form. The new vector state includes players, counters, orders, and game status.
---
 overcooked_simulator/game_server.py           |   7 +
 overcooked_simulator/gym_env.py               | 100 ++++---
 .../overcooked_environment.py                 | 267 +++++++++++++++++-
 overcooked_simulator/utils.py                 |  40 +++
 4 files changed, 359 insertions(+), 55 deletions(-)

diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py
index c5bb5d25..1db99598 100644
--- a/overcooked_simulator/game_server.py
+++ b/overcooked_simulator/game_server.py
@@ -432,6 +432,13 @@ class EnvironmentHandler:
                             / 1_000_000_000
                         )
                     )
+
+                    (
+                        grid,
+                        player,
+                        env_time,
+                        orders,
+                    ) = env_data.environment.get_vectorized_state("0")
                     env_data.last_step_time = step_start
                     if env_data.environment.game_ended:
                         log.info(f"Env {env_id} ended. Set env to STOPPED.")
diff --git a/overcooked_simulator/gym_env.py b/overcooked_simulator/gym_env.py
index 9a0c8094..26d5da99 100644
--- a/overcooked_simulator/gym_env.py
+++ b/overcooked_simulator/gym_env.py
@@ -1,5 +1,4 @@
 import json
-import os.path
 import time
 from datetime import timedelta
 from enum import Enum
@@ -8,6 +7,12 @@ from pathlib import Path
 import cv2
 import numpy as np
 import yaml
+from gymnasium import spaces, Env
+from stable_baselines3 import A2C
+from stable_baselines3 import DQN
+from stable_baselines3 import PPO
+from stable_baselines3.common.env_checker import check_env
+from stable_baselines3.common.env_util import make_vec_env
 
 from overcooked_simulator import ROOT_DIR
 from overcooked_simulator.gui_2d_vis.drawing import Visualizer
@@ -17,29 +22,22 @@ from overcooked_simulator.overcooked_environment import (
     ActionType,
     InterActionData,
 )
-import wandb
-from wandb.integration.sb3 import WandbCallback
-
-import gymnasium as gym
-import numpy as np
-from gymnasium import spaces, Env
-
-from stable_baselines3.common.env_checker import check_env
-from stable_baselines3.common.env_util import make_vec_env
-from stable_baselines3 import A2C
-from stable_baselines3 import DQN
-from stable_baselines3 import PPO
 
-SimpleActionSpace = Enum("SimpleActionSpace", ["Up",
-                                               # "Up_Left",
-                                               "Left",
-                                               # "Down_Left",
-                                               "Down",
-                                               # "Down_Right",
-                                               "Right",
-                                               # "Right_Up",
-                                               "Interact",
-                                               "Put"])
+SimpleActionSpace = Enum(
+    "SimpleActionSpace",
+    [
+        "Up",
+        # "Up_Left",
+        "Left",
+        # "Down_Left",
+        "Down",
+        # "Down_Right",
+        "Right",
+        # "Right_Up",
+        "Interact",
+        "Put",
+    ],
+)
 
 
 def get_env_action(player_id, simple_action, duration):
@@ -118,9 +116,7 @@ def get_env_action(player_id, simple_action, duration):
             print("FAIL", simple_action)
 
 
-environment_config_path: Path = (
-    ROOT_DIR / "game_content" / "environment_config_rl.yaml"
-)
+environment_config_path: Path = ROOT_DIR / "game_content" / "environment_config_rl.yaml"
 item_info_path: Path = ROOT_DIR / "game_content" / "item_info_rl.yaml"
 layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "rl.layout"
 with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file:
@@ -131,6 +127,7 @@ class EnvGymWrapper(Env):
     """Should enable this:
     observation, reward, terminated, truncated, info = env.step(action)
     """
+
     metadata = {"render_modes": ["human"], "render_fps": 30}
 
     def __init__(self):
@@ -142,7 +139,7 @@ class EnvGymWrapper(Env):
             env_config=environment_config_path,
             layout_config=layout_path,
             item_info=item_info_path,
-            as_files=True
+            as_files=True,
         )
 
         self.visualizer: Visualizer = Visualizer(config=visualization_config)
@@ -153,11 +150,10 @@ class EnvGymWrapper(Env):
         self.visualizer.create_player_colors(1)
 
         # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)}
-        self.action_space_map ={}
+        self.action_space_map = {}
         for idx, item in enumerate(SimpleActionSpace):
             self.action_space_map[idx] = item
 
-
         self.global_step_time = 1
         self.in_between_steps = 1
 
@@ -165,8 +161,9 @@ class EnvGymWrapper(Env):
         # Example for using image as input (channel-first; channel-last also works):
 
         dummy_obs = self.get_env_img(self.gridsize)
-        self.observation_space = spaces.Box(low=0, high=255,
-                                            shape=dummy_obs.shape, dtype=np.uint8)
+        self.observation_space = spaces.Box(
+            low=0, high=255, shape=dummy_obs.shape, dtype=np.uint8
+        )
 
         self.last_obs = dummy_obs
 
@@ -175,7 +172,9 @@ class EnvGymWrapper(Env):
 
     def step(self, action):
         simple_action = self.action_space_map[action]
-        env_action = get_env_action(self.player_id, simple_action, self.global_step_time)
+        env_action = get_env_action(
+            self.player_id, simple_action, self.global_step_time
+        )
         self.env.perform_action(env_action)
 
         for i in range(self.in_between_steps):
@@ -186,10 +185,13 @@ class EnvGymWrapper(Env):
         observation = self.get_env_img(self.gridsize)
 
         reward = -1
-        if self.env.order_and_score.score > self.prev_score and self.env.order_and_score.score != 0:
-            self.prev_score = self.env.order_and_score.score
+        if (
+            self.env.order_manager.score > self.prev_score
+            and self.env.order_manager.score != 0
+        ):
+            self.prev_score = self.env.order_manager.score
             reward = 100
-        elif self.env.order_and_score.score < self.prev_score:
+        elif self.env.order_manager.score < self.prev_score:
             self.prev_score = 0
             reward = -1
 
@@ -200,14 +202,12 @@ class EnvGymWrapper(Env):
         # self.render(self.gridsize)
         return observation, reward, terminated, truncated, info
 
-
     def reset(self, seed=None, options=None):
-
         self.env: Environment = Environment(
             env_config=environment_config_path,
             layout_config=layout_path,
             item_info=item_info_path,
-            as_files=True
+            as_files=True,
         )
 
         self.player_name = str(0)
@@ -219,10 +219,10 @@ class EnvGymWrapper(Env):
 
     def render(self):
         observation = self.get_env_img(self.gridsize)
-        img = observation.transpose((1,2,0))[:,:,::-1]
+        img = observation.transpose((1, 2, 0))[:, :, ::-1]
         print(img.shape)
-        img = cv2.resize(img, (img.shape[1]*5, img.shape[0]*5))
-        cv2.imshow("Overcooked",img)
+        img = cv2.resize(img, (img.shape[1] * 5, img.shape[0] * 5))
+        cv2.imshow("Overcooked", img)
         cv2.waitKey(1)
 
     def close(self):
@@ -234,7 +234,13 @@ class EnvGymWrapper(Env):
         observation = self.visualizer.get_state_image(
             grid_size=gridsize, state=json_dict
         ).transpose((1, 0, 2))
-        return observation.transpose((2,0,1))
+        return observation.transpose((2, 0, 1))
+
+    def get_vector_state(self):
+        grid, player, env_time, orders = self.env.get_vectorized_state("0")
+
+        # flatten: grid + player
+        # concatenate all (env_time to array)
 
     def sample_random_action(self):
         act = self.action_space.sample()
@@ -246,7 +252,6 @@ def main():
     rl_agent_checkpoints = Path("./rl_agent_checkpoints")
     rl_agent_checkpoints.mkdir(exist_ok=True)
 
-
     config = {
         "policy_type": "CnnPolicy",
         "total_timesteps": 1000000,  # hendric sagt eher so 300_000_000 schritte
@@ -266,8 +271,9 @@ def main():
 
     model_class = model_classes[2]
     # model = model_class(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
-    model = model_class(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{0}")
-
+    model = model_class(
+        config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{0}"
+    )
 
     model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"
     # if os.path.exists(model_save_path):
@@ -280,7 +286,7 @@ def main():
         #     verbose=0,
         # ),
         log_interval=1,
-        progress_bar=True
+        progress_bar=True,
     )
     # run.finish()
     model.save(model_save_path)
@@ -293,7 +299,7 @@ def main():
     check_env(env)
     obs, info = env.reset()
     while True:
-        time.sleep(1/30)
+        time.sleep(1 / 30)
         action, _states = model.predict(obs, deterministic=False)
         obs, reward, terminated, truncated, info = env.step(int(action))
         env.render()
diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py
index 6d8dd883..00a5794c 100644
--- a/overcooked_simulator/overcooked_environment.py
+++ b/overcooked_simulator/overcooked_environment.py
@@ -5,6 +5,7 @@ import inspect
 import json
 import logging
 import sys
+from collections import deque
 from datetime import timedelta, datetime
 from enum import Enum
 from pathlib import Path
@@ -20,11 +21,15 @@ from overcooked_simulator.counter_factory import CounterFactory
 from overcooked_simulator.counters import (
     Counter,
     PlateConfig,
+    CookingCounter,
+    Dispenser,
 )
 from overcooked_simulator.effect_manager import EffectManager
 from overcooked_simulator.game_items import (
     ItemInfo,
     ItemType,
+    CookingEquipment,
+    Item,
 )
 from overcooked_simulator.hooks import (
     ITEM_INFO_LOADED,
@@ -50,7 +55,11 @@ from overcooked_simulator.order import (
     OrderConfig,
 )
 from overcooked_simulator.player import Player, PlayerConfig
-from overcooked_simulator.utils import create_init_env_time, get_closest
+from overcooked_simulator.utils import (
+    create_init_env_time,
+    get_closest,
+    VectorStateGenerationData,
+)
 
 log = logging.getLogger(__name__)
 
@@ -194,7 +203,7 @@ class Environment:
             """The allowed meals depend on the `environment_config.yml` configured behaviour. Either all meals that 
             are possible or only a limited subset."""
 
-        self.order_and_score = OrderManager(
+        self.order_manager = OrderManager(
             order_config=self.environment_config["orders"],
             available_meals={
                 item: info
@@ -225,7 +234,7 @@ class Environment:
                     else {}
                 )
             ),
-            order_manager=self.order_and_score,
+            order_manager=self.order_manager,
             effect_manager_config=self.environment_config["effect_manager"],
             hook=self.hook,
             random=self.random,
@@ -268,7 +277,7 @@ class Environment:
         )
         """Counters that needs to be called in the step function via the `progress` method."""
 
-        self.order_and_score.create_init_orders(self.env_time)
+        self.order_manager.create_init_orders(self.env_time)
         self.start_time = self.env_time
         """The relative env time when it started."""
         self.env_time_end = self.env_time + timedelta(
@@ -281,6 +290,8 @@ class Environment:
             str, EffectManager
         ] = self.counter_factory.setup_effect_manger(self.counters)
 
+        self.vector_state_generation = self.setup_vectorization()
+
         self.hook(
             ENV_INITIALIZED,
             environment_config=env_config,
@@ -665,8 +676,8 @@ class Environment:
 
         for idx, p in enumerate(self.players.values()):
             if not (new_positions[idx] == player_positions[idx]).all():
-                p.turn(player_movement_vectors[idx])
                 p.move_abs(new_positions[idx])
+            p.turn(player_movement_vectors[idx])
 
     def add_player(self, player_name: str, pos: npt.NDArray = None):
         """Add a player to the environment.
@@ -742,7 +753,7 @@ class Environment:
 
             for counter in self.progressing_counters:
                 counter.progress(passed_time=passed_time, now=self.env_time)
-            self.order_and_score.progress(passed_time=passed_time, now=self.env_time)
+            self.order_manager.progress(passed_time=passed_time, now=self.env_time)
             for effect_manager in self.effect_manager.values():
                 effect_manager.progress(passed_time=passed_time, now=self.env_time)
         # self.hook(POST_STEP, passed_time=passed_time)
@@ -757,7 +768,7 @@ class Environment:
             "players": self.players,
             "counters": self.counters,
             "score": self.score,
-            "orders": self.order_and_score.open_orders,
+            "orders": self.order_manager.open_orders,
             "ended": self.game_ended,
             "env_time": self.env_time,
             "remaining_time": max(self.env_time_end - self.env_time, timedelta(0)),
@@ -771,7 +782,7 @@ class Environment:
                 "counters": [c.to_dict() for c in self.counters],
                 "kitchen": {"width": self.kitchen_width, "height": self.kitchen_height},
                 "score": self.score,
-                "orders": self.order_and_score.order_state(),
+                "orders": self.order_manager.order_state(),
                 "ended": self.game_ended,
                 "env_time": self.env_time.isoformat(),
                 "remaining_time": max(
@@ -793,6 +804,246 @@ class Environment:
             return json_data
         raise ValueError(f"No valid {player_id=}")
 
+    def setup_vectorization(self) -> VectorStateGenerationData:
+        grid_base_array = np.zeros(
+            (
+                int(self.kitchen_width),
+                int(self.kitchen_height),
+                114 + 12 + 4,  # TODO calc based on item info
+            ),
+            dtype=np.float32,
+        )
+        counter_list = [
+            "Counter",
+            "CuttingBoard",
+            "ServingWindow",
+            "Trashcan",
+            "Sink",
+            "SinkAddon",
+            "Stove",
+            "DeepFryer",
+            "Oven",
+        ]
+        grid_idxs = [
+            (x, y)
+            for x in range(int(self.kitchen_width))
+            for y in range(int(self.kitchen_height))
+        ]
+        # counters do not move
+        for counter in self.counters:
+            grid_idx = np.floor(counter.pos).astype(int)
+            counter_name = (
+                counter.name
+                if isinstance(counter, CookingCounter)
+                else (
+                    repr(counter)
+                    if isinstance(Counter, Dispenser)
+                    else counter.__class__.__name__
+                )
+            )
+            assert counter_name in counter_list or counter_name.endswith(
+                "Dispenser"
+            ), f"Unknown Counter {counter}"
+            oh_idx = len(counter_list)
+            if counter_name in counter_list:
+                oh_idx = counter_list.index(counter_name)
+
+            one_hot = [0] * (len(counter_list) + 2)
+            one_hot[oh_idx] = 1
+            grid_base_array[
+                grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2)
+            ] = np.array(one_hot, dtype=np.float32)
+
+            grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
+
+        for free_idx in grid_idxs:
+            one_hot = [0] * (len(counter_list) + 2)
+            one_hot[len(counter_list) + 1] = 1
+            grid_base_array[
+                free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2)
+            ] = np.array(one_hot, dtype=np.float32)
+
+        player_info_base_array = np.zeros(
+            (
+                4,
+                4 + 114,
+            ),
+            dtype=np.float32,
+        )
+        order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32)
+
+        return VectorStateGenerationData(
+            grid_base_array=grid_base_array,
+            oh_len=12,
+        )
+
+    def get_simple_vectorized_item(self, item: Item) -> npt.NDArray[float]:
+        name = item.name
+        array = np.zeros(21, dtype=np.float32)
+        if item.name.startswith("Burnt"):
+            name = name[len("Burnt") :]
+            array[0] = 1.0
+        if name.startswith("Chopped"):
+            array[1] = 1.0
+            name = name[len("Chopped") :]
+        if name in [
+            "PizzaBase",
+            "GratedCheese",
+            "RawChips",
+            "RawPatty",
+        ]:
+            array[1] = 1.0
+            name = {
+                "PizzaBase": "Dough",
+                "GratedCheese": "Cheese",
+                "RawChips": "Potato",
+                "RawPatty": "Meat",
+            }[name]
+        if name == "CookedPatty":
+            array[2] = 1.0
+            name = "Meat"
+
+        if name in self.vector_state_generation.meals:
+            idx = self.vector_state_generation.meals.index(name)
+        elif name in self.vector_state_generation.ingredients:
+            idx = len(
+                self.vector_state_generation.meals
+            ) + self.vector_state_generation.ingredients.index(name)
+        else:
+            raise ValueError(f"Unknown item {name} - {item}")
+        array[idx] = 1.0
+        return array
+
+    def get_vectorized_item(self, item: Item) -> npt.NDArray[float]:
+        item_array = np.zeros(114, dtype=np.float32)
+
+        if isinstance(item, CookingEquipment) or item.item_info.type == ItemType.Tool:
+            assert (
+                item.name in self.vector_state_generation.equipments
+            ), f"unknown equipment {item}"
+            idx = self.vector_state_generation.equipments.index(item.name)
+            item_array[idx] = 1.0
+            if isinstance(item, CookingEquipment):
+                for s_idx, sub_item in enumerate(item.content_list):
+                    if s_idx > 3:
+                        print("Too much content in the content list, info dropped")
+                        break
+                    start_idx = len(self.vector_state_generation.equipments) + 21 + 2
+                    item_array[
+                        start_idx + (s_idx * (21)) : start_idx + ((s_idx + 1) * (21))
+                    ] = self.get_simple_vectorized_item(sub_item)
+
+        else:
+            item_array[
+                len(self.vector_state_generation.equipments) : len(
+                    self.vector_state_generation.equipments
+                )
+                + 21
+            ] = self.get_simple_vectorized_item(item)
+
+        item_array[
+            len(self.vector_state_generation.equipments) + 21 + 1
+        ] = item.progress_percentage
+
+        if item.active_effects:
+            item_array[
+                len(self.vector_state_generation.equipments) + 21 + 2
+            ] = 1.0  # TODO percentage of fire...
+
+        return item_array
+
+    def get_vectorized_state(
+        self, player_id: str
+    ) -> Tuple[
+        npt.NDArray[npt.NDArray[float]],
+        npt.NDArray[npt.NDArray[float]],
+        float,
+        npt.NDArray[float],
+    ]:
+        grid_array = self.vector_state_generation.grid_base_array.copy()
+        for counter in self.counters:
+            grid_idx = np.floor(counter.pos).astype(int)  # store in counter?
+            if counter.occupied_by:
+                if isinstance(counter.occupied_by, deque):
+                    ...
+                else:
+                    item = counter.occupied_by
+                    grid_array[
+                        grid_idx[0],
+                        grid_idx[1],
+                        4 + self.vector_state_generation.oh_len :,
+                    ] = self.get_vectorized_item(item)
+            if counter.active_effects:
+                grid_array[
+                    grid_idx[0],
+                    grid_idx[1],
+                    4 + self.vector_state_generation.oh_len - 1,
+                ] = 1.0  # TODO percentage of fire...
+
+        assert len(self.players) <= 4, "To much players for vector representation"
+        player_vec = np.zeros(
+            (
+                4,
+                4 + 114,
+            ),
+            dtype=np.float32,
+        )
+        player_pos = 1
+        for player in self.players.values():
+            if player.name == player_id:
+                idx = 0
+                player_vec[0, :4] = np.array(
+                    [
+                        player.pos[0],
+                        player.pos[1],
+                        player.facing_point[0],
+                        player.facing_point[1],
+                    ],
+                    dtype=np.float32,
+                )
+            else:
+                idx = player_pos
+
+            if not idx:
+                player_pos += 1
+            grid_idx = np.floor(player.pos).astype(int)  # store in counter?
+            player_vec[idx, :4] = np.array(
+                [
+                    player.pos[0] - grid_idx[0],
+                    player.pos[1] - grid_idx[1],
+                    player.facing_point[0] / np.linalg.norm(player.facing_point),
+                    player.facing_point[1] / np.linalg.norm(player.facing_point),
+                ],
+                dtype=np.float32,
+            )
+            grid_array[grid_idx[0], grid_idx[1], idx] = 1.0
+
+            if player.holding:
+                player_vec[idx, 4:] = self.get_vectorized_item(player.holding)
+
+        order_array = np.zeros((10 * (8 + 1)), dtype=np.float32)
+
+        for i, order in enumerate(self.order_manager.open_orders):
+            if i > 9:
+                print("some orders are not represented in the vectorized state")
+                break
+            assert (
+                order.meal.name in self.vector_state_generation.meals
+            ), "unknown meal in order"
+            idx = self.vector_state_generation.meals.index(order.meal.name)
+            order_array[(i * 9) + idx] = 1.0
+            order_array[(i * 9) + 8] = (
+                self.env_time - order.start_time
+            ).total_seconds() / order.max_duration.total_seconds()
+
+        return (
+            grid_array,
+            player_vec,
+            (self.env_time - self.start_time).total_seconds()
+            / (self.env_time_end - self.start_time).total_seconds(),
+            order_array,
+        )
+
     def reset_env_time(self):
         """Reset the env time to the initial time, defined by `create_init_env_time`."""
         self.hook(PRE_RESET_ENV_TIME)
diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py
index b78d44af..357d2679 100644
--- a/overcooked_simulator/utils.py
+++ b/overcooked_simulator/utils.py
@@ -25,6 +25,46 @@ if TYPE_CHECKING:
 from overcooked_simulator.player import Player
 
 
+@dataclasses.dataclass
+class VectorStateGenerationData:
+    grid_base_array: npt.NDArray[npt.NDArray[float]]
+    oh_len: int
+
+    number_normal_ingredients = 10
+
+    meals = [
+        "Chips",
+        "FriedFish",
+        "Burger",
+        "Salad",
+        "TomatoSoup",
+        "OnionSoup",
+        "FishAndChips",
+        "Pizza",
+    ]
+    equipments = [
+        "Pot",
+        "Pan",
+        "Basket",
+        "Peel",
+        "Plate",
+        "DirtyPlate",
+        "Extinguisher",
+    ]
+    ingredients = [
+        "Tomato",
+        "Lettuce",
+        "Onion",
+        "Meat",
+        "Bun",
+        "Potato",
+        "Fish",
+        "Dough",
+        "Cheese",
+        "Sausage",
+    ]
+
+
 def create_init_env_time():
     """Init time of the environment time, because all environments should have the same internal time."""
     return datetime(
-- 
GitLab