From f5b714afe86e405c210b383b495b714acd2793be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20Schr=C3=B6der?= <fschroeder@techfak.uni-bielefeld.de> Date: Fri, 19 Jan 2024 14:22:31 +0100 Subject: [PATCH] update order doc strings and init timed penalties --- .../game_content/environment_config.yaml | 2 +- overcooked_simulator/order.py | 39 +++++++++++++------ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml index 1dbb4f56..8c29958e 100644 --- a/overcooked_simulator/game_content/environment_config.yaml +++ b/overcooked_simulator/game_content/environment_config.yaml @@ -29,7 +29,7 @@ orders: b: 60 max_orders: 6 # maximum number of active orders at the same time - num_start_meals: 3 + num_start_meals: 2 # number of orders generated at the start of the environment sample_on_dur_random_func: # 'random' library call with getattr, kwargs are passed to the function diff --git a/overcooked_simulator/order.py b/overcooked_simulator/order.py index ae3447aa..050ec15d 100644 --- a/overcooked_simulator/order.py +++ b/overcooked_simulator/order.py @@ -47,9 +47,9 @@ class Order: timed_penalties: list[ Tuple[timedelta, float] | Tuple[timedelta, float, int, timedelta] ] - """list of timed penalties when the order is not fulfilled.""" + """List of timed penalties when the order is not fulfilled.""" expired_penalty: float - """the penalty to the score if the order expires""" + """The penalty to the score if the order expires""" finished_info: dict[str, Any] = dataclasses.field(default_factory=dict) """Is set after the order is completed.""" @@ -388,17 +388,24 @@ class OrderAndScoreManager: self.order_gen: OrderGeneration = order_config["order_gen_class"]( available_meals=available_meals, kwargs=order_config["order_gen_kwargs"] ) - self.kwargs_for_func = order_config["order_gen_kwargs"] - self.serving_not_ordered_meals = order_config["serving_not_ordered_meals"] + self.serving_not_ordered_meals: Callable[ + [Item], Tuple[bool, float] + ] = order_config["serving_not_ordered_meals"] + """Function that decides if not ordered meals can be served and what score it gives""" self.available_meals = available_meals + """The meals for that orders can be sampled from.""" self.open_orders: Deque[Order] = deque() + """Current open orders. This attribute is used for the environment state.""" - # for logs or history in the future # TODO log who / which player served which meal -> for split scores self.served_meals: list[Tuple[Item, datetime]] = [] - self.last_finished = [] - self.next_relevant_time = datetime.max - self.last_expired = [] + """List of served meals. Maybe for the end screen.""" + self.last_finished: list[Order] = [] + """Cache last finished orders for `OrderGeneration.get_orders` call. From the served meals.""" + self.next_relevant_time: datetime = datetime.max + """For reduced order checking. Store the next time when to create an order or check for penalties.""" + self.last_expired: list[Order] = [] + """Cache last expired orders for `OrderGeneration.get_orders` call.""" def update_next_relevant_time(self): next_relevant_time = datetime.max @@ -455,6 +462,7 @@ class OrderAndScoreManager: def create_init_orders(self, env_time): """Create the initial orders in an environment.""" init_orders = self.order_gen.init_orders(env_time) + self.setup_penalties(new_orders=init_orders, env_time=env_time) self.open_orders.extend(init_orders) def progress(self, passed_time: timedelta, now: datetime): @@ -465,17 +473,23 @@ class OrderAndScoreManager: new_finished_orders=self.last_finished, expired_orders=self.last_expired, ) + self.setup_penalties(new_orders=new_orders, env_time=now) self.open_orders.extend(new_orders) self.last_finished = [] self.last_expired = [] if new_orders or self.next_relevant_time <= now: - remove_orders = [] + # reduce checking calls + + remove_orders: list[int] = [] for index, order in enumerate(self.open_orders): if now >= order.start_time + order.max_duration: + # orders expired self.increment_score(order.expired_penalty) remove_orders.append(index) + continue # no penalties for expired orders remove_penalties = [] for i, (penalty_time, penalty) in enumerate(order.timed_penalties): + # check penalties if penalty_time < now: self.score -= penalty remove_penalties.append(i) @@ -483,7 +497,8 @@ class OrderAndScoreManager: for i in reversed(remove_penalties): # or del order.timed_penalties[index] order.timed_penalties.pop(i) - expired_orders = [] + + expired_orders: list[Order] = [] for remove_order in reversed(remove_orders): expired_orders.append(self.open_orders[remove_order]) del self.open_orders[remove_order] @@ -496,6 +511,8 @@ class OrderAndScoreManager: if order.meal.name == meal.name: return order, index - def setup_penalties(self, new_orders: list[Order], env_time: datetime): + @staticmethod + def setup_penalties(new_orders: list[Order], env_time: datetime): + """Call the `Order.create_penalties` method for new orders.""" for order in new_orders: order.create_penalties(env_time) -- GitLab