From 4848c3107182de15c190285caaeb6466e9827570 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Thu, 18 Jan 2024 14:01:08 +0100
Subject: [PATCH] fix bugs

---
 .../game_content/environment_config.yaml      |   8 +-
 .../gui_2d_vis/overcooked_gui.py              |  66 +++++-----
 overcooked_simulator/order.py                 | 117 ++++++++++--------
 3 files changed, 102 insertions(+), 89 deletions(-)

diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml
index 5dab9e50..17d69768 100644
--- a/overcooked_simulator/game_content/environment_config.yaml
+++ b/overcooked_simulator/game_content/environment_config.yaml
@@ -15,13 +15,13 @@ orders:
         b: 50
     max_orders: 5
     num_start_meals: 3
-    sample_on_dur: false
+    sample_on_dur: true
     sample_on_dur_func:
       func: uniform
       kwargs:
-        a: 30
-        b: 50
-    sample_on_serving: true
+        a: 20
+        b: 30
+    sample_on_serving: false
     score_calc_gen_kwargs:
       other: 0
       scores:
diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
index db5c8188..9dc13066 100644
--- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py
+++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
@@ -433,7 +433,12 @@ class PyGameGUI:
                         pygame.draw.circle(screen, color, pos, radius)
 
     def draw_item(
-        self, pos: npt.NDArray[float], item: Item, scale: float = 1.0, plate=False
+        self,
+        pos: npt.NDArray[float],
+        item: Item,
+        scale: float = 1.0,
+        plate=False,
+        screen=None,
     ):
         """Visualization of an item at the specified position. On a counter or in the hands of the player.
         The visual composition of the item is read in from visualization.yaml file, where it is specified as
@@ -443,6 +448,8 @@ class PyGameGUI:
             pos: The position of the item to draw.
             item: The item do be drawn in the game.
             scale: Rescale the item by this factor.
+            screen: the pygame screen to draw on.
+            plate: item is on a plate (soup are is different on a plate and pot)
         """
 
         if not isinstance(item, list):
@@ -451,7 +458,10 @@ class PyGameGUI:
                 if "Soup" in item.name and plate:
                     item_key += "Plate"
                 self.draw_thing(
-                    pos, self.visualization_config[item_key]["parts"], scale=scale
+                    pos,
+                    self.visualization_config[item_key]["parts"],
+                    scale=scale,
+                    screen=screen,
                 )
 
         if isinstance(item, (Item, Plate)) and item.progress_percentage > 0.0:
@@ -463,7 +473,9 @@ class PyGameGUI:
                 and item.content_ready.name in self.visualization_config
             ):
                 self.draw_thing(
-                    pos, self.visualization_config[item.content_ready.name]["parts"]
+                    pos,
+                    self.visualization_config[item.content_ready.name]["parts"],
+                    screen=screen,
                 )
             else:
                 triangle_offsets = create_polygon(len(item.content_list), length=10)
@@ -474,6 +486,7 @@ class PyGameGUI:
                         o,
                         scale=scale,
                         plate=isinstance(item, Plate),
+                        screen=screen,
                     )
 
         # if isinstance(item, Meal):
@@ -549,36 +562,24 @@ class PyGameGUI:
         self.timer_label.set_text(f"Time remaining: {display_time}")
 
     def draw_orders(self, state):
-        # print(state["orders"])
-        # for o in state["orders"]:
-        #     print(o.meal.name)
-        # orders = [
-        #     "Burger",
-        #     "TomatoSoupPlate",
-        #     "OnionSoupPlate",
-        #     "OnionSoupPlate",
-        #     "OnionSoupPlate",
-        #     "OnionSoupPlate",
-        # ]
-        orders = [o.meal.name for o in state["orders"]]
         orders_width = self.game_width - 100
         orders_height = self.screen_margin
-        self.orders_screen = pygame.Surface(
+        order_screen = pygame.Surface(
             (orders_width, orders_height),
         )
 
         bg_color = colors[self.visualization_config["GameWindow"]["background_color"]]
-        pygame.draw.rect(self.orders_screen, bg_color, self.orders_screen.get_rect())
+        pygame.draw.rect(order_screen, bg_color, order_screen.get_rect())
 
         order_rects_start = (orders_height // 2) - (self.grid_size // 2)
 
-        for idx, order in enumerate(orders):
+        for idx, order in enumerate([o.meal for o in state["orders"]]):
             order_upper_left = [
                 order_rects_start + idx * self.grid_size * 1.5,
                 order_rects_start,
             ]
             pygame.draw.rect(
-                self.orders_screen,
+                order_screen,
                 colors["red"],
                 pygame.Rect(
                     order_upper_left[0],
@@ -594,27 +595,21 @@ class PyGameGUI:
             self.draw_thing(
                 center,
                 self.visualization_config["Plate"]["parts"],
-                screen=self.orders_screen,
+                screen=order_screen,
             )
-            self.draw_thing(
+            self.draw_item(
                 center,
-                self.visualization_config[order]["parts"],
-                screen=self.orders_screen,
+                order,
+                plate=True,
+                screen=order_screen,
             )
 
-        orders_rect = self.orders_screen.get_rect()
+        orders_rect = order_screen.get_rect()
         orders_rect.center = [
             self.screen_margin + (orders_width // 2),
             orders_height // 2,
         ]
-        self.main_window.blit(self.orders_screen, orders_rect)
-
-        self.orders_label = pygame_gui.elements.UILabel(
-            text="Orders:",
-            relative_rect=pygame.Rect(0, 0, self.screen_margin, self.screen_margin),
-            manager=self.manager,
-            object_id="#orders_label",
-        )
+        self.main_window.blit(order_screen, orders_rect)
 
     def draw(self, state):
         """Main visualization function.
@@ -732,6 +727,13 @@ class PyGameGUI:
             object_id="#timer_label",
         )
 
+        self.orders_label = pygame_gui.elements.UILabel(
+            text="Orders:",
+            relative_rect=pygame.Rect(0, 0, self.screen_margin, self.screen_margin),
+            manager=self.manager,
+            object_id="#orders_label",
+        )
+
     def set_window_size(self, window_width, window_height, game_width, game_height):
         self.game_screen = pygame.Surface(
             (
diff --git a/overcooked_simulator/order.py b/overcooked_simulator/order.py
index fc10a405..28023389 100644
--- a/overcooked_simulator/order.py
+++ b/overcooked_simulator/order.py
@@ -6,14 +6,14 @@ from collections import deque
 from datetime import datetime, timedelta
 from typing import Callable, Tuple, Any, Deque
 
-from overcooked_simulator.game_items import Item, Plate
+from overcooked_simulator.game_items import Item, Plate, ItemInfo
 
 log = logging.getLogger(__name__)
 
 
 @dataclasses.dataclass
 class Order:
-    meal: Item
+    meal: ItemInfo
     start_time: datetime
     max_duration: timedelta
     score_calc: Callable[[timedelta, ...], float]
@@ -44,8 +44,8 @@ class Order:
 
 
 class OrderGeneration:
-    def __init__(self, available_meals: dict[str, Item], **kwargs):
-        self.available_meals: list[Item] = list(available_meals.values())
+    def __init__(self, available_meals: dict[str, ItemInfo], **kwargs):
+        self.available_meals: list[ItemInfo] = list(available_meals.values())
 
     @abstractmethod
     def init_orders(self, now) -> list[Order]:
@@ -67,16 +67,16 @@ class RandomOrderKwarg:
     max_orders: int
     duration_sample: dict
     score_calc_gen_func: Callable[
-        [Item, timedelta, datetime, Any], Callable[[timedelta, Order], float]
+        [ItemInfo, timedelta, datetime, Any], Callable[[timedelta, Order], float]
     ]
     score_calc_gen_kwargs: dict
 
 
 class RandomOrderGeneration(OrderGeneration):
-    def __init__(self, available_meals: dict[str, Item], **kwargs):
+    def __init__(self, available_meals: dict[str, ItemInfo], **kwargs):
         super().__init__(available_meals, **kwargs)
         self.kwargs: RandomOrderKwarg = RandomOrderKwarg(**kwargs["kwargs"])
-        self.next_order_time: datetime | None = None
+        self.next_order_time: datetime | None = datetime.max
         self.number_cur_orders = 0
         self.needed_orders: int = 0
         """For the sample on dur but when it was restricted due to max order number."""
@@ -97,35 +97,33 @@ class RandomOrderGeneration(OrderGeneration):
         self.number_cur_orders -= len(new_finished_orders)
         if self.kwargs.sample_on_serving:
             if new_finished_orders:
-                self.number_cur_orders += len(new_finished_orders)
-                return self.create_orders_for_meals(
-                    random.choices(self.available_meals, k=len(new_finished_orders)),
-                    now,
-                    True,
-                )
-        if self.kwargs.sample_on_dur:
-            if self.needed_orders:
-                self.needed_orders -= len(new_finished_orders)
-                self.needed_orders = max(self.needed_orders, 0)
-                self.number_cur_orders += len(new_finished_orders)
+                self.create_random_next_time_delta(now)
+                return []
+        if self.needed_orders:
+            self.needed_orders -= len(new_finished_orders)
+            self.needed_orders = max(self.needed_orders, 0)
+            self.number_cur_orders += len(new_finished_orders)
+            return self.create_orders_for_meals(
+                random.choices(self.available_meals, k=len(new_finished_orders)),
+                now,
+            )
+        if self.next_order_time <= now:
+            if self.number_cur_orders >= self.kwargs.max_orders:
+                self.needed_orders += 1
+            else:
+                if self.kwargs.sample_on_dur:
+                    self.create_random_next_time_delta(now)
+                else:
+                    self.next_order_time = datetime.max
+                self.number_cur_orders += 1
                 return self.create_orders_for_meals(
-                    random.choices(self.available_meals, k=len(new_finished_orders)),
+                    [random.choice(self.available_meals)],
                     now,
                 )
-            if self.next_order_time < now:
-                if self.number_cur_orders >= self.kwargs.max_orders:
-                    self.needed_orders += 1
-                else:
-                    self.create_random_next_time_delta(now)
-                    self.number_cur_orders += 1
-                    return self.create_orders_for_meals(
-                        random.choice(self.available_meals),
-                        now,
-                    )
         return []
 
     def create_orders_for_meals(
-        self, meals: list[Item], now: datetime, no_time_limit: bool = False
+        self, meals: list[ItemInfo], now: datetime, no_time_limit: bool = False
     ) -> list[Order]:
         orders = []
         for meal in meals:
@@ -158,13 +156,14 @@ class RandomOrderGeneration(OrderGeneration):
         return orders
 
     def create_random_next_time_delta(self, now: datetime):
-        self.next_order_time = timedelta(
+        self.next_order_time = now + timedelta(
             seconds=int(
-                getattr(random, self.kwargs.duration_sample["func"])(
-                    **self.kwargs.duration_sample["kwargs"]
+                getattr(random, self.kwargs.sample_on_dur_func["func"])(
+                    **self.kwargs.sample_on_dur_func["kwargs"]
                 )
             )
         )
+        log.info(f"Next order in {self.next_order_time}")
 
 
 def simple_score_calc_gen_func(
@@ -182,7 +181,7 @@ def simple_score_calc_gen_func(
 
 
 class OrderAndScoreManager:
-    def __init__(self, order_config, available_meals: dict[str, Item]):
+    def __init__(self, order_config, available_meals: dict[str, ItemInfo]):
         self.score = 0
         self.order_gen: OrderGeneration = order_config["order_gen_class"](
             available_meals=available_meals, kwargs=order_config["kwargs"]
@@ -196,7 +195,17 @@ class OrderAndScoreManager:
         # TODO log who / which player served which meal -> for split scores
         self.served_meals: list[Tuple[Item, datetime]] = []
         self.last_finished = []
-        self.penalty_timers = []
+        self.next_relevant_time = datetime.max
+
+    def update_next_relevant_time(self):
+        next_relevant_time = datetime.max
+        for order in self.open_orders:
+            next_relevant_time = min(
+                next_relevant_time, order.start_time + order.max_duration
+            )
+            for penalty in order._timed_penalties:
+                next_relevant_time = min(next_relevant_time, penalty[0])
+        self.next_relevant_time = next_relevant_time
 
     def serve_meal(self, item: Item, env_time: datetime) -> bool:
         if isinstance(item, Plate):
@@ -250,26 +259,28 @@ class OrderAndScoreManager:
         )
         self.open_orders.extend(new_orders)
         self.last_finished = []
-
-        remove_orders = []
-        for index, order in enumerate(self.open_orders):
-            if now >= order.start_time + order.max_duration:
-                remove_orders.append(index)
-            remove_penalties = []
-            for index, (penalty_time, penalty) in enumerate(order.timed_penalties):
-                if penalty_time < now:
-                    self.score -= penalty
-                    remove_penalties.append(index)
-
-            for index in remove_penalties:
-                # or del order.timed_penalties[index]
-                order.timed_penalties.pop(index)
-
-        for remove_order in remove_orders:
-            del self.open_orders[remove_order]
+        if new_orders or self.next_relevant_time <= now:
+            remove_orders = []
+            for index, order in enumerate(self.open_orders):
+                if now >= order.start_time + order.max_duration:
+                    remove_orders.append(index)
+                remove_penalties = []
+                for i, (penalty_time, penalty) in enumerate(order.timed_penalties):
+                    if penalty_time < now:
+                        self.score -= penalty
+                        remove_penalties.append(i)
+
+                for i in remove_penalties:
+                    # or del order.timed_penalties[index]
+                    order.timed_penalties.pop(i)
+
+            for remove_order in remove_orders:
+                del self.open_orders[remove_order]
+
+            self.update_next_relevant_time()
 
     def find_order_for_meal(self, meal) -> Tuple[Order, int] | None:
-        for neg_index, order in enumerate(reversed(self.open_orders)):
+        for neg_index, order in enumerate(self.open_orders):
             if order.meal.name == meal.name:
                 return order, len(self.open_orders) - neg_index - 1
 
-- 
GitLab