diff --git a/overcooked_simulator/game_content/layouts/rl.layout b/overcooked_simulator/game_content/layouts/rl.layout
deleted file mode 100644
index 5115af280b02654e287bb9365ec0097a676c1986..0000000000000000000000000000000000000000
--- a/overcooked_simulator/game_content/layouts/rl.layout
+++ /dev/null
@@ -1,4 +0,0 @@
-#X##
-T__W
-U__P
-#C##
diff --git a/overcooked_simulator/gym_env.py b/overcooked_simulator/gym_env.py
deleted file mode 100644
index 47dbc108da7edc607936d3cfb175bf97e5a9cb48..0000000000000000000000000000000000000000
--- a/overcooked_simulator/gym_env.py
+++ /dev/null
@@ -1,399 +0,0 @@
-import json
-import random
-import time
-from datetime import timedelta
-from enum import Enum
-from pathlib import Path
-
-import cv2
-import numpy as np
-import wandb
-import yaml
-from gymnasium import spaces, Env
-from stable_baselines3 import A2C
-from stable_baselines3 import DQN
-from stable_baselines3 import PPO
-from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
-from stable_baselines3.common.env_checker import check_env
-from stable_baselines3.common.env_util import make_vec_env
-from stable_baselines3.common.vec_env import VecVideoRecorder
-from wandb.integration.sb3 import WandbCallback
-
-from overcooked_simulator import ROOT_DIR
-from overcooked_simulator.counters import Counter
-from overcooked_simulator.gui_2d_vis.drawing import Visualizer
-from overcooked_simulator.overcooked_environment import (
-    Environment,
-    Action,
-    ActionType,
-    InterActionData,
-)
-
-SimpleActionSpace = Enum(
-    "SimpleActionSpace",
-    [
-        "Up",
-        # "Up_Left",
-        "Left",
-        # "Down_Left",
-        "Down",
-        # "Down_Right",
-        "Right",
-        # "Right_Up",
-        "Interact",
-        "Put",
-    ],
-)
-
-
-def get_env_action(player_id, simple_action, duration):
-    match simple_action:
-        case SimpleActionSpace.Up:
-            return Action(
-                player_id,
-                ActionType.MOVEMENT,
-                np.array([0, -1]),
-                duration,
-            )
-        # case SimpleActionSpace.Up_Left:
-        #     return Action(
-        #         player_id,
-        #         ActionType.MOVEMENT,
-        #         np.array([-1, -1]),
-        #         duration,
-        #     )
-        case SimpleActionSpace.Left:
-            return Action(
-                player_id,
-                ActionType.MOVEMENT,
-                np.array([-1, 0]),
-                duration,
-            )
-        # case SimpleActionSpace.Down_Left:
-        #     return Action(
-        #         player_id,
-        #         ActionType.MOVEMENT,
-        #         np.array([-1, 1]),
-        #         duration,
-        #     )
-        case SimpleActionSpace.Down:
-            return Action(
-                player_id,
-                ActionType.MOVEMENT,
-                np.array([0, 1]),
-                duration,
-            )
-        # case SimpleActionSpace.Down_Right:
-        #     return Action(
-        #         player_id,
-        #         ActionType.MOVEMENT,
-        #         np.array([1, 1]),
-        #         duration,
-        #     )
-        case SimpleActionSpace.Right:
-            return Action(
-                player_id,
-                ActionType.MOVEMENT,
-                np.array([1, 0]),
-                duration,
-            )
-        # case SimpleActionSpace.Right_Up:
-        #     return Action(
-        #         player_id,
-        #         ActionType.MOVEMENT,
-        #         np.array([1, -1]),
-        #         duration,
-        #     )
-        case SimpleActionSpace.Put:
-            return Action(
-                player_id,
-                ActionType.PUT,
-                InterActionData.START,
-                duration,
-            )
-        case SimpleActionSpace.Interact:
-            return Action(
-                player_id,
-                ActionType.INTERACT,
-                InterActionData.START,
-                duration,
-            )
-        case other:
-            print("FAIL", simple_action)
-
-
-environment_config_path = ROOT_DIR / "game_content" / "environment_config_rl.yaml"
-layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "rl.layout"
-item_info_path = ROOT_DIR / "game_content" / "item_info_rl.yaml"
-with open(item_info_path, "r") as file:
-    item_info = file.read()
-with open(layout_path, "r") as file:
-    layout = file.read()
-with open(environment_config_path, "r") as file:
-    environment_config = file.read()
-with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file:
-    visualization_config = yaml.safe_load(file)
-
-
-def shuffle_counters(env):
-    sample_counter = []
-    sample_counter = []
-    for counter in env.counters:
-        if counter.__class__ != Counter:
-            sample_counter.append(counter)
-    new_counter_pos = [c.pos for c in sample_counter]
-    random.shuffle(new_counter_pos)
-    for counter, new_pos in zip(sample_counter, new_counter_pos):
-        counter.pos = new_pos
-    env.vector_state_generation = env.setup_vectorization()
-
-
-class EnvGymWrapper(Env):
-    """Should enable this:
-    observation, reward, terminated, truncated, info = env.step(action)
-    """
-
-    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}
-
-    def __init__(self):
-        super().__init__()
-
-        self.gridsize = 20
-
-        self.randomize_counter_placement = False
-        self.use_rgb_obs = False
-
-        self.env: Environment = Environment(
-            env_config=environment_config,
-            layout_config=layout,
-            item_info=item_info,
-            as_files=False,
-        )
-
-        if self.randomize_counter_placement:
-            shuffle_counters(self.env)
-
-        self.visualizer: Visualizer = Visualizer(config=visualization_config)
-        self.player_name = str(0)
-        self.env.add_player(self.player_name)
-        self.player_id = list(self.env.players.keys())[0]
-
-        self.env.setup_vectorization()
-
-        self.visualizer.create_player_colors(1)
-
-        # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)}
-        self.action_space_map = {}
-        for idx, item in enumerate(SimpleActionSpace):
-            self.action_space_map[idx] = item
-
-        self.global_step_time = 1
-        self.in_between_steps = 1
-
-        self.action_space = spaces.Discrete(len(self.action_space_map))
-        # Example for using image as input (channel-first; channel-last also works):
-
-        dummy_obs = self.get_observation()
-        self.observation_space = spaces.Box(
-            low=-1, high=8, shape=dummy_obs.shape, dtype=int
-        )
-
-        self.last_obs = dummy_obs
-
-        self.step_counter = 0
-        self.prev_score = 0
-
-    def step(self, action):
-        simple_action = self.action_space_map[action]
-        env_action = get_env_action(
-            self.player_id, simple_action, self.global_step_time
-        )
-        self.env.perform_action(env_action)
-
-        for i in range(self.in_between_steps):
-            self.env.step(
-                timedelta(seconds=self.global_step_time / self.in_between_steps)
-            )
-
-        observation = self.get_observation()
-
-        reward = self.env.score - self.prev_score
-        self.prev_score = self.env.score
-
-        terminated = self.env.game_ended
-        truncated = self.env.game_ended
-        info = {}
-
-        # self.render(self.gridsize)
-        return observation, reward, terminated, truncated, info
-
-    def reset(self, seed=None, options=None):
-        self.env: Environment = Environment(
-            env_config=environment_config,
-            layout_config=layout,
-            item_info=item_info,
-            as_files=False,
-        )
-
-        if self.randomize_counter_placement:
-            shuffle_counters(self.env)
-
-        self.player_name = str(0)
-        self.env.add_player(self.player_name)
-        self.player_id = list(self.env.players.keys())[0]
-
-        self.env.setup_vectorization()
-
-        info = {}
-        obs = self.get_observation()
-
-        self.prev_score = 0
-
-        return obs, info
-
-    def get_observation(self):
-        if self.use_rgb_obs:
-            obs = self.get_env_img(self.gridsize)
-        else:
-            obs = self.get_vector_state()
-        return obs
-
-    def render(self):
-        observation = self.get_env_img(self.gridsize)
-        img = (observation * 255.0).astype(np.uint8)
-        img = img.transpose((1, 2, 0))
-        img = cv2.resize(img, (img.shape[1], img.shape[0]))
-        return img
-
-    def close(self):
-        pass
-
-    def get_env_img(self, gridsize):
-        state = self.env.get_json_state(player_id=self.player_id)
-        json_dict = json.loads(state)
-        observation = self.visualizer.get_state_image(
-            grid_size=gridsize, state=json_dict
-        ).transpose((1, 0, 2))
-        return (observation.transpose((2, 0, 1)) / 255.0).astype(np.float32)
-
-    def get_vector_state(self):
-        # grid, player, env_time, orders = self.env.get_vectorized_state_full("0")
-        #
-        #
-        # obs = np.concatenate(
-        #     [grid.flatten(), player.flatten()], axis=0, dtype=np.float32
-        # )
-        # return obs
-
-        obs = self.env.get_vectorized_state_simple("0")
-        return obs
-
-    def sample_random_action(self):
-        act = self.action_space.sample()
-        return act
-
-
-def main():
-    rl_agent_checkpoints = Path("./rl_agent_checkpoints")
-    rl_agent_checkpoints.mkdir(exist_ok=True)
-
-    config = {
-        "policy_type": "MlpPolicy",
-        "total_timesteps": 30_000_000,  # hendric sagt eher so 300_000_000 schritte
-        "env_id": "overcooked",
-        "number_envs_parallel": 16,
-    }
-
-    debug = False
-    do_training = True
-    vec_env = True
-    number_envs_parallel = config["number_envs_parallel"]
-
-    model_classes = [A2C, DQN, PPO]
-    model_class = model_classes[2]
-
-    if vec_env:
-        env = make_vec_env(EnvGymWrapper, n_envs=number_envs_parallel)
-    else:
-        env = EnvGymWrapper()
-
-    env.render_mode = "rgb_array"
-
-    if not debug:
-        run = wandb.init(
-            project="overcooked",
-            config=config,
-            sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
-            monitor_gym=True
-            # save_code=True,  # optional
-        )
-
-        env = VecVideoRecorder(
-            env,
-            f"videos/{run.id}",
-            record_video_trigger=lambda x: x % 100_000 == 0,
-            video_length=300,
-        )
-
-    model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"
-
-    if do_training:
-        model = model_class(
-            config["policy_type"],
-            env,
-            verbose=1,
-            tensorboard_log=f"runs/{0}",
-            # n_steps=2048,
-            # n_epochs=10,
-        )
-        if debug:
-            model.learn(
-                total_timesteps=config["total_timesteps"],
-                log_interval=1,
-                progress_bar=True,
-            )
-        else:
-            checkpoint_callback = CheckpointCallback(
-                save_freq=50_000,
-                save_path="./logs/",
-                name_prefix="rl_model",
-                save_replay_buffer=True,
-                save_vecnormalize=True,
-            )
-            wandb_callback = WandbCallback(
-                model_save_path=f"models/{run.id}",
-                verbose=0,
-            )
-
-            callback = CallbackList([checkpoint_callback, wandb_callback])
-            model.learn(
-                total_timesteps=config["total_timesteps"],
-                callback=callback,
-                log_interval=1,
-                progress_bar=True,
-            )
-            run.finish()
-        model.save(model_save_path)
-
-        del model
-    print("LEARNING DONE.")
-
-    model = model_class.load(model_save_path)
-    env = EnvGymWrapper()
-
-    check_env(env)
-    obs, info = env.reset()
-    while True:
-        time.sleep(1 / 30)
-        action, _states = model.predict(obs, deterministic=False)
-        obs, reward, terminated, truncated, info = env.step(int(action))
-        print(reward)
-        rgb_img = env.render()
-        cv2.imshow("env", rgb_img)
-        cv2.waitKey(0)
-        if terminated or truncated:
-            obs, info = env.reset()
-
-
-if __name__ == "__main__":
-    main()
diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py
index c394624a9a75a9ba1f969de43ffb4ee2ca5de4f3..1a2e238db7fd4e37609c55d82c480337940ba528 100644
--- a/overcooked_simulator/overcooked_environment.py
+++ b/overcooked_simulator/overcooked_environment.py
@@ -5,7 +5,6 @@ import inspect
 import json
 import logging
 import sys
-from collections import deque
 from datetime import timedelta, datetime
 from enum import Enum
 from pathlib import Path
@@ -21,15 +20,11 @@ from overcooked_simulator.counter_factory import CounterFactory
 from overcooked_simulator.counters import (
     Counter,
     PlateConfig,
-    CookingCounter,
-    Dispenser,
 )
 from overcooked_simulator.effect_manager import EffectManager
 from overcooked_simulator.game_items import (
     ItemInfo,
     ItemType,
-    CookingEquipment,
-    Item,
 )
 from overcooked_simulator.hooks import (
     ITEM_INFO_LOADED,
@@ -59,7 +54,6 @@ from overcooked_simulator.player import Player, PlayerConfig
 from overcooked_simulator.utils import (
     create_init_env_time,
     get_closest,
-    VectorStateGenerationData,
 )
 
 log = logging.getLogger(__name__)
@@ -291,8 +285,6 @@ class Environment:
             str, EffectManager
         ] = self.counter_factory.setup_effect_manger(self.counters)
 
-        self.vector_state_generation = self.setup_vectorization()
-
         self.hook(
             ENV_INITIALIZED,
             environment_config=env_config,
@@ -805,404 +797,6 @@ class Environment:
             return json_data
         raise ValueError(f"No valid {player_id=}")
 
-    def setup_vectorization(self) -> VectorStateGenerationData:
-        grid_base_array = np.zeros(
-            (
-                int(self.kitchen_width),
-                int(self.kitchen_height),
-                114 + 12 + 4,  # TODO calc based on item info
-            ),
-            dtype=np.float32,
-        )
-        counter_list = [
-            "Counter",
-            "CuttingBoard",
-            "ServingWindow",
-            "Trashcan",
-            "Sink",
-            "SinkAddon",
-            "Stove",
-            "DeepFryer",
-            "Oven",
-        ]
-        grid_idxs = [
-            (x, y)
-            for x in range(int(self.kitchen_width))
-            for y in range(int(self.kitchen_height))
-        ]
-        # counters do not move
-        for counter in self.counters:
-            grid_idx = np.floor(counter.pos).astype(int)
-            counter_name = (
-                counter.name
-                if isinstance(counter, CookingCounter)
-                else (
-                    repr(counter)
-                    if isinstance(Counter, Dispenser)
-                    else counter.__class__.__name__
-                )
-            )
-            assert counter_name in counter_list or counter_name.endswith(
-                "Dispenser"
-            ), f"Unknown Counter {counter}"
-            oh_idx = len(counter_list)
-            if counter_name in counter_list:
-                oh_idx = counter_list.index(counter_name)
-
-            one_hot = [0] * (len(counter_list) + 2)
-            one_hot[oh_idx] = 1
-            grid_base_array[
-                grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2)
-            ] = np.array(one_hot, dtype=np.float32)
-
-            grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
-
-        for free_idx in grid_idxs:
-            one_hot = [0] * (len(counter_list) + 2)
-            one_hot[len(counter_list) + 1] = 1
-            grid_base_array[
-                free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2)
-            ] = np.array(one_hot, dtype=np.float32)
-
-        player_info_base_array = np.zeros(
-            (
-                4,
-                4 + 114,
-            ),
-            dtype=np.float32,
-        )
-        order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32)
-
-        return VectorStateGenerationData(
-            grid_base_array=grid_base_array,
-            oh_len=12,
-        )
-
-    def get_simple_vectorized_item(self, item: Item) -> npt.NDArray[float]:
-        name = item.name
-        array = np.zeros(21, dtype=np.float32)
-        if item.name.startswith("Burnt"):
-            name = name[len("Burnt") :]
-            array[0] = 1.0
-        if name.startswith("Chopped"):
-            array[1] = 1.0
-            name = name[len("Chopped") :]
-        if name in [
-            "PizzaBase",
-            "GratedCheese",
-            "RawChips",
-            "RawPatty",
-        ]:
-            array[1] = 1.0
-            name = {
-                "PizzaBase": "Dough",
-                "GratedCheese": "Cheese",
-                "RawChips": "Potato",
-                "RawPatty": "Meat",
-            }[name]
-        if name == "CookedPatty":
-            array[2] = 1.0
-            name = "Meat"
-
-        if name in self.vector_state_generation.meals:
-            idx = 3 + self.vector_state_generation.meals.index(name)
-        elif name in self.vector_state_generation.ingredients:
-            idx = (
-                3
-                + len(self.vector_state_generation.meals)
-                + self.vector_state_generation.ingredients.index(name)
-            )
-        else:
-            raise ValueError(f"Unknown item {name} - {item}")
-        array[idx] = 1.0
-        return array
-
-    def get_vectorized_item(self, item: Item) -> npt.NDArray[float]:
-        item_array = np.zeros(114, dtype=np.float32)
-
-        if isinstance(item, CookingEquipment) or item.item_info.type == ItemType.Tool:
-            assert (
-                item.name in self.vector_state_generation.equipments
-            ), f"unknown equipment {item}"
-            idx = self.vector_state_generation.equipments.index(item.name)
-            item_array[idx] = 1.0
-            if isinstance(item, CookingEquipment):
-                for s_idx, sub_item in enumerate(item.content_list):
-                    if s_idx > 3:
-                        print("Too much content in the content list, info dropped")
-                        break
-                    start_idx = len(self.vector_state_generation.equipments) + 21 + 2
-                    item_array[
-                        start_idx + (s_idx * (21)) : start_idx + ((s_idx + 1) * (21))
-                    ] = self.get_simple_vectorized_item(sub_item)
-
-        else:
-            item_array[
-                len(self.vector_state_generation.equipments) : len(
-                    self.vector_state_generation.equipments
-                )
-                + 21
-            ] = self.get_simple_vectorized_item(item)
-
-        item_array[
-            len(self.vector_state_generation.equipments) + 21 + 1
-        ] = item.progress_percentage
-
-        if item.active_effects:
-            item_array[
-                len(self.vector_state_generation.equipments) + 21 + 2
-            ] = 1.0  # TODO percentage of fire...
-
-        return item_array
-
-    def get_vectorized_state_full(
-        self, player_id: str
-    ) -> Tuple[
-        npt.NDArray[npt.NDArray[float]],
-        npt.NDArray[npt.NDArray[float]],
-        float,
-        npt.NDArray[float],
-    ]:
-        grid_array = self.vector_state_generation.grid_base_array.copy()
-        for counter in self.counters:
-            grid_idx = np.floor(counter.pos).astype(int)  # store in counter?
-            if counter.occupied_by:
-                if isinstance(counter.occupied_by, deque):
-                    ...
-                else:
-                    item = counter.occupied_by
-                    grid_array[
-                        grid_idx[0],
-                        grid_idx[1],
-                        4 + self.vector_state_generation.oh_len :,
-                    ] = self.get_vectorized_item(item)
-            if counter.active_effects:
-                grid_array[
-                    grid_idx[0],
-                    grid_idx[1],
-                    4 + self.vector_state_generation.oh_len - 1,
-                ] = 1.0  # TODO percentage of fire...
-
-        assert len(self.players) <= 4, "To many players for vector representation"
-        player_vec = np.zeros(
-            (
-                4,
-                4 + 114,
-            ),
-            dtype=np.float32,
-        )
-        player_pos = 1
-        for player in self.players.values():
-            if player.name == player_id:
-                idx = 0
-                player_vec[0, :4] = np.array(
-                    [
-                        player.pos[0],
-                        player.pos[1],
-                        player.facing_point[0],
-                        player.facing_point[1],
-                    ],
-                    dtype=np.float32,
-                )
-            else:
-                idx = player_pos
-
-            if not idx:
-                player_pos += 1
-            grid_idx = np.floor(player.pos).astype(int)  # store in counter?
-            player_vec[idx, :4] = np.array(
-                [
-                    player.pos[0] - grid_idx[0],
-                    player.pos[1] - grid_idx[1],
-                    player.facing_point[0] / np.linalg.norm(player.facing_point),
-                    player.facing_point[1] / np.linalg.norm(player.facing_point),
-                ],
-                dtype=np.float32,
-            )
-            grid_array[grid_idx[0], grid_idx[1], idx] = 1.0
-
-            if player.holding:
-                player_vec[idx, 4:] = self.get_vectorized_item(player.holding)
-
-        order_array = np.zeros((10 * (8 + 1)), dtype=np.float32)
-
-        for i, order in enumerate(self.order_manager.open_orders):
-            if i > 9:
-                print("some orders are not represented in the vectorized state")
-                break
-            assert (
-                order.meal.name in self.vector_state_generation.meals
-            ), "unknown meal in order"
-            idx = self.vector_state_generation.meals.index(order.meal.name)
-            order_array[(i * 9) + idx] = 1.0
-            order_array[(i * 9) + 8] = (
-                self.env_time - order.start_time
-            ).total_seconds() / order.max_duration.total_seconds()
-
-        return (
-            grid_array,
-            player_vec,
-            (self.env_time - self.start_time).total_seconds()
-            / (self.env_time_end - self.start_time).total_seconds(),
-            order_array,
-        )
-
-    # def setup_vectorization_simple(self) -> VectorStateGenerationDataSimple:
-    #     num_per_item = 114
-    #     num_per_counter = 12
-    #     num_players = 4
-    #     grid_base_array = np.zeros(
-    #         (
-    #             int(self.kitchen_width),
-    #             int(self.kitchen_height),
-    #             num_per_item
-    #             + num_per_counter
-    #             + num_players,  # TODO calc based on item info
-    #         ),
-    #         dtype=np.float32,
-    #     )
-    #     counter_list = [
-    #         "Counter",
-    #         "CuttingBoard",
-    #         "ServingWindow",
-    #         "Trashcan",
-    #         "Sink",
-    #         "SinkAddon",
-    #         "Stove",
-    #         "DeepFryer",
-    #         "Oven",
-    #     ]
-    #     grid_idxs = [
-    #         (x, y)
-    #         for x in range(int(self.kitchen_width))
-    #         for y in range(int(self.kitchen_height))
-    #     ]
-    #     # counters do not move
-    #     for counter in self.counters:
-    #         grid_idx = np.floor(counter.pos).astype(int)
-    #         counter_name = (
-    #             counter.name
-    #             if isinstance(counter, CookingCounter)
-    #             else (
-    #                 repr(counter)
-    #                 if isinstance(Counter, Dispenser)
-    #                 else counter.__class__.__name__
-    #             )
-    #         )
-    #         assert counter_name in counter_list or counter_name.endswith(
-    #             "Dispenser"
-    #         ), f"Unknown Counter {counter}"
-    #         oh_idx = len(counter_list)
-    #         if counter_name in counter_list:
-    #             oh_idx = counter_list.index(counter_name)
-    #
-    #         one_hot = [0] * (len(counter_list) + 2)
-    #         one_hot[oh_idx] = 1
-    #         grid_base_array[
-    #             grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2)
-    #         ] = np.array(one_hot, dtype=np.float32)
-    #
-    #         grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
-    #
-    #     for free_idx in grid_idxs:
-    #         one_hot = [0] * (len(counter_list) + 2)
-    #         one_hot[len(counter_list) + 1] = 1
-    #         grid_base_array[
-    #             free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2)
-    #         ] = np.array(one_hot, dtype=np.float32)
-    #
-    #     player_info_base_array = np.zeros(
-    #         (
-    #             4,
-    #             4 + 114,
-    #         ),
-    #         dtype=np.float32,
-    #     )
-    #     order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32)
-    #
-    #     return VectorStateGenerationData(
-    #         grid_base_array=grid_base_array,
-    #         oh_len=12,
-    #     )
-
-    def get_vectorized_state_simple(self, player):
-        item_list = ["Pot", "Tomato", "ChoppedTomato", "Plate"]
-        counter_list = [
-            "Counter",
-            "PlateDispenser",
-            "TomatoDispenser",
-            "ServingWindow",
-            "PlateReturn",
-            "Trashcan",
-            "Stove",
-            "CuttingBoard",
-        ]
-        player_pos = self.players[player].pos
-        player_dir = self.players[player].facing_direction
-
-        grid_width, grid_height = int(self.kitchen_width), int(self.kitchen_height)
-
-        counter_one_hot_length = len(counter_list) + 1  # one for empty field
-        grid_base_array = np.zeros(
-            (
-                grid_width,
-                grid_height,
-            ),
-            dtype=int,
-        )
-
-        grid_idxs = [(x, y) for x in range(grid_width) for y in range(grid_height)]
-
-        # counters do not move
-        for counter in self.counters:
-            grid_idx = np.floor(counter.pos).astype(int)
-            counter_name = (
-                counter.name
-                if isinstance(counter, CookingCounter)
-                else (
-                    repr(counter)
-                    if isinstance(Counter, Dispenser)
-                    else counter.__class__.__name__
-                )
-            )
-            if counter_name == "Dispenser":
-                counter_name = f"{counter.occupied_by.name}Dispenser"
-            assert counter_name in counter_list, f"Unknown Counter {counter}"
-
-            counter_oh_idx = counter_one_hot_length
-            if counter_name in counter_list:
-                counter_oh_idx = counter_list.index(counter_name)
-
-            grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx
-            grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
-
-        for free_idx in grid_idxs:
-            grid_base_array[free_idx[0], free_idx[1]] = counter_one_hot_length - 1
-
-        counter_grid_one_hot = np.zeros(
-            (grid_width, grid_height, counter_one_hot_length), dtype=int
-        )
-        for x in range(grid_width):
-            for y in range(grid_height):
-                counter_type_idx = grid_base_array[x, y]
-                counter_grid_one_hot[x, y, counter_type_idx] = 1
-
-        player_data = np.concatenate((player_pos, player_dir), axis=0)
-
-        items_one_hot_length = len(item_list) + 1
-        item_one_hot = np.zeros(items_one_hot_length, dtype=int)
-        player_item = self.players[player].holding
-        player_item_idx = items_one_hot_length - 1
-        if player_item:
-            if player_item.name in item_list:
-                player_item_idx = item_list.index(player_item.name)
-        item_one_hot[player_item_idx] = 1
-
-        final = np.concatenate(
-            (counter_grid_one_hot.flatten(), player_data, item_one_hot), axis=0
-        )
-        return final
-
     def reset_env_time(self):
         """Reset the env time to the initial time, defined by `create_init_env_time`."""
         self.hook(PRE_RESET_ENV_TIME)
diff --git a/overcooked_simulator/reinforcement_learning/__init__.py b/overcooked_simulator/reinforcement_learning/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/overcooked_simulator/game_content/environment_config_rl.yaml b/overcooked_simulator/reinforcement_learning/environment_config_rl.yaml
similarity index 98%
rename from overcooked_simulator/game_content/environment_config_rl.yaml
rename to overcooked_simulator/reinforcement_learning/environment_config_rl.yaml
index 0b3b4570f92b3ab8554a6183ae56d87b11889f51..b6c9d3bc212547eaffdcd197c29ebc793b2f5d5c 100644
--- a/overcooked_simulator/game_content/environment_config_rl.yaml
+++ b/overcooked_simulator/reinforcement_learning/environment_config_rl.yaml
@@ -1,5 +1,5 @@
 plates:
-  clean_plates: 1
+  clean_plates: 2
   dirty_plates: 0
   plate_delay: [ 2, 4 ]
   return_dirty: False
@@ -108,7 +108,7 @@ extra_setup_functions:
       hooks: [ trashcan_usage ]
       callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
       callback_class_kwargs:
-        static_score: -0.15
+        static_score: -0.5
   item_cut:
     func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class ''
     kwargs:
@@ -122,14 +122,14 @@ extra_setup_functions:
       hooks: [ post_step ]
       callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
       callback_class_kwargs:
-        static_score: -0.01
+        static_score: -0.05
   combine:
     func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class ''
     kwargs:
       hooks: [ drop_off_on_cooking_equipment ]
       callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
       callback_class_kwargs:
-        static_score: 0.10
+        static_score: 0.15
   #  json_states:
   #    func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
   #    kwargs:
diff --git a/overcooked_simulator/reinforcement_learning/gym_env.py b/overcooked_simulator/reinforcement_learning/gym_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b87ad085c3bb1938586378a97a4751f0c914441
--- /dev/null
+++ b/overcooked_simulator/reinforcement_learning/gym_env.py
@@ -0,0 +1,703 @@
+import json
+import random
+import time
+from collections import deque
+from datetime import timedelta
+from enum import Enum
+from pathlib import Path
+from typing import Tuple
+
+import cv2
+import numpy as np
+import numpy.typing as npt
+import wandb
+import yaml
+from gymnasium import spaces, Env
+from stable_baselines3 import A2C
+from stable_baselines3 import DQN
+from stable_baselines3 import PPO
+from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
+from stable_baselines3.common.env_checker import check_env
+from stable_baselines3.common.env_util import make_vec_env
+from stable_baselines3.common.vec_env import VecVideoRecorder
+from wandb.integration.sb3 import WandbCallback
+
+from overcooked_simulator import ROOT_DIR
+from overcooked_simulator.counters import Counter, CookingCounter, Dispenser
+from overcooked_simulator.game_items import Item, CookingEquipment, ItemType
+from overcooked_simulator.gui_2d_vis.drawing import Visualizer
+from overcooked_simulator.overcooked_environment import (
+    Environment,
+    Action,
+    ActionType,
+    InterActionData,
+)
+from overcooked_simulator.utils import VectorStateGenerationData
+
+SimpleActionSpace = Enum(
+    "SimpleActionSpace",
+    [
+        "Up",
+        "Left",
+        "Down",
+        "Right",
+        "Interact",
+        "Put",
+    ],
+)
+
+
+def get_env_action(player_id, simple_action, duration):
+    match simple_action:
+        case SimpleActionSpace.Up:
+            return Action(
+                player_id,
+                ActionType.MOVEMENT,
+                np.array([0, -1]),
+                duration,
+            )
+        case SimpleActionSpace.Left:
+            return Action(
+                player_id,
+                ActionType.MOVEMENT,
+                np.array([-1, 0]),
+                duration,
+            )
+        case SimpleActionSpace.Down:
+            return Action(
+                player_id,
+                ActionType.MOVEMENT,
+                np.array([0, 1]),
+                duration,
+            )
+        case SimpleActionSpace.Right:
+            return Action(
+                player_id,
+                ActionType.MOVEMENT,
+                np.array([1, 0]),
+                duration,
+            )
+        case SimpleActionSpace.Put:
+            return Action(
+                player_id,
+                ActionType.PUT,
+                InterActionData.START,
+                duration,
+            )
+        case SimpleActionSpace.Interact:
+            return Action(
+                player_id,
+                ActionType.INTERACT,
+                InterActionData.START,
+                duration,
+            )
+
+
+environment_config_path = (
+    ROOT_DIR / "reinforcement_learning" / "environment_config_rl.yaml"
+)
+layout_path: Path = ROOT_DIR / "reinforcement_learning" / "rl_small.layout"
+item_info_path = ROOT_DIR / "reinforcement_learning" / "item_info_rl.yaml"
+with open(item_info_path, "r") as file:
+    item_info = file.read()
+with open(layout_path, "r") as file:
+    layout = file.read()
+with open(environment_config_path, "r") as file:
+    environment_config = file.read()
+with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file:
+    visualization_config = yaml.safe_load(file)
+
+
+def shuffle_counters(env):
+    sample_counter = []
+    for counter in env.counters:
+        if counter.__class__ != Counter:
+            sample_counter.append(counter)
+    new_counter_pos = [c.pos for c in sample_counter]
+    random.shuffle(new_counter_pos)
+    for counter, new_pos in zip(sample_counter, new_counter_pos):
+        counter.pos = new_pos
+
+
+class EnvGymWrapper(Env):
+    """Should enable this:
+    observation, reward, terminated, truncated, info = env.step(action)
+    """
+
+    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 20}
+
+    def __init__(self):
+        super().__init__()
+
+        self.gridsize = 20
+
+        self.randomize_counter_placement = True
+        self.use_rgb_obs = False  # if False uses simple vectorized state
+        self.use_onehot = False
+        self.full_vector_state = True
+
+        self.env: Environment = Environment(
+            env_config=environment_config,
+            layout_config=layout,
+            item_info=item_info,
+            as_files=False,
+        )
+
+        if self.randomize_counter_placement:
+            shuffle_counters(self.env)
+
+        if self.full_vector_state:
+            self.vector_state_generation = self.setup_vectorization()
+
+        self.visualizer: Visualizer = Visualizer(config=visualization_config)
+        self.visualizer.create_player_colors(1)
+
+        self.player_name = str(0)
+        self.env.add_player(self.player_name)
+        self.player_id = list(self.env.players.keys())[0]
+
+        # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)}
+        self.action_space_map = {}
+        for idx, item in enumerate(SimpleActionSpace):
+            self.action_space_map[idx] = item
+        self.global_step_time = 1
+        self.in_between_steps = 1
+
+        self.action_space = spaces.Discrete(len(self.action_space_map))
+        # Example for using image as input (channel-first; channel-last also works):
+
+        min_obs_val = -1 if not self.use_rgb_obs else 0
+        max_obs_val = 1 if self.use_onehot else 255 if self.use_rgb_obs else 8
+        dummy_obs = self.get_observation()
+        self.observation_space = spaces.Box(
+            low=min_obs_val,
+            high=max_obs_val,
+            shape=dummy_obs.shape,
+            dtype=np.uint8 if self.use_rgb_obs else np.float32,
+        )
+        print(self.observation_space)
+
+        self.last_obs = dummy_obs
+
+        self.step_counter = 0
+        self.prev_score = 0
+
+    def get_vectorized_state_simple(self, player, onehot=True):
+        item_list = ["Pot", "Tomato", "ChoppedTomato", "Plate"]
+        counter_list = [
+            "Counter",
+            "PlateDispenser",
+            "TomatoDispenser",
+            "ServingWindow",
+            "PlateReturn",
+            "Trashcan",
+            "Stove",
+            "CuttingBoard",
+        ]
+
+        grid_width, grid_height = int(self.env.kitchen_width), int(
+            self.env.kitchen_height
+        )
+
+        counter_one_hot_length = len(counter_list) + 1  # one for empty field
+        grid_base_array = np.zeros(
+            (
+                grid_width,
+                grid_height,
+            ),
+            dtype=int,
+        )
+
+        grid_idxs = [(x, y) for x in range(grid_width) for y in range(grid_height)]
+
+        # counters do not move
+        for counter in self.env.counters:
+            grid_idx = np.floor(counter.pos).astype(int)
+            counter_name = (
+                counter.name
+                if isinstance(counter, CookingCounter)
+                else (
+                    repr(counter)
+                    if isinstance(Counter, Dispenser)
+                    else counter.__class__.__name__
+                )
+            )
+            if counter_name == "Dispenser":
+                counter_name = f"{counter.occupied_by.name}Dispenser"
+            assert counter_name in counter_list, f"Unknown Counter {counter}"
+
+            counter_oh_idx = counter_one_hot_length
+            if counter_name in counter_list:
+                counter_oh_idx = counter_list.index(counter_name)
+
+            grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx
+            grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
+
+        for free_idx in grid_idxs:
+            grid_base_array[free_idx[0], free_idx[1]] = counter_one_hot_length - 1
+
+        counter_grid_one_hot = np.zeros(
+            (grid_width, grid_height, counter_one_hot_length), dtype=int
+        )
+        for x in range(grid_width):
+            for y in range(grid_height):
+                counter_type_idx = grid_base_array[x, y]
+                counter_grid_one_hot[x, y, counter_type_idx] = 1
+
+        player_pos = self.env.players[player].pos
+        if onehot:
+            player_pos[0] /= self.env.kitchen_width
+            player_pos[1] /= self.env.kitchen_height
+        else:
+            player_pos = player_pos.astype(int)
+
+        player_dir = self.env.players[player].facing_direction
+        player_data = np.concatenate((player_pos, player_dir), axis=0)
+
+        items_one_hot_length = len(item_list) + 1
+        item_one_hot = np.zeros(items_one_hot_length, dtype=int)
+        player_item = self.env.players[player].holding
+        player_item_idx = items_one_hot_length - 1
+        if player_item:
+            if player_item.name in item_list:
+                player_item_idx = item_list.index(player_item.name)
+        item_one_hot[player_item_idx] = 1
+
+        final_idxs = np.concatenate(
+            (grid_base_array.flatten(), player_data, item_one_hot), axis=0
+        )
+        final_one_hot = np.concatenate(
+            (counter_grid_one_hot.flatten(), player_data, item_one_hot), axis=0
+        )
+
+        return final_one_hot if onehot else final_idxs
+
+    def setup_vectorization(self) -> VectorStateGenerationData:
+        grid_base_array = np.zeros(
+            (
+                int(self.env.kitchen_width),
+                int(self.env.kitchen_height),
+                114 + 12 + 4,  # TODO calc based on item info
+            ),
+            dtype=np.float32,
+        )
+        counter_list = [
+            "Counter",
+            "CuttingBoard",
+            "ServingWindow",
+            "Trashcan",
+            "Sink",
+            "SinkAddon",
+            "Stove",
+            "DeepFryer",
+            "Oven",
+        ]
+        grid_idxs = [
+            (x, y)
+            for x in range(int(self.env.kitchen_width))
+            for y in range(int(self.env.kitchen_height))
+        ]
+        # counters do not move
+        for counter in self.env.counters:
+            grid_idx = np.floor(counter.pos).astype(int)
+            counter_name = (
+                counter.name
+                if isinstance(counter, CookingCounter)
+                else (
+                    repr(counter)
+                    if isinstance(Counter, Dispenser)
+                    else counter.__class__.__name__
+                )
+            )
+            assert counter_name in counter_list or counter_name.endswith(
+                "Dispenser"
+            ), f"Unknown Counter {counter}"
+            oh_idx = len(counter_list)
+            if counter_name in counter_list:
+                oh_idx = counter_list.index(counter_name)
+
+            one_hot = [0] * (len(counter_list) + 2)
+            one_hot[oh_idx] = 1
+            grid_base_array[
+                grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2)
+            ] = np.array(one_hot, dtype=np.float32)
+
+            grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
+
+        for free_idx in grid_idxs:
+            one_hot = [0] * (len(counter_list) + 2)
+            one_hot[len(counter_list) + 1] = 1
+            grid_base_array[
+                free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2)
+            ] = np.array(one_hot, dtype=np.float32)
+
+        player_info_base_array = np.zeros(
+            (
+                4,
+                4 + 114,
+            ),
+            dtype=np.float32,
+        )
+        order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32)
+
+        return VectorStateGenerationData(
+            grid_base_array=grid_base_array,
+            oh_len=12,
+        )
+
+    def get_simple_vectorized_item(self, item: Item) -> npt.NDArray[float]:
+        name = item.name
+        array = np.zeros(21, dtype=np.float32)
+        if item.name.startswith("Burnt"):
+            name = name[len("Burnt") :]
+            array[0] = 1.0
+        if name.startswith("Chopped"):
+            array[1] = 1.0
+            name = name[len("Chopped") :]
+        if name in [
+            "PizzaBase",
+            "GratedCheese",
+            "RawChips",
+            "RawPatty",
+        ]:
+            array[1] = 1.0
+            name = {
+                "PizzaBase": "Dough",
+                "GratedCheese": "Cheese",
+                "RawChips": "Potato",
+                "RawPatty": "Meat",
+            }[name]
+        if name == "CookedPatty":
+            array[2] = 1.0
+            name = "Meat"
+
+        if name in self.vector_state_generation.meals:
+            idx = 3 + self.vector_state_generation.meals.index(name)
+        elif name in self.vector_state_generation.ingredients:
+            idx = (
+                3
+                + len(self.vector_state_generation.meals)
+                + self.vector_state_generation.ingredients.index(name)
+            )
+        else:
+            raise ValueError(f"Unknown item {name} - {item}")
+        array[idx] = 1.0
+        return array
+
+    def get_vectorized_item(self, item: Item) -> npt.NDArray[float]:
+        item_array = np.zeros(114, dtype=np.float32)
+
+        if isinstance(item, CookingEquipment) or item.item_info.type == ItemType.Tool:
+            assert (
+                item.name in self.vector_state_generation.equipments
+            ), f"unknown equipment {item}"
+            idx = self.vector_state_generation.equipments.index(item.name)
+            item_array[idx] = 1.0
+            if isinstance(item, CookingEquipment):
+                for s_idx, sub_item in enumerate(item.content_list):
+                    if s_idx > 3:
+                        print("Too much content in the content list, info dropped")
+                        break
+                    start_idx = len(self.vector_state_generation.equipments) + 21 + 2
+                    item_array[
+                        start_idx + (s_idx * (21)) : start_idx + ((s_idx + 1) * (21))
+                    ] = self.get_simple_vectorized_item(sub_item)
+
+        else:
+            item_array[
+                len(self.vector_state_generation.equipments) : len(
+                    self.vector_state_generation.equipments
+                )
+                + 21
+            ] = self.get_simple_vectorized_item(item)
+
+        item_array[
+            len(self.vector_state_generation.equipments) + 21 + 1
+        ] = item.progress_percentage
+
+        if item.active_effects:
+            item_array[
+                len(self.vector_state_generation.equipments) + 21 + 2
+            ] = 1.0  # TODO percentage of fire...
+
+        return item_array
+
+    def get_vectorized_state_full(
+        self, player_id: str
+    ) -> Tuple[
+        npt.NDArray[npt.NDArray[float]],
+        npt.NDArray[npt.NDArray[float]],
+        float,
+        npt.NDArray[float],
+    ]:
+        grid_array = self.vector_state_generation.grid_base_array.copy()
+        for counter in self.env.counters:
+            grid_idx = np.floor(counter.pos).astype(int)  # store in counter?
+            if counter.occupied_by:
+                if isinstance(counter.occupied_by, deque):
+                    ...
+                else:
+                    item = counter.occupied_by
+                    grid_array[
+                        grid_idx[0],
+                        grid_idx[1],
+                        4 + self.vector_state_generation.oh_len :,
+                    ] = self.get_vectorized_item(item)
+            if counter.active_effects:
+                grid_array[
+                    grid_idx[0],
+                    grid_idx[1],
+                    4 + self.vector_state_generation.oh_len - 1,
+                ] = 1.0  # TODO percentage of fire...
+
+        assert len(self.env.players) <= 4, "To many players for vector representation"
+        player_vec = np.zeros(
+            (
+                4,
+                4 + 114,
+            ),
+            dtype=np.float32,
+        )
+        player_pos = 1
+        for player in self.env.players.values():
+            if player.name == player_id:
+                idx = 0
+                player_vec[0, :4] = np.array(
+                    [
+                        player.pos[0],
+                        player.pos[1],
+                        player.facing_point[0],
+                        player.facing_point[1],
+                    ],
+                    dtype=np.float32,
+                )
+            else:
+                idx = player_pos
+
+            if not idx:
+                player_pos += 1
+            grid_idx = np.floor(player.pos).astype(int)  # store in counter?
+            player_vec[idx, :4] = np.array(
+                [
+                    player.pos[0] - grid_idx[0],
+                    player.pos[1] - grid_idx[1],
+                    player.facing_point[0] / np.linalg.norm(player.facing_point),
+                    player.facing_point[1] / np.linalg.norm(player.facing_point),
+                ],
+                dtype=np.float32,
+            )
+            grid_array[grid_idx[0], grid_idx[1], idx] = 1.0
+
+            if player.holding:
+                player_vec[idx, 4:] = self.get_vectorized_item(player.holding)
+
+        order_array = np.zeros((10 * (8 + 1)), dtype=np.float32)
+
+        for i, order in enumerate(self.env.order_manager.open_orders):
+            if i > 9:
+                print("some orders are not represented in the vectorized state")
+                break
+            assert (
+                order.meal.name in self.vector_state_generation.meals
+            ), "unknown meal in order"
+            idx = self.vector_state_generation.meals.index(order.meal.name)
+            order_array[(i * 9) + idx] = 1.0
+            order_array[(i * 9) + 8] = (
+                self.env_time - order.start_time
+            ).total_seconds() / order.max_duration.total_seconds()
+
+        return (
+            grid_array,
+            player_vec,
+            (self.env.env_time - self.env.start_time).total_seconds()
+            / (self.env.env_time_end - self.env.start_time).total_seconds(),
+            order_array,
+        )
+
+    def step(self, action):
+        simple_action = self.action_space_map[action]
+        env_action = get_env_action(
+            self.player_id, simple_action, self.global_step_time
+        )
+        self.env.perform_action(env_action)
+
+        for i in range(self.in_between_steps):
+            self.env.step(
+                timedelta(seconds=self.global_step_time / self.in_between_steps)
+            )
+
+        observation = self.get_observation()
+
+        reward = self.env.score - self.prev_score
+        self.prev_score = self.env.score
+
+        terminated = self.env.game_ended
+        truncated = self.env.game_ended
+        info = {}
+
+        return observation, reward, terminated, truncated, info
+
+    def reset(self, seed=None, options=None):
+        self.env: Environment = Environment(
+            env_config=environment_config,
+            layout_config=layout,
+            item_info=item_info,
+            as_files=False,
+        )
+
+        if self.randomize_counter_placement:
+            shuffle_counters(self.env)
+
+        self.player_name = str(0)
+        self.env.add_player(self.player_name)
+        self.player_id = list(self.env.players.keys())[0]
+
+        if self.full_vector_state:
+            self.vector_state_generation = self.setup_vectorization()
+
+        info = {}
+        obs = self.get_observation()
+
+        self.prev_score = 0
+
+        return obs, info
+
+    def get_observation(self):
+        if self.use_rgb_obs:
+            obs = self.get_env_img(self.gridsize)
+        else:
+            obs = self.get_vector_state()
+        return obs
+
+    def render(self):
+        observation = self.get_env_img(self.gridsize)
+        img = observation.astype(np.uint8)
+        img = img.transpose((1, 2, 0))
+        img = cv2.resize(img, (img.shape[1], img.shape[0]))
+        return img
+
+    def close(self):
+        pass
+
+    def get_env_img(self, gridsize):
+        state = self.env.get_json_state(player_id=self.player_id)
+        json_dict = json.loads(state)
+        observation = self.visualizer.get_state_image(
+            grid_size=gridsize, state=json_dict
+        ).transpose((1, 0, 2))
+        return (observation.transpose((2, 0, 1))).astype(np.uint8)
+
+    def get_vector_state(self):
+        obs = self.get_vectorized_state_simple("0", self.use_onehot)
+        return obs
+
+    def sample_random_action(self):
+        act = self.action_space.sample()
+        return act
+
+
+def main():
+    rl_agent_checkpoints = Path("rl_agent_checkpoints")
+    rl_agent_checkpoints.mkdir(exist_ok=True)
+
+    config = {
+        "policy_type": "MlpPolicy",
+        "total_timesteps": 30_000_000,  # hendric sagt eher so 300_000_000 schritte
+        "env_id": "overcooked",
+        "number_envs_parallel": 4,
+    }
+
+    debug = False
+    do_training = True
+    vec_env = True
+    number_envs_parallel = config["number_envs_parallel"]
+
+    model_classes = [A2C, DQN, PPO]
+    model_class = model_classes[2]
+
+    if vec_env:
+        env = make_vec_env(EnvGymWrapper, n_envs=number_envs_parallel)
+    else:
+        env = EnvGymWrapper()
+
+    env.render_mode = "rgb_array"
+
+    if not debug:
+        run = wandb.init(
+            project="overcooked",
+            config=config,
+            sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
+            monitor_gym=True,
+            # save_code=True,  # optional
+        )
+
+        env = VecVideoRecorder(
+            env,
+            f"videos/{run.id}",
+            record_video_trigger=lambda x: x % 200_000 == 0,
+            video_length=300,
+        )
+
+    model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"
+
+    if do_training:
+        model = model_class(
+            config["policy_type"],
+            env,
+            verbose=1,
+            tensorboard_log=f"runs/{0}",
+            device="cpu"
+            # n_steps=2048,
+            # n_epochs=10,
+        )
+        if debug:
+            model.learn(
+                total_timesteps=config["total_timesteps"],
+                log_interval=1,
+                progress_bar=True,
+            )
+        else:
+            checkpoint_callback = CheckpointCallback(
+                save_freq=50_000,
+                save_path="logs",
+                name_prefix="rl_model",
+                save_replay_buffer=True,
+                save_vecnormalize=True,
+            )
+            wandb_callback = WandbCallback(
+                model_save_path=f"models/{run.id}",
+                verbose=0,
+            )
+
+            callback = CallbackList([checkpoint_callback, wandb_callback])
+            model.learn(
+                total_timesteps=config["total_timesteps"],
+                callback=callback,
+                log_interval=1,
+                progress_bar=True,
+            )
+            run.finish()
+        model.save(model_save_path)
+
+        del model
+    print("LEARNING DONE.")
+
+    model = model_class.load(model_save_path)
+    env = EnvGymWrapper()
+
+    check_env(env)
+    obs, info = env.reset()
+    while True:
+        action, _states = model.predict(obs, deterministic=False)
+        obs, reward, terminated, truncated, info = env.step(int(action))
+        print(reward)
+        rgb_img = env.render()
+        cv2.imshow("env", rgb_img)
+        cv2.waitKey(0)
+        if terminated or truncated:
+            obs, info = env.reset()
+        time.sleep(1 / env.metadata["render_fps"])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/overcooked_simulator/game_content/item_info_rl.yaml b/overcooked_simulator/reinforcement_learning/item_info_rl.yaml
similarity index 100%
rename from overcooked_simulator/game_content/item_info_rl.yaml
rename to overcooked_simulator/reinforcement_learning/item_info_rl.yaml
diff --git a/overcooked_simulator/reinforcement_learning/rl.layout b/overcooked_simulator/reinforcement_learning/rl.layout
new file mode 100644
index 0000000000000000000000000000000000000000..e1e8c0757badf2379f3a29728f99f0088ea76bca
--- /dev/null
+++ b/overcooked_simulator/reinforcement_learning/rl.layout
@@ -0,0 +1,5 @@
+##X##
+T___P
+#___#
+U___#
+#C#W#
diff --git a/overcooked_simulator/reinforcement_learning/rl_small.layout b/overcooked_simulator/reinforcement_learning/rl_small.layout
new file mode 100644
index 0000000000000000000000000000000000000000..a9eda0c9e9680e3c781c2f54579c0c6309a6bc47
--- /dev/null
+++ b/overcooked_simulator/reinforcement_learning/rl_small.layout
@@ -0,0 +1,4 @@
+#X##
+T__P
+U__#
+#CW#
diff --git a/overcooked_simulator/rl_agent_checkpoints/overcooked_PPO.zip b/overcooked_simulator/rl_agent_checkpoints/overcooked_PPO.zip
deleted file mode 100644
index b79af5605fa8773501888d028bca62b6845cb2d4..0000000000000000000000000000000000000000
Binary files a/overcooked_simulator/rl_agent_checkpoints/overcooked_PPO.zip and /dev/null differ