diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 6bfc4d740e72b58f291b0b44b6b958e9fe22822e..d8e1afad0b9c7a5006b9e6b12a2c0402eade3121 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,11 +1,11 @@ pytest: stage: test script: - - apt-get update -qy - - apt-get install -y python3-dev python3-pip - - pip install pytest - - pip install . - - pytest --junitxml=report.xml + - apt-get update -qy + - apt-get install -y python3-dev python3-pip + - pip install pytest + - pip install . + - pytest --junitxml=report.xml artifacts: when: always reports: @@ -13,14 +13,14 @@ pytest: pages: script: - - apt-get update -qy - - apt-get install -y python3-dev python3-pip - - pip install pdoc - - pip install . - - pdoc --output-dir public overcooked_simulator --logo https://gitlab.ub.uni-bielefeld.de/uploads/-/system/project/avatar/6780/Cooking-Vector-Illustration-Icon-Graphics-4267218-1-580x435.jpg + - apt-get update -qy + - apt-get install -y python3-dev python3-pip + - pip install pdoc + - pip install . + - pdoc --output-dir public overcooked_simulator --logo https://gitlab.ub.uni-bielefeld.de/uploads/-/system/project/avatar/6780/Cooking-Vector-Illustration-Icon-Graphics-4267218-1-580x435.jpg --docformat google artifacts: paths: - - public + - public rules: - - if: $CI_COMMIT_BRANCH == "main" + - if: $CI_COMMIT_BRANCH == "main" diff --git a/README.md b/README.md index e266fce58b7a96956c513e436c9cd7a6eaa62986..7222709f10f94cc0b0ab6343079ab35dfd8f970e 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,69 @@ # Overcooked Simulator -[API Docs](https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator) +[Documentation](https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator) -The real-time overcooked simulation for a cognitive cooperative system +The real-time overcooked simulation for a cognitive cooperative system. + +**The name ist still work in progress and we will probably change it.** + +## Installation + +You have two options to install the environment. Either clone it and install it locally or install it in your +site-packages. +You need a Python 3.10 or higher environment. Either conda or PyEnv. + +### Local Editable Installation + +In your `repo`, `PyCharmProjects` or similar directory with the correct environment active: + +```bash +git clone https://gitlab.ub.uni-bielefeld.de/scs/cocosy/overcooked-simulator.git +cd overcooked_simulator +pip install -e . +`` + +#### Run +You can use it in your Python code or run the `main.py`from the command line: +```bash +python3 overcooked_simulator/main.py +``` + +### Library Installation + +The correct environment needs to be active: + +```bash +pip install overcooked-environment@git+https://gitlab.ub.uni-bielefeld.de/scs/cocosy/overcooked-simulator@main +``` + +#### Run + +You can now use the environment and/or simulator in your python code. Just by importing +it `import overcooked_environment` + +## Configuration + +The environment configuration is currently done with 3 config files + GUI configuration. + +### Item Config + +The item config defines which ingredients, cooking equipment and meals can exist and how meals and processed ingredients +can be cooked/created. + +### Layout Config + +You can define the layout of the kitchen via a layout file. The position of counters are based on a grid system, even +when the players do not move grid steps but continuous steps. Each character defines a different type of counter. + +### Environment Config + +The environment config defines how a level/environment is defined. Here, the available plates, meals, order and player +configuration is done. + +### PyGame Visualization Config + +Here the visualisation for all objects is defined. Reference the images or define a list of base shapes that represent +the counters, ingredients, meals and players. ## Troubleshooting @@ -11,4 +72,4 @@ if you have a conda environment: ```bash conda install -c conda-forge libstdcxx-ng -``` \ No newline at end of file +``` diff --git a/overcooked_simulator/__init__.py b/overcooked_simulator/__init__.py index 7d983330e5810899308a2e0b5b06990a754af5f6..2c7bb07c452710b36d9c0a5aa12cd56325842530 100644 --- a/overcooked_simulator/__init__.py +++ b/overcooked_simulator/__init__.py @@ -1,19 +1,67 @@ """ -This is the documentation of Overcooked Simulator. -It contains of +This is the documentation of the Overcooked Simulator. # About the package +The package contains of an environment for cooperation between players/agents. A PyGameGUI visualizes the game to +human or visual agents in 2D. A 3D web-enabled version (for example for online studies, currently under development) +can be found [here](https://gitlab.ub.uni-bielefeld.de/scs/cocosy/godot-overcooked-3d-visualization) + # Background / Literature +The overcooked/cooking domain is a well established cooperation domain/task. There exists +environments designed for reinforcement learning agents as well as the game and adaptations of the game for human +players in a more "real-time" environment. They all mostly differ in the visual and graphics dimension. 2D versions +like overcooked-ai, ... are most known in the community. But more visual appealing 3D versions for cooperation with +humans are getting developed more frequently (cite,...). Besides, the general adaptations of the original overcooked +game. +With this overcooked-simulator, we want to bring both worlds together: the reinforcement learning and real-time playable +environment with an appealing visualisation. Enable the potential of developing artificial agents that play with humans +like a "real" cooperative / human partner. # Usage / Examples +Our overcooked simulator is designed for real time interaction but also with reinforcement learning in mind (gymnasium environment). +It focuses on configurability, extensibility and appealing visualization options. + +## Human Player +Start `main.py` in your python/conda environment: +```bash +python overcooked_simulator/main.py +``` + +## Connect with player and receive game state +... + +## Direct integration into your code. +Initialize an environment.... + +**TODO** JSON State description. + # Citation +# Structure of the Documentation +The API documentation follows the file and content structure in the repo. +On the left you can find the navigation panel that brings you to the implementation of +- the **counters**, including the kitchen utility objects like dispenser, stove, sink, etc., +- the **game items**, the holdable ingredients, cooking equipment, composed ingredients, and meals, +- in **main**, you find an example how to start a simulation, +- the **orders**, how to sample incoming orders and their attributes, +- the **environment**, handles the incoming actions and provides the state, +- the **player**/agent, that interacts in the environment, +- a **simulation runner**, that calls the step function of the environment for a real-time interaction, and +- **util**ity code. + """ import os from pathlib import Path ROOT_DIR = Path(os.path.dirname(os.path.abspath(__file__))) # This is your Project Root +"""A path variable to get access to the layouts coming with the package. For example, +```python +from overcooked_simulator import ROOT_DIR + +environment_config_path = ROOT_DIR / "game_content" / "environment_config.yaml" +``` +""" diff --git a/overcooked_simulator/counters.py b/overcooked_simulator/counters.py index d62109ebaf861a8daf05e1e10740d80cfe5dc441..02ae69135bac11693ebcfc39d671a54d5528e102 100644 --- a/overcooked_simulator/counters.py +++ b/overcooked_simulator/counters.py @@ -1,11 +1,44 @@ +"""All counters are derived from the `Counter` class. Counters implement the `Counter.pick_up` method, which defines +what should happen when the agent wants to pick something up from the counter. On the other side, +the `Counter.drop_off` method receives the item what should be put on the counter. Before that the +`Counter.can_drop_off` method checked if the item can be put on the counter. The progress on Counters or on objects +on the counters are handled via the Counters. They have the task to delegate the progress call via the +`progress` method, e.g., the `CuttingBoard.progress`. On which type of counter the progress method is called is currently defined in the +environment class. + +Inside the item_info.yaml, equipment needs to be defined. It includes counters that are part of the interaction/requirements for the interaction. + + CuttingBoard: + type: Equipment + + Sink: + type: Equipment + + Stove: + type: Equipment + +The defined counter classes are: +- `Counter` +- `CuttingBoard` +- `ServingWindow` +- `Dispenser` +- `PlateDispenser` +- `Trashcan` +- `Stove` (maybe abstracted in a class for all cooking machine counters (stove, deep fryer, oven)) +- `Sink` +- `SinkAddon` + +## Code Documentation +""" from __future__ import annotations +import dataclasses import logging import uuid from collections import deque from collections.abc import Iterable from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Optional, Callable +from typing import TYPE_CHECKING, Optional, Callable, TypedDict if TYPE_CHECKING: from overcooked_simulator.overcooked_environment import ( @@ -28,8 +61,31 @@ log = logging.getLogger(__name__) COUNTER_CATEGORY = "Counter" +class TransitionsValueDict(TypedDict): + """The values in the transitions dicts of the `CookingEquipment`.""" + + seconds: int | float + """The needed seconds to progress for the transition.""" + needs: list[str] + """The names of the needed items for the transition.""" + info: ItemInfo | str + """The ItemInfo of the resulting item.""" + + +class TransitionsValueByNameDict(TypedDict): + """The values in the transitions dicts of the `CuttingBoard` and the `Sink`.""" + + seconds: int | float + """The needed seconds to progress for the transition.""" + result: str + """The new name of the item after the transition.""" + + class Counter: - """Simple class for a counter at a specified position (center of counter). Can hold things on top.""" + """Simple class for a counter at a specified position (center of counter). Can hold things on top. + + The character `#` in the `layout` file represents the standard Counter. + """ def __init__( self, @@ -37,6 +93,12 @@ class Counter: occupied_by: Optional[Item] = None, uid: hex = None, ): + """Constructor setting the arguments as attributes. + + Args: + pos: Position of the counter in the environment. 2-element vector. + occupied_by: The item on top of the counter. + """ self.uuid = uuid.uuid4().hex if uid is None else None self.pos: npt.NDArray[float] = pos self.occupied_by: Optional[Item] = occupied_by @@ -45,12 +107,14 @@ class Counter: def occupied(self): return self.occupied_by is not None - def pick_up(self, on_hands: bool = True): + def pick_up(self, on_hands: bool = True) -> Item | None: """Gets called upon a player performing the pickup action. If the counter can give something to the player, it does so. In the standard counter this is when an item is on the counter. - Returns: The item which the counter is occupied by. None if nothing is there. + Args: + on_hands: Will the item be put on empty hands or on a cooking equipment. + Returns: The item which the counter is occupied by. None if nothing is there. """ if on_hands: if self.occupied_by: @@ -82,8 +146,8 @@ class Counter: Args: item: The item to be placed on the counter. - Returns: TODO Return information, whether the score is affected (Serving Window?) - + Returns: + Item or None what should be put back on the players hand, e.g., the cooking equipment. """ if self.occupied_by is None: self.occupied_by = item @@ -121,13 +185,36 @@ class Counter: class CuttingBoard(Counter): - def __init__(self, pos: np.ndarray, transitions: dict): + """Cutting ingredients on. The requirement in a new object could look like + + ```yaml + ChoppedTomato: + type: Ingredient + needs: [ Tomato ] + seconds: 4.0 + equipment: CuttingBoard + ``` + The character `C` in the `layout` file represents the CuttingBoard. + """ + + def __init__( + self, pos: np.ndarray, transitions: dict[str, TransitionsValueByNameDict] + ): self.progressing = False self.transitions = transitions - super().__init__(pos) + super().__init__(pos=pos) def progress(self, passed_time: timedelta, now: datetime): - """Called by environment step function for time progression""" + """Called by environment step function for time progression. + + Args: + passed_time: the time passed since the last progress call + now: the current env time. **Not the same as `datetime.now`**. + + Checks if the item on the board is in the allowed transitions via a Cutting board. Pass the progress call to + the item on the board. If the progress on the item reaches 100% it changes the name of the item based on the + "goal" name in the transition definition. + """ if ( self.occupied and self.progressing @@ -169,9 +256,19 @@ class CuttingBoard(Counter): class ServingWindow(Counter): + """The orders and scores are updated based on completed and dropped off meals. The plate dispenser is pinged for the info about a plate outside of the kitchen. + + All items in the `item_info.yml` with the type meal are considered to be servable, if they are ordered. Not + ordered meals can also be served, if a `serving_not_ordered_meals` function is set in the `environment_config.yml`. + + The plate dispenser will put after some time a dirty plate on itself after a meal was served. + + The character `W` in the `layout` file represents the ServingWindow. + """ + def __init__( self, - pos, + pos: npt.NDArray[float], order_and_score: OrderAndScoreManager, meals: set[str], env_time_func: Callable[[], datetime], @@ -181,7 +278,7 @@ class ServingWindow(Counter): self.plate_dispenser = plate_dispenser self.meals = meals self.env_time_func = env_time_func - super().__init__(pos) + super().__init__(pos=pos) def drop_off(self, item) -> Item | None: env_time = self.env_time_func() @@ -197,7 +294,7 @@ class ServingWindow(Counter): or (len(item.content_list) == 1 and item.content_list[0].name in self.meals) ) - def pick_up(self, on_hands: bool = True): + def pick_up(self, on_hands: bool = True) -> Item | None: pass def add_plate_dispenser(self, plate_dispenser): @@ -205,14 +302,32 @@ class ServingWindow(Counter): class Dispenser(Counter): - def __init__(self, pos, dispensing: ItemInfo): + """The class for all dispensers except plate dispenser. Here ingredients can be grabbed from the player/agent. + + At the moment all ingredients have an unlimited stock. + + The character for each dispenser in the `layout` file is currently hard coded in the environment class: + ```yaml + T: Tomato + L: Lettuce + N: Onion # N for oNioN + B: Bun + M: Meat + ``` + The plan is to put the info also in the config. + + In the implementation, an instance of the item to dispense is always on top of the dispenser. + Which also is easier for the visualization of the dispenser. + """ + + def __init__(self, pos: npt.NDArray[float], dispensing: ItemInfo): self.dispensing = dispensing super().__init__( - pos, - self.create_item(), + pos=pos, + occupied_by=self.create_item(), ) - def pick_up(self, on_hands: bool = True): + def pick_up(self, on_hands: bool = True) -> Item | None: return_this = self.occupied_by self.occupied_by = self.create_item() return return_this @@ -227,7 +342,7 @@ class Dispenser(Counter): def __repr__(self): return f"{self.dispensing.name}Dispenser" - def create_item(self): + def create_item(self) -> Item: kwargs = { "name": self.dispensing.name, "item_info": self.dispensing, @@ -240,21 +355,52 @@ class Dispenser(Counter): return d +@dataclasses.dataclass +class PlateConfig: + """Configure the initial and behavior of the plates in the environment.""" + + clean_plates: int = 0 + """clean plates at the start.""" + dirty_plates: int = 3 + """dirty plates at the start.""" + plate_delay: list[int, int] = dataclasses.field(default_factory=lambda: [5, 10]) + """The uniform sampling range for the plate delay between serving and return in seconds.""" + + class PlateDispenser(Counter): + """At the moment, one and only one plate dispenser must exist in an environment, because only at one place the dirty + plates should arrive. + + How many plates should exist at the start of the level on the plate dispenser is defined in the `environment_config.yml`: + ```yaml + plates: + clean_plates: 1 + dirty_plates: 2 + plate_delay: [ 5, 10 ] + # seconds until the dirty plate arrives. + ``` + + The character `P` in the `layout` file represents the PlateDispenser. + """ + def __init__( - self, pos, dispensing, plate_config, plate_transitions, **kwargs + self, + pos: npt.NDArray[float], + dispensing: ItemInfo, + plate_config: PlateConfig, + plate_transitions: dict, + **kwargs, ) -> None: - super().__init__(pos, **kwargs) + super().__init__(pos=pos, **kwargs) self.dispensing = dispensing self.occupied_by = deque() self.out_of_kitchen_timer = [] - self.plate_config = {"plate_delay": [5, 10]} - self.plate_config.update(plate_config) + self.plate_config = plate_config self.next_plate_time = datetime.max - self.plate_transitions = plate_transitions + self.plate_transitions: dict[str, TransitionsValueDict] = plate_transitions self.setup_plates() - def pick_up(self, on_hands: bool = True): + def pick_up(self, on_hands: bool = True) -> Item | None: if self.occupied_by: return self.occupied_by.pop() @@ -262,14 +408,8 @@ class PlateDispenser(Counter): return not self.occupied_by or self.occupied_by[-1].can_combine(item) def drop_off(self, item: Item) -> Item | None: - """Takes the thing dropped of by the player. - - Args: - item: The item to be placed on the counter. - - Returns: TODO Return information, whether the score is affected (Serving Window?) - - """ + """At the moment items can be put on the top of the plate dispenser or the top plate if it is clean and can + be put on a plate.""" if not self.occupied_by: self.occupied_by.append(item) elif self.occupied_by[-1].can_combine(item): @@ -284,8 +424,8 @@ class PlateDispenser(Counter): # not perfect identical to datetime.now but based on framerate enough. time_plate_to_add = env_time + timedelta( seconds=np.random.uniform( - low=self.plate_config["plate_delay"][0], - high=self.plate_config["plate_delay"][1], + low=self.plate_config.plate_delay[0], + high=self.plate_config.plate_delay[1], ) ) log.debug(f"New plate out of kitchen until {time_plate_to_add}") @@ -295,15 +435,17 @@ class PlateDispenser(Counter): def setup_plates(self): """Create plates based on the config. Clean and dirty ones.""" - if "dirty_plates" in self.plate_config: + if self.plate_config.dirty_plates > 0: + log.info(f"Setup {self.plate_config.dirty_plates} dirty plates.") self.occupied_by.extend( - [self.create_item() for _ in range(self.plate_config["dirty_plates"])] + [self.create_item() for _ in range(self.plate_config.dirty_plates)] ) - if "clean_plates" in self.plate_config: + if self.plate_config.clean_plates > 0: + log.info(f"Setup {self.plate_config.dirty_plates} clean plates.") self.occupied_by.extend( [ self.create_item(clean=True) - for _ in range(self.plate_config["clean_plates"]) + for _ in range(self.plate_config.clean_plates) ] ) @@ -327,7 +469,7 @@ class PlateDispenser(Counter): def __repr__(self): return "PlateReturn" - def create_item(self, clean: bool = False): + def create_item(self, clean: bool = False) -> Plate: kwargs = { "clean": clean, "transitions": self.plate_transitions, @@ -336,8 +478,13 @@ class PlateDispenser(Counter): return Plate(**kwargs) -class Trash(Counter): - def pick_up(self, on_hands: bool = True): +class Trashcan(Counter): + """Ingredients and content on a cooking equipment can be removed from the environment via the trash. + + The character `X` in the `layout` file represents the Trashcan. + """ + + def pick_up(self, on_hands: bool = True) -> Item | None: pass def drop_off(self, item: Item) -> Item | None: @@ -351,6 +498,16 @@ class Trash(Counter): class Stove(Counter): + """Cooking machine. Currently, the stove which can have a pot and pan on top. In the future one class for stove, + deep fryer, and oven. + + The character depends on the cooking equipment on top of it: + ```yaml + U: Stove with a pot + Q: Stove with a pan + ``` + """ + def can_drop_off(self, item: Item) -> bool: if self.occupied_by is None: return isinstance(item, CookingEquipment) and item.name in ["Pot", "Pan"] @@ -368,12 +525,32 @@ class Stove(Counter): class Sink(Counter): - def __init__(self, pos, transitions, sink_addon=None): - super().__init__(pos) + """The counter in which the dirty plates can be washed to clean plates. + + Needs a `SinkAddon`. The closest is calculated during initialisation, should not be seperated by each other (needs + to touch the sink). + + The logic is similar to the CuttingBoard because there is no additional cooking equipment between the object to + progress and the counter. When the progress on the dirty plate is done, it is set to clean and is passed to the + `SinkAddon`. + + The character `S` in the `layout` file represents the Sink. + """ + + def __init__( + self, + pos: npt.NDArray[float], + transitions: dict[str, TransitionsValueByNameDict], + sink_addon: SinkAddon = None, + ): + super().__init__(pos=pos) self.progressing = False self.sink_addon: SinkAddon = sink_addon + """The connected sink addon which will receive the clean plates""" self.occupied_by = deque() + """The queue of dirty plates. Only the one on the top is progressed.""" self.transitions = transitions + """The allowed transitions for the items in the sink. Here only clean plates transfer from dirty plates.""" @property def occupied(self): @@ -425,10 +602,10 @@ class Sink(Counter): self.occupied_by.appendleft(item) return None - def pick_up(self, on_hands: bool = True): - return + def pick_up(self, on_hands: bool = True) -> Item | None: + return None - def set_addon(self, sink_addon): + def set_addon(self, sink_addon: SinkAddon): self.sink_addon = sink_addon def to_dict(self) -> dict: @@ -438,27 +615,27 @@ class Sink(Counter): class SinkAddon(Counter): - def __init__(self, pos, occupied_by=None): - super().__init__(pos) + """The counter on which the clean plates appear after cleaning them in the `Sink` + + It needs to be set close to/touching the `Sink`. + + The character `+` in the `layout` file represents the SinkAddon. + """ + + def __init__(self, pos: npt.NDArray[float], occupied_by=None): + super().__init__(pos=pos) + # maybe check if occupied by is already a list or deque? self.occupied_by = deque([occupied_by]) if occupied_by else deque() def can_drop_off(self, item: Item) -> bool: return self.occupied_by and self.occupied_by[-1].can_combine(item) def drop_off(self, item: Item) -> Item | None: - """Takes the thing dropped of by the player. - - Args: - item: The item to be placed on the counter. - - Returns: - - """ return self.occupied_by[-1].combine(item) def add_clean_plate(self, plate: Plate): self.occupied_by.appendleft(plate) - def pick_up(self, on_hands: bool = True): + def pick_up(self, on_hands: bool = True) -> Item | None: if self.occupied_by: return self.occupied_by.pop() diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml index d5eb376d0a489e799a5e213dbe5add26764f2f93..64b7efa95099079f30add7085f9dd1cc38913a41 100644 --- a/overcooked_simulator/game_content/environment_config.yaml +++ b/overcooked_simulator/game_content/environment_config.yaml @@ -2,7 +2,7 @@ plates: clean_plates: 1 dirty_plates: 2 plate_delay: [ 5, 10 ] - # seconds until the dirty plate arrives. + # range of seconds until the dirty plate arrives. game: time_limit_seconds: 300 @@ -17,8 +17,10 @@ meals: - Salad orders: - kwargs: - duration_sample: + order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration '' + # the class to that receives the kwargs. Should be a child class of OrderGeneration in order.py + order_gen_kwargs: + order_duration_random_func: # how long should the orders be alive # 'random' library call with getattr, kwargs are passed to the function func: uniform @@ -27,19 +29,16 @@ orders: b: 60 max_orders: 6 # maximum number of active orders at the same time - num_start_meals: 3 + num_start_meals: 2 # number of orders generated at the start of the environment - sample_on_dur: true - # if true, the next order is generated based on the sample_on_dur_func method in seconds - # if sample_on_serving is also true, the value is sampled after a meal was served, otherwise it is sampled directly after an order generation. - sample_on_dur_func: + sample_on_dur_random_func: # 'random' library call with getattr, kwargs are passed to the function func: uniform kwargs: a: 10 b: 20 sample_on_serving: false - # The sample time for a new incoming order is only generated after a meal was served. + # Sample the delay for the next order only after a meal was served. score_calc_gen_func: !!python/name:overcooked_simulator.order.simple_score_calc_gen_func '' score_calc_gen_kwargs: # the kwargs for the score_calc_gen_func @@ -54,8 +53,6 @@ orders: default: -5 serving_not_ordered_meals: null # a func that calcs a store for not ordered but served meals. Input: meal - order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration '' - # the class to that receives the kwargs. Should be a child class of OrderGeneration in order.py player_config: radius: 0.4 diff --git a/overcooked_simulator/game_items.py b/overcooked_simulator/game_items.py index ac244df9660d113248437c27d5911a12b880051e..8f1b39bd52c35fdc923ce714a88b4c7f815fe471 100644 --- a/overcooked_simulator/game_items.py +++ b/overcooked_simulator/game_items.py @@ -1,3 +1,23 @@ +""" +The game items that a player can hold. + +They have methods that +- check if items can be combined (`Item.can_combine`): cooking equipment and ingredients, and so on, +- combine the items after a successful check (`Item.combine`), +- and a method to call the progress on the items (`Item.progress`) + +All game items need to be specified in the `item_info.yml`. + +The following classes are used for the base for all game items: +- `Item`: ingredients and meals. +- `CookingEquipment`: pots, pans, etc. +- `Plate`: clean and dirty plates. + +The `ItemInfo` is the dataclass for the items in the `item_info.yml`. + +## Code Documentation +""" + from __future__ import annotations import collections @@ -22,42 +42,52 @@ class ItemType(Enum): @dataclasses.dataclass class ItemInfo: + """Wrapper for the info in the `item_info.yml`. + + Example: + A simple example for the tomato soup with 6 game items. + ```yaml + CuttingBoard: + type: Equipment + + Stove: + type: Equipment + + Pot: + type: Equipment + equipment: Stove + + Tomato: + type: Ingredient + + ChoppedTomato: + type: Ingredient + needs: [ Tomato ] + seconds: 4.0 + equipment: CuttingBoard + + TomatoSoup: + type: Meal + needs: [ ChoppedTomato, ChoppedTomato, ChoppedTomato ] + seconds: 6.0 + equipment: Pot + ``` + """ + type: ItemType = dataclasses.field(compare=False) + """Type of the item. Either `Ingredient`, `Meal` or `Equipment`.""" name: str = dataclasses.field(compare=True) + """The name of the item, is set automatically by the "group" name of the item.""" seconds: float = dataclasses.field(compare=False, default=0) - needs: list[ItemInfo] = dataclasses.field(compare=False, default_factory=list) + """If progress is needed this argument defines how long it takes to complete the process in seconds.""" + needs: list[str] = dataclasses.field(compare=False, default_factory=list) + """The ingredients/items which are needed to create the item/start the progress.""" equipment: ItemInfo | None = dataclasses.field(compare=False, default=None) - - _start_meals: list[ItemInfo] = dataclasses.field( - compare=False, default_factory=list - ) + """On which the item can be created. `null`, `~` (None) converts to Plate.""" def __post_init__(self): self.type = ItemType(self.type) - def add_start_meal_to_equipment(self, start_item: ItemInfo): - self._start_meals.append(start_item) - - def sort_start_meals(self): - self._start_meals.sort(key=lambda item_info: len(item_info.needs)) - - # def can_start_meal(self, items: list[Item]): - # return items and self._return_start_meal(items) is not None - - # def start_meal(self, items: list[Item]) -> Item: - # return self._return_start_meal(items).create_item(parts=items) - - def _return_start_meal(self, items: list[Item]) -> ItemInfo | None: - for meal in self._start_meals: - satisfied = [False for _ in range(len(items))] - for i, p in enumerate(items): - for _, n in enumerate(meal.needs): - if not satisfied[i] and p.name == n: - satisfied[i] = True - break - if all(satisfied): - return meal - class Item: """Base class for game items which can be held by a player.""" @@ -119,16 +149,24 @@ class Item: class CookingEquipment(Item): + """Pot, Pan, ... that can hold items. It holds the progress of the content (e.g., the soup) in itself ( + progress_percentage) and not in the items in the content list.""" + item_category = "Cooking Equipment" def __init__(self, transitions: dict, *args, **kwargs): super().__init__(*args, **kwargs) self.transitions = transitions self.active_transition: Optional[dict] = None + """The info how and when to convert the content_list to a new item.""" # TODO change content ready just to str (name of the item)? self.content_ready: Item | None = None + """Helper attribute that can have a ready meal which is also represented via it ingredients in the + content_list. But soups or other processed meals are not covered here. For a Burger or Salad, this attribute + is set.""" self.content_list: list[Item] = [] + """The items that the equipment holds.""" log.debug(f"Initialize {self.name}: {self.transitions}") @@ -192,12 +230,6 @@ class CookingEquipment(Item): # todo set active transition for fire/burnt? - # def can_release_content(self) -> bool: - # return ( - # self.content - # and isinstance(self.content, ProgressibleItem) - # and self.content.finished - # ) def reset_content(self): self.content_list = [] self.content_ready = None @@ -239,9 +271,12 @@ class CookingEquipment(Item): class Plate(CookingEquipment): + """The plate can have to states: clean and dirty. In the clean state it can hold content/other items.""" + def __init__(self, transitions, clean, *args, **kwargs): self.clean = clean self.meals = set(transitions.keys()) + """All meals can be hold by a clean plate""" super().__init__( name=self.create_name(), transitions={ @@ -269,6 +304,8 @@ class Plate(CookingEquipment): and not self.content_list and self.clean ): + # additional check for meals in the content list of another equipment, + # e.g., Soups which is not covered by the normal transition checks. return other.content_list[0].name in self.meals return False elif self.clean: diff --git a/overcooked_simulator/gui_2d_vis/__init__.py b/overcooked_simulator/gui_2d_vis/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..9c531caa6003ae868d4eeef943dfa92d5c349fda 100644 --- a/overcooked_simulator/gui_2d_vis/__init__.py +++ b/overcooked_simulator/gui_2d_vis/__init__.py @@ -0,0 +1,20 @@ +""" +2D visualization of the overcooked simulator. + +You can select the layout and start an environment: +- You can play the overcooked simulator. You can quit the application in the top right or end the level in the bottom right: [Screenshot](https://gitlab.ub.uni-bielefeld.de/scs/cocosy/overcooked-simulator/-/raw/main/overcooked_simulator/gui_2d_vis/images/overcooked-start-screen.png?ref_type=heads) +- The orders are pictured in the top, the current score in the bottom left and the remaining time in the bottom: [Screenshot](https://gitlab.ub.uni-bielefeld.de/scs/cocosy/overcooked-simulator/-/raw/main/overcooked_simulator/gui_2d_vis/images/overcooked-level-screen.png?ref_type=heads) +- The final screen after ending a level shows the score: [Screenshot](https://gitlab.ub.uni-bielefeld.de/scs/cocosy/overcooked-simulator/-/raw/main/overcooked_simulator/gui_2d_vis/images/overcooked-end-screen.png?ref_type=heads) + +The keys for the control of the players are: + +### Player 1: +- Movement: `W`, `A`, `S`, `D`, +- Pickup: `E` +- Interact: `F` + +### Player 2: +- Movement: `⬆`, `⬅`, `⬇`, `➡` (arrow keys) +- Pickup: `I` +- Interact: `SPACE` +""" diff --git a/overcooked_simulator/gui_2d_vis/images/overcooked-end-screen.png b/overcooked_simulator/gui_2d_vis/images/overcooked-end-screen.png new file mode 100644 index 0000000000000000000000000000000000000000..d678687d61277930d839882aa5f6956caeb5e9a6 Binary files /dev/null and b/overcooked_simulator/gui_2d_vis/images/overcooked-end-screen.png differ diff --git a/overcooked_simulator/gui_2d_vis/images/overcooked-level-screen.png b/overcooked_simulator/gui_2d_vis/images/overcooked-level-screen.png new file mode 100644 index 0000000000000000000000000000000000000000..0f2cc384aacabcd80d3b88542fd9a3345506b35d Binary files /dev/null and b/overcooked_simulator/gui_2d_vis/images/overcooked-level-screen.png differ diff --git a/overcooked_simulator/gui_2d_vis/images/overcooked-start-screen.png b/overcooked_simulator/gui_2d_vis/images/overcooked-start-screen.png new file mode 100644 index 0000000000000000000000000000000000000000..420d1eedb6cf57a493aba7dbabe7e3ba80ecb1b4 Binary files /dev/null and b/overcooked_simulator/gui_2d_vis/images/overcooked-start-screen.png differ diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py index 3c51e41a8f5ce255f00ca1821c4cdf789ee2284e..b47875f5598276c30e2621c049f41f42d9ba4177 100644 --- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py +++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py @@ -16,7 +16,11 @@ from overcooked_simulator import ROOT_DIR from overcooked_simulator.game_server import CreateEnvironmentConfig from overcooked_simulator.gui_2d_vis.drawing import Visualizer from overcooked_simulator.gui_2d_vis.game_colors import colors -from overcooked_simulator.overcooked_environment import Action +from overcooked_simulator.overcooked_environment import ( + Action, + ActionType, + InterActionData, +) class MenuStates(Enum): @@ -38,9 +42,12 @@ class PlayerKeySet: def __init__(self, player_name: str | int, keys: list[pygame.key]): """Creates a player key set which contains information about which keyboard keys control the player. + Movement keys in the following order: Down, Up, Left, Right - Args: player_name: The name of the player to control. - keys: The keys which control this player in the following order: Down, Up, Left, Right, Interact, Pickup. + + Args: + player_name: The name of the player to control. + keys: The keys which control this player in the following order: Down, Up, Left, Right, Interact, Pickup. """ self.name = player_name self.player_keys = keys @@ -167,29 +174,32 @@ class PyGameGUI: if np.linalg.norm(move_vec) != 0: move_vec = move_vec / np.linalg.norm(move_vec) - action = Action( - key_set.name, "movement", move_vec, duration=1 / (self.FPS) - ) + action = Action(key_set.name, ActionType.MOVEMENT, move_vec) self.send_action(action) def handle_key_event(self, event): """Handles key events for the pickup and interaction keys. Pickup is a single action, for interaction keydown and keyup is necessary, because the player has to be able to hold the key down. + Args: event: Pygame event for extracting the key action. """ for key_set in self.player_key_sets: if event.key == key_set.pickup_key and event.type == pygame.KEYDOWN: - action = Action(key_set.name, "pickup", "pickup") + action = Action(key_set.name, ActionType.PUT, "pickup") self.send_action(action) if event.key == key_set.interact_key: if event.type == pygame.KEYDOWN: - action = Action(key_set.name, "interact", "keydown") + action = Action( + key_set.name, ActionType.INTERACT, InterActionData.START + ) self.send_action(action) elif event.type == pygame.KEYUP: - action = Action(key_set.name, "interact", "keyup") + action = Action( + key_set.name, ActionType.INTERACT, InterActionData.STOP + ) self.send_action(action) def init_ui_elements(self): diff --git a/overcooked_simulator/gui_2d_vis/visualization.yaml b/overcooked_simulator/gui_2d_vis/visualization.yaml index 8d4d52deb8f111a56dda65c44501aca3414fe979..a4ba4f7ec354aa825fe8eb09b44ee19995015eaa 100644 --- a/overcooked_simulator/gui_2d_vis/visualization.yaml +++ b/overcooked_simulator/gui_2d_vis/visualization.yaml @@ -39,7 +39,7 @@ PlateDispenser: width: 0.95 color: cadetblue1 -Trash: +Trashcan: parts: - type: image path: images/trash3.png diff --git a/overcooked_simulator/order.py b/overcooked_simulator/order.py index 13ce34c391152499782c9439bfcd6115cbc84521..f158170f73e539cc774b6901a6a3e8bcdfbb9db1 100644 --- a/overcooked_simulator/order.py +++ b/overcooked_simulator/order.py @@ -1,3 +1,46 @@ +""" +You can configure the order creation/generation via the `environment_config.yml`. + +It is very configurable by letting you reference own Python classes and functions. + +```yaml +orders: + serving_not_ordered_meals: null + order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration '' + order_gen_kwargs: + ... +``` + +`serving_not_ordered_meals` expects a function. It received a meal as an argument and should return a +tuple of a bool and the score. If the bool is true, the score will be added to the score. Otherwise, it will not +accept the meal for serving. + +The `order_gen_class` should be a child of the `OrderGeneration` class. The `order_gen_kwargs` depend then on your +class referenced. + +This file defines the following classes: +- `Order` +- `OrderGeneration` +- `OrderAndScoreManager` + +Further, it defines same implementations for the basic order generation based on random sampling: +- `RandomOrderGeneration` +- `simple_score_calc_gen_func` +- `simple_expired_penalty` +- `zero` + +For an easier usage of the random orders, also some classes for type hints and dataclasses are defined: +- `RandomOrderKwarg` +- `RandomFuncConfig` +- `ScoreCalcFuncType` +- `ScoreCalcGenFuncType` +- `ExpiredPenaltyFuncType` + + +## Code Documentation +""" +from __future__ import annotations + import dataclasses import logging import random @@ -5,7 +48,7 @@ import uuid from abc import abstractmethod from collections import deque from datetime import datetime, timedelta -from typing import Callable, Tuple, Any, Deque +from typing import Callable, Tuple, Any, Deque, Protocol, TypedDict from overcooked_simulator.game_items import Item, Plate, ItemInfo @@ -16,20 +59,30 @@ ORDER_CATEGORY = "Order" @dataclasses.dataclass class Order: + """Datawrapper for Orders""" + meal: ItemInfo + """The meal to serve and that should be cooked.""" start_time: datetime + """The start time relative to the env_time. On which the order is returned from the get_orders func.""" max_duration: timedelta - score_calc: Callable[[timedelta, ...], float] + """The duration after which the order expires.""" + score_calc: ScoreCalcFuncType + """The function that calculates the score of the served meal/fulfilled order.""" timed_penalties: list[ Tuple[timedelta, float] | Tuple[timedelta, float, int, timedelta] ] + """List of timed penalties when the order is not fulfilled.""" expired_penalty: float + """The penalty to the score if the order expires""" uuid: str = dataclasses.field(default_factory=lambda: uuid.uuid4().hex) finished_info: dict[str, Any] = dataclasses.field(default_factory=dict) + """Is set after the order is completed.""" _timed_penalties: list[Tuple[datetime, float]] = dataclasses.field( default_factory=list ) + """Converted penalties the env is working with from the `timed_penalties`""" def order_time(self, env_time: datetime) -> timedelta: return self.start_time - env_time @@ -49,11 +102,24 @@ class Order: class OrderGeneration: + """Base class for generating orders. + + You can set your child class via the `environment_config.yml`. + Example: + ```yaml + orders: + order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration '' + kwargs: + ... + ``` + """ + def __init__(self, available_meals: dict[str, ItemInfo], **kwargs): self.available_meals: list[ItemInfo] = list(available_meals.values()) @abstractmethod def init_orders(self, now) -> list[Order]: + """Get the orders the environment starts with.""" ... @abstractmethod @@ -64,30 +130,322 @@ class OrderGeneration: new_finished_orders: list[Order], expired_orders: list[Order], ) -> list[Order]: + """Orders for each progress call. Should often be the empty list.""" + ... + + +class OrderAndScoreManager: + """The Order and Score Manager that is called from the serving window.""" + + def __init__(self, order_config, available_meals: dict[str, ItemInfo]): + self.score = 0 + self.order_gen: OrderGeneration = order_config["order_gen_class"]( + available_meals=available_meals, kwargs=order_config["order_gen_kwargs"] + ) + self.serving_not_ordered_meals: Callable[ + [Item], Tuple[bool, float] + ] = order_config["serving_not_ordered_meals"] + """Function that decides if not ordered meals can be served and what score it gives""" + self.available_meals = available_meals + """The meals for that orders can be sampled from.""" + self.open_orders: Deque[Order] = deque() + """Current open orders. This attribute is used for the environment state.""" + + # TODO log who / which player served which meal -> for split scores + self.served_meals: list[Tuple[Item, datetime]] = [] + """List of served meals. Maybe for the end screen.""" + self.last_finished: list[Order] = [] + """Cache last finished orders for `OrderGeneration.get_orders` call. From the served meals.""" + self.next_relevant_time: datetime = datetime.max + """For reduced order checking. Store the next time when to create an order or check for penalties.""" + self.last_expired: list[Order] = [] + """Cache last expired orders for `OrderGeneration.get_orders` call.""" + + def update_next_relevant_time(self): + next_relevant_time = datetime.max + for order in self.open_orders: + next_relevant_time = min( + next_relevant_time, order.start_time + order.max_duration + ) + for penalty in order._timed_penalties: + next_relevant_time = min(next_relevant_time, penalty[0]) + self.next_relevant_time = next_relevant_time + + def serve_meal(self, item: Item, env_time: datetime) -> bool: + if isinstance(item, Plate): + meal = item.get_potential_meal() + if meal is not None: + if meal.name in self.available_meals: + order = self.find_order_for_meal(meal) + if order is None: + if self.serving_not_ordered_meals: + accept, score = self.serving_not_ordered_meals(meal) + if accept: + log.info( + f"Serving meal without order {meal.name} with score {score}" + ) + self.increment_score(score) + self.served_meals.append((meal, env_time)) + return accept + log.info( + f"Do not serve meal {meal.name} because it is not ordered" + ) + return False + order, index = order + score = order.score_calc( + relative_order_time=env_time - order.start_time, + order=order, + ) + self.increment_score(score) + order.finished_info = { + "end_time": env_time, + "score": score, + } + log.info(f"Serving meal {meal.name} with order with score {score}") + self.last_finished.append(order) + del self.open_orders[index] + self.served_meals.append((meal, env_time)) + return True + log.info(f"Do not serve item {item}") + return False + + def increment_score(self, score: int | float): + self.score += score + log.debug(f"Score: {self.score}") + + def create_init_orders(self, env_time): + """Create the initial orders in an environment.""" + init_orders = self.order_gen.init_orders(env_time) + self.setup_penalties(new_orders=init_orders, env_time=env_time) + self.open_orders.extend(init_orders) + + def progress(self, passed_time: timedelta, now: datetime): + """Check expired orders and check order generation.""" + new_orders = self.order_gen.get_orders( + passed_time=passed_time, + now=now, + new_finished_orders=self.last_finished, + expired_orders=self.last_expired, + ) + self.setup_penalties(new_orders=new_orders, env_time=now) + self.open_orders.extend(new_orders) + self.last_finished = [] + self.last_expired = [] + if new_orders or self.next_relevant_time <= now: + # reduce checking calls + + remove_orders: list[int] = [] + for index, order in enumerate(self.open_orders): + if now >= order.start_time + order.max_duration: + # orders expired + self.increment_score(order.expired_penalty) + remove_orders.append(index) + continue # no penalties for expired orders + remove_penalties = [] + for i, (penalty_time, penalty) in enumerate(order.timed_penalties): + # check penalties + if penalty_time < now: + self.score -= penalty + remove_penalties.append(i) + + for i in reversed(remove_penalties): + # or del order.timed_penalties[index] + order.timed_penalties.pop(i) + + expired_orders: list[Order] = [] + for remove_order in reversed(remove_orders): + expired_orders.append(self.open_orders[remove_order]) + del self.open_orders[remove_order] + self.last_expired = expired_orders + + self.update_next_relevant_time() + + def find_order_for_meal(self, meal) -> Tuple[Order, int] | None: + for index, order in enumerate(self.open_orders): + if order.meal.name == meal.name: + return order, index + + @staticmethod + def setup_penalties(new_orders: list[Order], env_time: datetime): + """Call the `Order.create_penalties` method for new orders.""" + for order in new_orders: + order.create_penalties(env_time) + + def order_state(self) -> list[dict]: + return [ + { + "id": order.uuid, + "category": ORDER_CATEGORY, + "meal": order.meal.name, + "start_time": order.start_time.isoformat(), + "max_duration": order.max_duration.total_seconds(), + } + for order in self.open_orders + ] + + +class ScoreCalcFuncType(Protocol): + """Typed kwargs of the expected `Order.score_calc` function. Which is also returned by the + `RandomOrderKwarg.score_calc_gen_func`. + + The function should calculate the score for the completed orders. + + Args: + relative_order_time: `timedelta` the duration how long the order was active. + order: `Order` the order that was completed. + + Returns: + `float`: the score for a completed order and duration of the order. + """ + + def __call__(self, relative_order_time: timedelta, order: Order) -> float: + ... + + +class ScoreCalcGenFuncType(Protocol): + """Typed kwargs of the expected function for the `RandomOrderKwarg.score_calc_gen_func`. + + Generate the ScoreCalcFunc for an order based on its meal, duration etc. + + Args: + meal: `ItemInfo` the type of meal the order orders. + duration: `timedelta` the duration after the order expires. + now: `datetime` the environment time the order is created. + kwargs: `dict` the static kwargs defined in the `environment_config.yml` + + Returns: + `ScoreCalcFuncType` a reference to a function that calculates the score for a completed meal. + """ + + def __call__( + self, + meal: ItemInfo, + duration: timedelta, + now: datetime, + kwargs: dict, + **other_kwargs, + ) -> ScoreCalcFuncType: + ... + + +class ExpiredPenaltyFuncType(Protocol): + """Typed kwargs of the expected function for the `RandomOrderKwarg.expired_penalty_func`. + + An example is the `zero` function. + + Args: + item: `ItemInfo` the meal of the order that expired. It is calculated before the order is active. + """ + + def __call__(self, item: ItemInfo, **kwargs) -> float: ... def zero(item: ItemInfo, **kwargs) -> float: + """Example and default for the `RandomOrderKwarg.expired_penalty_func` function. + + Just no penalty for expired orders. + + Returns: + zero / 0.0 + """ return 0.0 +class RandomFuncConfig(TypedDict): + """Types of the dict for sampling with different random functions from the [`random` library](https://docs.python.org/3/library/random.html). + + Example: + Sampling [uniform](https://docs.python.org/3/library/random.html#random.uniform)ly between `10` and `20`. + ```yaml + func: uniform + kwargs: + a: 10 + b: 20 + ``` + + Or in Python: + ```python + random_func = {'func': 'uniform', 'kwargs': {'a': 10, 'b': 20}} + ``` + """ + + func: Callable + """the name of a functions in the `random` library.""" + kwargs: dict + """the kwargs of the functions in the `random` library.""" + + @dataclasses.dataclass class RandomOrderKwarg: num_start_meals: int + """Number of meals sampled at the start.""" sample_on_serving: bool - sample_on_dur: bool - sample_on_dur_func: dict + """Only sample the delay for the next order after a meal was served.""" + sample_on_dur_random_func: RandomFuncConfig + """How to sample the delay of the next incoming order. Either after a new meal was served or the last order was + generated (based on the `sample_on_serving` attribute).""" max_orders: int - duration_sample: dict - score_calc_gen_func: Callable[ - [ItemInfo, timedelta, datetime, Any], Callable[[timedelta, Order], float] - ] + """How many orders can maximally be active at the same time.""" + order_duration_random_func: RandomFuncConfig + """How long the order is alive until it expires. If `sample_on_serving` is `true` all orders have no expire time.""" + score_calc_gen_func: ScoreCalcGenFuncType + """The function that generates the `Order.score_calc` for each order.""" score_calc_gen_kwargs: dict + """The additional static kwargs for `score_calc_gen_func`.""" expired_penalty_func: Callable[[ItemInfo], float] = zero + """The function that calculates the penalty for a meal that was not served.""" expired_penalty_kwargs: dict = dataclasses.field(default_factory=dict) + """The additional static kwargs for the `expired_penalty_func`.""" class RandomOrderGeneration(OrderGeneration): + """A simple order generation based on random sampling with two options. + + Either sample the delay when a new order should come in after the last order comes in or after a meal was served + (and an order got removed). + + To configure it align your kwargs with the `RandomOrderKwarg` class. + + You can set this order generation in your `environment_config.yml` with + ```yaml + orders: + order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration '' + kwargs: + order_duration_random_func: + # how long should the orders be alive + # 'random' library call with getattr, kwargs are passed to the function + func: uniform + kwargs: + a: 40 + b: 60 + max_orders: 6 + # maximum number of active orders at the same time + num_start_meals: 3 + # number of orders generated at the start of the environment + sample_on_dur_random_func: + # 'random' library call with getattr, kwargs are passed to the function + func: uniform + kwargs: + a: 10 + b: 20 + sample_on_serving: false + # Sample the delay for the next order only after a meal was served. + score_calc_gen_func: !!python/name:overcooked_simulator.order.simple_score_calc_gen_func '' + score_calc_gen_kwargs: + # the kwargs for the score_calc_gen_func + other: 0 + scores: + Burger: 15 + OnionSoup: 10 + Salad: 5 + TomatoSoup: 10 + expired_penalty_func: !!python/name:overcooked_simulator.order.simple_expired_penalty '' + expired_penalty_kwargs: + default: -5 + ``` + """ + def __init__(self, available_meals: dict[str, ItemInfo], **kwargs): super().__init__(available_meals, **kwargs) self.kwargs: RandomOrderKwarg = RandomOrderKwarg(**kwargs["kwargs"]) @@ -98,7 +456,7 @@ class RandomOrderGeneration(OrderGeneration): def init_orders(self, now) -> list[Order]: self.number_cur_orders = self.kwargs.num_start_meals - if self.kwargs.sample_on_dur: + if not self.kwargs.sample_on_serving: self.create_random_next_time_delta(now) return self.create_orders_for_meals( random.choices(self.available_meals, k=self.kwargs.num_start_meals), @@ -131,7 +489,7 @@ class RandomOrderGeneration(OrderGeneration): if self.number_cur_orders >= self.kwargs.max_orders: self.needed_orders += 1 else: - if self.kwargs.sample_on_dur: + if not self.kwargs.sample_on_serving: self.create_random_next_time_delta(now) else: self.next_order_time = datetime.max @@ -151,9 +509,9 @@ class RandomOrderGeneration(OrderGeneration): duration = datetime.max - now else: duration = timedelta( - seconds=getattr(random, self.kwargs.duration_sample["func"])( - **self.kwargs.duration_sample["kwargs"] - ) + seconds=getattr( + random, self.kwargs.order_duration_random_func["func"] + )(**self.kwargs.order_duration_random_func["kwargs"]) ) log.info(f"Create order for meal {meal} with duration {duration}") orders.append( @@ -178,8 +536,8 @@ class RandomOrderGeneration(OrderGeneration): def create_random_next_time_delta(self, now: datetime): self.next_order_time = now + timedelta( - seconds=getattr(random, self.kwargs.sample_on_dur_func["func"])( - **self.kwargs.sample_on_dur_func["kwargs"] + seconds=getattr(random, self.kwargs.sample_on_dur_random_func["func"])( + **self.kwargs.sample_on_dur_random_func["kwargs"] ) ) log.info(f"Next order in {self.next_order_time}") @@ -188,6 +546,21 @@ class RandomOrderGeneration(OrderGeneration): def simple_score_calc_gen_func( meal: Item, duration: timedelta, now: datetime, kwargs: dict, **other_kwargs ) -> Callable: + """An example for the `RandomOrderKwarg.score_calc_gen_func` that selects the score for an order based on its meal from a list. + + Example: + ```yaml + score_calc_gen_func: !!python/name:overcooked_simulator.order.simple_score_calc_gen_func '' + score_calc_gen_kwargs: + # the kwargs for the score_calc_gen_func + other: 0 + scores: + Burger: 15 + OnionSoup: 10 + Salad: 5 + TomatoSoup: 10 + ``` + """ scores = kwargs["scores"] other = kwargs["other"] @@ -200,170 +573,15 @@ def simple_score_calc_gen_func( def simple_expired_penalty(item: ItemInfo, default: float, **kwargs) -> float: - return default - - -class OrderAndScoreManager: - def __init__(self, order_config, available_meals: dict[str, ItemInfo]): - self.score = 0 - self.order_gen: OrderGeneration = order_config["order_gen_class"]( - available_meals=available_meals, kwargs=order_config["kwargs"] - ) - self.kwargs_for_func = order_config["kwargs"] - self.serving_not_ordered_meals = order_config["serving_not_ordered_meals"] - self.available_meals = available_meals - self.open_orders: Deque[Order] = deque() + """Example for the `RandomOrderKwarg.expired_penalty_func` function. - # for logs or history in the future - # TODO log who / which player served which meal -> for split scores - self.served_meals: list[Tuple[Item, datetime]] = [] - self.last_finished = [] - self.next_relevant_time = datetime.max - self.last_expired = [] + A static default. - def update_next_relevant_time(self): - next_relevant_time = datetime.max - for order in self.open_orders: - next_relevant_time = min( - next_relevant_time, order.start_time + order.max_duration - ) - for penalty in order._timed_penalties: - next_relevant_time = min(next_relevant_time, penalty[0]) - self.next_relevant_time = next_relevant_time - - def serve_meal(self, item: Item, env_time: datetime) -> bool: - if isinstance(item, Plate): - meal = item.get_potential_meal() - if meal is not None: - if meal.name in self.available_meals: - order = self.find_order_for_meal(meal) - if order is None: - if self.serving_not_ordered_meals: - accept, score = self.serving_not_ordered_meals(meal) - if accept: - log.info( - f"Serving meal without order {meal.name} with score {score}" - ) - self.score += score - self.served_meals.append((meal, env_time)) - return accept - log.info( - f"Do not serve meal {meal.name} because it is not ordered" - ) - return False - order, index = order - score = order.score_calc( - relative_order_time=env_time - order.start_time, - order=order, - ) - self.score += score - order.finished_info = { - "end_time": env_time, - "score": score, - } - log.info(f"Serving meal {meal.name} with order with score {score}") - self.last_finished.append(order) - del self.open_orders[index] - self.served_meals.append((meal, env_time)) - return True - log.info(f"Do not serve item {item}") - return False - - def increment_score(self, score: int): - self.score += score - log.debug(f"Score: {self.score}") - - def create_init_orders(self, env_time): - init_orders = self.order_gen.init_orders(env_time) - self.open_orders.extend(init_orders) - - def progress(self, passed_time: timedelta, now: datetime): - new_orders = self.order_gen.get_orders( - passed_time=passed_time, - now=now, - new_finished_orders=self.last_finished, - expired_orders=self.last_expired, - ) - self.open_orders.extend(new_orders) - self.last_finished = [] - self.last_expired = [] - if new_orders or self.next_relevant_time <= now: - remove_orders = [] - for index, order in enumerate(self.open_orders): - if now >= order.start_time + order.max_duration: - self.score += order.expired_penalty - remove_orders.append(index) - remove_penalties = [] - for i, (penalty_time, penalty) in enumerate(order.timed_penalties): - if penalty_time < now: - self.score -= penalty - remove_penalties.append(i) - - for i in reversed(remove_penalties): - # or del order.timed_penalties[index] - order.timed_penalties.pop(i) - expired_orders = [] - for remove_order in reversed(remove_orders): - expired_orders.append(self.open_orders[remove_order]) - del self.open_orders[remove_order] - self.last_expired = expired_orders - - self.update_next_relevant_time() - - def find_order_for_meal(self, meal) -> Tuple[Order, int] | None: - for index, order in enumerate(self.open_orders): - if order.meal.name == meal.name: - return order, index - - def setup_penalties(self, new_orders: list[Order], env_time: datetime): - for order in new_orders: - order.create_penalties(env_time) - - def order_state(self) -> list[dict]: - return [ - { - "id": order.uuid, - "category": ORDER_CATEGORY, - "meal": order.meal.name, - "start_time": order.start_time.isoformat(), - "max_duration": order.max_duration.total_seconds(), - } - for order in self.open_orders - ] - - -if __name__ == "__main__": - import yaml - - order_config = yaml.safe_load( - """orders: - kwargs: - duration_sample: - func: uniform - kwargs: - a: 30 - b: 50 - max_orders: 5 - num_start_meals: 3 - sample_on_dur: false - sample_on_dur_func: - func: uniform - kwargs: - a: 30 - b: 50 - sample_on_serving: true - score_calc_gen_func: null - score_calc_gen_kwargs: - other: 0 - scores: - Burger: 15 - OnionSoup: 10 - Salad: 5 - TomatoSoup: 10 - score_calc_gen_func: ~'' - order_gen_class: ~ - serving_not_ordered_meals: null""" - ) - order_config["orders"]["order_gen_class"] = RandomOrderGeneration - order_config["orders"]["kwargs"]["score_calc_gen_func"] = simple_score_calc_gen_func - print(yaml.dump(order_config)) + Example: + ```yaml + expired_penalty_func: !!python/name:overcooked_simulator.order.simple_expired_penalty '' + expired_penalty_kwargs: + default: -5 + ``` + """ + return default diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index 68f7177428491d655cb0757f2df040189104ef9a..9154646dba08e8775fd07a4256ca92a4728ed5e1 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -1,13 +1,15 @@ from __future__ import annotations -import dataclasses import datetime import json +import dataclasses import logging import random from datetime import timedelta +from enum import Enum from pathlib import Path from threading import Lock +from typing import Literal, Any import numpy as np import numpy.typing as npt @@ -17,13 +19,14 @@ from scipy.spatial import distance_matrix from overcooked_simulator.counters import ( Counter, CuttingBoard, - Trash, + Trashcan, Dispenser, ServingWindow, Stove, Sink, PlateDispenser, SinkAddon, + PlateConfig, ) from overcooked_simulator.game_items import ( ItemInfo, @@ -31,24 +34,52 @@ from overcooked_simulator.game_items import ( CookingEquipment, ) from overcooked_simulator.order import OrderAndScoreManager -from overcooked_simulator.player import Player +from overcooked_simulator.player import Player, PlayerConfig from overcooked_simulator.state_representation import StateRepresentation from overcooked_simulator.utils import create_init_env_time log = logging.getLogger(__name__) +class ActionType(Enum): + """The 3 different types of valid actions. They can be extended via the `Action.action_data` attribute.""" + + MOVEMENT = "movement" + """move the agent.""" + PUT = "pickup" + """interaction type 1, e.g., for pickup or drop off. Maybe other words: transplace?""" + # TODO change value to put + INTERACT = "interact" + """interaction type 2, e.g., for progressing. Start and stop interaction via `keydown` and `keyup` actions.""" + + +class InterActionData(Enum): + """The data for the interaction action: `ActionType.MOVEMENT`.""" + + START = "keydown" + "start an interaction." + STOP = "keyup" + "stop an interaction without moving away." + + @dataclasses.dataclass class Action: """Action class, specifies player, action type and action itself.""" player: str - act_type: str - action: str | list | npt.NDArray[float] + """Id of the player.""" + action_type: ActionType + """Type of the action to perform. Defines what action data is valid.""" + action_data: npt.NDArray[float] | InterActionData | Literal["pickup"] + """Data for the action, e.g., movement vector or start and stop interaction.""" duration: float | int = 0 + """Duration of the action (relevant for movement)""" def __repr__(self): - return f"Action({self.player},{self.act_type},{self.action})" + return f"Action({self.player},{self.action_type.value},{self.action_data},{self.duration})" + + +# TODO Abstract base class for different environments class Environment: @@ -68,7 +99,9 @@ class Environment: as_files: bool = True, ): self.lock = Lock() + """temporal lock for GUI until it uses the json state.""" self.players: dict[str, Player] = {} + """the player, keyed by their id/name.""" self.as_files = as_files @@ -81,7 +114,8 @@ class Environment: # self.counter_side_length = 1 # -> this changed! is 1 now self.item_info = self.load_item_info(item_info) - self.validate_item_info() + """The loaded item info dict. Keys are the item names.""" + # self.validate_item_info() if self.environment_config["meals"]["all"]: self.allowed_meal_names = set( [ @@ -92,7 +126,8 @@ class Environment: ) else: self.allowed_meal_names = set(self.environment_config["meals"]["list"]) - + """The allowed meals depend on the `environment_config.yml` configured behaviour. Either all meals that + are possible or only a limited subset.""" self.order_and_score = OrderAndScoreManager( order_config=self.environment_config["orders"], available_meals={ @@ -101,6 +136,7 @@ class Environment: if info.type == ItemType.Meal and item in self.allowed_meal_names }, ) + """The manager for the orders and score update.""" plate_transitions = { item: { "seconds": info.seconds, @@ -122,7 +158,7 @@ class Environment: and info.equipment.name == "CuttingBoard" }, ), - "X": Trash, + "X": Trashcan, "W": lambda pos: ServingWindow( pos, self.order_and_score, @@ -135,9 +171,13 @@ class Environment: plate_transitions=plate_transitions, pos=pos, dispensing=self.item_info["Plate"], - plate_config=self.environment_config["plates"] - if "plates" in self.environment_config - else {}, + plate_config=PlateConfig( + **( + self.environment_config["plates"] + if "plates" in self.environment_config + else {} + ) + ), ), "N": lambda pos: Dispenser(pos, self.item_info["Onion"]), # N for oNioN "_": "Free", @@ -186,9 +226,13 @@ class Environment: ), "+": SinkAddon, } + """Map of the characters in the layout file to callables returning the object/counter. In the future, + maybe replaced with a factory and the characters defined elsewhere in an config.""" self.kitchen_height: int = 0 + """The height of the kitchen, is set by the `Environment.parse_layout_file` method""" self.kitchen_width: int = 0 + """The width of the kitchen, is set by the `Environment.parse_layout_file` method""" ( self.counters, @@ -199,21 +243,30 @@ class Environment: self.init_counters() self.env_time: datetime.datetime = create_init_env_time() + """the internal time of the environment. An environment starts always with the time from + `create_init_env_time`.""" self.order_and_score.create_init_orders(self.env_time) self.start_time = self.env_time + """The relative env time when it started.""" self.env_time_end = self.env_time + timedelta( seconds=self.environment_config["game"]["time_limit_seconds"] ) + """The relative env time when it will stop/end""" log.debug(f"End time: {self.env_time_end}") def get_env_time(self): + """the internal time of the environment. An environment starts always with the time from `create_init_env_time`. + + Utility method to pass a reference to the serving window.""" return self.env_time @property def game_ended(self) -> bool: + """Whether the game is over or not based on the calculated `Environment.env_time_end`""" return self.env_time >= self.env_time_end def load_item_info(self, data) -> dict[str, ItemInfo]: + """Load `item_info.yml` if only the path is given, create ItemInfo classes and replace equipment strings with item infos.""" if self.as_files: with open(data, "r") as file: item_lookup = yaml.safe_load(file) @@ -225,15 +278,11 @@ class Environment: for item_name, item_info in item_lookup.items(): if item_info.equipment: item_info.equipment = item_lookup[item_info.equipment] - item_info.equipment.add_start_meal_to_equipment(item_info) - for item_name, item_info in item_lookup.items(): - if item_info.type == ItemType.Equipment: - # first select meals with smaller needs / ingredients - item_info.sort_start_meals() return item_lookup def validate_item_info(self): - pass + """TODO""" + raise NotImplementedError # infos = {t: [] for t in ItemType} # graph = nx.DiGraph() # for info in self.item_info.values(): @@ -348,24 +397,23 @@ class Environment: assert action.player in self.players.keys(), "Unknown player." player = self.players[action.player] - if action.act_type == "movement": + if action.action_type == ActionType.MOVEMENT: player.set_movement( - action.action, + action.action_data, self.env_time + datetime.timedelta(seconds=action.duration), ) - else: counter = self.get_facing_counter(player) if player.can_reach(counter): - if action.act_type == "pickup": + if action.action_type == ActionType.PUT: with self.lock: player.pick_action(counter) - elif action.act_type == "interact": - if action.action == "keydown": + elif action.action_type == ActionType.INTERACT: + if action.action_data == InterActionData.START: player.perform_interact_hold_start(counter) player.last_interacted_counter = counter - if action.action == "keyup": + if action.action_data == InterActionData.STOP: if player.last_interacted_counter: player.perform_interact_hold_stop(player.last_interacted_counter) @@ -491,7 +539,7 @@ class Environment: other_players = filter(lambda p: p.name != player.name, self.players.values()) def collide(p): - return np.linalg.norm(player.pos - p.pos) <= (player.radius) + (p.radius) + return np.linalg.norm(player.pos - p.pos) <= player.radius + p.radius return list(filter(collide, other_players)) @@ -509,7 +557,7 @@ class Environment: other_players = filter(lambda p: p.name != player.name, self.players.values()) def collide(p): - return np.linalg.norm(player.pos - p.pos) <= ((player.radius) + (p.radius)) + return np.linalg.norm(player.pos - p.pos) <= (player.radius + p.radius) return any(map(collide, other_players)) @@ -529,12 +577,12 @@ class Environment: ) ) - def detect_collision_player_counter(self, player: Player, counter: Counter): + @staticmethod + def detect_collision_player_counter(player: Player, counter: Counter): """Checks if the player and counter collide (overlap). A counter is modelled as a rectangle (square actually), a player is modelled as a circle. The distance of the player position (circle center) and the counter rectangle is calculated, if it is smaller than the player radius, a collision is detected. - TODO: Efficiency improvement by checking only nearest counters? Quadtree...? Args: player: The player to check the collision for. @@ -547,12 +595,27 @@ class Environment: dx = max(np.abs(cx - counter.pos[0]) - 1 / 2, 0) dy = max(np.abs(cy - counter.pos[1]) - 1 / 2, 0) distance = np.linalg.norm([dx, dy]) - return distance < (player.radius) + # TODO: Efficiency improvement by checking only nearest counters? Quadtree...? + return distance < player.radius def add_player(self, player_name: str, pos: npt.NDArray = None): + """Add a player to the environment. + + Args: + player_name: The id/name of the player to reference actions and in the state. + pos: The optional init position of the player. + """ log.debug(f"Add player {player_name} to the game") player = Player( - player_name, player_config=self.environment_config["player_config"], pos=pos + player_name, + player_config=PlayerConfig( + **( + self.environment_config["player_config"] + if "player_config" in self.environment_config + else {} + ) + ), + pos=pos, ) self.players[player.name] = player if player.pos is None: @@ -634,6 +697,11 @@ class Environment: return json_data def init_counters(self): + """Initialize the counters in the environment. + + Connect the `ServingWindow`(s) with the `PlateDispenser`. + Find and connect the `SinkAddon`s with the `Sink`s + """ plate_dispenser = self.get_counter_of_type(PlateDispenser) assert len(plate_dispenser) > 0, "No Plate Return in the environment" @@ -642,25 +710,34 @@ class Environment: for counter in self.counters: match counter: case ServingWindow(): + counter: ServingWindow # Pycharm type checker does now work for match statements? counter.add_plate_dispenser(plate_dispenser[0]) case Sink(pos=pos): + counter: Sink # Pycharm type checker does now work for match statements? assert len(sink_addons) > 0, "No SinkAddon but normal Sink" closest_addon = self.get_closest(pos, sink_addons) assert 1 - (1 * 0.05) <= np.linalg.norm( closest_addon.pos - pos ), f"No SinkAddon connected to Sink at pos {pos}" counter.set_addon(closest_addon) - pass @staticmethod - def get_closest(pos: npt.NDArray[float], counter: list[Counter]): - return min(counter, key=lambda c: np.linalg.norm(c.pos - pos)) + def get_closest(pos: npt.NDArray[float], counters: list[Counter]): + """Find the closest counter for a position + + Args: + pos: the position to find the closest one from. Needs to be the same shape as the `Counter.pos` array. + counters: target to find the closest one. + """ + return min(counters, key=lambda c: np.linalg.norm(c.pos - pos)) def get_counter_of_type(self, counter_type) -> list[Counter]: + """Filter all counters in the environment for a counter type.""" return list( filter(lambda counter: isinstance(counter, counter_type), self.counters) ) def reset_env_time(self): + """Reset the env time to the initial time, defined by `create_init_env_time`.""" self.env_time = create_init_env_time() log.debug(f"Reset env time to {self.env_time}") diff --git a/overcooked_simulator/player.py b/overcooked_simulator/player.py index e75be94681015999b966adf9e9abe1fd2e6d8f2e..da4c9a8c87bd300f911db048efa9045801181903 100644 --- a/overcooked_simulator/player.py +++ b/overcooked_simulator/player.py @@ -1,7 +1,16 @@ +"""The player contains the logic which method to call on counters and items for a pick action: + +* If the player **holds nothing**, it **picks up** the content from the counter. +* If the **item** the player **holds** can be **dropped** on the counter it will do so. +* If the counter is not a sink or plate dispenser, it checks if it **can combine the content** on the counter **with the +holding object**. If so, it picks up the content and combines it on its hands. +""" + +import dataclasses import datetime import logging from collections import deque -from typing import Optional, Any +from typing import Optional import numpy as np import numpy.typing as npt @@ -13,6 +22,18 @@ from overcooked_simulator.state_representation import PlayerState log = logging.getLogger(__name__) +@dataclasses.dataclass +class PlayerConfig: + """Configure the player attributes in the `environment.yml`.""" + + radius: float = 0.4 + """The size of the player. The size of a counter is 1""" + move_dist: float = 0.15 + """The move distance/speed of the player per action call.""" + interaction_range: float = 1.6 + """How far player can interact with counters.""" + + class Player: """Class representing a player in the game environment. A player consists of a name, their position and what the player is currently holding in the hands. @@ -23,28 +44,39 @@ class Player: def __init__( self, name: str, - player_config: dict[str, Any], + player_config: PlayerConfig, pos: Optional[npt.NDArray[float]] = None, ): self.name: str = name + """Reference for the player""" self.player_config = player_config + """Player configuration from the `environment.yml`""" + self.pos: npt.NDArray[float] | None = None + """The initial/suggested position of the player.""" if pos is not None: self.pos: npt.NDArray[float] = np.array(pos, dtype=float) - else: - self.pos = None self.holding: Optional[Item] = None + """What item the player is holding.""" self.radius: float = self.player_config["radius"] + """See `PlayerConfig.radius`.""" self.move_speed: int = self.player_config["player_speed_units_per_seconds"] + """See `PlayerConfig.move_dist`.""" self.interaction_range: int = self.player_config["interaction_range"] + """See `PlayerConfig.interaction_range`.""" self.facing_direction: npt.NDArray[float] = np.array([0, 1]) + """Current direction the player looks.""" self.last_interacted_counter: Optional[ Counter ] = None # needed to stop progress when moved away + """With which counter the player interacted with in the last environment step.""" self.current_nearest_counter: Optional[Counter] = None + """The counter to interact with.""" self.facing_point: npt.NDArray[float] = np.zeros(2, float) + """A point on the "circle" of the players border in the `facing_direction` with which the closest counter is + calculated with.""" self.current_movement: npt.NDArray[2] = np.zeros(2, float) self.movement_until: datetime.datetime = datetime.datetime.min @@ -85,6 +117,7 @@ class Player: self.update_facing_point() def update_facing_point(self): + """Update facing point on the player border circle based on the radius.""" self.facing_point = self.pos + (self.facing_direction * self.radius * 0.5) def can_reach(self, counter: Counter): @@ -125,7 +158,8 @@ class Player: if isinstance(self.holding, Plate): log.debug(self.holding.clean) - def perform_interact_hold_start(self, counter: Counter): + @staticmethod + def perform_interact_hold_start(counter: Counter): """Starts an interaction with the counter. Should be called for a keydown event, for holding down a key on the keyboard. @@ -134,7 +168,8 @@ class Player: """ counter.interact_start() - def perform_interact_hold_stop(self, counter: Counter): + @staticmethod + def perform_interact_hold_stop(counter: Counter): """Stops an interaction with the counter. Should be called for a keyup event, for letting go of a keyboard key. diff --git a/overcooked_simulator/simulation_runner.py b/overcooked_simulator/simulation_runner.py index a2127f7f13f0e14a9b257f4688099884a4983a01..6df6a33e0df43ec1a43240e9fcbea53908e7b35e 100644 --- a/overcooked_simulator/simulation_runner.py +++ b/overcooked_simulator/simulation_runner.py @@ -17,6 +17,8 @@ class Simulator(Thread): Main Simulator class which runs the game environment. Players can be registered in the game. The simulator is run as its own thread. + Is a child class of the `Thread` class from the `threading` library. + Typical usage example: ```python sim = Simulator() @@ -33,15 +35,28 @@ class Simulator(Thread): item_info_path=ROOT_DIR / "game_content" / "item_info.yaml", seed: int = 8654321, ): + """Constructor of the `Simulator class. + + Args: + env_config_path: Path to the environment configuration file. + layout_path: Path to the layout file. + frequency: Frequency of the environment step function call. + item_info_path: Path to the item information configuration file. + seed: Random seed to set the numpy random number generator. + """ # TODO look at https://builtin.com/data-science/numpy-random-seed to change to other random np.random.seed(seed) self.finished: bool = False + """The environment runs as long it is `True`""" self.step_frequency: int = frequency + """Frequency of the environment step function call.""" self.preferred_sleep_time_ns: float = 1e9 / self.step_frequency + """If the environment step call would need no computation time. The duration for one "frame".""" self.env: Environment = Environment( env_config_path, layout_path, item_info_path ) + """Reference to the `Environment`.""" super().__init__() @@ -63,14 +78,23 @@ class Simulator(Thread): Returns: The current state of the game. Currently, as dict with lists of environment objects. """ - return self.env.get_state() + def get_state_json(self): + """Get the current game state in json-like dict. + + Returns: + The gamest ate encoded in a json style nested dict. + """ + + return self.env.get_state_json() + def register_player(self, player_name: str, pos=None): """Adds a player to the environment. Args: - player: The player to be added. + player_name: the reference to the player (name/id). + pos: optional position of the player. """ self.env.add_player(player_name, pos) @@ -80,13 +104,11 @@ class Simulator(Thread): Args: players: List of players to be added. """ - for p in players: self.register_player(p) def run(self): """Starts the simulator thread. Runs in a loop until stopped.""" - overslept_in_ns = 0 self.env.reset_env_time() last_step_start = time.time_ns() diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py index 4e4ab7adac96f14999878ac6acd442f2bc8e03b5..dfb5da0068a533134ad72e85bf14289d295cc585 100644 --- a/overcooked_simulator/utils.py +++ b/overcooked_simulator/utils.py @@ -2,6 +2,7 @@ from datetime import datetime def create_init_env_time(): + """Init time of the environment time, because all environments should have the same internal time.""" return datetime( year=2000, month=1, day=1, hour=0, minute=0, second=0, microsecond=0 ) diff --git a/tests/test_start.py b/tests/test_start.py index 3dc676af7bc5c79d6a2d4bb1b78643c50f6b1721..059fc4a9dddefba6c982ba14119bc12c6820c32b 100644 --- a/tests/test_start.py +++ b/tests/test_start.py @@ -7,7 +7,12 @@ import pytest from overcooked_simulator import ROOT_DIR from overcooked_simulator.counters import Counter, CuttingBoard from overcooked_simulator.game_items import Item -from overcooked_simulator.overcooked_environment import Action, Environment +from overcooked_simulator.overcooked_environment import ( + Action, + Environment, + ActionType, + InterActionData, +) from overcooked_simulator.simulation_runner import Simulator from overcooked_simulator.utils import create_init_env_time @@ -89,7 +94,7 @@ def test_movement(): start_pos = np.array([1, 2]) sim.register_player(player_name, start_pos) move_direction = np.array([1, 0]) - move_action = Action(player_name, "movement", move_direction) + move_action = Action(player_name, ActionType.MOVEMENT, move_direction) do_moves_number = 6 for i in range(do_moves_number): sim.enter_action(move_action) @@ -157,7 +162,7 @@ def test_player_reach(): do_moves_number = 30 for i in range(do_moves_number): - move_action = Action("p1", "movement", np.array([0, -1])) + move_action = Action("p1", ActionType.MOVEMENT, np.array([0, -1])) sim.enter_action(move_action) assert player.can_reach(counter), "Player can reach counter?" @@ -177,9 +182,9 @@ def test_pickup(): sim.register_player("p1", np.array([2, 3])) player = sim.env.players["p1"] - move_down = Action("p1", "movement", np.array([0, -1])) - move_up = Action("p1", "movement", np.array([0, 1])) - pick = Action("p1", "pickup", "pickup") + move_down = Action("p1", ActionType.MOVEMENT, np.array([0, -1])) + move_up = Action("p1", ActionType.MOVEMENT, np.array([0, 1])) + pick = Action("p1", ActionType.PUT, "pickup") sim.enter_action(move_down) assert player.can_reach(counter), "Player can reach counter?" @@ -232,13 +237,13 @@ def test_processing(): player = sim.env.players["p1"] player.holding = tomato - move = Action("p1", "movement", np.array([0, -1])) - pick = Action("p1", "pickup", "pickup") + move = Action("p1", ActionType.MOVEMENT, np.array([0, -1])) + pick = Action("p1", ActionType.PUT, "pickup") sim.enter_action(move) sim.enter_action(pick) - hold_down = Action("p1", "interact", "keydown") + hold_down = Action("p1", ActionType.INTERACT, InterActionData.START) sim.enter_action(hold_down) assert tomato.name != "ChoppedTomato", "Tomato is not finished yet." @@ -247,7 +252,7 @@ def test_processing(): assert tomato.name == "ChoppedTomato", "Tomato should be finished." - button_up = Action("p1", "interact", "keyup") + button_up = Action("p1", ActionType.INTERACT, InterActionData.STOP) sim.enter_action(button_up) finally: sim.stop()