diff --git a/cooperative_cuisine/configs/study/level1/level1_config.yaml b/cooperative_cuisine/configs/study/level1/level1_config.yaml index 7dad55acb1a62c84628b0b4b29411c4483a5fc95..d5ff23731253b196263ea97da05146f4f4f1ce74 100644 --- a/cooperative_cuisine/configs/study/level1/level1_config.yaml +++ b/cooperative_cuisine/configs/study/level1/level1_config.yaml @@ -3,6 +3,7 @@ plates: dirty_plates: 0 plate_delay: [ 5, 10 ] # range of seconds until the dirty plate arrives. + return_dirty: False game: time_limit_seconds: 300 diff --git a/cooperative_cuisine/environment.py b/cooperative_cuisine/environment.py index 8b289e15d3ad9723d593deeab3be66f8173cb057..e05e74fa77e66b21b420034f0d1147ba810177d9 100644 --- a/cooperative_cuisine/environment.py +++ b/cooperative_cuisine/environment.py @@ -6,32 +6,36 @@ import json import logging import os import sys +import warnings from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from datetime import timedelta, datetime from enum import Enum from pathlib import Path from random import Random -from typing import Literal, TypedDict, Callable, Tuple +from typing import Literal, TypedDict, Callable, Tuple, Iterator import networkx +import networkx.drawing.layout import numpy as np import numpy.typing as npt import yaml from networkx import DiGraph +from networkx import Graph from scipy.spatial import distance_matrix from cooperative_cuisine import ROOT_DIR from cooperative_cuisine.counter_factory import CounterFactory from cooperative_cuisine.counters import ( Counter, + Dispenser, PlateConfig, + PlateDispenser, + CuttingBoard, + CookingCounter, ) from cooperative_cuisine.effect_manager import EffectManager -from cooperative_cuisine.game_items import ( - ItemInfo, - ItemType, -) +from cooperative_cuisine.game_items import ItemInfo, ItemType, Item from cooperative_cuisine.hooks import ( ITEM_INFO_LOADED, LAYOUT_FILE_PARSED, @@ -192,7 +196,6 @@ class Environment: """The loaded item info dict. Keys are the item names.""" self.hook(ITEM_INFO_LOADED, item_info=item_info, as_files=as_files) - # self.validate_item_info() if self.environment_config["meals"]["all"]: self.allowed_meal_names = set( [ @@ -208,11 +211,6 @@ class Environment: self.order_manager = OrderManager( order_config=self.environment_config["orders"], - available_meals={ - item: info - for item, info in self.item_info.items() - if info.type == ItemType.Meal and item in self.allowed_meal_names - }, hook=self.hook, random=self.random, ) @@ -289,6 +287,12 @@ class Environment: self.counter_positions = np.array([c.pos for c in self.counters]) + # TODO Maybe validation can be turned off in config... + meals_to_be_ordered = self.validate_environment() + assert meals_to_be_ordered, "Need possible meals for order generation." + + available_meals = {meal: self.item_info[meal] for meal in meals_to_be_ordered} + self.order_manager.set_available_meals(available_meals) self.order_manager.create_init_orders(self.env_time) self.start_time = self.env_time """The relative env time when it started.""" @@ -370,6 +374,251 @@ class Environment: item_info.equipment = item_lookup[item_info.equipment] return item_lookup + @staticmethod + def infer_recipe_graph(item_info) -> DiGraph: + colors = { + ItemType.Ingredient: "black", + ItemType.Equipment: "red", + ItemType.Meal: "green", + ItemType.Waste: "brown", + } + + graph = DiGraph(directed=True) + for item_name, item_info in item_info.items(): + graph.add_node(item_name, color=colors.get(item_info.type, "blue")) + if item_info.equipment is None: + for item in item_info.needs: + graph.add_edge(item, item_name) + else: + if len(item_info.needs) > 0: + for item in item_info.needs: + graph.add_edge(item, item_info.equipment.name) + graph.add_edge(item_info.equipment.name, item_name) + else: + graph.add_edge(item_name, item_info.equipment.name) + return graph + + def get_meal_graph(self, meal: ItemInfo) -> tuple[Graph, dict[str, list[float]]]: + graph = DiGraph(directed=True, rankdir="LR") + + root = meal.name + "_0" + + graph.add_node(root) + add_queue = ["Plate_0", root] + + start = True + while add_queue: + current = add_queue.pop() + + current_info = self.item_info[current.split("_")[0]] + current_index = current.split("_")[-1] + + if start: + graph.add_edge("Plate_0", current) + current = "Plate_0" + start = False + + if current_info.needs: + if len(current_info.needs) == 1: + need = current_info.needs[0] + f"_{current_index}" + add_queue.append(need) + + if current_info.equipment: + equip_id = current_info.equipment.name + f"_{current_index}" + if current_info.equipment.equipment: + equip_equip_id = ( + current_info.equipment.equipment.name + + f"_{current_index}" + ) + graph.add_edge(equip_equip_id, current) + graph.add_edge(equip_id, equip_equip_id) + graph.add_edge(need, equip_id) + else: + graph.add_edge(equip_id, current) + graph.add_edge(need, equip_id) + else: + graph.add_edge(need, current) + + elif len(current_info.needs) > 1: + for idx, item_name in enumerate(current_info.needs): + add_queue.append(item_name + f"_{idx}") + + if current_info.equipment and current_info.equipment.equipment: + equip_id = current_info.equipment.name + f"_{current_index}" + equip_equip_id = ( + current_info.equipment.equipment.name + + f"_{current_index}" + ) + graph.add_edge(equip_equip_id, current) + graph.add_edge(equip_id, equip_equip_id) + graph.add_edge(item_name + f"_{idx}", equip_id) + else: + graph.add_edge( + item_name + f"_{idx}", + current, + ) + + agraph = networkx.nx_agraph.to_agraph(graph) + layout = networkx.nx_agraph.graphviz_layout(graph, prog="dot") + agraph.draw( + ROOT_DIR / "generated" / f"recipe_graph_{meal.name}.png", + format="png", + prog="dot", + ) + + return graph, layout + + def reduce_item_node(self, graph, base_ingredients, item, visited): + visited.append(item) + if item in base_ingredients: + return True + else: + return all( + self.reduce_item_node(graph, base_ingredients, pred, visited) + for pred in graph.predecessors(item) + if pred not in visited + ) + + def assert_equipment_is_present(self): + expected = set( + name + for name, info in self.item_info.items() + if info.type == ItemType.Equipment and "Plate" not in info.name + ) + counters = set(c.__class__.__name__ for c in self.counters).union( + set(c.name for c in self.counters if hasattr(c, "name")) + ) + items = set( + c.occupied_by.name + for c in self.counters + if c.occupied_by is not None and isinstance(c.occupied_by, Item) + ) + for equipment in expected: + if equipment not in counters and equipment not in items: + raise ValueError( + f"Equipment '{equipment}' from config files not found in the environment layout.\n" + f"Config Equipment: {sorted(expected)}\n" + f"Layout Counters: {sorted(counters)}\n" + f"Layout Items: {sorted(items)}" + ) + + def assert_plate_cycle_present(self): + for plate in ["Plate", "DirtyPlate"]: + if plate not in self.item_info: + raise ValueError(f"{plate} not found in item info") + + relevant_counters = ["PlateDispenser", "ServingWindow"] + for counter in self.counters: + if isinstance(counter, PlateDispenser): + if counter.plate_config.return_dirty: + relevant_counters = [ + "PlateDispenser", + "ServingWindow", + "Sink", + "SinkAddon", + ] + + counter_names = [c.__class__.__name__ for c in self.counters] + for counter in relevant_counters: + if counter not in counter_names: + raise ValueError(f"{counter} not found in counters") + + @staticmethod + def assert_no_orphans(graph: DiGraph): + orphans = [ + n + for n in graph.nodes() + if graph.in_degree(n) == 0 and graph.out_degree(n) == 0 + ] + if orphans: + raise ValueError( + f"Expected all items to be part of a recipe, but found orphans: {orphans}" + ) + + @staticmethod + def assert_roots_are_dispensable(graph, base_ingredients): + root_nodes = [ + n for n in graph.nodes() if graph.in_degree(n) == 0 and "Plate" not in n + ] + if set(root_nodes) != set(base_ingredients): + raise ValueError( + f"Expected root nodes in the recipe graph and dispensable items to be identical, but found\n " + f"Root nodes: {sorted(root_nodes)}\n" + f"Dispensable items: {sorted(base_ingredients)}" + ) + + def assert_meals_are_reducible(self, graph, base_ingredients): + meals = [n for n in graph.nodes() if self.item_info[n].type == ItemType.Meal] + + for meal in meals: + visited = [] + if not self.reduce_item_node(graph, base_ingredients, meal, visited): + raise ValueError( + f"Meal '{meal}' can not be reduced to base ingredients" + ) + + def get_requirements(self, item_name: str) -> Iterator[str]: + """ + Get all base ingredients and equipment required to create the given meal. + """ + item = self.item_info[item_name] + is_equipment = item.type == ItemType.Equipment + is_base_ingredient = item.type == ItemType.Ingredient and not item.needs + + if is_equipment or is_base_ingredient: + yield item_name + for need in item.needs: + yield from self.get_requirements(need) + if item.equipment is not None: + yield from self.get_requirements(item.equipment.name) + + def get_item_info_requirements(self) -> dict[str, set[str]]: + recipes = {} + for item_name, item_info in self.item_info.items(): + if item_info.type == ItemType.Meal: + requirements = set(r for r in self.get_requirements(item_name)) + recipes[item_name] = requirements | {"Plate"} + return recipes + + def get_layout_requirements(self): + layout_requirements = set() + for counter in self.counters: + if isinstance(counter, (Dispenser, PlateDispenser)): + layout_requirements.add(counter.dispensing.name) + if isinstance(counter, CuttingBoard): + layout_requirements.add("CuttingBoard") + if isinstance(counter, CookingCounter): + layout_requirements.add(counter.name) + if counter.occupied_by is not None and hasattr(counter.occupied_by, "name"): + layout_requirements.add(counter.occupied_by.name) + return layout_requirements + + def validate_environment(self): + graph = self.infer_recipe_graph(self.item_info) + os.makedirs(ROOT_DIR / "generated", exist_ok=True) + networkx.nx_agraph.to_agraph(graph).draw( + ROOT_DIR / "generated" / "recipe_graph.png", format="png", prog="dot" + ) + + expected = self.get_item_info_requirements() + present = self.get_layout_requirements() + possible_meals = set(meal for meal in expected if expected[meal] <= present) + defined_meals = set( + possible_meals + if self.environment_config["meals"]["all"] + else self.environment_config["meals"]["list"] + ) + + # print(f"Ordered meals: {defined_meals}, Possible meals: {possible_meals}") + if len(defined_meals - possible_meals) > 0: + warnings.warn( + f"Ordered meals are not possible: {defined_meals - possible_meals}" + ) + + meals_to_be_ordered = possible_meals.intersection(defined_meals) + return meals_to_be_ordered + # print("FINAL MEALS:", meals_to_be_ordered) + def get_meal_graph(self, meal: ItemInfo) -> dict: graph = DiGraph( directed=True, rankdir="LR", graph_attr={"nslimit": "0", "nslimit1": "2"} @@ -924,20 +1173,15 @@ class Environment: def get_recipe_graphs(self) -> list: os.makedirs(ROOT_DIR / "generated", exist_ok=True) - if self.environment_config["meals"]["all"]: - meals = [m for m in self.item_info.values() if m.type == ItemType.Meal] - else: - meals = [ - self.item_info[m] - for m in self.environment_config["meals"]["list"] - if self.item_info[m].type == ItemType.Meal - ] - - # print(list(m.name for m in meals)) # time_start = time.time() - # graph_dicts = list(map(self.get_meal_graph, meals)) - with ThreadPoolExecutor(max_workers=len(meals)) as executor: - graph_dicts = list(executor.map(self.get_meal_graph, meals)) + with ThreadPoolExecutor( + max_workers=len(self.order_manager.available_meals) + ) as executor: + graph_dicts = list( + executor.map( + self.get_meal_graph, self.order_manager.available_meals.values() + ) + ) # print("DURATION", time.time() - time_start) return graph_dicts diff --git a/cooperative_cuisine/game_items.py b/cooperative_cuisine/game_items.py index 6ed9f4d25daa7e0eb3b7e844df4c340e29f4e997..c728897de471bce2cb395743e03e1668fe537a5a 100644 --- a/cooperative_cuisine/game_items.py +++ b/cooperative_cuisine/game_items.py @@ -115,6 +115,8 @@ class ItemInfo: """Internally set in CookingEquipment""" def __post_init__(self): + if self.seconds < 0.0: + raise ValueError(f"Expected seconds >= 0 for item '{self.name}', but got {self.seconds} in item info") self.type = ItemType(self.type) if self.effect_type: self.effect_type = EffectType(self.effect_type) diff --git a/cooperative_cuisine/orders.py b/cooperative_cuisine/orders.py index e1e082e15536bcb3996b2469f1e3e82c222e0b09..a5b48ec743e6b70980ea02f706b0bb5946c8d858 100644 --- a/cooperative_cuisine/orders.py +++ b/cooperative_cuisine/orders.py @@ -105,12 +105,11 @@ class OrderGeneration: def __init__( self, - available_meals: dict[str, ItemInfo], hook: Hooks, random: Random, **kwargs, ): - self.available_meals: list[ItemInfo] = list(available_meals.values()) + self.available_meals: list[ItemInfo] | None = None """Available meals restricted through the `environment_config.yml`.""" self.hook = hook """Reference to the hook manager.""" @@ -140,14 +139,12 @@ class OrderManager: def __init__( self, order_config, - available_meals: dict[str, ItemInfo], hook: Hooks, random: Random, ): self.random = random """Random instance.""" self.order_gen: OrderGeneration = order_config["order_gen_class"]( - available_meals=available_meals, hook=hook, random=random, kwargs=order_config["order_gen_kwargs"], @@ -158,7 +155,7 @@ class OrderManager: ] = order_config["serving_not_ordered_meals"] """Function that decides if not ordered meals can be served and what score it gives""" - self.available_meals = available_meals + self.available_meals = None """The meals for that orders can be sampled from.""" self.open_orders: Deque[Order] = deque() """Current open orders. This attribute is used for the environment state.""" @@ -176,6 +173,10 @@ class OrderManager: self.hook = hook """Reference to the hook manager.""" + def set_available_meals(self, available_meals): + self.available_meals = available_meals + self.order_gen.available_meals = list(available_meals.values()) + def update_next_relevant_time(self): """For more efficient checking when to do something in the progress call.""" next_relevant_time = datetime.max @@ -363,12 +364,11 @@ class RandomOrderGeneration(OrderGeneration): def __init__( self, - available_meals: dict[str, ItemInfo], hook: Hooks, random: Random, **kwargs, ): - super().__init__(available_meals, hook, random, **kwargs) + super().__init__(hook, random, **kwargs) self.kwargs: RandomOrderKwarg = RandomOrderKwarg(**kwargs["kwargs"]) self.next_order_time: datetime | None = datetime.max self.number_cur_orders: int = 0 diff --git a/cooperative_cuisine/pygame_2d_vis/drawing.py b/cooperative_cuisine/pygame_2d_vis/drawing.py index 164a20d86c73af7433d68c958bc0b55eb2da71b4..0031ca7204136918ac36c94da8bd1bb334cefc5d 100644 --- a/cooperative_cuisine/pygame_2d_vis/drawing.py +++ b/cooperative_cuisine/pygame_2d_vis/drawing.py @@ -637,7 +637,9 @@ class Visualizer: pos, grid_size, self.config["Counter"]["parts"], - orientation=counter_dict["orientation"], + orientation=counter_dict["orientation"] + if "orientation" in counter_dict + else None, ) if counter_type in self.config: self.draw_thing( @@ -878,7 +880,9 @@ class Visualizer: self.draw_gamescreen(screen, state, grid_size, [0 for _ in state["players"]]) pygame.image.save(screen, filename) - def get_state_image(self, grid_size: int, state: dict) -> npt.NDArray[np.uint8]: + def get_state_image( + self, grid_size: int, save_folder: dict + ) -> npt.NDArray[np.uint8]: width = int(np.ceil(state["kitchen"]["width"] * grid_size)) height = int(np.ceil(state["kitchen"]["height"] * grid_size)) diff --git a/setup.py b/setup.py index db8f170a0466a1aabca8db8554b369640b1dd150..bafd4edde0be2f57fef2f6ab094a643f2731b421 100644 --- a/setup.py +++ b/setup.py @@ -23,11 +23,12 @@ requirements = [ "websockets>=12.0", "requests>=2.31.0", "platformdirs>=4.1.0", - "tqdm>=4.65.0", - "networkx", "matplotlib>=3.8.0", "pygraphviz>=1.9", "pydot>=2.0.0", + "networkx>=3.2.1", + "tqdm>=4.65.0", + "networkx", ] test_requirements = ["pytest>=3", "pytest-cov>=4.1"] diff --git a/tests/test_start.py b/tests/test_start.py index e69b1ac95c1590adff21abcaa9291f7352ddb8d6..fd7e3addee6018e26691c138746a6aa6813f9806 100644 --- a/tests/test_start.py +++ b/tests/test_start.py @@ -253,7 +253,7 @@ def test_time_passed(): np.random.seed(42) env = Environment( ROOT_DIR / "configs" / "environment_config.yaml", - layouts_folder / "empty.layout", + layouts_folder / "basic.layout", ROOT_DIR / "configs" / "item_info.yaml", ) env.add_player("0") @@ -275,7 +275,7 @@ def test_time_limit(): np.random.seed(42) env = Environment( ROOT_DIR / "configs" / "environment_config.yaml", - layouts_folder / "empty.layout", + layouts_folder / "basic.layout", ROOT_DIR / "configs" / "item_info.yaml", ) env.add_player("0")