diff --git a/cooperative_cuisine/configs/layouts/overcooked-ai/1-cramped-room.layout b/cooperative_cuisine/configs/layouts/overcooked-ai/1-cramped-room.layout deleted file mode 100644 index 49655d1ce3c4217a1fd2385cd5b33918fd395137..0000000000000000000000000000000000000000 --- a/cooperative_cuisine/configs/layouts/overcooked-ai/1-cramped-room.layout +++ /dev/null @@ -1,4 +0,0 @@ -##U## -NA_AN -#___# -#P#$# \ No newline at end of file diff --git a/cooperative_cuisine/configs/layouts/overcooked-ai/2-asymmetric-advantages.layout b/cooperative_cuisine/configs/layouts/overcooked-ai/2-asymmetric-advantages.layout deleted file mode 100644 index 3b87f3e58caaa5a5c59f732889986704ce2dce8f..0000000000000000000000000000000000000000 --- a/cooperative_cuisine/configs/layouts/overcooked-ai/2-asymmetric-advantages.layout +++ /dev/null @@ -1,5 +0,0 @@ -######### -N_#$#N#_$ -#_A_U_A_# -#___U___# -###P#P### \ No newline at end of file diff --git a/cooperative_cuisine/configs/layouts/overcooked-ai/3-coordination-ring.layout b/cooperative_cuisine/configs/layouts/overcooked-ai/3-coordination-ring.layout deleted file mode 100644 index f83c2adf9045825e86de62ae3626e3325cbb0a7f..0000000000000000000000000000000000000000 --- a/cooperative_cuisine/configs/layouts/overcooked-ai/3-coordination-ring.layout +++ /dev/null @@ -1,5 +0,0 @@ -###U# -#__AU -P_#_# -NA__# -#N$## diff --git a/cooperative_cuisine/configs/layouts/overcooked-ai/4-forced-coordination.layout b/cooperative_cuisine/configs/layouts/overcooked-ai/4-forced-coordination.layout deleted file mode 100644 index bc835ae37db299eb0cf7c61f7c5de85478becb99..0000000000000000000000000000000000000000 --- a/cooperative_cuisine/configs/layouts/overcooked-ai/4-forced-coordination.layout +++ /dev/null @@ -1,5 +0,0 @@ -###U# -N_#AU -N_#_# -PA#_# -###$# diff --git a/cooperative_cuisine/configs/layouts/overcooked-ai/5-counter-circuit.layout b/cooperative_cuisine/configs/layouts/overcooked-ai/5-counter-circuit.layout deleted file mode 100644 index 6c5839500ace8c8314b87b17dde31c257c994243..0000000000000000000000000000000000000000 --- a/cooperative_cuisine/configs/layouts/overcooked-ai/5-counter-circuit.layout +++ /dev/null @@ -1,5 +0,0 @@ -###UU### -#A_____# -P_####_$ -#_____A# -###NN### \ No newline at end of file diff --git a/cooperative_cuisine/reinforcement_learning/rl.layout b/cooperative_cuisine/configs/layouts/rl/rl.layout similarity index 100% rename from cooperative_cuisine/reinforcement_learning/rl.layout rename to cooperative_cuisine/configs/layouts/rl/rl.layout diff --git a/cooperative_cuisine/configs/layouts/rl_small.layout b/cooperative_cuisine/configs/layouts/rl/rl_small.layout similarity index 100% rename from cooperative_cuisine/configs/layouts/rl_small.layout rename to cooperative_cuisine/configs/layouts/rl/rl_small.layout diff --git a/cooperative_cuisine/orders.py b/cooperative_cuisine/orders.py index 8ee060e8a4488028f70bca0e6a15468c7c5c95a2..968a042a739e234f574ddd764e4d5cad1a88fdd9 100644 --- a/cooperative_cuisine/orders.py +++ b/cooperative_cuisine/orders.py @@ -651,6 +651,7 @@ class DeterministicOrderGeneration(OrderGeneration): self.current_queue[0].start -= diff_to_next self.next_order_time = self.current_queue[0].start orders.extend(self.get_orders(passed_time, now, [], [])) + log.info(f"Create order for meal {orders}.") return orders def parse_timed_orders(self) -> list[ParsedTimedOrder]: diff --git a/cooperative_cuisine/reinforcement_learning/__init__.py b/cooperative_cuisine/reinforcement_learning/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..79d7ca298abb7b947888d0b6b54f72c3e4f922e6 100644 --- a/cooperative_cuisine/reinforcement_learning/__init__.py +++ b/cooperative_cuisine/reinforcement_learning/__init__.py @@ -0,0 +1,24 @@ + +""" +## Reinforcement Learning Module Overview + +The reinforcement learning module consists of several key functions designed to: + +- **Train the agent** +- **Test the agent** +- **Utilize the environment** + +### Configurations + +All hyperparameters related to the reinforcement learning agent and the environment are configurable via specific configuration files. These configurations are managed with **Hydra**, allowing for easy selection of specific config items or entire config files directly from the command line. + +### Layouts + +Several layouts are predefined in the `cooperative_cuisine/configs/layouts` directory. The layout path can be selected within the corresponding config file. + +Additionally, **Overcooked-AI** layouts can be transformed into the cooperative cuisine format using the `convert_overcooked_ai_layout.py` script. To use this script: + +1. Specify the path of the Overcooked-AI layout file as a command-line argument. +2. The script will generate the corresponding layout file and save it in the `configs/layouts/overcooked-ai` directory. + +""" \ No newline at end of file diff --git a/cooperative_cuisine/reinforcement_learning/config/additional_configs/additional_config_base.yaml b/cooperative_cuisine/reinforcement_learning/config/additional_configs/additional_config_base.yaml index 56ca8938c0d1c4e40cc3f3bb0a7be96978ad3531..4ba0f49b5cb740b6bb6f1ba3d36fa1518950ef6b 100644 --- a/cooperative_cuisine/reinforcement_learning/config/additional_configs/additional_config_base.yaml +++ b/cooperative_cuisine/reinforcement_learning/config/additional_configs/additional_config_base.yaml @@ -1,7 +1,6 @@ -order_generator: "random_orders.yaml" # Here the filename of the converter should be given. The converter class needs to be called StateConverter and implement the abstract StateToObservationConverter class state_converter: - _target_: "cooperative_cuisine.reinforcement_learning.obs_converter.base_converter_onehot.BaseStateConverterOnehot" + _target_: "cooperative_cuisine.reinforcement_learning.obs_converter.base_converter.BaseStateConverter" log_path: "logs/reinforcement_learning" checkpoint_path: "rl_agent_checkpoints" render_mode: "rgb_array" diff --git a/cooperative_cuisine/reinforcement_learning/config/environment/environment_config_rl.yaml b/cooperative_cuisine/reinforcement_learning/config/environment/environment_config_rl.yaml index 5151d980a30826f68925fc320b03e6275a5077d1..4a089e918efc603910fdd85521a649e4e65ba07e 100644 --- a/cooperative_cuisine/reinforcement_learning/config/environment/environment_config_rl.yaml +++ b/cooperative_cuisine/reinforcement_learning/config/environment/environment_config_rl.yaml @@ -11,7 +11,7 @@ game: undo_dispenser_pickup: true validate_recipes: false - +layout_name: configs/layouts/rl/rl_small.layout layout_chars: _: Free @@ -51,6 +51,9 @@ layout_chars: orders: + order_generator: + _target_: "cooperative_cuisine.orders.RandomOrderGeneration" + _partial_: true meals: all: false # if all: false -> only orders for these meals are generated @@ -97,38 +100,59 @@ effect_manager: { } # spreading_duration: [ 5, 10 ] # fire_burns_ingredients_and_meals: true - hook_callbacks: # # --------------- Scoring --------------- orders: hooks: [ completed_order ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: static_score: 0.95 serve_not_ordered_meals: hooks: [ serve_not_ordered_meal ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: static_score: 0.95 trashcan_usages: hooks: [ trashcan_usage ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: static_score: -0.2 item_cut: hooks: [ cutting_board_100 ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: static_score: 0.1 stepped: hooks: [ post_step ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: static_score: -0.01 combine: hooks: [ drop_off_on_cooking_equipment ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: static_score: 0.01 start_interact: hooks: [ player_start_interaction ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: static_score: 0.01 + # json_states: # hooks: [ json_state ] # record_class: !!python/name:cooperative_cuisine.recording.LogRecorder '' diff --git a/cooperative_cuisine/reinforcement_learning/config/environment/environment_config_rl_deterministic_order_generation.yaml b/cooperative_cuisine/reinforcement_learning/config/environment/environment_config_rl_deterministic_order_generation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f093db50a610e71580442aa3853906dc9c8748c --- /dev/null +++ b/cooperative_cuisine/reinforcement_learning/config/environment/environment_config_rl_deterministic_order_generation.yaml @@ -0,0 +1,166 @@ + +plates: + clean_plates: 2 + dirty_plates: 0 + plate_delay: [ 2, 4 ] + return_dirty: False + # range of seconds until the dirty plate arrives. + +game: + time_limit_seconds: 300 + undo_dispenser_pickup: true + validate_recipes: false + +layout_name: configs/layouts/rl/rl_small.layout + +layout_chars: + _: Free + hash: Counter # # + A: Agent + pipe: Extinguisher + P: PlateDispenser + C: CuttingBoard + X: Trashcan + $: ServingWindow + S: Sink + +: SinkAddon + at: Plate # @ just a clean plate on a counter + U: Pot # with Stove + Q: Pan # with Stove + O: Peel # with Oven + F: Basket # with DeepFryer + T: Tomato + N: Onion # oNioN + L: Lettuce + K: Potato # Kartoffel + I: Fish # fIIIsh + D: Dough + E: Cheese # chEEEse + G: Sausage # sausaGe + B: Bun + M: Meat + question: Counter # ? mushroom + ↓: Counter + ^: Counter + right: Counter + left: Counter + wave: Free # ~ Water + minus: Free # - Ice + dquote: Counter # " wall/truck + p: Counter # second plate return ?? + +orders: + order_generator: + _target_: "cooperative_cuisine.orders.DeterministicOrderGeneration" + _partial_: true + meals: + all: false + # if all: false -> only orders for these meals are generated + # TODO: what if this list is empty? + list: + - TomatoSoup + - OnionSoup + #- Salad + # - FriedFish + # the class to that receives the kwargs. Should be a child class of OrderGeneration in orders.py + order_gen_kwargs: + # structure: [meal_name, start, duration] (start and duration as seconds or timedeltas https://github.com/wroberts/pytimeparse) + timed_orders: + - [ TomatoSoup, 0:00, 0:10 ] + - [ OnionSoup, 0:00, 0:10 ] + - [ TomatoSoup, 0:10, 0:10 ] + - [ TomatoSoup, 0:15, 0:06 ] + never_no_order: False + never_no_order_update_all_remaining: False + serving_not_ordered_meals: null + +player_config: + radius: 0.4 + speed_units_per_seconds: 1 + interaction_range: 1.6 + restricted_view: False + view_angle: 95 + +effect_manager: { } +# FireManager: +# class: !!python/name:cooperative_cuisine.effects.FireEffectManager '' +# kwargs: +# spreading_duration: [ 5, 10 ] +# fire_burns_ingredients_and_meals: true + +hook_callbacks: + # # --------------- Scoring --------------- + orders: + hooks: [ completed_order ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: 0.95 + + serve_not_ordered_meals: + hooks: [ serve_not_ordered_meal ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: 0.95 + trashcan_usages: + hooks: [ trashcan_usage ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: -0.2 + item_cut: + hooks: [ cutting_board_100 ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: 0.1 + stepped: + hooks: [ post_step ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: -0.01 + combine: + hooks: [ drop_off_on_cooking_equipment ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: 0.01 + start_interact: + hooks: [ player_start_interaction ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: 0.01 + +# json_states: +# hooks: [ json_state ] +# record_class: !!python/name:cooperative_cuisine.recording.LogRecorder '' +# record_class_kwargs: +# record_path: USER_LOG_DIR/ENV_NAME/json_states.jsonl +# actions: +# hooks: [ pre_perform_action ] +# record_class: !!python/name:cooperative_cuisine.recording.LogRecorder '' +# record_class_kwargs: +# record_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl +# random_env_events: +# hooks: [ order_duration_sample, plate_out_of_kitchen_time ] +# record_class: !!python/name:cooperative_cuisine.recording.LogRecorder '' +# record_class_kwargs: +# record_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl +# add_hook_ref: true +# env_configs: +# hooks: [ env_initialized, item_info_config ] +# record_class: !!python/name:cooperative_cuisine.recording.LogRecorder '' +# record_class_kwargs: +# record_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl +# add_hook_ref: true + diff --git a/cooperative_cuisine/reinforcement_learning/config/environment/environment_config_rl_small_rewards.yaml b/cooperative_cuisine/reinforcement_learning/config/environment/environment_config_rl_small_rewards.yaml index c7f0fbf30e7717294ef1a6d169918b5f241b5bbf..1c96c365dba53332a29b1113c5fb2b68362bbaee 100644 --- a/cooperative_cuisine/reinforcement_learning/config/environment/environment_config_rl_small_rewards.yaml +++ b/cooperative_cuisine/reinforcement_learning/config/environment/environment_config_rl_small_rewards.yaml @@ -11,6 +11,7 @@ game: undo_dispenser_pickup: true validate_recipes: false +layout_name: configs/layouts/rl/rl_small.layout layout_chars: @@ -51,6 +52,9 @@ layout_chars: orders: + order_generator: + _target_: "cooperative_cuisine.orders.RandomOrderGeneration" + _partial_: true meals: all: true # if all: false -> only orders for these meals are generated @@ -97,38 +101,58 @@ effect_manager: { } # spreading_duration: [ 5, 10 ] # fire_burns_ingredients_and_meals: true - hook_callbacks: # # --------------- Scoring --------------- orders: hooks: [ completed_order ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: static_score: 0.1 serve_not_ordered_meals: hooks: [ serve_not_ordered_meal ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: static_score: 0.1 trashcan_usages: hooks: [ trashcan_usage ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: static_score: -0.2 item_cut: hooks: [ cutting_board_100 ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: - static_score: 0.0 + static_score: 0 stepped: hooks: [ post_step ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: - static_score: -0.0 + static_score: 0 combine: hooks: [ drop_off_on_cooking_equipment ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: - static_score: 0.0 + static_score: 0 start_interact: hooks: [ player_start_interaction ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true callback_class_kwargs: - static_score: 0.0 + static_score: 0 # json_states: # hooks: [ json_state ] # record_class: !!python/name:cooperative_cuisine.recording.LogRecorder '' diff --git a/cooperative_cuisine/reinforcement_learning/config/environment/order_config.yaml b/cooperative_cuisine/reinforcement_learning/config/environment/order_config.yaml deleted file mode 100644 index 9cc3de7dbd523c5f814d87d86cb6ac51e807b999..0000000000000000000000000000000000000000 --- a/cooperative_cuisine/reinforcement_learning/config/environment/order_config.yaml +++ /dev/null @@ -1,33 +0,0 @@ -orders: - meals: - all: true - # if all: false -> only orders for these meals are generated - # TODO: what if this list is empty? - list: - - TomatoSoup - - OnionSoup - - Salad - #order_gen_class: !!python/name:cooperative_cuisine.orders.RandomOrderGeneration '' - # the class to that receives the kwargs. Should be a child class of OrderGeneration in orders.py - order_gen_kwargs: - order_duration_random_func: - # how long should the orders be alive - # 'random' library call with getattr, kwargs are passed to the function - func: uniform - kwargs: - a: 40 - b: 60 - max_orders: 6 - # maximum number of active orders at the same time - num_start_meals: 2 - # number of orders generated at the start of the environment - sample_on_dur_random_func: - # 'random' library call with getattr, kwargs are passed to the function - func: uniform - kwargs: - a: 10 - b: 20 - sample_on_serving: false - # Sample the delay for the next order only after a meal was served. - serving_not_ordered_meals: true - # can meals that are not ordered be served / dropped on the serving window \ No newline at end of file diff --git a/cooperative_cuisine/reinforcement_learning/config/environment/overcooked-ai_environment_config.yaml b/cooperative_cuisine/reinforcement_learning/config/environment/overcooked-ai_environment_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7c9d41c70a7333a9816325eb8f822ccbc0ba5e70 --- /dev/null +++ b/cooperative_cuisine/reinforcement_learning/config/environment/overcooked-ai_environment_config.yaml @@ -0,0 +1,179 @@ + +plates: + clean_plates: 1 + dirty_plates: 0 + plate_delay: [ 0, 0 ] + return_dirty: False + # range of seconds until the dirty plate arrives. + +game: + time_limit_seconds: 300 + undo_dispenser_pickup: true + validate_recipes: false + +layout_name: configs/layouts/rl/rl_small.layout + + +layout_chars: + _: Free + hash: Counter # # + A: Agent + pipe: Extinguisher + P: PlateDispenser + C: CuttingBoard + X: Trashcan + $: ServingWindow + S: Sink + +: SinkAddon + at: Plate # @ just a clean plate on a counter + U: Pot # with Stove + Q: Pan # with Stove + O: Peel # with Oven + F: Basket # with DeepFryer + T: Tomato + N: Onion # oNioN + L: Lettuce + K: Potato # Kartoffel + I: Fish # fIIIsh + D: Dough + E: Cheese # chEEEse + G: Sausage # sausaGe + B: Bun + M: Meat + question: Counter # ? mushroom + ↓: Counter + ^: Counter + right: Counter + left: Counter + wave: Free # ~ Water + minus: Free # - Ice + dquote: Counter # " wall/truck + p: Counter # second plate return ?? + + +orders: + order_generator: + _target_: "cooperative_cuisine.orders.RandomOrderGeneration" + _partial_: true + meals: + all: false + # if all: false -> only orders for these meals are generated + # TODO: what if this list is empty? + list: + - TomatoSoup + - OnionSoup + - Salad + # the class to that receives the kwargs. Should be a child class of OrderGeneration in orders.py + order_gen_kwargs: + order_duration_random_func: + # how long should the orders be alive + # 'random' library call with getattr, kwargs are passed to the function + func: uniform + kwargs: + a: 40 + b: 60 + max_orders: 6 + # maximum number of active orders at the same time + num_start_meals: 2 + # number of orders generated at the start of the environment + sample_on_dur_random_func: + # 'random' library call with getattr, kwargs are passed to the function + func: uniform + kwargs: + a: 10 + b: 20 + sample_on_serving: false + # Sample the delay for the next order only after a meal was served. + serving_not_ordered_meals: true + # can meals that are not ordered be served / dropped on the serving window + +player_config: + radius: 0.1 + speed_units_per_seconds: 1 + interaction_range: 1 + restricted_view: False + view_angle: 60 + +effect_manager: { } +# FireManager: +# class: !!python/name:cooperative_cuisine.effects.FireEffectManager '' +# kwargs: +# spreading_duration: [ 5, 10 ] +# fire_burns_ingredients_and_meals: true + + +hook_callbacks: + # # --------------- Scoring --------------- + orders: + hooks: [ completed_order ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: 5 + + serve_not_ordered_meals: + hooks: [ serve_not_ordered_meal ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: 3 + trashcan_usages: + hooks: [ trashcan_usage ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: 0 + item_cut: + hooks: [ cutting_board_100 ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: 0 + stepped: + hooks: [ post_step ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: 0 + combine: + hooks: [ drop_off_on_cooking_equipment ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: 0 + start_interact: + hooks: [ player_start_interaction ] + callback_class: + _target_: "cooperative_cuisine.scores.ScoreViaHooks" + _partial_: true + callback_class_kwargs: + static_score: 0 +# json_states: +# hooks: [ json_state ] +# record_class: !!python/name:cooperative_cuisine.recording.LogRecorder '' +# record_class_kwargs: +# record_path: USER_LOG_DIR/ENV_NAME/json_states.jsonl +# actions: +# hooks: [ pre_perform_action ] +# record_class: !!python/name:cooperative_cuisine.recording.LogRecorder '' +# record_class_kwargs: +# record_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl +# random_env_events: +# hooks: [ order_duration_sample, plate_out_of_kitchen_time ] +# record_class: !!python/name:cooperative_cuisine.recording.LogRecorder '' +# record_class_kwargs: +# record_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl +# add_hook_ref: true +# env_configs: +# hooks: [ env_initialized, item_info_config ] +# record_class: !!python/name:cooperative_cuisine.recording.LogRecorder '' +# record_class_kwargs: +# record_path: USER_LOG_DIR/ENV_NAME/LOG_RECORD_NAME.jsonl +# add_hook_ref: true + diff --git a/cooperative_cuisine/reinforcement_learning/config/item_info/item_info_overcooked-ai.yaml b/cooperative_cuisine/reinforcement_learning/config/item_info/item_info_overcooked-ai.yaml new file mode 100644 index 0000000000000000000000000000000000000000..62d1fab6cdf3dd40c297d1b75d9bb1d764117cf6 --- /dev/null +++ b/cooperative_cuisine/reinforcement_learning/config/item_info/item_info_overcooked-ai.yaml @@ -0,0 +1,219 @@ +CuttingBoard: + type: Equipment + +Sink: + type: Equipment + +Stove: + type: Equipment + +DeepFryer: + type: Equipment + +Oven: + type: Equipment + +Pot: + type: Equipment + equipment: Stove + +Pan: + type: Equipment + equipment: Stove + +Basket: + type: Equipment + equipment: DeepFryer + +Peel: + type: Equipment + equipment: Oven + +DirtyPlate: + type: Equipment + +Plate: + type: Equipment + needs: [ DirtyPlate ] + seconds: 2.0 + equipment: Sink + +# -------------------------------------------------------------------------------- + +Tomato: + type: Ingredient + +Lettuce: + type: Ingredient + +Onion: + type: Ingredient + +Meat: + type: Ingredient + +Bun: + type: Ingredient + +Potato: + type: Ingredient + +Fish: + type: Ingredient + +Dough: + type: Ingredient + +Cheese: + type: Ingredient + +Sausage: + type: Ingredient + +# Chopped things +ChoppedTomato: + type: Ingredient + needs: [ Tomato ] + seconds: 4.0 + equipment: CuttingBoard + +ChoppedLettuce: + type: Ingredient + needs: [ Lettuce ] + seconds: 3.0 + equipment: CuttingBoard + +ChoppedOnion: + type: Ingredient + needs: [ Onion ] + seconds: 4.0 + equipment: CuttingBoard + +RawPatty: + type: Ingredient + needs: [ Meat ] + seconds: 4.0 + equipment: CuttingBoard + +RawChips: + type: Ingredient + needs: [ Potato ] + seconds: 4.0 + equipment: CuttingBoard + +ChoppedFish: + type: Ingredient + needs: [ Fish ] + seconds: 4.0 + equipment: CuttingBoard + +PizzaBase: + type: Ingredient + needs: [ Dough ] + seconds: 4.0 + equipment: CuttingBoard + +GratedCheese: + type: Ingredient + needs: [ Cheese ] + seconds: 4.0 + equipment: CuttingBoard + +ChoppedSausage: + type: Ingredient + needs: [ Sausage ] + seconds: 4.0 + equipment: CuttingBoard + +CookedPatty: + type: Ingredient + seconds: 5.0 + needs: [ RawPatty ] + equipment: Pan + +# -------------------------------------------------------------------------------- + +Chips: + type: Meal + seconds: 5.0 + needs: [ RawChips ] + equipment: Basket + +FriedFish: + type: Meal + seconds: 5.0 + needs: [ ChoppedFish ] + equipment: Basket + +Burger: + type: Meal + needs: [ Bun, ChoppedLettuce, ChoppedTomato, CookedPatty ] + equipment: ~ + +Salad: + type: Meal + needs: [ ChoppedLettuce, ChoppedTomato ] + equipment: ~ + +TomatoSoup: + type: Meal + needs: [Tomato,Tomato, Tomato ] + seconds: 1 + equipment: Pot + +OnionSoup: + type: Meal + needs: [ Onion, Onion, Onion ] + seconds: 1 + equipment: Pot + +FishAndChips: + type: Meal + needs: [ FriedFish, Chips ] + equipment: ~ + +Pizza: + type: Meal + needs: [ PizzaBase, ChoppedTomato, GratedCheese, ChoppedSausage ] + seconds: 7.0 + equipment: Peel + +# -------------------------------------------------------------------------------- + +BurntCookedPatty: + type: Waste + seconds: 10.0 + needs: [ CookedPatty ] + equipment: Pan + +BurntChips: + type: Waste + seconds: 10.0 + needs: [ Chips ] + equipment: Basket + +BurntFriedFish: + type: Waste + seconds: 10.0 + needs: [ FriedFish ] + equipment: Basket + +BurntTomatoSoup: + type: Waste + needs: [ TomatoSoup ] + seconds: 20.0 + equipment: Pot + +BurntOnionSoup: + type: Waste + needs: [ OnionSoup ] + seconds: 20.0 + equipment: Pot + +BurntPizza: + type: Waste + needs: [ Pizza ] + seconds: 10.0 + equipment: Peel + + diff --git a/cooperative_cuisine/reinforcement_learning/config/model/PPO.yaml b/cooperative_cuisine/reinforcement_learning/config/model/PPO.yaml index 2e868217f39c270db820e8bb7070a70d17905a1f..ece8b516ae1682f108330bd16ed4d4f59ea36716 100644 --- a/cooperative_cuisine/reinforcement_learning/config/model/PPO.yaml +++ b/cooperative_cuisine/reinforcement_learning/config/model/PPO.yaml @@ -8,10 +8,10 @@ model_type_inference: _partial_: true _target_: stable_baselines3.PPO.load total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte -number_envs_parallel: 64 +number_envs_parallel: 16 learning_rate: 0.0003 n_steps: 2048 -batch_size: 64 +batch_size: 16 n_epochs: 10 gamma: 0.99 gae_lambda: 0.95 diff --git a/cooperative_cuisine/reinforcement_learning/config/random_orders.yaml b/cooperative_cuisine/reinforcement_learning/config/random_orders.yaml deleted file mode 100644 index e8a93ff75c0fe47a9909028c8905a2c8511e2fce..0000000000000000000000000000000000000000 --- a/cooperative_cuisine/reinforcement_learning/config/random_orders.yaml +++ /dev/null @@ -1,2 +0,0 @@ - order_gen_class: !!python/name:cooperative_cuisine.orders.RandomOrderGeneration '' - callback_class: !!python/name:cooperative_cuisine.scores.ScoreViaHooks '' \ No newline at end of file diff --git a/cooperative_cuisine/reinforcement_learning/config/rl_config.yaml b/cooperative_cuisine/reinforcement_learning/config/rl_config.yaml index 4fd36579d767cad57a3807caefa21f7e70c9866a..af22536dd8f119d89f0ab4c187b6bed30d191769 100644 --- a/cooperative_cuisine/reinforcement_learning/config/rl_config.yaml +++ b/cooperative_cuisine/reinforcement_learning/config/rl_config.yaml @@ -1,5 +1,5 @@ defaults: - - environment: environment_config_rl - - item_info: item_info_rl + - environment: overcooked-ai_environment_config + - item_info: item_info_overcooked-ai - model: PPO - additional_configs: additional_config_base \ No newline at end of file diff --git a/cooperative_cuisine/reinforcement_learning/convert_overcooked_ai_layouts.py b/cooperative_cuisine/reinforcement_learning/convert_overcooked_ai_layouts.py new file mode 100644 index 0000000000000000000000000000000000000000..595ecf954f4bd5c1a3f0c9d31003f19416325f4a --- /dev/null +++ b/cooperative_cuisine/reinforcement_learning/convert_overcooked_ai_layouts.py @@ -0,0 +1,51 @@ +import argparse +from pathlib import Path, PurePath +from cooperative_cuisine import ROOT_DIR + + +def convert_overcookd_ai_layouts(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--file", dest="inputfile", help="Input file path", required=True + ) + args = parser.parse_args() + filepath = PurePath(args.inputfile) + print(filepath) + convertion_dict = { + " ": "_", + "X": "#", + "O": "N", + "T": "T", + "P": "U", + "D": "P", + "S": "$", + "1": "A", + "2": "A" + } + + savepath = Path(ROOT_DIR) / "configs" / "layouts" / "overcooked-ai" / filepath.name + with open(args.inputfile, "r") as f: + layoutfile = f.read() + f.close() + layout = eval(layoutfile) + lines = layout["grid"].split("\n") + additional_info = [] + for key in layout: + if key != "grid": + additional_info.append( + '; {}: {}'.format(key, str(layout[key]).replace("'", "").replace("None", "null"))) + + with open(savepath, "w+") as f: + for line in lines: + line = line.lstrip() + for char in line: + f.write(convertion_dict[char]) + f.write("\n") + for info in additional_info: + f.write(info) + f.write("\n") + f.close() + + +if __name__ == "__main__": + convert_overcookd_ai_layouts() diff --git a/cooperative_cuisine/reinforcement_learning/gym_env.py b/cooperative_cuisine/reinforcement_learning/gym_env.py index 0a853999dc10e0a4b9f116bfc6e0ab00cc7e63e4..96842152ad3f1c02a04219af8984c5b716c59c71 100644 --- a/cooperative_cuisine/reinforcement_learning/gym_env.py +++ b/cooperative_cuisine/reinforcement_learning/gym_env.py @@ -36,6 +36,18 @@ class SimpleActionSpace(Enum): def get_env_action(player_id, simple_action, duration): + + """ + + Args: + player_id: id of the player + simple_action: an action in the form of a SimpleActionSpace + duration: for how long an action should be conducted + + Returns: a concrete action + + """ + match simple_action: case SimpleActionSpace.Up: return Action( @@ -82,9 +94,6 @@ def get_env_action(player_id, simple_action, duration): ) -layout_path: Path = ROOT_DIR / "reinforcement_learning" / "rl_small.layout" -with open(layout_path, "r") as file: - layout = file.read() with open(ROOT_DIR / "pygame_2d_vis" / "visualization.yaml", "r") as file: visualization_config = yaml.safe_load(file) @@ -94,6 +103,12 @@ visualizer.set_grid_size(40) def shuffle_counters(env): + + """ + Shuffles the counters of an environment + Args: + env: the environment object + """ sample_counter = [] other_counters = [] for counter in env.counters: @@ -110,11 +125,10 @@ def shuffle_counters(env): class StateToObservationConverter: - ''' - - + """ + Abstract definition of a class that gets and environment and outputs a state representation for rl + """ - ''' @abstractmethod def setup(self, env): ... @@ -132,25 +146,31 @@ class EnvGymWrapper(Env): metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10} def __init__(self, config): + """ + Initializes all necessary variables + + Args: + config:gets the rl and environment configuration from hydra + """ super().__init__() self.randomize_counter_placement = False - self.use_rgb_obs = True # if False uses simple vectorized state + self.use_rgb_obs = False # if False uses simple vectorized state self.full_vector_state = True config_env = OmegaConf.to_container(config.environment, resolve=True) config_item_info = OmegaConf.to_container(config.item_info, resolve=True) - order_generator = config.additional_configs.order_generator - custom_config_path = ROOT_DIR / "reinforcement_learning" / "config" / order_generator - with open(custom_config_path, "r") as file: - custom_classes = yaml.load(file, Loader=yaml.Loader) - for key, value in config_env['hook_callbacks'].items(): - value['callback_class'] = custom_classes['callback_class'] - config_env["orders"]["order_gen_class"] = custom_classes['order_gen_class'] + for val in config_env['hook_callbacks']: + config_env['hook_callbacks'][val]["callback_class"] = instantiate(config_env['hook_callbacks'][val]["callback_class"]) + config_env["orders"]["order_gen_class"] = instantiate(config_env["orders"]["order_generator"]) self.config_env = config_env self.config_item_info = config_item_info + layout_file = config_env["layout_name"] + layout_path: Path = ROOT_DIR / layout_file + with open(layout_path, "r") as file: + self.layout = file.read() self.env: Environment = Environment( env_config=deepcopy(config_env), - layout_config=layout, + layout_config=self.layout, item_info=deepcopy(config_item_info), as_files=False, yaml_already_loaded=True, @@ -197,6 +217,10 @@ class EnvGymWrapper(Env): self.prev_score = 0 def step(self, action): + """ + takes one step in the environment and returns the observation, reward, info whether terminated, truncated + and additional information + """ # this is simply a work-around to enable no action which is necessary for the play_gym.py if action == 8: observation = self.get_observation() @@ -231,10 +255,14 @@ class EnvGymWrapper(Env): return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): - del visualizer.surface_cache_dict[self.env.env_name] + """ + Resets the environment according to the configs + """ + if self.env.env_name in visualizer.surface_cache_dict: + del visualizer.surface_cache_dict[self.env.env_name] self.env: Environment = Environment( env_config=deepcopy(self.config_env), - layout_config=layout, + layout_config=self.layout, item_info=deepcopy(self.config_item_info), as_files=False, env_name=uuid.uuid4().hex, diff --git a/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter.py b/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter.py index 703a7d7a27a4cede1ee3b488fdd75a7ddc5808c3..bccd685674a5de259f9b871dc323cf60764f9d0d 100644 --- a/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter.py +++ b/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter.py @@ -8,13 +8,20 @@ from cooperative_cuisine.reinforcement_learning.gym_env import StateToObservatio class BaseStateConverter(StateToObservationConverter): + """ + Converts an environment state to an Encoding where each counter/item has its unique value + """ + def __init__(self): self.onehot = False + self.grid_height: int | None = None + self.grid_width: int | None = None self.counter_list = [ "Empty", "Counter", "PlateDispenser", "TomatoDispenser", + "OnionDispenser", "ServingWindow", "PlateReturn", "Trashcan", @@ -26,28 +33,48 @@ class BaseStateConverter(StateToObservationConverter): self.item_list = [ "None", "Pot", - "PotOne", - "PotTwo", - "PotThree", - "PotDone", + "PotOne_Tomato", + "PotTwo_Tomato", + "PotThree_Tomato", + "PotDone_Tomato", + "PotOne_Onion", + "PotTwo_Onion", + "PotThree_Onion", + "PotDone_Onion", "Tomato", + "Onion", "ChoppedTomato", "Plate", "PlateTomatoSoup", + "PlateOnionSoup", "PlateSalad", "Lettuce", "PlateChoppedTomato", "PlateChoppedLettuce", "ChoppedLettuce", + "ChoppedOnion", ] self.player = "0" def setup(self, env): + """ + Constructor setting basic variables as attributes. + + """ self.grid_width, self.grid_height = int(env.kitchen_width), int( env.kitchen_height) def convert_state_to_observation(self, env) -> np.ndarray: + """ + Convert the environment into an onehot encoding + Args: + env: The environment object used + + Returns: An encoding for the environment state that is not onehot + + """ + grid_base_array = np.zeros( ( self.grid_width, @@ -115,18 +142,31 @@ class BaseStateConverter(StateToObservationConverter): if item.name == "Pot": if len(item.content_list) > 0: if item.content_list[0].name == "TomatoSoup": - item_name = "PotDone" + item_name = "PotDone_Tomato" + if item.content_list[0].name == "OnionSoup": + item_name = "PotDone_Onion" elif len(item.content_list) == 1: - item_name = "PotOne" + if item.content_list[0].name == "Tomato": + item_name = "PotOne_Tomato" + if item.content_list[0].name == "Onion": + item_name = "PotOne_Onion" elif len(item.content_list) == 2: - item_name = "PotTwo" + if item.content_list[0].name == "Tomato": + item_name = "PotTwo_Tomato" + if item.content_list[0].name == "Onion": + item_name = "PotTwo_Onion" elif len(item.content_list) == 3: - item_name = "PotThree" + if item.content_list[0].name == "Tomato": + item_name = "PotThree_Tomato" + if item.content_list[0].name == "Onion": + item_name = "PotThree_Onion" if "Plate" in item.name: content_list = [i.name for i in item.content_list] match content_list: case ["TomatoSoup"]: item_name = "PlateTomatoSoup" + case ["OnionSoup"]: + item_name = "PlateOnionSoup" case ["ChoppedTomato"]: item_name = "PlateChoppedTomato" case ["ChoppedLettuce"]: diff --git a/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter_onehot.py b/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter_onehot.py index d3a7d877db6cc4786c948fc522b156dba19eb00b..1a5ce31a7ba25c110c8a8978c79e385538386392 100644 --- a/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter_onehot.py +++ b/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter_onehot.py @@ -8,15 +8,24 @@ from cooperative_cuisine.reinforcement_learning.gym_env import StateToObservatio class BaseStateConverterOnehot(StateToObservationConverter): + """ + Converts an environment state to an Onehot Encoding + """ + def __init__(self): + """ + Constructor setting basic variables as attributes. + + """ self.onehot = True - self.grid_height = None - self.grid_width = None + self.grid_height: int | None = None + self.grid_width: int | None = None self.counter_list = [ "Empty", "Counter", "PlateDispenser", "TomatoDispenser", + "OnionDispenser", "ServingWindow", "PlateReturn", "Trashcan", @@ -28,27 +37,51 @@ class BaseStateConverterOnehot(StateToObservationConverter): self.item_list = [ "None", "Pot", - "PotOne", - "PotTwo", - "PotThree", - "PotDone", + "PotOne_Tomato", + "PotTwo_Tomato", + "PotThree_Tomato", + "PotDone_Tomato", + "PotOne_Onion", + "PotTwo_Onion", + "PotThree_Onion", + "PotDone_Onion", "Tomato", + "Onion", "ChoppedTomato", "Plate", "PlateTomatoSoup", + "PlateOnionSoup", "PlateSalad", "Lettuce", "PlateChoppedTomato", "PlateChoppedLettuce", "ChoppedLettuce", + "ChoppedOnion", ] self.player = "0" def setup(self, env): + """ + Set the grid width and height according to the present environment + + Args: + env: The environment object used + """ + self.grid_width, self.grid_height = int(env.kitchen_width), int( env.kitchen_height) def convert_state_to_observation(self, env) -> np.ndarray: + + """ + Convert the environment into an onehot encoding + Args: + env: The environment object used + + Returns: An onehot encoding for the environment state + + """ + grid_base_array = np.zeros( ( self.grid_width, @@ -92,7 +125,7 @@ class BaseStateConverterOnehot(StateToObservationConverter): player_item_one_hot = self.vectorize_item( env.players[self.player].holding, self.item_list ) - + # simply concat all entities to one large vector final = np.concatenate( ( counters.flatten(), @@ -116,22 +149,36 @@ class BaseStateConverterOnehot(StateToObservationConverter): else: item_name = item.name + # different naming convention for the different pots to include the progress. New implementation should be found here if isinstance(item, CookingEquipment): if item.name == "Pot": if len(item.content_list) > 0: if item.content_list[0].name == "TomatoSoup": - item_name = "PotDone" + item_name = "PotDone_Tomato" + if item.content_list[0].name == "OnionSoup": + item_name = "PotDone_Onion" elif len(item.content_list) == 1: - item_name = "PotOne" + if item.content_list[0].name == "Tomato": + item_name = "PotOne_Tomato" + if item.content_list[0].name == "Onion": + item_name = "PotOne_Onion" elif len(item.content_list) == 2: - item_name = "PotTwo" + if item.content_list[0].name == "Tomato": + item_name = "PotTwo_Tomato" + if item.content_list[0].name == "Onion": + item_name = "PotTwo_Onion" elif len(item.content_list) == 3: - item_name = "PotThree" + if item.content_list[0].name == "Tomato": + item_name = "PotThree_Tomato" + if item.content_list[0].name == "Onion": + item_name = "PotThree_Onion" if "Plate" in item.name: content_list = [i.name for i in item.content_list] match content_list: case ["TomatoSoup"]: item_name = "PlateTomatoSoup" + case ["OnionSoup"]: + item_name = "PlateOnionSoup" case ["ChoppedTomato"]: item_name = "PlateChoppedTomato" case ["ChoppedLettuce"]: diff --git a/cooperative_cuisine/reinforcement_learning/overcooked_ai.md b/cooperative_cuisine/reinforcement_learning/overcooked_ai.md new file mode 100644 index 0000000000000000000000000000000000000000..7e2ae81eec713b7b5d424fa40664f08bf38ebf4d --- /dev/null +++ b/cooperative_cuisine/reinforcement_learning/overcooked_ai.md @@ -0,0 +1,13 @@ +# Overcooked-AI and Cooperative Cuisine + +## Use the overcooked-AI levels and configs in cooperative cuisine +All the layouts from overcooked-AI can be used within cooperative cuisine. Dedicated configs are defined and can be loaded via hydra. +The overcooked-ai_environment_config.yaml must be chosen as environment config. Under layout_name any layout from overcooked-AI can be defined. +Additionally, the item_config must be item_info_overcooked-ai.yaml. +With those chosen configs the layouts and rewards from overcooked-AI are used. + +## How is the connection between Overcooked-AI and cooperative cuisine defined? +Cooperative Cuisine is highly modular due to the usage of hydra as config manager. +Therefore, the parameters used for overcooked-AI are simply used in the dedicated config file. +The layout format is different, which is why a mapping is defined which converts the overcooked-AI layout into the cooperative cuisine layout. +The layout file has to be present in cooperative_cuisine/reinforcement_learning/layouts/overcooked_ai_layouts. diff --git a/cooperative_cuisine/reinforcement_learning/play_gym.py b/cooperative_cuisine/reinforcement_learning/play_gym.py index 96c7e73636cf46f961323885a6fbf117184fdffd..1d29bd4a7aaa75ebea585b53171ec8ef5b299a5f 100644 --- a/cooperative_cuisine/reinforcement_learning/play_gym.py +++ b/cooperative_cuisine/reinforcement_learning/play_gym.py @@ -7,6 +7,9 @@ from gym_env import EnvGymWrapper, SimpleActionSpace @hydra.main(version_base="1.3", config_path="config", config_name="rl_config") def main(cfg: DictConfig): + """ + Enables steering the agent in the environment used for rl. + """ env = EnvGymWrapper(cfg) env.render_mode = "rgb_array" play(env, keys_to_action={"a": 2, "d": 3, "w": 0, "s": 1, " ": 4, "k": 5}, noop=8) diff --git a/cooperative_cuisine/reinforcement_learning/rl_small.layout b/cooperative_cuisine/reinforcement_learning/rl_small.layout deleted file mode 100644 index 1743aba48c998e33800d0e4366a84766bf67fb24..0000000000000000000000000000000000000000 --- a/cooperative_cuisine/reinforcement_learning/rl_small.layout +++ /dev/null @@ -1,4 +0,0 @@ -##X# -T__L -U__P -#C$# diff --git a/cooperative_cuisine/reinforcement_learning/run_single_agent.py b/cooperative_cuisine/reinforcement_learning/run_single_agent.py index 0a4034329e326752d4a2825509cb489ccd8c1477..ac5e2d947eb6035d33ea0ce08fb19f673c40f460 100644 --- a/cooperative_cuisine/reinforcement_learning/run_single_agent.py +++ b/cooperative_cuisine/reinforcement_learning/run_single_agent.py @@ -1,4 +1,5 @@ import time +from pathlib import Path import cv2 from stable_baselines3 import DQN, A2C, PPO @@ -11,15 +12,17 @@ from hydra.utils import instantiate, call @hydra.main(version_base="1.3", config_path="config", config_name="rl_config") def main(cfg: DictConfig): + """ + loads the trained model and enables the user to see an example with the according rewards. + """ additional_config = OmegaConf.to_container(cfg.additional_configs, resolve=True) - model_save_path = additional_config["log_path"] + "/" + additional_config["checkpoint_path"] + "/" + \ - additional_config["project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)[ - "model_name"] + model_save_path = Path(additional_config["log_path"]) / Path(additional_config["checkpoint_path"]) / Path( + additional_config["project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"]) model_class = call(cfg.model.model_type_inference) model = model_class(model_save_path) env = EnvGymWrapper(cfg) - #check_env(env) + # check_env(env) obs, info = env.reset() print(obs) while True: diff --git a/cooperative_cuisine/reinforcement_learning/train_single_agent.py b/cooperative_cuisine/reinforcement_learning/train_single_agent.py index d039aa9a98df90a85984c74684236dc3b5ba2dd8..aaad353f444464e2c64075e207d9c65bc7a50947 100644 --- a/cooperative_cuisine/reinforcement_learning/train_single_agent.py +++ b/cooperative_cuisine/reinforcement_learning/train_single_agent.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Any import wandb from omegaconf import DictConfig, OmegaConf @@ -17,13 +18,17 @@ from hydra.utils import instantiate @hydra.main(version_base="1.3", config_path="config", config_name="rl_config") def main(cfg: DictConfig): - additional_configs = OmegaConf.to_container(cfg.additional_configs, resolve=True) - rl_logs = Path(additional_configs["log_path"]) + """ + trains an agent from scratch and saves the model to the specified path + All configs are managed with hydra. + """ + additional_configs: dict[str, Any] = OmegaConf.to_container(cfg.additional_configs, resolve=True) + rl_logs: Path = Path(additional_configs["log_path"]) rl_logs.mkdir(exist_ok=True) - rl_agent_checkpoints = rl_logs / Path(additional_configs["checkpoint_path"]) + rl_agent_checkpoints: Path = rl_logs / Path(additional_configs["checkpoint_path"]) rl_agent_checkpoints.mkdir(exist_ok=True) - config = OmegaConf.to_container(cfg.model, resolve=True) - debug = additional_configs["debug_mode"] + config: dict[str, Any] = OmegaConf.to_container(cfg.model, resolve=True) + debug: bool = additional_configs["debug_mode"] vec_env = additional_configs["vec_env"] number_envs_parallel = config["number_envs_parallel"] model_class = instantiate(cfg.model.model_type)