Skip to content
Snippets Groups Projects
Commit d5de0850 authored by fheinrich's avatar fheinrich
Browse files

RL agent has learned something

parent 343242fd
No related branches found
No related tags found
1 merge request!52Resolve "gym env"
Pipeline #45792 passed
...@@ -6,7 +6,7 @@ plates: ...@@ -6,7 +6,7 @@ plates:
# range of seconds until the dirty plate arrives. # range of seconds until the dirty plate arrives.
game: game:
time_limit_seconds: 300 time_limit_seconds: 400
meals: meals:
all: true all: true
...@@ -67,18 +67,8 @@ orders: ...@@ -67,18 +67,8 @@ orders:
b: 20 b: 20
sample_on_serving: false sample_on_serving: false
# Sample the delay for the next order only after a meal was served. # Sample the delay for the next order only after a meal was served.
score_calc_gen_func: !!python/name:overcooked_simulator.order.simple_score_calc_gen_func '' serving_not_ordered_meals: true
score_calc_gen_kwargs: # can meals that are not ordered be served / dropped on the serving window
# the kwargs for the score_calc_gen_func
other: 0
scores: [ ]
expired_penalty_func: !!python/name:overcooked_simulator.order.simple_expired_penalty ''
expired_penalty_kwargs:
default: 0
serving_not_ordered_meals: !!python/name:overcooked_simulator.order.serving_not_ordered_meals_with_five_score ''
# a func that calcs a store for not ordered but served meals. Input: meal
penalty_for_trash: !!python/name:overcooked_simulator.order.penalty_for_each_item ''
# a func that calcs the penalty for items that the player puts into the trashcan.
player_config: player_config:
radius: 0.4 radius: 0.4
...@@ -103,20 +93,15 @@ extra_setup_functions: ...@@ -103,20 +93,15 @@ extra_setup_functions:
hooks: [ completed_order ] hooks: [ completed_order ]
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs: callback_class_kwargs:
static_score: 20 static_score: 100
score_on_specific_kwarg: meal_name
score_map: serve_not_ordered_meals:
Burger: 15
OnionSoup: 10
Salad: 5
TomatoSoup: 10
not_ordered_meals:
func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class ''
kwargs: kwargs:
hooks: [ serve_not_ordered_meal ] hooks: [ serve_not_ordered_meal ]
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs: callback_class_kwargs:
static_score: 2 static_score: 100
trashcan_usages: trashcan_usages:
func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class ''
kwargs: kwargs:
...@@ -124,13 +109,21 @@ extra_setup_functions: ...@@ -124,13 +109,21 @@ extra_setup_functions:
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs: callback_class_kwargs:
static_score: -5 static_score: -5
expired_orders: item_cut:
func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class ''
kwargs:
hooks: [ cutting_board_100 ]
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs:
static_score: 10
stepped:
func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class ''
kwargs: kwargs:
hooks: [ order_expired ] hooks: [ post_step ]
callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks ''
callback_class_kwargs: callback_class_kwargs:
static_score: -10 static_score: -1
# json_states: # json_states:
# func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' # func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks ''
# kwargs: # kwargs:
......
#X### #X##
T___# T__#
#___# U__P
U___P ##W#
##W##
...@@ -196,17 +196,8 @@ class EnvGymWrapper(Env): ...@@ -196,17 +196,8 @@ class EnvGymWrapper(Env):
observation = self.get_vector_state() observation = self.get_vector_state()
reward = -1 reward = self.env.score - self.prev_score
if ( self.prev_score = self.env.score
self.env.score > self.prev_score
and self.env.score != 0
):
self.prev_score = self.env.score
reward = 100
elif self.env.score < self.prev_score:
self.prev_score = 0
reward = -1
terminated = self.env.game_ended terminated = self.env.game_ended
truncated = self.env.game_ended truncated = self.env.game_ended
info = {} info = {}
...@@ -283,7 +274,7 @@ def main(): ...@@ -283,7 +274,7 @@ def main():
# # save_code=True, # optional # # save_code=True, # optional
# ) # )
env = make_vec_env(EnvGymWrapper, n_envs=64) env = make_vec_env(EnvGymWrapper, n_envs=8)
# env = EnvGymWrapper() # env = EnvGymWrapper()
model_classes = [A2C, DQN, PPO] model_classes = [A2C, DQN, PPO]
...@@ -317,7 +308,7 @@ def main(): ...@@ -317,7 +308,7 @@ def main():
check_env(env) check_env(env)
obs, info = env.reset() obs, info = env.reset()
while True: while True:
time.sleep(1 / 30) time.sleep(1 / 10)
action, _states = model.predict(obs, deterministic=False) action, _states = model.predict(obs, deterministic=False)
obs, reward, terminated, truncated, info = env.step(int(action)) obs, reward, terminated, truncated, info = env.step(int(action))
env.render() env.render()
......
...@@ -48,7 +48,7 @@ from overcooked_simulator.hooks import ( ...@@ -48,7 +48,7 @@ from overcooked_simulator.hooks import (
ACTION_ON_NOT_REACHABLE_COUNTER, ACTION_ON_NOT_REACHABLE_COUNTER,
ACTION_PUT, ACTION_PUT,
ACTION_INTERACT_START, ACTION_INTERACT_START,
ITEM_INFO_CONFIG, ITEM_INFO_CONFIG, POST_STEP,
) )
from overcooked_simulator.order import ( from overcooked_simulator.order import (
OrderManager, OrderManager,
...@@ -756,7 +756,7 @@ class Environment: ...@@ -756,7 +756,7 @@ class Environment:
self.order_manager.progress(passed_time=passed_time, now=self.env_time) self.order_manager.progress(passed_time=passed_time, now=self.env_time)
for effect_manager in self.effect_manager.values(): for effect_manager in self.effect_manager.values():
effect_manager.progress(passed_time=passed_time, now=self.env_time) effect_manager.progress(passed_time=passed_time, now=self.env_time)
# self.hook(POST_STEP, passed_time=passed_time) self.hook(POST_STEP, passed_time=passed_time)
def get_state(self): def get_state(self):
"""Get the current state of the game environment. The state here is accessible by the current python objects. """Get the current state of the game environment. The state here is accessible by the current python objects.
......
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment