Skip to content
Snippets Groups Projects
tryout_pearl.py 2.01 KiB
import cv2
from pearl.action_representation_modules.one_hot_action_representation_module import (
    OneHotActionTensorRepresentationModule,
)
from pearl.pearl_agent import PearlAgent
from pearl.policy_learners.sequential_decision_making.deep_q_learning import (
    DeepQLearning,
)
from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import (
    FIFOOffPolicyReplayBuffer,
)
from pearl.utils.instantiations.environments.gym_environment import GymEnvironment

from cooperative_cuisine.reinforcement_learning import EnvGymWrapper

custom = True
if custom:
    env = GymEnvironment(EnvGymWrapper())
else:
    env = GymEnvironment("LunarLander-v2", render_mode="rgb_array")

num_actions = env.action_space.n
agent = PearlAgent(
    policy_learner=DeepQLearning(
        state_dim=env.observation_space.shape[0],
        action_space=env.action_space,
        hidden_dims=[64, 64],
        training_rounds=20,
        action_representation_module=OneHotActionTensorRepresentationModule(
            max_number_actions=num_actions
        ),
    ),
    replay_buffer=FIFOOffPolicyReplayBuffer(10_000),
)

for i in range(40):
    print(i)
    observation, action_space = env.reset()
    agent.reset(observation, action_space)
    done = False
    while not done:
        action = agent.act(exploit=False)
        action_result = env.step(action)
        agent.observe(action_result)
        agent.learn()
        done = action_result.done

if custom:
    env = GymEnvironment(EnvGymWrapper())
else:
    env = GymEnvironment("LunarLander-v2", render_mode="human")

for i in range(40):
    print(i)
    observation, action_space = env.reset()
    agent.reset(observation, action_space)
    done = False
    while not done:
        action = agent.act(exploit=False)
        action_result = env.step(action)
        agent.observe(action_result)
        agent.learn()
        done = action_result.done

        if custom:
            img = env.env.render()
            cv2.imshow("image", img[:, :, ::-1])
            cv2.waitKey(1)