Skip to content
Snippets Groups Projects
Commit 7d86dd70 authored by Christoph Kowalski's avatar Christoph Kowalski
Browse files

Adapted docstring

parent 3a5fa88e
No related branches found
No related tags found
2 merge requests!110V1.2.0 changes,!109SB3 RL with Hydra
Pipeline #64011 passed
......@@ -36,7 +36,6 @@ class SimpleActionSpace(Enum):
def get_env_action(player_id, simple_action, duration):
"""
Args:
......@@ -103,7 +102,6 @@ visualizer.set_grid_size(40)
def shuffle_counters(env):
"""
Shuffles the counters of an environment
Args:
......@@ -160,7 +158,8 @@ class EnvGymWrapper(Env):
config_env = OmegaConf.to_container(config.environment, resolve=True)
config_item_info = OmegaConf.to_container(config.item_info, resolve=True)
for val in config_env['hook_callbacks']:
config_env['hook_callbacks'][val]["callback_class"] = instantiate(config_env['hook_callbacks'][val]["callback_class"])
config_env['hook_callbacks'][val]["callback_class"] = instantiate(
config_env['hook_callbacks'][val]["callback_class"])
config_env["orders"]["order_gen_class"] = instantiate(config_env["orders"]["order_generator"])
self.config_env = config_env
self.config_item_info = config_item_info
......@@ -223,6 +222,7 @@ class EnvGymWrapper(Env):
and additional information
"""
# this is simply a work-around to enable no action which is necessary for the play_gym.py
# but not for the rl agent
if action == 8:
observation = self.get_observation()
reward = self.env.score - self.prev_score
......@@ -240,19 +240,14 @@ class EnvGymWrapper(Env):
self.env.step(
timedelta(seconds=self.global_step_time / self.in_between_steps)
)
observation = self.get_observation()
reward = self.env.score - self.prev_score
self.prev_score = self.env.score
if reward > 0:
print("- - - - - - - - - - - - - - - - SCORED", reward)
terminated = self.env.game_ended
truncated = self.env.game_ended
info = {}
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
......@@ -272,16 +267,12 @@ class EnvGymWrapper(Env):
if self.randomize_counter_placement:
shuffle_counters(self.env)
self.player_name = str(0)
self.env.add_player(self.player_name)
self.player_id = list(self.env.players.keys())[0]
info = {}
obs = self.get_observation()
self.prev_score = 0
return obs, info
def get_observation(self):
......
......@@ -52,7 +52,7 @@ class AdvancedStateConverterArray(StateToObservationConverter):
setup is chosen.
Returns: An encoding for the environment state that is not onehot
Returns: An encoding for the environment state
"""
if player is not None:
......
......@@ -27,7 +27,6 @@ def main(cfg: DictConfig):
config: dict[str, Any] = OmegaConf.to_container(cfg.model, resolve=True)
env_info: dict[str, Any] = OmegaConf.to_container(cfg.environment, resolve=True)
debug: bool = additional_configs["debug_mode"]
vec_env = additional_configs["vec_env"]
number_envs_parallel = config["number_envs_parallel"]
model_class = instantiate(cfg.model.model_type)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment