Newer
Older

Christoph Kowalski
committed
from abc import abstractmethod
import uuid
from copy import deepcopy
from datetime import timedelta
from enum import Enum
from pathlib import Path
from uuid import uuid4
import cv2
import numpy as np
import yaml
from gymnasium import spaces, Env

Christoph Kowalski
committed
from hydra.utils import instantiate

Christoph Kowalski
committed
from omegaconf import OmegaConf
from cooperative_cuisine import ROOT_DIR
from cooperative_cuisine.action import ActionType, InterActionData, Action

Christoph Kowalski
committed
from cooperative_cuisine.counters import Counter
from cooperative_cuisine.environment import (
from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer, CacheFlags
"""Enumeration of actions for simple action spaces for an RL agent."""
Up = "Up"
Down = "Down"
Left = "Left"
Right = "Right"
Interact = "Interact"
Put = "Put"
def get_env_action(player_id, simple_action, duration):

Christoph Kowalski
committed

Christoph Kowalski
committed
Args:
player_id: id of the player
simple_action: an action in the form of a SimpleActionSpace
duration: for how long an action should be conducted
Returns: a concrete action

Christoph Kowalski
committed
match simple_action:
case SimpleActionSpace.Up:
return Action(
player_id,
ActionType.MOVEMENT,
np.array([0, -1]),
duration,
)
case SimpleActionSpace.Left:
return Action(
player_id,
ActionType.MOVEMENT,
np.array([-1, 0]),
duration,
)
case SimpleActionSpace.Down:
return Action(
player_id,
ActionType.MOVEMENT,
np.array([0, 1]),
duration,
)
case SimpleActionSpace.Right:
return Action(
player_id,
ActionType.MOVEMENT,
np.array([1, 0]),
duration,
)
case SimpleActionSpace.Put:
return Action(
player_id,
InterActionData.START,
duration,
)
case SimpleActionSpace.Interact:
return Action(
player_id,
ActionType.INTERACT,
InterActionData.START,
duration,
)
with open(ROOT_DIR / "pygame_2d_vis" / "visualization.yaml", "r") as file:
visualizer: Visualizer = Visualizer(config=visualization_config)
visualizer.create_player_colors(1)
visualizer.set_grid_size(40)

Christoph Kowalski
committed
"""
Shuffles the counters of an environment

Christoph Kowalski
committed
Args:
env: the environment object
other_counters = []
for counter in env.counters:
if counter.__class__ != Counter:
sample_counter.append(counter)
else:
other_counters.append(counter)
new_counter_pos = [c.pos for c in sample_counter]
random.shuffle(new_counter_pos)
for counter, new_pos in zip(sample_counter, new_counter_pos):
counter.pos = new_pos
sample_counter.extend(other_counters)
env.overwrite_counters(sample_counter)
class StateToObservationConverter:
"""
Abstract definition of a class that gets and environment and outputs a state representation for rl
"""
@abstractmethod
def setup(self, env, item_info):
...
@abstractmethod
def convert_state_to_observation(self, state) -> np.ndarray:
...
class EnvGymWrapper(Env):
"""Should enable this:
observation, reward, terminated, truncated, info = env.step(action)
"""
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10}
def __init__(self, config):
"""
Initializes all necessary variables

Christoph Kowalski
committed
Args:
config:gets the rl and environment configuration from hydra
self.randomize_counter_placement = False

Christoph Kowalski
committed
self.use_rgb_obs = False # if False uses simple vectorized state
config_env = OmegaConf.to_container(config.environment, resolve=True)
config_item_info = OmegaConf.to_container(config.item_info, resolve=True)
for val in config_env['hook_callbacks']:
config_env['hook_callbacks'][val]["callback_class"] = instantiate(config_env['hook_callbacks'][val]["callback_class"])
config_env["orders"]["order_gen_class"] = instantiate(config_env["orders"]["order_generator"])
self.config_env = config_env
self.config_item_info = config_item_info
layout_file = config_env["layout_name"]

Christoph Kowalski
committed
layout_path: Path = ROOT_DIR / layout_file
with open(layout_path, "r") as file:
self.layout = file.read()
env_config=deepcopy(config_env),
layout_config=self.layout,
item_info=deepcopy(config_item_info),
env_name=uuid.uuid4().hex,
)
if self.randomize_counter_placement:
shuffle_counters(self.env)
self.player_name = str(0)
self.env.add_player(self.player_name)
self.player_id = list(self.env.players.keys())[0]
self.visualizer = visualizer
# self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)}
self.action_space_map = {}
for idx, item in enumerate(SimpleActionSpace):
self.action_space_map[idx] = item
self.global_step_time = 1
self.in_between_steps = 1
self.action_space = spaces.Discrete(len(self.action_space_map))
self.converter = instantiate(config.additional_configs.state_converter)
# self.converter.setup could also get the item info config in order to get all the possible items.
self.converter.setup(self.env, self.config_item_info)
if hasattr(self.converter, "onehot"):

Christoph Kowalski
committed
self.onehot_state = self.converter.onehot
else:
self.onehot_state = 'onehot' in config.additional_configs.state_converter.lower()
min_obs_val = -1 if not self.use_rgb_obs else 0
max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 20
self.observation_space = spaces.Box(
low=min_obs_val,
high=max_obs_val,
shape=dummy_obs.shape,
dtype=np.uint8 if self.use_rgb_obs else int,
)
self.last_obs = dummy_obs
self.step_counter = 0
self.prev_score = 0
def step(self, action):
"""
takes one step in the environment and returns the observation, reward, info whether terminated, truncated
and additional information
"""

Christoph Kowalski
committed
# this is simply a work-around to enable no action which is necessary for the play_gym.py
if action == 8:
observation = self.get_observation()
reward = self.env.score - self.prev_score
terminated = self.env.game_ended
truncated = self.env.game_ended
info = {}
return observation, reward, terminated, truncated, info
simple_action = self.action_space_map[action]
env_action = get_env_action(
self.player_id, simple_action, self.global_step_time
)
self.env.perform_action(env_action)
for i in range(self.in_between_steps):
self.env.step(
timedelta(seconds=self.global_step_time / self.in_between_steps)
)
observation = self.get_observation()
reward = self.env.score - self.prev_score
self.prev_score = self.env.score
if reward > 0.6:
print("- - - - - - - - - - - - - - - - SCORED", reward)
terminated = self.env.game_ended
truncated = self.env.game_ended
info = {}
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
"""
Resets the environment according to the configs
"""

Christoph Kowalski
committed
if self.env.env_name in visualizer.surface_cache_dict:
del visualizer.surface_cache_dict[self.env.env_name]

Christoph Kowalski
committed
env_config=deepcopy(self.config_env),
layout_config=self.layout,
item_info=deepcopy(self.config_item_info),
env_name=uuid.uuid4().hex,
)
if self.randomize_counter_placement:
shuffle_counters(self.env)
self.player_name = str(0)
self.env.add_player(self.player_name)
self.player_id = list(self.env.players.keys())[0]
info = {}
obs = self.get_observation()
self.prev_score = 0
return obs, info
def get_observation(self):
if self.use_rgb_obs:
obs = self.get_env_img()
else:
obs = self.get_vector_state()
return obs
def render(self):
return self.get_env_img()
def get_env_img(self):
state = self.env.get_json_state(player_id=self.player_id)
json_dict = json.loads(state)
observation = self.visualizer.get_state_image(state=json_dict, env_id_ref=self.env.env_name).astype(np.uint8)
return observation

Christoph Kowalski
committed
obs = self.converter.convert_state_to_observation(self.env)
return obs
def sample_random_action(self):
act = self.action_space.sample()
return act