Skip to content
Snippets Groups Projects
Commit 0a75551a authored by Florian Schröder's avatar Florian Schröder
Browse files

Refactor progress counter handling in overcooked simulator

The commit refines the way the 'progress' method is called on counters in the overcooked simulator. Utilizing Python's introspection functions, the update identifies the relevant counter classes dynamically on setup, and stores them. Simplified calls to these relevant counters are then performed during steps in the simulation. Additionally, the code now checks the superclass with the built-in 'issubclass' function to create counters.
parent 4cf989f6
No related branches found
No related tags found
1 merge request!34Resolve "Counter Factory"
Pipeline #44799 failed
......@@ -151,11 +151,11 @@ class CounterFactory:
kwargs = {
"pos": pos,
}
if counter_class.__name__ in [CuttingBoard.__name__, Sink.__name__]:
if issubclass(counter_class, (CuttingBoard, Sink)):
kwargs["transitions"] = self.filter_item_info(
by_equipment_name=counter_class.__name__
)
elif counter_class.__name__ == PlateDispenser.__name__:
elif issubclass(counter_class, PlateDispenser):
kwargs.update(
{
"plate_transitions": self.filter_item_info(
......@@ -165,7 +165,7 @@ class CounterFactory:
"dispensing": self.item_info["Plate"],
}
)
elif counter_class.__name__ == ServingWindow.__name__:
elif issubclass(counter_class, ServingWindow):
kwargs.update(self.serving_window_additional_kwargs)
return counter_class(**kwargs)
......
from __future__ import annotations
import dataclasses
import inspect
import json
import logging
import random
import sys
from datetime import timedelta, datetime
from enum import Enum
from pathlib import Path
......@@ -17,9 +19,7 @@ from scipy.spatial import distance_matrix
from overcooked_simulator.counter_factory import CounterFactory
from overcooked_simulator.counters import (
Counter,
CuttingBoard,
ServingWindow,
CookingCounter,
Sink,
PlateDispenser,
SinkAddon,
......@@ -167,6 +167,24 @@ class Environment:
self.free_positions,
) = self.parse_layout_file()
progress_counter_classes = list(
filter(
lambda cl: hasattr(cl, "progress"),
dict(
inspect.getmembers(
sys.modules["overcooked_simulator.counters"], inspect.isclass
)
).values(),
)
)
self.progressing_counters = list(
filter(
lambda c: c.__class__ in progress_counter_classes,
self.counters,
)
)
"""Counters that needs to be called in the step function via the `progress` method."""
self.post_counter_setup()
self.env_time: datetime = create_init_env_time()
......@@ -583,11 +601,8 @@ class Environment:
if self.env_time <= player.movement_until:
self.perform_movement(player, passed_time)
for counter in self.counters:
if isinstance(
counter, (CuttingBoard, CookingCounter, Sink, PlateDispenser)
):
counter.progress(passed_time=passed_time, now=self.env_time)
for counter in self.progressing_counters:
counter.progress(passed_time=passed_time, now=self.env_time)
self.order_and_score.progress(passed_time=passed_time, now=self.env_time)
def get_state(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment