Skip to content
Snippets Groups Projects
hooks.py 4.86 KiB
Newer Older
  • Learn to ignore specific revisions
  • from __future__ import annotations
    
    from abc import abstractmethod
    
    from collections import defaultdict
    
    from functools import partial
    
    from typing import Callable, Any, TYPE_CHECKING, Type
    
    if TYPE_CHECKING:
        from overcooked_simulator.overcooked_environment import Environment
    
    
    # TODO add player_id as kwarg to all hooks -> pass player id to all methods
    
    ITEM_INFO_LOADED = "item_info_load"
    """Called after the item info is loaded and stored in the env attribute `item_info`. The kwargs are the passed config 
    (`item_info`) to the environment from which it was loaded and if it is a file path or the config string (`as_files`)"""
    
    LAYOUT_FILE_PARSED = "layout_file_parsed"
    """After the layout file was parsed. No additional kwargs. Everything is stored in the env."""
    
    
    ITEM_INFO_CONFIG = "item_info_config"
    
    
    ENV_INITIALIZED = "env_initialized"
    """At the end of the __init__ method. No additional kwargs. Everything is stored in the env."""
    
    PRE_PERFORM_ACTION = "pre_perform_action"
    """Before an action is performed / entered into the environment. `action` kwarg with the entered action."""
    
    
    POST_PERFORM_ACTION = "post_perform_action"
    
    """After an action is performed / entered into the environment. `action` kwarg with the entered action."""
    
    # TODO Pre and Post Perform Movement
    
    PLAYER_ADDED = "player_added"
    """Called after a player has been added. Kwargs: `player_name` and `pos`."""
    
    GAME_ENDED_STEP = "game_ended_step"
    
    PRE_STATE = "pre_state"
    
    STATE_DICT = "state_dict"
    
    JSON_STATE = "json_state"
    
    PRE_RESET_ENV_TIME = "pre_reset_env_time"
    
    POST_RESET_ENV_TIME = "post_reset_env_time"
    
    PRE_COUNTER_PICK_UP = "pre_counter_pick_up"
    POST_COUNTER_PICK_UP = "post_counter_pick_up"
    
    PRE_COUNTER_DROP_OFF = "pre_counter_drop_off"
    POST_COUNTER_DROP_OFF = "post_counter_drop_off"
    
    PRE_DISPENSER_PICK_UP = "dispenser_pick_up"
    POST_DISPENSER_PICK_UP = "dispenser_pick_up"
    
    CUTTING_BOARD_PROGRESS = "cutting_board_progress"
    CUTTING_BOARD_100 = "cutting_board_100"
    
    CUTTING_BOARD_START_INTERACT = "cutting_board_start_interaction"
    CUTTING_BOARD_END_INTERACT = "cutting_board_end_interact"
    
    PRE_SERVING = "pre_serving"
    POST_SERVING = "post_serving"
    NO_SERVING = "no_serving"
    
    # TODO drop off
    
    
    PLATE_OUT_OF_KITCHEN_TIME = "plate_out_of_kitchen_time"
    
    
    DIRTY_PLATE_ARRIVES = "dirty_plate_arrives"
    
    TRASHCAN_USAGE = "trashcan_usage"
    
    PLATE_CLEANED = "plate_cleaned"
    
    SINK_START_INTERACT = "sink_start_interact"
    
    SINK_END_INTERACT = "sink_end_interact"
    
    ADDED_PLATE_TO_SINK = "added_plate_to_sink"
    
    DROP_ON_SINK_ADDON = "drop_on_sink_addon"
    
    PICK_UP_FROM_SINK_ADDON = "pick_up_from_sink_addon"
    
    SERVE_NOT_ORDERED_MEAL = "serve_not_ordered_meal"
    
    SERVE_WITHOUT_PLATE = "serve_without_plate"
    
    
    ORDER_DURATION_SAMPLE = "order_duration_sample"
    
    
    COMPLETED_ORDER = "completed_order"
    
    INIT_ORDERS = "init_orders"
    
    NEW_ORDERS = "new_orders"
    
    
    ORDER_EXPIRED = "order_expired"
    
    
    ACTION_ON_NOT_REACHABLE_COUNTER = "action_on_not_reachable_counter"
    
    ACTION_PUT = "action_put"
    
    ACTION_INTERACT_START = "action_interact_start"
    
    
    class Hooks:
        def __init__(self, env):
            self.hooks = defaultdict(list)
            self.env = env
    
        def __call__(self, hook_ref, **kwargs):
            for callback in self.hooks[hook_ref]:
                callback(hook_ref=hook_ref, env=self.env, **kwargs)
    
        def register_callback(self, hook_ref: str | list[str], callback: Callable):
            if isinstance(hook_ref, (tuple, list, set)):  # TODO check for iterable
                for ref in hook_ref:
                    self.hooks[ref].append(callback)
            else:
                self.hooks[hook_ref].append(callback)
    
    
    
    def print_hook_callback(text, env, **kwargs):
        print(env.env_time, text)
    
    class HookCallbackClass:
        def __init__(self, name: str, env: Environment, **kwargs):
            self.name = name
            self.env = env
    
        @abstractmethod
        def __call__(self, hook_ref: str, env: Environment, **kwargs):
            ...
    
    
    def hooks_via_callback_class(
        name: str,
        env: Environment,
        hooks: list[str],
        callback_class: Type[HookCallbackClass],
        callback_class_kwargs: dict[str, Any],
    ):
        recorder = callback_class(name=name, env=env, **callback_class_kwargs)
        for hook in hooks:
            env.register_callback_for_hook(hook, recorder)
    
    
    
    def add_dummy_callbacks(env):
        env.register_callback_for_hook(
            SERVE_NOT_ORDERED_MEAL,
            partial(
                print_hook_callback,
                text="You tried to served a meal that was not ordered!",
            ),
        )
        env.register_callback_for_hook(
            SINK_START_INTERACT,
            partial(
                print_hook_callback,
                text="You started to use the Sink!",
            ),
        )
        env.register_callback_for_hook(
            COMPLETED_ORDER,
            partial(
                print_hook_callback,
                text="You completed an order!",
            ),
        )
        env.register_callback_for_hook(
            TRASHCAN_USAGE,
            partial(
                print_hook_callback,
                text="You used the trashcan!",
            ),
        )