diff --git a/cooperative_cuisine/__init__.py b/cooperative_cuisine/__init__.py index 8a467c2fe985f8960dcf782690b497a4066ba3e3..18fc18e11ab51ab9b8d97b4c3e40a3972dc591ea 100644 --- a/cooperative_cuisine/__init__.py +++ b/cooperative_cuisine/__init__.py @@ -372,12 +372,12 @@ websockets, - the **orders**, how to sample incoming orders and their attributes, - the **player**/agent, that interacts in the environment, - the **pygame 2d visualization**, GUI, drawing, and video generation, -- the **recipe** validation and graph generation, - the **recording**, via hooks, actions, environment configs, states, etc. can be recorded in files, - the **scores**, via hooks, events can affect the scores, - type hints are defined in **state representation** for the json state and **server results** for the data returned by the game server in post requests. -- **util**ity code. +- **util**ity code, +- the config **validation** and graph generation. """ diff --git a/cooperative_cuisine/environment.py b/cooperative_cuisine/environment.py index 5410e5e4355a3d29fa14bc4e7b8fca79ace95004..cbfa3de8d2e4ccb2970a4a463e014f6bb2d34bd8 100644 --- a/cooperative_cuisine/environment.py +++ b/cooperative_cuisine/environment.py @@ -52,12 +52,12 @@ from cooperative_cuisine.orders import ( OrderConfig, ) from cooperative_cuisine.player import Player, PlayerConfig -from cooperative_cuisine.recipes import RecipeValidation from cooperative_cuisine.state_representation import InfoMsg from cooperative_cuisine.utils import ( create_init_env_time, get_closest, ) +from cooperative_cuisine.validation import Validation log = logging.getLogger(__name__) @@ -216,7 +216,7 @@ class Environment: """Counters that needs to be called in the step function via the `progress` method.""" self.overwrite_counters(self.counters) - self.recipe_validation = RecipeValidation( + self.recipe_validation = Validation( meals=[m for m in self.item_info.values() if m.type == ItemType.Meal] if self.environment_config["meals"]["all"] else [ @@ -228,7 +228,7 @@ class Environment: order_manager=self.order_manager, ) - meals_to_be_ordered = self.recipe_validation.validate_environment() + meals_to_be_ordered = self.recipe_validation.validate_environment(self.counters) assert meals_to_be_ordered, "Need possible meals for order generation." available_meals = {meal: self.item_info[meal] for meal in meals_to_be_ordered} diff --git a/cooperative_cuisine/recipes.py b/cooperative_cuisine/validation.py similarity index 87% rename from cooperative_cuisine/recipes.py rename to cooperative_cuisine/validation.py index cdf34b259801011beff9492b3dfb58a60f72c65f..f4e8420435b89b03b3df85e3a6bc5a2b36a2d216 100644 --- a/cooperative_cuisine/recipes.py +++ b/cooperative_cuisine/validation.py @@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import TypedDict, Tuple, Iterator import networkx as nx -from networkx import DiGraph, Graph +from networkx import DiGraph from cooperative_cuisine import ROOT_DIR from cooperative_cuisine.counters import ( @@ -24,7 +24,7 @@ class MealGraphDict(TypedDict): layout: dict[str, Tuple[float, float]] -class RecipeValidation: +class Validation: def __init__(self, meals, item_info, order_manager): self.meals: list[ItemInfo] = meals self.item_info: dict[str, ItemInfo] = item_info @@ -54,13 +54,13 @@ class RecipeValidation: graph.add_edge(item_name, item_info.equipment.name) return graph - def get_meal_graph(self, meal: ItemInfo) -> tuple[Graph, dict[str, list[float]]]: + def get_meal_graph(self, meal: ItemInfo) -> MealGraphDict: graph = DiGraph(directed=True, rankdir="LR") - root = meal.name + "_0" + root = f"{meal.name}_0" graph.add_node(root) - add_queue = ["Plate_0", root] + add_queue = [root] # Add "Plate_0" if dishwashing should be part of the recipe start = True while add_queue: @@ -76,16 +76,13 @@ class RecipeValidation: if current_info.needs: if len(current_info.needs) == 1: - need = current_info.needs[0] + f"_{current_index}" + need = f"{current_info.needs[0]}_{current_index}" add_queue.append(need) if current_info.equipment: - equip_id = current_info.equipment.name + f"_{current_index}" + equip_id = f"{current_info.equipment.name}_{current_index}" if current_info.equipment.equipment: - equip_equip_id = ( - current_info.equipment.equipment.name - + f"_{current_index}" - ) + equip_equip_id = f"{current_info.equipment.equipment.name}_{current_index}" graph.add_edge(equip_equip_id, current) graph.add_edge(equip_id, equip_equip_id) graph.add_edge(need, equip_id) @@ -97,32 +94,25 @@ class RecipeValidation: elif len(current_info.needs) > 1: for idx, item_name in enumerate(current_info.needs): - add_queue.append(item_name + f"_{idx}") + add_queue.append(f"{item_name}_{idx}") if current_info.equipment and current_info.equipment.equipment: - equip_id = current_info.equipment.name + f"_{current_index}" - equip_equip_id = ( - current_info.equipment.equipment.name - + f"_{current_index}" - ) + equip_id = f"{current_info.equipment.name}_{current_index}" + equip_equip_id = f"{current_info.equipment.equipment.name}_{current_index}" graph.add_edge(equip_equip_id, current) graph.add_edge(equip_id, equip_equip_id) - graph.add_edge(item_name + f"_{idx}", equip_id) + graph.add_edge(f"{item_name}_{idx}", equip_id) else: graph.add_edge( - item_name + f"_{idx}", + f"{item_name}_{idx}", current, ) - agraph = nx.nx_agraph.to_agraph(graph) - layout = nx.nx_agraph.graphviz_layout(graph, prog="dot") - agraph.draw( - ROOT_DIR / "generated" / f"recipe_graph_{meal.name}.png", - format="png", - prog="dot", - ) - - return graph, layout + return { + "meal": meal.name, + "edges": list(graph.edges), + "layout": nx.nx_agraph.graphviz_layout(graph, prog="dot"), + } def reduce_item_node(self, graph, base_ingredients, item, visited): visited.append(item) @@ -274,7 +264,7 @@ class RecipeValidation: return meals_to_be_ordered # print("FINAL MEALS:", meals_to_be_ordered) - def get_recipe_graphs(self) -> list: + def get_recipe_graphs(self) -> list[MealGraphDict]: os.makedirs(ROOT_DIR / "generated", exist_ok=True) # time_start = time.time()