Skip to content
Snippets Groups Projects
Commit b4dd71bf authored by Fabian Heinrich's avatar Fabian Heinrich
Browse files

Added experimentation script for meta pearl in rl

parent 2f69d600
No related branches found
No related tags found
No related merge requests found
Pipeline #49265 failed
......@@ -129,12 +129,12 @@ class EnvGymWrapper(Env):
def __init__(self):
super().__init__()
self.gridsize = 30
self.gridsize = 40
self.randomize_counter_placement = True
self.randomize_counter_placement = False
self.use_rgb_obs = False # if False uses simple vectorized state
self.full_vector_state = True
self.onehot_state = False
self.onehot_state = True
self.env: Environment = Environment(
env_config=environment_config,
......
import cv2
from pearl.action_representation_modules.one_hot_action_representation_module import (
OneHotActionTensorRepresentationModule,
)
from pearl.pearl_agent import PearlAgent
from pearl.policy_learners.sequential_decision_making.deep_q_learning import (
DeepQLearning,
)
from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import (
FIFOOffPolicyReplayBuffer,
)
from pearl.utils.instantiations.environments.gym_environment import GymEnvironment
from cooperative_cuisine.reinforcement_learning.gym_env import EnvGymWrapper
custom = True
if custom:
env = GymEnvironment(EnvGymWrapper())
else:
env = GymEnvironment("LunarLander-v2", render_mode="rgb_array")
num_actions = env.action_space.n
agent = PearlAgent(
policy_learner=DeepQLearning(
state_dim=env.observation_space.shape[0],
action_space=env.action_space,
hidden_dims=[64, 64],
training_rounds=20,
action_representation_module=OneHotActionTensorRepresentationModule(
max_number_actions=num_actions
),
),
replay_buffer=FIFOOffPolicyReplayBuffer(10_000),
)
for i in range(40):
print(i)
observation, action_space = env.reset()
agent.reset(observation, action_space)
done = False
while not done:
action = agent.act(exploit=False)
action_result = env.step(action)
agent.observe(action_result)
agent.learn()
done = action_result.done
if custom:
env = GymEnvironment(EnvGymWrapper())
else:
env = GymEnvironment("LunarLander-v2", render_mode="human")
for i in range(40):
print(i)
observation, action_space = env.reset()
agent.reset(observation, action_space)
done = False
while not done:
action = agent.act(exploit=False)
action_result = env.step(action)
agent.observe(action_result)
agent.learn()
done = action_result.done
if custom:
img = env.env.render()
cv2.imshow("image", img[:, :, ::-1])
cv2.waitKey(1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment