From f5b714afe86e405c210b383b495b714acd2793be Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Fri, 19 Jan 2024 14:22:31 +0100
Subject: [PATCH] update order doc strings and init timed penalties

---
 .../game_content/environment_config.yaml      |  2 +-
 overcooked_simulator/order.py                 | 39 +++++++++++++------
 2 files changed, 29 insertions(+), 12 deletions(-)

diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml
index 1dbb4f56..8c29958e 100644
--- a/overcooked_simulator/game_content/environment_config.yaml
+++ b/overcooked_simulator/game_content/environment_config.yaml
@@ -29,7 +29,7 @@ orders:
         b: 60
     max_orders: 6
     # maximum number of active orders at the same time
-    num_start_meals: 3
+    num_start_meals: 2
     # number of orders generated at the start of the environment
     sample_on_dur_random_func:
       # 'random' library call with getattr, kwargs are passed to the function
diff --git a/overcooked_simulator/order.py b/overcooked_simulator/order.py
index ae3447aa..050ec15d 100644
--- a/overcooked_simulator/order.py
+++ b/overcooked_simulator/order.py
@@ -47,9 +47,9 @@ class Order:
     timed_penalties: list[
         Tuple[timedelta, float] | Tuple[timedelta, float, int, timedelta]
     ]
-    """list of timed penalties when the order is not fulfilled."""
+    """List of timed penalties when the order is not fulfilled."""
     expired_penalty: float
-    """the penalty to the score if the order expires"""
+    """The penalty to the score if the order expires"""
 
     finished_info: dict[str, Any] = dataclasses.field(default_factory=dict)
     """Is set after the order is completed."""
@@ -388,17 +388,24 @@ class OrderAndScoreManager:
         self.order_gen: OrderGeneration = order_config["order_gen_class"](
             available_meals=available_meals, kwargs=order_config["order_gen_kwargs"]
         )
-        self.kwargs_for_func = order_config["order_gen_kwargs"]
-        self.serving_not_ordered_meals = order_config["serving_not_ordered_meals"]
+        self.serving_not_ordered_meals: Callable[
+            [Item], Tuple[bool, float]
+        ] = order_config["serving_not_ordered_meals"]
+        """Function that decides if not ordered meals can be served and what score it gives"""
         self.available_meals = available_meals
+        """The meals for that orders can be sampled from."""
         self.open_orders: Deque[Order] = deque()
+        """Current open orders. This attribute is used for the environment state."""
 
-        # for logs or history in the future
         # TODO log who / which player served which meal -> for split scores
         self.served_meals: list[Tuple[Item, datetime]] = []
-        self.last_finished = []
-        self.next_relevant_time = datetime.max
-        self.last_expired = []
+        """List of served meals. Maybe for the end screen."""
+        self.last_finished: list[Order] = []
+        """Cache last finished orders for `OrderGeneration.get_orders` call. From the served meals."""
+        self.next_relevant_time: datetime = datetime.max
+        """For reduced order checking. Store the next time when to create an order or check for penalties."""
+        self.last_expired: list[Order] = []
+        """Cache last expired orders for `OrderGeneration.get_orders` call."""
 
     def update_next_relevant_time(self):
         next_relevant_time = datetime.max
@@ -455,6 +462,7 @@ class OrderAndScoreManager:
     def create_init_orders(self, env_time):
         """Create the initial orders in an environment."""
         init_orders = self.order_gen.init_orders(env_time)
+        self.setup_penalties(new_orders=init_orders, env_time=env_time)
         self.open_orders.extend(init_orders)
 
     def progress(self, passed_time: timedelta, now: datetime):
@@ -465,17 +473,23 @@ class OrderAndScoreManager:
             new_finished_orders=self.last_finished,
             expired_orders=self.last_expired,
         )
+        self.setup_penalties(new_orders=new_orders, env_time=now)
         self.open_orders.extend(new_orders)
         self.last_finished = []
         self.last_expired = []
         if new_orders or self.next_relevant_time <= now:
-            remove_orders = []
+            # reduce checking calls
+
+            remove_orders: list[int] = []
             for index, order in enumerate(self.open_orders):
                 if now >= order.start_time + order.max_duration:
+                    # orders expired
                     self.increment_score(order.expired_penalty)
                     remove_orders.append(index)
+                    continue  # no penalties for expired orders
                 remove_penalties = []
                 for i, (penalty_time, penalty) in enumerate(order.timed_penalties):
+                    # check penalties
                     if penalty_time < now:
                         self.score -= penalty
                         remove_penalties.append(i)
@@ -483,7 +497,8 @@ class OrderAndScoreManager:
                 for i in reversed(remove_penalties):
                     # or del order.timed_penalties[index]
                     order.timed_penalties.pop(i)
-            expired_orders = []
+
+            expired_orders: list[Order] = []
             for remove_order in reversed(remove_orders):
                 expired_orders.append(self.open_orders[remove_order])
                 del self.open_orders[remove_order]
@@ -496,6 +511,8 @@ class OrderAndScoreManager:
             if order.meal.name == meal.name:
                 return order, index
 
-    def setup_penalties(self, new_orders: list[Order], env_time: datetime):
+    @staticmethod
+    def setup_penalties(new_orders: list[Order], env_time: datetime):
+        """Call the `Order.create_penalties` method for new orders."""
         for order in new_orders:
             order.create_penalties(env_time)
-- 
GitLab