diff --git a/overcooked_simulator/counter_factory.py b/overcooked_simulator/counter_factory.py index 93193edad4afa17ad4e1c40e39f60a40c880ceaf..02abcf47a6e4f1ce408bd673528934a96f127585 100644 --- a/overcooked_simulator/counter_factory.py +++ b/overcooked_simulator/counter_factory.py @@ -51,6 +51,7 @@ from overcooked_simulator.counters import ( Trashcan, ) from overcooked_simulator.game_items import ItemInfo, ItemType, CookingEquipment, Plate +from overcooked_simulator.hooks import Hooks from overcooked_simulator.order import OrderAndScoreManager from overcooked_simulator.utils import get_closest @@ -111,6 +112,7 @@ class CounterFactory: serving_window_additional_kwargs: dict[str, Any], plate_config: PlateConfig, order_and_score: OrderAndScoreManager, + hook: Hooks, ) -> None: """Constructor for the `CounterFactory` class. Set up the attributes necessary to instantiate the counters. @@ -166,6 +168,9 @@ class CounterFactory: } """A dictionary mapping cooking counters to the list of equipment items associated with them.""" + self.hook = hook + """Reference to the hook manager.""" + def get_counter_object(self, c: str, pos: npt.NDArray[float]) -> Counter: """Create and returns a counter object based on the provided character and position.""" @@ -188,14 +193,16 @@ class CounterFactory: by_equipment_name=item_info.name ), ), + hook=self.hook, ) elif item_info.type == ItemType.Ingredient: - return Dispenser(pos=pos, dispensing=item_info) + return Dispenser(pos=pos, hook=self.hook, dispensing=item_info) if counter_class is None: counter_class = self.counter_classes[self.layout_chars_config[c]] kwargs = { "pos": pos, + "hook": self.hook, } if issubclass(counter_class, (CuttingBoard, Sink)): kwargs["transitions"] = self.filter_item_info( diff --git a/overcooked_simulator/counters.py b/overcooked_simulator/counters.py index 7e07048f5e544318c199b5652d64c6aff03bf571..d4b72eec6029dd96ca6c8d5604deb64b12b0c970 100644 --- a/overcooked_simulator/counters.py +++ b/overcooked_simulator/counters.py @@ -42,6 +42,29 @@ from collections.abc import Iterable from datetime import datetime, timedelta from typing import TYPE_CHECKING, Optional, Callable, Set +from overcooked_simulator.hooks import ( + Hooks, + POST_DISPENSER_PICK_UP, + PRE_DISPENSER_PICK_UP, + CUTTING_BOARD_PROGRESS, + CUTTING_BOARD_100, + CUTTING_BOARD_START_INTERACT, + CUTTING_BOARD_END_INTERACT, + PRE_COUNTER_PICK_UP, + POST_COUNTER_PICK_UP, + PRE_SERVING, + POST_SERVING, + NO_SERVING, + DIRTY_PLATE_ARRIVES, + TRASHCAN_USAGE, + PLATE_CLEANED, + SINK_START_INTERACT, + SINK_END_INTERACT, + ADDED_PLATE_TO_SINK, + DROP_ON_SINK_ADDON, + PICK_UP_FROM_SINK_ADDON, +) + if TYPE_CHECKING: from overcooked_simulator.overcooked_environment import ( OrderAndScoreManager, @@ -74,6 +97,7 @@ class Counter: def __init__( self, pos: npt.NDArray[float], + hook: Hooks, occupied_by: Optional[Item] = None, uid: hex = None, **kwargs, @@ -90,6 +114,8 @@ class Counter: """The position of the counter.""" self.occupied_by: Optional[Item] = occupied_by """What is on top of the counter, e.g., `Item`s.""" + self.hook = hook + """Reference to the hook manager.""" @property def occupied(self) -> bool: @@ -105,16 +131,36 @@ class Counter: Returns: The item which the counter is occupied by. None if nothing is there. """ + self.hook(PRE_COUNTER_PICK_UP, counter=self, on_hands=on_hands) if on_hands: if self.occupied_by: occupied_by = self.occupied_by self.occupied_by = None + self.hook( + POST_COUNTER_PICK_UP, + counter=self, + on_hands=on_hands, + return_this=occupied_by, + ) return occupied_by return None if self.occupied_by and isinstance(self.occupied_by, CookingEquipment): - return self.occupied_by.release() + return_this = self.occupied_by.release() + self.hook( + POST_COUNTER_PICK_UP, + counter=self, + on_hands=on_hands, + return_this=return_this, + ) + return return_this occupied_by = self.occupied_by self.occupied_by = None + self.hook( + POST_COUNTER_PICK_UP, + counter=self, + on_hands=on_hands, + return_this=occupied_by, + ) return occupied_by def can_drop_off(self, item: Item) -> bool: @@ -186,7 +232,7 @@ class CuttingBoard(Counter): The character `C` in the `layout` file represents the CuttingBoard. """ - def __init__(self, pos: np.ndarray, transitions: dict[str, ItemInfo], **kwargs): + def __init__(self, transitions: dict[str, ItemInfo], **kwargs): self.progressing: bool = False """Is a player progressing/cutting on the board.""" self.transitions: dict[str, ItemInfo] = transitions @@ -197,7 +243,7 @@ class CuttingBoard(Counter): } """For faster accessing the needed item. Keys are the ingredients that the player can put and chop on the board.""" - super().__init__(pos=pos, **kwargs) + super().__init__(**kwargs) def progress(self, passed_time: timedelta, now: datetime): """Called by environment step function for time progression. @@ -222,11 +268,18 @@ class CuttingBoard(Counter): self.occupied_by.progress( equipment=self.__class__.__name__, percent=percent ) + self.hook( + CUTTING_BOARD_PROGRESS, + counter=self, + percent=percent, + passed_time=passed_time, + ) if self.occupied_by.progress_percentage == 1.0: self.occupied_by.reset() self.occupied_by.name = self.inverted_transition_dict[ self.occupied_by.name ].name + self.hook(CUTTING_BOARD_100, counter=self) def start_progress(self): """Starts the cutting process.""" @@ -239,10 +292,12 @@ class CuttingBoard(Counter): def interact_start(self): """Handles player interaction, starting to hold key down.""" self.start_progress() + self.hook(CUTTING_BOARD_START_INTERACT, counter=self) def interact_stop(self): """Handles player interaction, stopping to hold key down.""" self.pause_progress() + self.hook(CUTTING_BOARD_END_INTERACT, counter=self) def to_dict(self) -> dict: d = super().to_dict() @@ -264,7 +319,6 @@ class ServingWindow(Counter): def __init__( self, - pos: npt.NDArray[float], order_and_score: OrderAndScoreManager, meals: set[str], env_time_func: Callable[[], datetime], @@ -281,14 +335,17 @@ class ServingWindow(Counter): """All allowed meals by the `environment_config.yml`.""" self.env_time_func: Callable[[], datetime] = env_time_func """Reference to get the current env time by calling the `env_time_func`.""" - super().__init__(pos=pos, **kwargs) + super().__init__(**kwargs) def drop_off(self, item) -> Item | None: env_time = self.env_time_func() + self.hook(PRE_SERVING, counter=self, item=item, env_time=env_time) if self.order_and_score.serve_meal(item=item, env_time=env_time): if self.plate_dispenser is not None: self.plate_dispenser.update_plate_out_of_kitchen(env_time=env_time) + self.hook(POST_SERVING, counter=self, item=item, env_time=env_time) return None + self.hook(NO_SERVING, counter=self, item=item, env_time=env_time) return item def can_drop_off(self, item: Item) -> bool: @@ -322,18 +379,24 @@ class Dispenser(Counter): Which also is easier for the visualization of the dispenser. """ - def __init__(self, pos: npt.NDArray[float], dispensing: ItemInfo, **kwargs): + def __init__(self, dispensing: ItemInfo, **kwargs): self.dispensing: ItemInfo = dispensing """`ItemInfo` what the the Dispenser is dispensing. One ready object always is on top of the counter.""" super().__init__( - pos=pos, occupied_by=self.create_item(), **kwargs, ) def pick_up(self, on_hands: bool = True) -> Item | None: + self.hook(PRE_DISPENSER_PICK_UP, counter=self, on_hands=on_hands) return_this = self.occupied_by self.occupied_by = self.create_item() + self.hook( + POST_DISPENSER_PICK_UP, + counter=self, + on_hands=on_hands, + return_this=return_this, + ) return return_this def drop_off(self, item: Item) -> Item | None: @@ -390,13 +453,12 @@ class PlateDispenser(Counter): def __init__( self, - pos: npt.NDArray[float], dispensing: ItemInfo, plate_config: PlateConfig, plate_transitions: dict[str, ItemInfo], **kwargs, ) -> None: - super().__init__(pos=pos, **kwargs) + super().__init__(**kwargs) self.dispensing: ItemInfo = dispensing """Plate ItemInfo.""" self.occupied_by: deque = deque() @@ -467,6 +529,7 @@ class PlateDispenser(Counter): idx_delete = [] for i, times in enumerate(self.out_of_kitchen_timer): if times < now: + self.hook(DIRTY_PLATE_ARRIVES, counter=self, times=times, now=now) idx_delete.append(i) log.debug("Add dirty plate") self.add_dirty_plate() @@ -501,10 +564,8 @@ class Trashcan(Counter): The character `X` in the `layout` file represents the Trashcan. """ - def __init__( - self, order_and_score: OrderAndScoreManager, pos: npt.NDArray[float], **kwargs - ): - super().__init__(pos, **kwargs) + def __init__(self, order_and_score: OrderAndScoreManager, **kwargs): + super().__init__(**kwargs) self.order_and_score: OrderAndScoreManager = order_and_score """Reference to the `OrderAndScoreManager`, because unnecessary removed items can affect the score.""" @@ -513,11 +574,14 @@ class Trashcan(Counter): def drop_off(self, item: Item) -> Item | None: if isinstance(item, CookingEquipment): - self.order_and_score.apply_penalty_for_using_trash(item.content_list) + penalty = self.order_and_score.apply_penalty_for_using_trash( + item.content_list + ) item.reset_content() return item else: - self.order_and_score.apply_penalty_for_using_trash(item) + penalty = self.order_and_score.apply_penalty_for_using_trash(item) + self.hook(TRASHCAN_USAGE, counter=self, item=item, penalty=penalty) return None def can_drop_off(self, item: Item) -> bool: @@ -588,12 +652,11 @@ class Sink(Counter): def __init__( self, - pos: npt.NDArray[float], transitions: dict[str, ItemInfo], sink_addon: SinkAddon = None, **kwargs, ): - super().__init__(pos=pos, **kwargs) + super().__init__(**kwargs) self.progressing: bool = False """If a player currently cleans a plate.""" self.sink_addon: SinkAddon = sink_addon @@ -630,6 +693,7 @@ class Sink(Counter): equipment=self.__class__.__name__, percent=percent ) if self.occupied_by[-1].progress_percentage == 1.0: + self.hook(PLATE_CLEANED, counter=self) self.occupied_by[-1].reset() self.occupied_by[-1].name = name plate = self.occupied_by.pop() @@ -648,16 +712,19 @@ class Sink(Counter): def interact_start(self): """Handles player interaction, starting to hold key down.""" self.start_progress() + self.hook(SINK_START_INTERACT, counter=self) def interact_stop(self): """Handles player interaction, stopping to hold key down.""" self.pause_progress() + self.hook(SINK_END_INTERACT, counter=self) def can_drop_off(self, item: Item) -> bool: return isinstance(item, Plate) and not item.clean def drop_off(self, item: Plate) -> Item | None: self.occupied_by.appendleft(item) + self.hook(ADDED_PLATE_TO_SINK, counter=self, item=item) return None def pick_up(self, on_hands: bool = True) -> Item | None: @@ -681,8 +748,8 @@ class SinkAddon(Counter): The character `+` in the `layout` file represents the SinkAddon. """ - def __init__(self, pos: npt.NDArray[float], occupied_by=None): - super().__init__(pos=pos) + def __init__(self, occupied_by=None, **kwargs): + super().__init__(**kwargs) # maybe check if occupied by is already a list or deque? self.occupied_by: deque = deque([occupied_by]) if occupied_by else deque() """The stack of clean plates.""" @@ -691,6 +758,7 @@ class SinkAddon(Counter): return self.occupied_by and self.occupied_by[-1].can_combine(item) def drop_off(self, item: Item) -> Item | None: + self.hook(DROP_ON_SINK_ADDON, counter=self, item=item) return self.occupied_by[-1].combine(item) def add_clean_plate(self, plate: Plate): @@ -699,4 +767,5 @@ class SinkAddon(Counter): def pick_up(self, on_hands: bool = True) -> Item | None: if self.occupied_by: + self.hook(PICK_UP_FROM_SINK_ADDON) return self.occupied_by.pop() diff --git a/overcooked_simulator/hooks.py b/overcooked_simulator/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..e147cfd3d45fbffc75c4982a7bd10c08d83fabe8 --- /dev/null +++ b/overcooked_simulator/hooks.py @@ -0,0 +1,107 @@ +from collections import defaultdict +from typing import Callable + +# TODO add player_id as kwarg to all hooks -> pass player id to all methods + +ITEM_INFO_LOADED = "item_info_load" +"""Called after the item info is loaded and stored in the env attribute `item_info`. The kwargs are the passed config +(`item_info`) to the environment from which it was loaded and if it is a file path or the config string (`as_files`)""" + +LAYOUT_FILE_PARSED = "layout_file_parsed" +"""After the layout file was parsed. No additional kwargs. Everything is stored in the env.""" + +ENV_INITIALIZED = "env_initialized" +"""At the end of the __init__ method. No additional kwargs. Everything is stored in the env.""" + +PRE_PERFORM_ACTION = "pre_perform_action" +"""Before an action is performed / entered into the environment. `action` kwarg with the entered action.""" + +POST_PERFORM_ACTION = "post_perfom_action" +"""After an action is performed / entered into the environment. `action` kwarg with the entered action.""" + +# TODO Pre and Post Perform Movement + +PLAYER_ADDED = "player_added" +"""Called after a player has been added. Kwargs: `player_name` and `pos`.""" + +GAME_ENDED_STEP = "game_ended_step" + +PRE_STATE = "pre_state" + +STATE_DICT = "state_dict" + +JSON_STATE = "json_state" + +PRE_RESET_ENV_TIME = "pre_reset_env_time" + +POST_RESET_ENV_TIME = "post_reset_env_time" + +PRE_COUNTER_PICK_UP = "pre_counter_pick_up" +POST_COUNTER_PICK_UP = "post_counter_pick_up" + +PRE_COUNTER_DROP_OFF = "pre_counter_drop_off" +POST_COUNTER_DROP_OFF = "post_counter_drop_off" + +PRE_DISPENSER_PICK_UP = "dispenser_pick_up" +POST_DISPENSER_PICK_UP = "dispenser_pick_up" + +CUTTING_BOARD_PROGRESS = "cutting_board_progress" +CUTTING_BOARD_100 = "cutting_board_100" + +CUTTING_BOARD_START_INTERACT = "cutting_board_start_interaction" +CUTTING_BOARD_END_INTERACT = "cutting_board_end_interact" + +PRE_SERVING = "pre_serving" +POST_SERVING = "post_serving" +NO_SERVING = "no_serving" + +# TODO drop off + +DIRTY_PLATE_ARRIVES = "dirty_plate_arrives" + +TRASHCAN_USAGE = "trashcan_usage" + +PLATE_CLEANED = "plate_cleaned" + +SINK_START_INTERACT = "sink_start_interact" + +SINK_END_INTERACT = "sink_end_interact" + +ADDED_PLATE_TO_SINK = "added_plate_to_sink" + +DROP_ON_SINK_ADDON = "drop_on_sink_addon" + +PICK_UP_FROM_SINK_ADDON = "pick_up_from_sink_addon" + +SERVE_NOT_ORDERED_MEAL = "serve_not_ordered_meal" + +SERVE_WITHOUT_PLATE = "serve_without_plate" + +COMPLETED_ORDER = "completed_order" + +INIT_ORDERS = "init_orders" + +NEW_ORDERS = "new_orders" + +ACTION_ON_NOT_REACHABLE_COUNTER = "action_on_not_reachable_counter" + +ACTION_PUT = "action_put" + +ACTION_INTERACT_START = "action_interact_start" + + +class Hooks: + def __init__(self, env): + self.hooks = defaultdict(list) + self.env = env + + def __call__(self, hook_ref, **kwargs): + for callback in self.hooks[hook_ref]: + callback(hook_ref=hook_ref, env=self.env, **kwargs) + + def register_callback(self, hook_ref: str | list[str], callback: Callable): + if isinstance(hook_ref, (tuple, list, set)): # TODO check for iterable + for ref in hook_ref: + self.hooks[ref].append(callback) + else: + self.hooks[hook_ref].append(callback) diff --git a/overcooked_simulator/order.py b/overcooked_simulator/order.py index 0c954882a3ef3eb0708ffb63a71abc090814f562..299bcdd6e3c7afbfcd9b2a57c816a911606106f2 100644 --- a/overcooked_simulator/order.py +++ b/overcooked_simulator/order.py @@ -54,6 +54,14 @@ from datetime import datetime, timedelta from typing import Callable, Tuple, Any, Deque, Protocol, TypedDict, Type from overcooked_simulator.game_items import Item, Plate, ItemInfo +from overcooked_simulator.hooks import ( + Hooks, + SERVE_NOT_ORDERED_MEAL, + SERVE_WITHOUT_PLATE, + COMPLETED_ORDER, + INIT_ORDERS, + NEW_ORDERS, +) log = logging.getLogger(__name__) """The logger for this module.""" @@ -153,7 +161,7 @@ class OrderGeneration: class OrderAndScoreManager: """The Order and Score Manager that is called from the serving window.""" - def __init__(self, order_config, available_meals: dict[str, ItemInfo]): + def __init__(self, order_config, available_meals: dict[str, ItemInfo], hook: Hooks): self.score: float = 0.0 """The current score of the environment.""" self.order_gen: OrderGeneration = order_config["order_gen_class"]( @@ -187,6 +195,9 @@ class OrderAndScoreManager: self.last_expired: list[Order] = [] """Cache last expired orders for `OrderGeneration.get_orders` call.""" + self.hook = hook + """Reference to the hook manager.""" + def update_next_relevant_time(self): """For more efficient checking when to do something in the progress call.""" next_relevant_time = datetime.max @@ -209,6 +220,12 @@ class OrderAndScoreManager: if order is None: if self.serving_not_ordered_meals: accept, score = self.serving_not_ordered_meals(meal) + self.hook( + SERVE_NOT_ORDERED_MEAL, + accept=accept, + score=score, + meal=meal, + ) if accept: log.info( f"Serving meal without order {meal.name!r} with score {score}" @@ -236,7 +253,10 @@ class OrderAndScoreManager: self.last_finished.append(order) del self.open_orders[index] self.served_meals.append((meal, env_time)) + self.hook(COMPLETED_ORDER, score=score, order=order, meal=meal) return True + else: + self.hook(SERVE_WITHOUT_PLATE, item=item) log.info(f"Do not serve item {item}") return False @@ -248,6 +268,7 @@ class OrderAndScoreManager: def create_init_orders(self, env_time): """Create the initial orders in an environment.""" init_orders = self.order_gen.init_orders(env_time) + self.hook(INIT_ORDERS) self.setup_penalties(new_orders=init_orders, env_time=env_time) self.open_orders.extend(init_orders) @@ -259,6 +280,8 @@ class OrderAndScoreManager: new_finished_orders=self.last_finished, expired_orders=self.last_expired, ) + if new_orders: + self.hook(NEW_ORDERS, new_orders=new_orders) self.setup_penalties(new_orders=new_orders, env_time=now) self.open_orders.extend(new_orders) self.last_finished = [] @@ -277,6 +300,7 @@ class OrderAndScoreManager: for i, (penalty_time, penalty) in enumerate(order.timed_penalties): # check penalties if penalty_time < now: + # TODO add hook self.score -= penalty remove_penalties.append(i) @@ -318,9 +342,11 @@ class OrderAndScoreManager: for order in self.open_orders ] - def apply_penalty_for_using_trash(self, remove: Item | list[Item]): + def apply_penalty_for_using_trash(self, remove: Item | list[Item]) -> float: """Is called if a item is put into the trashcan.""" - self.increment_score(self.penalty_for_trash(remove)) + penalty = self.penalty_for_trash(remove) + self.increment_score(penalty) + return penalty class ScoreCalcFuncType(Protocol): diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index 4aee44b774204eec337a256003a2321cf3110f72..f7880728e98448d12f26539a8ba3612b1406bd86 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -9,7 +9,7 @@ import sys from datetime import timedelta, datetime from enum import Enum from pathlib import Path -from typing import Literal, TypedDict +from typing import Literal, TypedDict, Callable import numpy as np import numpy.typing as npt @@ -24,6 +24,24 @@ from overcooked_simulator.game_items import ( ItemInfo, ItemType, ) +from overcooked_simulator.hooks import ( + ITEM_INFO_LOADED, + LAYOUT_FILE_PARSED, + ENV_INITIALIZED, + PRE_PERFORM_ACTION, + POST_PERFORM_ACTION, + PLAYER_ADDED, + GAME_ENDED_STEP, + PRE_STATE, + STATE_DICT, + JSON_STATE, + PRE_RESET_ENV_TIME, + POST_RESET_ENV_TIME, + Hooks, + ACTION_ON_NOT_REACHABLE_COUNTER, + ACTION_PUT, + ACTION_INTERACT_START, +) from overcooked_simulator.order import ( OrderAndScoreManager, OrderConfig, @@ -106,6 +124,9 @@ class Environment: item_info: Path | str, as_files: bool = True, ): + self.hook: Hooks = Hooks(self) + """Hook manager. Register callbacks and create hook points with additional kwargs.""" + self.players: dict[str, Player] = {} """the player, keyed by their id/name.""" @@ -126,6 +147,8 @@ class Environment: self.item_info: dict[str, ItemInfo] = self.load_item_info(item_info) """The loaded item info dict. Keys are the item names.""" + self.hook(ITEM_INFO_LOADED, item_info=item_info, as_files=as_files) + # self.validate_item_info() if self.environment_config["meals"]["all"]: self.allowed_meal_names = set( @@ -147,6 +170,7 @@ class Environment: for item, info in self.item_info.items() if info.type == ItemType.Meal and item in self.allowed_meal_names }, + hook=self.hook, ) """The manager for the orders and score update.""" @@ -170,6 +194,7 @@ class Environment: ) ), order_and_score=self.order_and_score, + hook=self.hook, ) ( @@ -177,6 +202,7 @@ class Environment: self.designated_player_positions, self.free_positions, ) = self.parse_layout_file() + self.hook(LAYOUT_FILE_PARSED) self.world_borders_x = [-0.5, self.kitchen_width - 0.5] self.world_borders_y = [-0.5, self.kitchen_height - 0.5] @@ -211,6 +237,8 @@ class Environment: """The relative env time when it will stop/end""" log.debug(f"End time: {self.env_time_end}") + self.hook(ENV_INITIALIZED) + @property def game_ended(self) -> bool: """Whether the game is over or not based on the calculated `Environment.env_time_end`""" @@ -351,6 +379,7 @@ class Environment: action: The action to be performed """ assert action.player in self.players.keys(), "Unknown player." + self.hook(PRE_PERFORM_ACTION, action=action) player = self.players[action.player] if action.action_type == ActionType.MOVEMENT: @@ -362,16 +391,23 @@ class Environment: counter = self.get_facing_counter(player) if player.can_reach(counter): if action.action_type == ActionType.PUT: - player.pick_action(counter) - + player.put_action(counter) + self.hook(ACTION_PUT, action=action, counter=counter) elif action.action_type == ActionType.INTERACT: if action.action_data == InterActionData.START: player.perform_interact_hold_start(counter) player.last_interacted_counter = counter + self.hook(ACTION_INTERACT_START, action=action, counter=counter) + else: + self.hook( + ACTION_ON_NOT_REACHABLE_COUNTER, action=action, counter=counter + ) if action.action_data == InterActionData.STOP: if player.last_interacted_counter: player.perform_interact_hold_stop(player.last_interacted_counter) + self.hook(POST_PERFORM_ACTION, action=action) + def get_facing_counter(self, player: Player): """Determines the counter which the player is looking at. Adds a multiple of the player facing direction onto the player position and finds the closest @@ -572,6 +608,8 @@ class Environment: log.debug("No free positions left in kitchens") player.update_facing_point() + self.hook(PLAYER_ADDED, player_name=player_name, pos=pos) + def detect_collision_world_bounds(self, player: Player): """Checks for detections of the player and the world bounds. @@ -594,9 +632,12 @@ class Environment: """Performs a step of the environment. Affects time based events such as cooking or cutting things, orders and time limits. """ + # self.hook(PRE_STEP, passed_time=passed_time) self.env_time += passed_time - if not self.game_ended: + if self.game_ended: + self.hook(GAME_ENDED_STEP) + else: for player in self.players.values(): if self.env_time <= player.movement_until: self.perform_movement(player, passed_time) @@ -605,6 +646,8 @@ class Environment: counter.progress(passed_time=passed_time, now=self.env_time) self.order_and_score.progress(passed_time=passed_time, now=self.env_time) + # self.hook(POST_STEP, passed_time=passed_time) + def get_state(self): """Get the current state of the game environment. The state here is accessible by the current python objects. @@ -622,6 +665,7 @@ class Environment: } def get_json_state(self, player_id: str = None): + self.hook(PRE_STATE, player_id=player_id) state = { "players": [p.to_dict() for p in self.players.values()], "counters": [c.to_dict() for c in self.counters], @@ -634,11 +678,18 @@ class Environment: self.env_time_end - self.env_time, timedelta(0) ).total_seconds(), } + self.hook(STATE_DICT, state=state) json_data = json.dumps(state) + self.hook(JSON_STATE, json_data=json_data) assert StateRepresentation.model_validate_json(json_data=json_data) return json_data def reset_env_time(self): """Reset the env time to the initial time, defined by `create_init_env_time`.""" + self.hook(PRE_RESET_ENV_TIME) self.env_time = create_init_env_time() + self.hook(POST_RESET_ENV_TIME) log.debug(f"Reset env time to {self.env_time}") + + def register_callback_for_hook(self, hook_ref: str | list[str], callback: Callable): + self.hook.register_callback(hook_ref, callback) diff --git a/overcooked_simulator/player.py b/overcooked_simulator/player.py index 5f87b87f7032e54f12cace9fecda1d882aac5c54..0507e279ab80b765ead610623cd29991e37b0332 100644 --- a/overcooked_simulator/player.py +++ b/overcooked_simulator/player.py @@ -138,7 +138,7 @@ class Player: """ return np.linalg.norm(counter.pos - self.facing_point) <= self.interaction_range - def pick_action(self, counter: Counter): + def put_action(self, counter: Counter): """Performs the pickup-action with the counter. Handles the logic of what the player is currently holding, what is currently on the counter and what can be picked up or combined in hand. diff --git a/tests/test_start.py b/tests/test_start.py index 1dbdafdafd48a79f29969cc686bc87665562799e..cd4bb7abc7831770c0201ec45eb4b17ef119483f 100644 --- a/tests/test_start.py +++ b/tests/test_start.py @@ -2,6 +2,7 @@ from datetime import timedelta import numpy as np import pytest +from overcooked_simulator.hook import Hooks from overcooked_simulator import ROOT_DIR from overcooked_simulator.counters import Counter, CuttingBoard @@ -126,7 +127,7 @@ def test_collision_detection(env_config, layout_config, item_info): env = Environment(env_config, layout_config, item_info, as_files=False) counter_pos = np.array([1, 2]) - counter = Counter(counter_pos) + counter = Counter(pos=counter_pos, hook=Hooks(env)) env.counters = [counter] env.add_player("1", np.array([1, 1])) env.add_player("2", np.array([1, 4])) @@ -156,7 +157,7 @@ def test_player_reach(env_config, layout_empty_config, item_info): env = Environment(env_config, layout_empty_config, item_info, as_files=False) counter_pos = np.array([2, 2]) - counter = Counter(counter_pos) + counter = Counter(pos=counter_pos, hook=Hooks(env)) env.counters = [counter] env.add_player("1", np.array([2, 4])) env.players["1"].player_speed_units_per_seconds = 1 @@ -175,7 +176,7 @@ def test_pickup(env_config, layout_config, item_info): env = Environment(env_config, layout_config, item_info, as_files=False) counter_pos = np.array([2, 2]) - counter = Counter(counter_pos) + counter = Counter(pos=counter_pos, hook=Hooks(env)) counter.occupied_by = Item(name="Tomato", item_info=None) env.counters = [counter] @@ -226,7 +227,8 @@ def test_processing(env_config, layout_config, item_info): env = Environment(env_config, layout_config, item_info, as_files=False) counter_pos = np.array([2, 2]) counter = CuttingBoard( - counter_pos, + pos=counter_pos, + hook=Hooks(env), transitions={ "ChoppedTomato": ItemInfo( name="ChoppedTomato",