Skip to content
Snippets Groups Projects
Commit 8721e435 authored by Florian Schröder's avatar Florian Schröder
Browse files

Refactor recipe validation and improve graph generation

The 'recipes.py' module is renamed to 'validation.py' and the 'RecipeValidation' class is renamed to 'Validation'. The code is refactored to use f-strings for better readability. Graph generation has also been improved by simplifying the way edges are added and returning information as a dictionary. An unnecessary png file generation has been removed. Changes have been made where these classes are imported and used.
parent a737c646
No related branches found
No related tags found
1 merge request!71Resolve "Refactoring Environment class + file"
Pipeline #47674 passed
...@@ -372,12 +372,12 @@ websockets, ...@@ -372,12 +372,12 @@ websockets,
- the **orders**, how to sample incoming orders and their attributes, - the **orders**, how to sample incoming orders and their attributes,
- the **player**/agent, that interacts in the environment, - the **player**/agent, that interacts in the environment,
- the **pygame 2d visualization**, GUI, drawing, and video generation, - the **pygame 2d visualization**, GUI, drawing, and video generation,
- the **recipe** validation and graph generation,
- the **recording**, via hooks, actions, environment configs, states, etc. can be recorded in files, - the **recording**, via hooks, actions, environment configs, states, etc. can be recorded in files,
- the **scores**, via hooks, events can affect the scores, - the **scores**, via hooks, events can affect the scores,
- type hints are defined in **state representation** for the json state and **server results** for the data returned by - type hints are defined in **state representation** for the json state and **server results** for the data returned by
the game server in post requests. the game server in post requests.
- **util**ity code. - **util**ity code,
- the config **validation** and graph generation.
""" """
......
...@@ -52,12 +52,12 @@ from cooperative_cuisine.orders import ( ...@@ -52,12 +52,12 @@ from cooperative_cuisine.orders import (
OrderConfig, OrderConfig,
) )
from cooperative_cuisine.player import Player, PlayerConfig from cooperative_cuisine.player import Player, PlayerConfig
from cooperative_cuisine.recipes import RecipeValidation
from cooperative_cuisine.state_representation import InfoMsg from cooperative_cuisine.state_representation import InfoMsg
from cooperative_cuisine.utils import ( from cooperative_cuisine.utils import (
create_init_env_time, create_init_env_time,
get_closest, get_closest,
) )
from cooperative_cuisine.validation import Validation
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -216,7 +216,7 @@ class Environment: ...@@ -216,7 +216,7 @@ class Environment:
"""Counters that needs to be called in the step function via the `progress` method.""" """Counters that needs to be called in the step function via the `progress` method."""
self.overwrite_counters(self.counters) self.overwrite_counters(self.counters)
self.recipe_validation = RecipeValidation( self.recipe_validation = Validation(
meals=[m for m in self.item_info.values() if m.type == ItemType.Meal] meals=[m for m in self.item_info.values() if m.type == ItemType.Meal]
if self.environment_config["meals"]["all"] if self.environment_config["meals"]["all"]
else [ else [
...@@ -228,7 +228,7 @@ class Environment: ...@@ -228,7 +228,7 @@ class Environment:
order_manager=self.order_manager, order_manager=self.order_manager,
) )
meals_to_be_ordered = self.recipe_validation.validate_environment() meals_to_be_ordered = self.recipe_validation.validate_environment(self.counters)
assert meals_to_be_ordered, "Need possible meals for order generation." assert meals_to_be_ordered, "Need possible meals for order generation."
available_meals = {meal: self.item_info[meal] for meal in meals_to_be_ordered} available_meals = {meal: self.item_info[meal] for meal in meals_to_be_ordered}
......
...@@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor ...@@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
from typing import TypedDict, Tuple, Iterator from typing import TypedDict, Tuple, Iterator
import networkx as nx import networkx as nx
from networkx import DiGraph, Graph from networkx import DiGraph
from cooperative_cuisine import ROOT_DIR from cooperative_cuisine import ROOT_DIR
from cooperative_cuisine.counters import ( from cooperative_cuisine.counters import (
...@@ -24,7 +24,7 @@ class MealGraphDict(TypedDict): ...@@ -24,7 +24,7 @@ class MealGraphDict(TypedDict):
layout: dict[str, Tuple[float, float]] layout: dict[str, Tuple[float, float]]
class RecipeValidation: class Validation:
def __init__(self, meals, item_info, order_manager): def __init__(self, meals, item_info, order_manager):
self.meals: list[ItemInfo] = meals self.meals: list[ItemInfo] = meals
self.item_info: dict[str, ItemInfo] = item_info self.item_info: dict[str, ItemInfo] = item_info
...@@ -54,13 +54,13 @@ class RecipeValidation: ...@@ -54,13 +54,13 @@ class RecipeValidation:
graph.add_edge(item_name, item_info.equipment.name) graph.add_edge(item_name, item_info.equipment.name)
return graph return graph
def get_meal_graph(self, meal: ItemInfo) -> tuple[Graph, dict[str, list[float]]]: def get_meal_graph(self, meal: ItemInfo) -> MealGraphDict:
graph = DiGraph(directed=True, rankdir="LR") graph = DiGraph(directed=True, rankdir="LR")
root = meal.name + "_0" root = f"{meal.name}_0"
graph.add_node(root) graph.add_node(root)
add_queue = ["Plate_0", root] add_queue = [root] # Add "Plate_0" if dishwashing should be part of the recipe
start = True start = True
while add_queue: while add_queue:
...@@ -76,16 +76,13 @@ class RecipeValidation: ...@@ -76,16 +76,13 @@ class RecipeValidation:
if current_info.needs: if current_info.needs:
if len(current_info.needs) == 1: if len(current_info.needs) == 1:
need = current_info.needs[0] + f"_{current_index}" need = f"{current_info.needs[0]}_{current_index}"
add_queue.append(need) add_queue.append(need)
if current_info.equipment: if current_info.equipment:
equip_id = current_info.equipment.name + f"_{current_index}" equip_id = f"{current_info.equipment.name}_{current_index}"
if current_info.equipment.equipment: if current_info.equipment.equipment:
equip_equip_id = ( equip_equip_id = f"{current_info.equipment.equipment.name}_{current_index}"
current_info.equipment.equipment.name
+ f"_{current_index}"
)
graph.add_edge(equip_equip_id, current) graph.add_edge(equip_equip_id, current)
graph.add_edge(equip_id, equip_equip_id) graph.add_edge(equip_id, equip_equip_id)
graph.add_edge(need, equip_id) graph.add_edge(need, equip_id)
...@@ -97,32 +94,25 @@ class RecipeValidation: ...@@ -97,32 +94,25 @@ class RecipeValidation:
elif len(current_info.needs) > 1: elif len(current_info.needs) > 1:
for idx, item_name in enumerate(current_info.needs): for idx, item_name in enumerate(current_info.needs):
add_queue.append(item_name + f"_{idx}") add_queue.append(f"{item_name}_{idx}")
if current_info.equipment and current_info.equipment.equipment: if current_info.equipment and current_info.equipment.equipment:
equip_id = current_info.equipment.name + f"_{current_index}" equip_id = f"{current_info.equipment.name}_{current_index}"
equip_equip_id = ( equip_equip_id = f"{current_info.equipment.equipment.name}_{current_index}"
current_info.equipment.equipment.name
+ f"_{current_index}"
)
graph.add_edge(equip_equip_id, current) graph.add_edge(equip_equip_id, current)
graph.add_edge(equip_id, equip_equip_id) graph.add_edge(equip_id, equip_equip_id)
graph.add_edge(item_name + f"_{idx}", equip_id) graph.add_edge(f"{item_name}_{idx}", equip_id)
else: else:
graph.add_edge( graph.add_edge(
item_name + f"_{idx}", f"{item_name}_{idx}",
current, current,
) )
agraph = nx.nx_agraph.to_agraph(graph) return {
layout = nx.nx_agraph.graphviz_layout(graph, prog="dot") "meal": meal.name,
agraph.draw( "edges": list(graph.edges),
ROOT_DIR / "generated" / f"recipe_graph_{meal.name}.png", "layout": nx.nx_agraph.graphviz_layout(graph, prog="dot"),
format="png", }
prog="dot",
)
return graph, layout
def reduce_item_node(self, graph, base_ingredients, item, visited): def reduce_item_node(self, graph, base_ingredients, item, visited):
visited.append(item) visited.append(item)
...@@ -274,7 +264,7 @@ class RecipeValidation: ...@@ -274,7 +264,7 @@ class RecipeValidation:
return meals_to_be_ordered return meals_to_be_ordered
# print("FINAL MEALS:", meals_to_be_ordered) # print("FINAL MEALS:", meals_to_be_ordered)
def get_recipe_graphs(self) -> list: def get_recipe_graphs(self) -> list[MealGraphDict]:
os.makedirs(ROOT_DIR / "generated", exist_ok=True) os.makedirs(ROOT_DIR / "generated", exist_ok=True)
# time_start = time.time() # time_start = time.time()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment