-
Christoph Kowalski authoredChristoph Kowalski authored
tryout_pearl.py 2.01 KiB
import cv2
from pearl.action_representation_modules.one_hot_action_representation_module import (
OneHotActionTensorRepresentationModule,
)
from pearl.pearl_agent import PearlAgent
from pearl.policy_learners.sequential_decision_making.deep_q_learning import (
DeepQLearning,
)
from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import (
FIFOOffPolicyReplayBuffer,
)
from pearl.utils.instantiations.environments.gym_environment import GymEnvironment
from cooperative_cuisine.reinforcement_learning import EnvGymWrapper
custom = True
if custom:
env = GymEnvironment(EnvGymWrapper())
else:
env = GymEnvironment("LunarLander-v2", render_mode="rgb_array")
num_actions = env.action_space.n
agent = PearlAgent(
policy_learner=DeepQLearning(
state_dim=env.observation_space.shape[0],
action_space=env.action_space,
hidden_dims=[64, 64],
training_rounds=20,
action_representation_module=OneHotActionTensorRepresentationModule(
max_number_actions=num_actions
),
),
replay_buffer=FIFOOffPolicyReplayBuffer(10_000),
)
for i in range(40):
print(i)
observation, action_space = env.reset()
agent.reset(observation, action_space)
done = False
while not done:
action = agent.act(exploit=False)
action_result = env.step(action)
agent.observe(action_result)
agent.learn()
done = action_result.done
if custom:
env = GymEnvironment(EnvGymWrapper())
else:
env = GymEnvironment("LunarLander-v2", render_mode="human")
for i in range(40):
print(i)
observation, action_space = env.reset()
agent.reset(observation, action_space)
done = False
while not done:
action = agent.act(exploit=False)
action_result = env.step(action)
agent.observe(action_result)
agent.learn()
done = action_result.done
if custom:
img = env.env.render()
cv2.imshow("image", img[:, :, ::-1])
cv2.waitKey(1)