diff --git a/cooperative_cuisine/reinforcement_learning/gym_env.py b/cooperative_cuisine/reinforcement_learning/gym_env.py index 4bb079e3b45372e8d208df59b49da035b419a67a..db6357bac6a62b1c73c731a0c25dcd55f5dac13b 100644 --- a/cooperative_cuisine/reinforcement_learning/gym_env.py +++ b/cooperative_cuisine/reinforcement_learning/gym_env.py @@ -129,12 +129,12 @@ class EnvGymWrapper(Env): def __init__(self): super().__init__() - self.gridsize = 30 + self.gridsize = 40 - self.randomize_counter_placement = True + self.randomize_counter_placement = False self.use_rgb_obs = False # if False uses simple vectorized state self.full_vector_state = True - self.onehot_state = False + self.onehot_state = True self.env: Environment = Environment( env_config=environment_config, diff --git a/cooperative_cuisine/reinforcement_learning/pearl_test.py b/cooperative_cuisine/reinforcement_learning/pearl_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c4cded0bfed470c2bcf60ea7e306ba21de4053ab --- /dev/null +++ b/cooperative_cuisine/reinforcement_learning/pearl_test.py @@ -0,0 +1,68 @@ +import cv2 +from pearl.action_representation_modules.one_hot_action_representation_module import ( + OneHotActionTensorRepresentationModule, +) +from pearl.pearl_agent import PearlAgent +from pearl.policy_learners.sequential_decision_making.deep_q_learning import ( + DeepQLearning, +) +from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import ( + FIFOOffPolicyReplayBuffer, +) +from pearl.utils.instantiations.environments.gym_environment import GymEnvironment + +from cooperative_cuisine.reinforcement_learning.gym_env import EnvGymWrapper + +custom = True +if custom: + env = GymEnvironment(EnvGymWrapper()) +else: + env = GymEnvironment("LunarLander-v2", render_mode="rgb_array") + +num_actions = env.action_space.n +agent = PearlAgent( + policy_learner=DeepQLearning( + state_dim=env.observation_space.shape[0], + action_space=env.action_space, + hidden_dims=[64, 64], + training_rounds=20, + action_representation_module=OneHotActionTensorRepresentationModule( + max_number_actions=num_actions + ), + ), + replay_buffer=FIFOOffPolicyReplayBuffer(10_000), +) + +for i in range(40): + print(i) + observation, action_space = env.reset() + agent.reset(observation, action_space) + done = False + while not done: + action = agent.act(exploit=False) + action_result = env.step(action) + agent.observe(action_result) + agent.learn() + done = action_result.done + +if custom: + env = GymEnvironment(EnvGymWrapper()) +else: + env = GymEnvironment("LunarLander-v2", render_mode="human") + +for i in range(40): + print(i) + observation, action_space = env.reset() + agent.reset(observation, action_space) + done = False + while not done: + action = agent.act(exploit=False) + action_result = env.step(action) + agent.observe(action_result) + agent.learn() + done = action_result.done + + if custom: + img = env.env.render() + cv2.imshow("image", img[:, :, ::-1]) + cv2.waitKey(1)