Skip to content
Snippets Groups Projects
Commit 3606ecfd authored by Florian Schröder's avatar Florian Schröder
Browse files

Refactor simulator environment and add counter factory

This commit simplifies the overcooked environment by refactoring the way counters are created and managed. A new CounterFactory class is introduced, offloading logic from the environment class. In addition, the symbol to character mapping and other environment-related configurations are moved to a separate YAML file. The .gitignore file is also updated to ignore the 'playground' directory. Making these changes enhances code maintainability and readability.
parent eff27363
No related branches found
No related tags found
1 merge request!34Resolve "Counter Factory"
Pipeline #44796 passed
# Created by https://www.toptal.com/developers/gitignore/api/python,intellij,visualstudiocode,pycharm,git,flask,django,docusaurus,ros,ros2,linux,macos,windows # Created by https://www.toptal.com/developers/gitignore/api/python,intellij,visualstudiocode,pycharm,git,flask,django,docusaurus,ros,ros2,linux,macos,windows
# Edit at https://www.toptal.com/developers/gitignore?templates=python,intellij,visualstudiocode,pycharm,git,flask,django,docusaurus,ros,ros2,linux,macos,windows # Edit at https://www.toptal.com/developers/gitignore?templates=python,intellij,visualstudiocode,pycharm,git,flask,django,docusaurus,ros,ros2,linux,macos,windows
playground
### Django ### ### Django ###
*.log *.log
*.pot *.pot
......
import inspect
import sys
from typing import Any
import numpy.typing as npt
from overcooked_simulator.counters import (
Counter,
CookingCounter,
Dispenser,
ServingWindow,
CuttingBoard,
PlateDispenser,
Sink,
PlateConfig,
)
from overcooked_simulator.game_items import ItemInfo, ItemType, CookingEquipment
def convert_words_to_chars(layout_chars_config: dict[str, str]) -> dict[str, str]:
word_refs = {
"underscore": "_",
"hash": "#",
"space": " ",
"dot": ".",
"comma": ",",
"semicolon": ";",
"colon": ":",
"plus": "+",
"minus": "-",
"exclamation": "!",
"question": "?",
"dquote": '"',
"squote": "'",
"star": "*",
"dollar": "$",
"euro": "",
"ampersand": "&",
"slash": "/",
"oparentheses": "(",
"cparentheses": ")",
"equal": "=",
"right": ">",
"left": "<",
"pipe": "|",
"at": "@",
"top": "^",
"tilde": "~",
}
return {word_refs.get(c, c): name for c, name in layout_chars_config.items()}
class CounterFactory:
additional_counter_names = {"Counter"}
def __init__(
self,
layout_chars_config: dict[str, str],
item_info: dict[str, ItemInfo],
serving_window_additional_kwargs: dict[str, Any],
plate_config: PlateConfig,
) -> None:
self.layout_chars_config = convert_words_to_chars(layout_chars_config)
self.item_info = item_info
self.serving_window_additional_kwargs = serving_window_additional_kwargs
self.plate_config = plate_config
self.no_counter_chars = set(
c
for c, name in self.layout_chars_config.items()
if name in ["Agent", "Free"]
)
self.counter_classes = dict(
inspect.getmembers(
sys.modules["overcooked_simulator.counters"], inspect.isclass
)
)
self.cooking_counter_equipments = {
cooking_counter: [
equipment
for equipment, e_info in self.item_info.items()
if e_info.equipment and e_info.equipment.name == cooking_counter
]
for cooking_counter, info in self.item_info.items()
if info.type == ItemType.Equipment and info.equipment is None
}
def get_counter_object(self, c: str, pos: npt.NDArray[float]) -> Counter:
assert self.can_map(c), f"Can't map counter char {c}"
counter_class = None
if self.layout_chars_config[c] in self.item_info:
item_info = self.item_info[self.layout_chars_config[c]]
if item_info.type == ItemType.Equipment and item_info.equipment:
if item_info.equipment.name in self.counter_classes:
counter_class = self.counter_classes[item_info.equipment.name]
else:
return CookingCounter(
name=item_info.equipment.name,
cooking_counter_equipments=self.cooking_counter_equipments,
pos=pos,
occupied_by=CookingEquipment(
name=item_info.name,
item_info=item_info,
transitions=self.filter_item_info(
by_equipment_name=item_info.name
),
),
)
elif item_info.type == ItemType.Ingredient:
return Dispenser(pos=pos, dispensing=item_info)
if counter_class is None:
counter_class = self.counter_classes[self.layout_chars_config[c]]
kwargs = {
"pos": pos,
}
if counter_class.__name__ in [CuttingBoard.__name__, Sink.__name__]:
kwargs["transitions"] = self.filter_item_info(
by_equipment_name=counter_class.__name__
)
elif counter_class.__name__ == PlateDispenser.__name__:
kwargs.update(
{
"plate_transitions": self.filter_item_info(
by_item_type=ItemType.Meal
),
"plate_config": self.plate_config,
"dispensing": self.item_info["Plate"],
}
)
elif counter_class.__name__ == ServingWindow.__name__:
kwargs.update(self.serving_window_additional_kwargs)
return counter_class(**kwargs)
def can_map(self, char) -> bool:
return char in self.layout_chars_config and (
not self.is_counter(char)
or self.layout_chars_config[char] in self.item_info
or self.layout_chars_config[char] in self.counter_classes
)
def is_counter(self, c: str) -> bool:
return c in self.layout_chars_config and c not in self.no_counter_chars
def map_not_counter(self, c: str) -> str:
assert self.can_map(c) and not self.is_counter(
c
), "Cannot map char {c} as a 'not counter'"
return self.layout_chars_config[c]
def filter_item_info(
self,
by_item_type: ItemType = None,
by_equipment_name: str = None,
) -> dict[str, ItemInfo]:
"""Filter the item info dict by item type or equipment name"""
if by_item_type is not None:
return {
name: info
for name, info in self.item_info.items()
if info.type == by_item_type
}
if by_equipment_name is not None:
return {
name: info
for name, info in self.item_info.items()
if info.equipment is not None
and info.equipment.name == by_equipment_name
}
return self.item_info
...@@ -73,6 +73,7 @@ class Counter: ...@@ -73,6 +73,7 @@ class Counter:
pos: npt.NDArray[float], pos: npt.NDArray[float],
occupied_by: Optional[Item] = None, occupied_by: Optional[Item] = None,
uid: hex = None, uid: hex = None,
**kwargs,
): ):
"""Constructor setting the arguments as attributes. """Constructor setting the arguments as attributes.
...@@ -178,13 +179,13 @@ class CuttingBoard(Counter): ...@@ -178,13 +179,13 @@ class CuttingBoard(Counter):
The character `C` in the `layout` file represents the CuttingBoard. The character `C` in the `layout` file represents the CuttingBoard.
""" """
def __init__(self, pos: np.ndarray, transitions: dict[str, ItemInfo]): def __init__(self, pos: np.ndarray, transitions: dict[str, ItemInfo], **kwargs):
self.progressing = False self.progressing = False
self.transitions = transitions self.transitions = transitions
self.inverted_transition_dict = { self.inverted_transition_dict = {
info.needs[0]: info for name, info in self.transitions.items() info.needs[0]: info for name, info in self.transitions.items()
} }
super().__init__(pos=pos) super().__init__(pos=pos, **kwargs)
def progress(self, passed_time: timedelta, now: datetime): def progress(self, passed_time: timedelta, now: datetime):
"""Called by environment step function for time progression. """Called by environment step function for time progression.
...@@ -256,12 +257,13 @@ class ServingWindow(Counter): ...@@ -256,12 +257,13 @@ class ServingWindow(Counter):
meals: set[str], meals: set[str],
env_time_func: Callable[[], datetime], env_time_func: Callable[[], datetime],
plate_dispenser: PlateDispenser = None, plate_dispenser: PlateDispenser = None,
**kwargs,
): ):
self.order_and_score = order_and_score self.order_and_score = order_and_score
self.plate_dispenser = plate_dispenser self.plate_dispenser = plate_dispenser
self.meals = meals self.meals = meals
self.env_time_func = env_time_func self.env_time_func = env_time_func
super().__init__(pos=pos) super().__init__(pos=pos, **kwargs)
def drop_off(self, item) -> Item | None: def drop_off(self, item) -> Item | None:
env_time = self.env_time_func() env_time = self.env_time_func()
...@@ -303,11 +305,12 @@ class Dispenser(Counter): ...@@ -303,11 +305,12 @@ class Dispenser(Counter):
Which also is easier for the visualization of the dispenser. Which also is easier for the visualization of the dispenser.
""" """
def __init__(self, pos: npt.NDArray[float], dispensing: ItemInfo): def __init__(self, pos: npt.NDArray[float], dispensing: ItemInfo, **kwargs):
self.dispensing = dispensing self.dispensing = dispensing
super().__init__( super().__init__(
pos=pos, pos=pos,
occupied_by=self.create_item(), occupied_by=self.create_item(),
**kwargs,
) )
def pick_up(self, on_hands: bool = True) -> Item | None: def pick_up(self, on_hands: bool = True) -> Item | None:
...@@ -548,8 +551,9 @@ class Sink(Counter): ...@@ -548,8 +551,9 @@ class Sink(Counter):
pos: npt.NDArray[float], pos: npt.NDArray[float],
transitions: dict[str, ItemInfo], transitions: dict[str, ItemInfo],
sink_addon: SinkAddon = None, sink_addon: SinkAddon = None,
**kwargs,
): ):
super().__init__(pos=pos) super().__init__(pos=pos, **kwargs)
self.progressing = False self.progressing = False
self.sink_addon: SinkAddon = sink_addon self.sink_addon: SinkAddon = sink_addon
"""The connected sink addon which will receive the clean plates""" """The connected sink addon which will receive the clean plates"""
......
...@@ -16,6 +16,31 @@ meals: ...@@ -16,6 +16,31 @@ meals:
- OnionSoup - OnionSoup
- Salad - Salad
layout_chars:
underscore: Free
hash: Counter
A: Agent
P: PlateDispenser
C: CuttingBoard
X: Trashcan
W: ServingWindow
S: Sink
plus: SinkAddon
U: Pot # with Stove
Q: Pan # with Stove
O: Peel # with Oven
F: Basket # with DeepFryer
T: Tomato
N: Onion # oNioN
L: Lettuce
K: Potato # Kartoffel
I: Fish # fIIIsh
D: Dough
E: Cheese # chEEEse
G: Sausage # sausaGe
B: Bun
M: Meat
orders: orders:
order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration '' order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration ''
# the class to that receives the kwargs. Should be a child class of OrderGeneration in order.py # the class to that receives the kwargs. Should be a child class of OrderGeneration in order.py
......
...@@ -35,7 +35,11 @@ from overcooked_simulator.server_results import ( ...@@ -35,7 +35,11 @@ from overcooked_simulator.server_results import (
PlayerInfo, PlayerInfo,
PlayerRequestResult, PlayerRequestResult,
) )
from overcooked_simulator.utils import setup_logging, url_and_port_arguments from overcooked_simulator.utils import (
setup_logging,
url_and_port_arguments,
disable_websocket_logging_arguments,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -728,6 +732,7 @@ if __name__ == "__main__": ...@@ -728,6 +732,7 @@ if __name__ == "__main__":
) )
url_and_port_arguments(parser) url_and_port_arguments(parser)
disable_websocket_logging_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
setup_logging(args.enable_websocket_logging) setup_logging(args.enable_websocket_logging)
main(args.url, args.port) main(args.url, args.port)
......
...@@ -14,11 +14,10 @@ import numpy.typing as npt ...@@ -14,11 +14,10 @@ import numpy.typing as npt
import yaml import yaml
from scipy.spatial import distance_matrix from scipy.spatial import distance_matrix
from overcooked_simulator.counter_factory import CounterFactory
from overcooked_simulator.counters import ( from overcooked_simulator.counters import (
Counter, Counter,
CuttingBoard, CuttingBoard,
Trashcan,
Dispenser,
ServingWindow, ServingWindow,
CookingCounter, CookingCounter,
Sink, Sink,
...@@ -29,7 +28,6 @@ from overcooked_simulator.counters import ( ...@@ -29,7 +28,6 @@ from overcooked_simulator.counters import (
from overcooked_simulator.game_items import ( from overcooked_simulator.game_items import (
ItemInfo, ItemInfo,
ItemType, ItemType,
CookingEquipment,
) )
from overcooked_simulator.order import OrderAndScoreManager from overcooked_simulator.order import OrderAndScoreManager
from overcooked_simulator.player import Player, PlayerConfig from overcooked_simulator.player import Player, PlayerConfig
...@@ -114,7 +112,7 @@ class Environment: ...@@ -114,7 +112,7 @@ class Environment:
self.layout_config = layout_config self.layout_config = layout_config
# self.counter_side_length = 1 # -> this changed! is 1 now # self.counter_side_length = 1 # -> this changed! is 1 now
self.item_info = self.load_item_info(item_info) self.item_info: dict[str, ItemInfo] = self.load_item_info(item_info)
"""The loaded item info dict. Keys are the item names.""" """The loaded item info dict. Keys are the item names."""
# self.validate_item_info() # self.validate_item_info()
if self.environment_config["meals"]["all"]: if self.environment_config["meals"]["all"]:
...@@ -139,128 +137,35 @@ class Environment: ...@@ -139,128 +137,35 @@ class Environment:
) )
"""The manager for the orders and score update.""" """The manager for the orders and score update."""
cooking_counter_equipments = {
cooking_counter: [
equipment
for equipment, e_info in self.item_info.items()
if e_info.equipment and e_info.equipment.name == cooking_counter
]
for cooking_counter, info in self.item_info.items()
if info.type == ItemType.Equipment and info.equipment is None
}
self.SYMBOL_TO_CHARACTER_MAP = {
"#": Counter,
"C": lambda pos: CuttingBoard(
pos=pos,
transitions=self.filter_item_info(
self.item_info, by_equipment_name="CuttingBoard"
),
),
"X": Trashcan,
"W": lambda pos: ServingWindow(
pos,
self.order_and_score,
meals=self.allowed_meal_names,
env_time_func=self.get_env_time,
),
"T": lambda pos: Dispenser(pos, self.item_info["Tomato"]),
"L": lambda pos: Dispenser(pos, self.item_info["Lettuce"]),
"K": lambda pos: Dispenser(pos, self.item_info["Potato"]), # Kartoffel
"I": lambda pos: Dispenser(pos, self.item_info["Fish"]), # fIIIsh
"D": lambda pos: Dispenser(pos, self.item_info["Dough"]),
"E": lambda pos: Dispenser(pos, self.item_info["Cheese"]), # chEEEEse
"G": lambda pos: Dispenser(pos, self.item_info["Sausage"]), # sausaGe
"P": lambda pos: PlateDispenser(
plate_transitions=self.filter_item_info(
item_info=self.item_info, by_item_type=ItemType.Meal
),
pos=pos,
dispensing=self.item_info["Plate"],
plate_config=PlateConfig(
**(
self.environment_config["plates"]
if "plates" in self.environment_config
else {}
)
),
),
"N": lambda pos: Dispenser(pos, self.item_info["Onion"]), # N for oNioN
"_": "Free",
"A": "Agent",
"U": lambda pos: CookingCounter(
name="Stove",
cooking_counter_equipments=cooking_counter_equipments,
pos=pos,
occupied_by=CookingEquipment(
name="Pot",
item_info=self.item_info["Pot"],
transitions=self.filter_item_info(
self.item_info, by_equipment_name="Pot"
),
),
), # Stove with pot: U because it looks like a pot
"Q": lambda pos: CookingCounter(
name="Stove",
cooking_counter_equipments=cooking_counter_equipments,
pos=pos,
occupied_by=CookingEquipment(
name="Pan",
item_info=self.item_info["Pan"],
transitions=self.filter_item_info(
self.item_info, by_equipment_name="Pan"
),
),
), # Stove with pan: Q because it looks like a pan
"O": lambda pos: CookingCounter(
name="Oven",
cooking_counter_equipments=cooking_counter_equipments,
pos=pos,
occupied_by=CookingEquipment(
name="Peel",
item_info=self.item_info["Peel"],
transitions=self.filter_item_info(
self.item_info, by_equipment_name="Peel"
),
),
),
"F": lambda pos: CookingCounter(
name="DeepFryer",
cooking_counter_equipments=cooking_counter_equipments,
pos=pos,
occupied_by=CookingEquipment(
name="Basket",
item_info=self.item_info["Basket"],
transitions=self.filter_item_info(
self.item_info, by_equipment_name="Basket"
),
),
), # Stove with pan: Q because it looks like a pan
"B": lambda pos: Dispenser(pos, self.item_info["Bun"]),
"M": lambda pos: Dispenser(pos, self.item_info["Meat"]),
"S": lambda pos: Sink(
pos,
transitions=self.filter_item_info(
item_info=self.item_info, by_equipment_name="Sink"
),
),
"+": SinkAddon,
}
"""Map of the characters in the layout file to callables returning the object/counter. In the future,
maybe replaced with a factory and the characters defined elsewhere in an config."""
self.kitchen_height: int = 0 self.kitchen_height: int = 0
"""The height of the kitchen, is set by the `Environment.parse_layout_file` method""" """The height of the kitchen, is set by the `Environment.parse_layout_file` method"""
self.kitchen_width: int = 0 self.kitchen_width: int = 0
"""The width of the kitchen, is set by the `Environment.parse_layout_file` method""" """The width of the kitchen, is set by the `Environment.parse_layout_file` method"""
self.counter_factory = CounterFactory(
layout_chars_config=self.environment_config["layout_chars"],
item_info=self.item_info,
serving_window_additional_kwargs={
"order_and_score": self.order_and_score,
"meals": self.allowed_meal_names,
"env_time_func": self.get_env_time,
},
plate_config=PlateConfig(
**(
self.environment_config["plates"]
if "plates" in self.environment_config
else {}
)
),
)
( (
self.counters, self.counters,
self.designated_player_positions, self.designated_player_positions,
self.free_positions, self.free_positions,
) = self.parse_layout_file() ) = self.parse_layout_file()
self.init_counters() self.post_counter_setup()
self.env_time: datetime = create_init_env_time() self.env_time: datetime = create_init_env_time()
"""the internal time of the environment. An environment starts always with the time from """the internal time of the environment. An environment starts always with the time from
...@@ -379,6 +284,7 @@ class Environment: ...@@ -379,6 +284,7 @@ class Environment:
else: else:
lines = self.layout_config.split("\n") lines = self.layout_config.split("\n")
self.kitchen_height = len(lines) self.kitchen_height = len(lines)
print(self.kitchen_height)
for line in lines: for line in lines:
line = line.replace("\n", "").replace(" ", "") # remove newline char line = line.replace("\n", "").replace(" ", "") # remove newline char
...@@ -386,17 +292,20 @@ class Environment: ...@@ -386,17 +292,20 @@ class Environment:
for character in line: for character in line:
character = character.capitalize() character = character.capitalize()
pos = np.array([current_x, current_y]) pos = np.array([current_x, current_y])
counter_class = self.SYMBOL_TO_CHARACTER_MAP[character] assert self.counter_factory.can_map(
if not isinstance(counter_class, str): character
counter = counter_class(pos) ), f"{character=} in layout file can not be mapped"
counters.append(counter) if self.counter_factory.is_counter(character):
counters.append(
self.counter_factory.get_counter_object(character, pos)
)
else: else:
if counter_class == "Agent": match self.counter_factory.map_not_counter(character):
designated_player_positions.append( case "Agent":
np.array([current_x, current_y]) designated_player_positions.append(pos)
) case "Free":
elif counter_class == "Free": free_positions.append(np.array([current_x, current_y]))
free_positions.append(np.array([current_x, current_y]))
current_x += 1 current_x += 1
if current_x > self.kitchen_width: if current_x > self.kitchen_width:
self.kitchen_width = current_x self.kitchen_width = current_x
...@@ -715,14 +624,14 @@ class Environment: ...@@ -715,14 +624,14 @@ class Environment:
assert StateRepresentation.model_validate_json(json_data=json_data) assert StateRepresentation.model_validate_json(json_data=json_data)
return json_data return json_data
def init_counters(self): def post_counter_setup(self):
"""Initialize the counters in the environment. """Initialize the counters in the environment.
Connect the `ServingWindow`(s) with the `PlateDispenser`. Connect the `ServingWindow`(s) with the `PlateDispenser`.
Find and connect the `SinkAddon`s with the `Sink`s Find and connect the `SinkAddon`s with the `Sink`s
""" """
plate_dispenser = self.get_counter_of_type(PlateDispenser) plate_dispenser = self.get_counter_of_type(PlateDispenser)
assert len(plate_dispenser) > 0, "No Plate Return in the environment" assert len(plate_dispenser) > 0, "No Plate Dispenser in the environment"
sink_addons = self.get_counter_of_type(SinkAddon) sink_addons = self.get_counter_of_type(SinkAddon)
...@@ -760,25 +669,3 @@ class Environment: ...@@ -760,25 +669,3 @@ class Environment:
"""Reset the env time to the initial time, defined by `create_init_env_time`.""" """Reset the env time to the initial time, defined by `create_init_env_time`."""
self.env_time = create_init_env_time() self.env_time = create_init_env_time()
log.debug(f"Reset env time to {self.env_time}") log.debug(f"Reset env time to {self.env_time}")
@staticmethod
def filter_item_info(
item_info: dict[str, ItemInfo],
by_item_type: ItemType = None,
by_equipment_name: str = None,
) -> dict[str, ItemInfo]:
"""Filter the item info dict by item type or equipment name"""
if by_item_type is not None:
return {
name: info
for name, info in item_info.items()
if info.type == by_item_type
}
if by_equipment_name is not None:
return {
name: info
for name, info in item_info.items()
if info.equipment is not None
and info.equipment.name == by_equipment_name
}
return item_info
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment