-
Fabian Heinrich authored
# Conflicts: # overcooked_simulator/game_content/environment_config.yaml # overcooked_simulator/game_server.py # overcooked_simulator/order.py # overcooked_simulator/overcooked_environment.py # overcooked_simulator/utils.py
Fabian Heinrich authored# Conflicts: # overcooked_simulator/game_content/environment_config.yaml # overcooked_simulator/game_server.py # overcooked_simulator/order.py # overcooked_simulator/overcooked_environment.py # overcooked_simulator/utils.py
order.py 25.25 KiB
"""
You can configure the order creation/generation via the `environment_config.yml`.
It is very configurable by letting you reference own Python classes and functions.
```yaml
orders:
serving_not_ordered_meals: null
order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration ''
order_gen_kwargs:
...
```
`serving_not_ordered_meals` expects a function. It received a meal as an argument and should return a
tuple of a bool and the score. If the bool is true, the score will be added to the score. Otherwise, it will not
accept the meal for serving.
The `order_gen_class` should be a child of the `OrderGeneration` class. The `order_gen_kwargs` depend then on your
class referenced.
This file defines the following classes:
- `Order`
- `OrderGeneration`
- `OrderAndScoreManager`
Further, it defines same implementations for the basic order generation based on random sampling:
- `RandomOrderGeneration`
- `simple_score_calc_gen_func`
- `simple_expired_penalty`
- `zero`
For an easier usage of the random orders, also some classes for type hints and dataclasses are defined:
- `RandomOrderKwarg`
- `RandomFuncConfig`
- `ScoreCalcFuncType`
- `ScoreCalcGenFuncType`
- `ExpiredPenaltyFuncType`
For the scoring of using the trashcan the `penalty_for_each_item` example function is defined. You can set/replace it
in the `environment_config`.
## Code Documentation
"""
from __future__ import annotations
import dataclasses
import logging
import uuid
from abc import abstractmethod
from collections import deque
from datetime import datetime, timedelta
from random import Random
from typing import Callable, Tuple, Any, Deque, Protocol, TypedDict, Type
from overcooked_simulator.game_items import Item, Plate, ItemInfo
from overcooked_simulator.hooks import (
Hooks,
SERVE_NOT_ORDERED_MEAL,
SERVE_WITHOUT_PLATE,
COMPLETED_ORDER,
INIT_ORDERS,
NEW_ORDERS,
ORDER_DURATION_SAMPLE,
)
log = logging.getLogger(__name__)
"""The logger for this module."""
ORDER_CATEGORY = "Order"
"""The string for the `category` value in the json state representation for all orders."""
class OrderConfig(TypedDict):
"""The configuration of the order in the `environment_config`under the `order` key."""
order_gen_class: Type[OrderGeneration]
"""The class that should handle the order generation."""
order_gen_kwargs: dict[str, Any]
"""The additional kwargs for the order gen class."""
serving_not_ordered_meals: Callable[[Item], Tuple[bool, float]]
""""""
@dataclasses.dataclass
class Order:
"""Datawrapper for Orders"""
meal: ItemInfo
"""The meal to serve and that should be cooked."""
start_time: datetime
"""The start time relative to the env_time. On which the order is returned from the get_orders func."""
max_duration: timedelta
"""The duration after which the order expires."""
score_calc: ScoreCalcFuncType
"""The function that calculates the score of the served meal/fulfilled order."""
timed_penalties: list[
Tuple[timedelta, float] | Tuple[timedelta, float, int, timedelta]
]
"""List of timed penalties when the order is not fulfilled."""
expired_penalty: float
"""The penalty to the score if the order expires"""
uuid: str = dataclasses.field(default_factory=lambda: uuid.uuid4().hex)
"""The unique identifier for the order."""
finished_info: dict[str, Any] = dataclasses.field(default_factory=dict)
"""Is set after the order is completed."""
_timed_penalties: list[Tuple[datetime, float]] = dataclasses.field(
default_factory=list
)
"""Converted penalties the env is working with from the `timed_penalties`"""
def create_penalties(self, env_time: datetime):
"""Create the general timed penalties list to check for penalties after some time the order is still not
fulfilled."""
for penalty_info in self.timed_penalties:
match penalty_info:
case (offset, penalty):
self._timed_penalties.append((env_time + offset, penalty))
case (duration, penalty, number_repeat, offset):
self._timed_penalties.extend(
[
(env_time + offset + (duration * i), penalty)
for i in range(number_repeat)
]
)
class OrderGeneration:
"""Base class for generating orders.
You can set your child class via the `environment_config.yml`.
Example:
```yaml
orders:
order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration ''
kwargs:
...
```
"""
def __init__(self, available_meals: dict[str, ItemInfo], hook: Hooks, random: Random, **kwargs):
self.available_meals: list[ItemInfo] = list(available_meals.values())
"""Available meals restricted through the `environment_config.yml`."""
self.hook = hook
"""Reference to the hook manager."""
self.random = random
"""Random instance."""
@abstractmethod
def init_orders(self, now) -> list[Order]:
"""Get the orders the environment starts with."""
...
@abstractmethod
def get_orders(
self,
passed_time: timedelta,
now: datetime,
new_finished_orders: list[Order],
expired_orders: list[Order],
) -> list[Order]:
"""Orders for each progress call. Should often be the empty list."""
...
class OrderAndScoreManager:
"""The Order and Score Manager that is called from the serving window."""
def __init__(
self,
order_config,
available_meals: dict[str, ItemInfo],
hook: Hooks,
random: Random,
):
self.random = random
"""Random instance."""
self.score: float = 0.0
"""The current score of the environment."""
self.order_gen: OrderGeneration = order_config["order_gen_class"](
available_meals=available_meals,
hook=hook,
random=random,
kwargs=order_config["order_gen_kwargs"],
)
"""The order generation."""
self.serving_not_ordered_meals: Callable[
[Item], Tuple[bool, float]
] = order_config["serving_not_ordered_meals"]
"""Function that decides if not ordered meals can be served and what score it gives"""
self.penalty_for_trash: Callable[[Item | list[Item]], float] = (
order_config["penalty_for_trash"]
if "penalty_for_trash" in order_config
else zero
)
"""Function that calculates the penalty for items which were put into the trashcan."""
self.available_meals = available_meals
"""The meals for that orders can be sampled from."""
self.open_orders: Deque[Order] = deque()
"""Current open orders. This attribute is used for the environment state."""
# TODO log who / which player served which meal -> for split scores
self.served_meals: list[Tuple[Item, datetime]] = []
"""List of served meals. Maybe for the end screen."""
self.last_finished: list[Order] = []
"""Cache last finished orders for `OrderGeneration.get_orders` call. From the served meals."""
self.next_relevant_time: datetime = datetime.max
"""For reduced order checking. Store the next time when to create an order or check for penalties."""
self.last_expired: list[Order] = []
"""Cache last expired orders for `OrderGeneration.get_orders` call."""
self.hook = hook
"""Reference to the hook manager."""
def update_next_relevant_time(self):
"""For more efficient checking when to do something in the progress call."""
next_relevant_time = datetime.max
for order in self.open_orders:
next_relevant_time = min(
next_relevant_time, order.start_time + order.max_duration
)
for penalty in order._timed_penalties:
next_relevant_time = min(next_relevant_time, penalty[0])
self.next_relevant_time = next_relevant_time
def serve_meal(self, item: Item, env_time: datetime) -> bool:
"""Is called by the ServingWindow to serve a meal. Returns True if the meal can be served and should be
"deleted" from the hands of the player."""
if isinstance(item, Plate):
meal = item.get_potential_meal()
if meal is not None:
if meal.name in self.available_meals:
order = self.find_order_for_meal(meal)
if order is None:
if self.serving_not_ordered_meals:
accept, score = self.serving_not_ordered_meals(meal)
self.hook(
SERVE_NOT_ORDERED_MEAL,
accept=accept,
score=score,
meal=meal,
)
if accept:
log.info(
f"Serving meal without order {meal.name!r} with score {score}"
)
self.increment_score(score)
self.served_meals.append((meal, env_time))
return accept
log.info(
f"Do not serve meal {meal.name!r} because it is not ordered"
)
return False
order, index = order
score = order.score_calc(
relative_order_time=env_time - order.start_time,
order=order,
)
self.increment_score(score)
order.finished_info = {
"end_time": env_time,
"score": score,
}
log.info(
f"Serving meal {meal.name!r} with order with score {score}"
)
self.last_finished.append(order)
del self.open_orders[index]
self.served_meals.append((meal, env_time))
self.hook(COMPLETED_ORDER, score=score, order=order, meal=meal)
return True
else:
self.hook(SERVE_WITHOUT_PLATE, item=item)
log.info(f"Do not serve item {item}")
return False
def increment_score(self, score: int | float):
"""Add a value to the current score and log it."""
self.score += score
log.debug(f"Score: {self.score}")
def create_init_orders(self, env_time):
"""Create the initial orders in an environment."""
init_orders = self.order_gen.init_orders(env_time)
self.hook(INIT_ORDERS)
self.setup_penalties(new_orders=init_orders, env_time=env_time)
self.open_orders.extend(init_orders)
def progress(self, passed_time: timedelta, now: datetime):
"""Check expired orders and check order generation."""
new_orders = self.order_gen.get_orders(
passed_time=passed_time,
now=now,
new_finished_orders=self.last_finished,
expired_orders=self.last_expired,
)
if new_orders:
self.hook(NEW_ORDERS, new_orders=new_orders)
self.setup_penalties(new_orders=new_orders, env_time=now)
self.open_orders.extend(new_orders)
self.last_finished = []
self.last_expired = []
if new_orders or self.next_relevant_time <= now:
# reduce checking calls
remove_orders: list[int] = []
for index, order in enumerate(self.open_orders):
if now >= order.start_time + order.max_duration:
# orders expired
self.increment_score(order.expired_penalty)
remove_orders.append(index)
continue # no penalties for expired orders
remove_penalties = []
for i, (penalty_time, penalty) in enumerate(order.timed_penalties):
# check penalties
if penalty_time < now:
# TODO add hook
self.score -= penalty
remove_penalties.append(i)
for i in reversed(remove_penalties):
# or del order.timed_penalties[index]
order.timed_penalties.pop(i)
expired_orders: list[Order] = []
for remove_order in reversed(remove_orders):
expired_orders.append(self.open_orders[remove_order])
del self.open_orders[remove_order]
self.last_expired = expired_orders
self.update_next_relevant_time()
def find_order_for_meal(self, meal) -> Tuple[Order, int] | None:
"""Get the order that will be fulfilled for a meal. At the moment the oldest order in the list that has the
same meal (name)."""
for index, order in enumerate(self.open_orders):
if order.meal.name == meal.name:
return order, index
@staticmethod
def setup_penalties(new_orders: list[Order], env_time: datetime):
"""Call the `Order.create_penalties` method for new orders."""
for order in new_orders:
order.create_penalties(env_time)
def order_state(self) -> list[dict]:
"""Similar to the `to_dict` in `Item` and `Counter`. Relevant for the state of the environment"""
return [
{
"id": order.uuid,
"category": ORDER_CATEGORY,
"meal": order.meal.name,
"start_time": order.start_time.isoformat(),
"max_duration": order.max_duration.total_seconds(),
}
for order in self.open_orders
]
def apply_penalty_for_using_trash(self, remove: Item | list[Item]) -> float:
"""Is called if a item is put into the trashcan."""
penalty = self.penalty_for_trash(remove)
self.increment_score(penalty)
return penalty
class ScoreCalcFuncType(Protocol):
"""Typed kwargs of the expected `Order.score_calc` function. Which is also returned by the
`RandomOrderKwarg.score_calc_gen_func`.
The function should calculate the score for the completed orders.
Args:
relative_order_time: `timedelta` the duration how long the order was active.
order: `Order` the order that was completed.
Returns:
`float`: the score for a completed order and duration of the order.
"""
def __call__(self, relative_order_time: timedelta, order: Order) -> float:
...
class ScoreCalcGenFuncType(Protocol):
"""Typed kwargs of the expected function for the `RandomOrderKwarg.score_calc_gen_func`.
Generate the ScoreCalcFunc for an order based on its meal, duration etc.
Args:
meal: `ItemInfo` the type of meal the order orders.
duration: `timedelta` the duration after the order expires.
now: `datetime` the environment time the order is created.
kwargs: `dict` the static kwargs defined in the `environment_config.yml`
Returns:
`ScoreCalcFuncType` a reference to a function that calculates the score for a completed meal.
"""
def __call__(
self,
meal: ItemInfo,
duration: timedelta,
now: datetime,
kwargs: dict,
**other_kwargs,
) -> ScoreCalcFuncType:
...
class ExpiredPenaltyFuncType(Protocol):
"""Typed kwargs of the expected function for the `RandomOrderKwarg.expired_penalty_func`.
An example is the `zero` function.
Args:
item: `ItemInfo` the meal of the order that expired. It is calculated before the order is active.
"""
def __call__(self, item: ItemInfo, **kwargs) -> float:
...
def zero(item: ItemInfo, **kwargs) -> float:
"""Example and default for the `RandomOrderKwarg.expired_penalty_func` function.
Just no penalty for expired orders.
Returns:
zero / 0.0
"""
return 0.0
class RandomFuncConfig(TypedDict):
"""Types of the dict for sampling with different random functions from the [`random` library](https://docs.python.org/3/library/random.html).
Example:
Sampling [uniform](https://docs.python.org/3/library/random.html#random.uniform)ly between `10` and `20`.
```yaml
func: uniform
kwargs:
a: 10
b: 20
```
Or in Python:
```python
random_func = {'func': 'uniform', 'kwargs': {'a': 10, 'b': 20}}
```
"""
func: Callable
"""the name of a functions in the `random` library."""
kwargs: dict
"""the kwargs of the functions in the `random` library."""
@dataclasses.dataclass
class RandomOrderKwarg:
num_start_meals: int
"""Number of meals sampled at the start."""
sample_on_serving: bool
"""Only sample the delay for the next order after a meal was served."""
sample_on_dur_random_func: RandomFuncConfig
"""How to sample the delay of the next incoming order. Either after a new meal was served or the last order was
generated (based on the `sample_on_serving` attribute)."""
max_orders: int
"""How many orders can maximally be active at the same time."""
order_duration_random_func: RandomFuncConfig
"""How long the order is alive until it expires. If `sample_on_serving` is `true` all orders have no expire time."""
score_calc_gen_func: ScoreCalcGenFuncType
"""The function that generates the `Order.score_calc` for each order."""
score_calc_gen_kwargs: dict
"""The additional static kwargs for `score_calc_gen_func`."""
expired_penalty_func: Callable[[ItemInfo], float] = zero
"""The function that calculates the penalty for a meal that was not served."""
expired_penalty_kwargs: dict = dataclasses.field(default_factory=dict)
"""The additional static kwargs for the `expired_penalty_func`."""
class RandomOrderGeneration(OrderGeneration):
"""A simple order generation based on random sampling with two options.
Either sample the delay when a new order should come in after the last order comes in or after a meal was served
(and an order got removed).
To configure it align your kwargs with the `RandomOrderKwarg` class.
You can set this order generation in your `environment_config.yml` with
```yaml
orders:
order_gen_class: !!python/name:overcooked_simulator.order.RandomOrderGeneration ''
kwargs:
order_duration_random_func:
# how long should the orders be alive
# 'random' library call with getattr, kwargs are passed to the function
func: uniform
kwargs:
a: 40
b: 60
max_orders: 6
# maximum number of active orders at the same time
num_start_meals: 3
# number of orders generated at the start of the environment
sample_on_dur_random_func:
# 'random' library call with getattr, kwargs are passed to the function
func: uniform
kwargs:
a: 10
b: 20
sample_on_serving: false
# Sample the delay for the next order only after a meal was served.
score_calc_gen_func: !!python/name:overcooked_simulator.order.simple_score_calc_gen_func ''
score_calc_gen_kwargs:
# the kwargs for the score_calc_gen_func
other: 0
scores:
Burger: 15
OnionSoup: 10
Salad: 5
TomatoSoup: 10
expired_penalty_func: !!python/name:overcooked_simulator.order.simple_expired_penalty ''
expired_penalty_kwargs:
default: -5
```
"""
def __init__(self, available_meals: dict[str, ItemInfo], hook: Hooks, random: Random, **kwargs):
super().__init__(available_meals, hook, random, **kwargs)
self.kwargs: RandomOrderKwarg = RandomOrderKwarg(**kwargs["kwargs"])
self.next_order_time: datetime | None = datetime.max
self.number_cur_orders: int = 0
self.needed_orders: int = 0
"""For the sample on dur but when it was restricted due to max order number."""
def init_orders(self, now) -> list[Order]:
self.number_cur_orders = self.kwargs.num_start_meals
if not self.kwargs.sample_on_serving:
self.create_random_next_time_delta(now)
return self.create_orders_for_meals(
self.random.choices(self.available_meals, k=self.kwargs.num_start_meals),
now,
self.kwargs.sample_on_serving,
)
def get_orders(
self,
passed_time: timedelta,
now: datetime,
new_finished_orders: list[Order],
expired_orders: list[Order],
) -> list[Order]:
self.number_cur_orders -= len(new_finished_orders)
self.number_cur_orders -= len(expired_orders)
if self.kwargs.sample_on_serving:
if new_finished_orders:
self.create_random_next_time_delta(now)
return []
if self.needed_orders:
self.needed_orders -= len(new_finished_orders)
self.needed_orders = max(self.needed_orders, 0)
self.number_cur_orders += len(new_finished_orders)
return self.create_orders_for_meals(
self.random.choices(self.available_meals, k=len(new_finished_orders)),
now,
)
if self.next_order_time <= now:
if self.number_cur_orders >= self.kwargs.max_orders:
self.needed_orders += 1
else:
if not self.kwargs.sample_on_serving:
self.create_random_next_time_delta(now)
else:
self.next_order_time = datetime.max
self.number_cur_orders += 1
return self.create_orders_for_meals(
[self.random.choice(self.available_meals)],
now,
)
return []
def create_orders_for_meals(
self, meals: list[ItemInfo], now: datetime, no_time_limit: bool = False
) -> list[Order]:
orders = []
for meal in meals:
if no_time_limit:
duration = datetime.max - now
else:
duration = timedelta(
seconds=getattr(
self.random, self.kwargs.order_duration_random_func["func"]
)(**self.kwargs.order_duration_random_func["kwargs"])
)
self.hook(
ORDER_DURATION_SAMPLE,
duration=duration,
)
log.info(f"Create order for meal {meal} with duration {duration}")
orders.append(
Order(
meal=meal,
start_time=now,
max_duration=duration,
score_calc=self.kwargs.score_calc_gen_func(
meal=meal,
duration=duration,
now=now,
kwargs=self.kwargs.score_calc_gen_kwargs,
),
timed_penalties=[],
expired_penalty=self.kwargs.expired_penalty_func(
meal, **self.kwargs.expired_penalty_kwargs
),
)
)
return orders
def create_random_next_time_delta(self, now: datetime):
self.next_order_time = now + timedelta(
seconds=getattr(self.random, self.kwargs.sample_on_dur_random_func["func"])(
**self.kwargs.sample_on_dur_random_func["kwargs"]
)
)
log.info(f"Next order in {self.next_order_time}")
def simple_score_calc_gen_func(
meal: Item, duration: timedelta, now: datetime, kwargs: dict, **other_kwargs
) -> Callable:
"""An example for the `RandomOrderKwarg.score_calc_gen_func` that selects the score for an order based on its meal from a list.
Example:
```yaml
score_calc_gen_func: !!python/name:overcooked_simulator.order.simple_score_calc_gen_func ''
score_calc_gen_kwargs:
# the kwargs for the score_calc_gen_func
other: 0
scores:
Burger: 15
OnionSoup: 10
Salad: 5
TomatoSoup: 10
```
"""
scores = kwargs["scores"]
other = kwargs["other"]
def score_calc(relative_order_time: timedelta, order: Order) -> float:
if order.meal.name in scores:
return scores[order.meal.name]
return other
return score_calc
def simple_expired_penalty(item: ItemInfo, default: float, **kwargs) -> float:
"""Example for the `RandomOrderKwarg.expired_penalty_func` function.
A static default.
Example:
```yaml
expired_penalty_func: !!python/name:overcooked_simulator.order.simple_expired_penalty ''
expired_penalty_kwargs:
default: -5
```
"""
return default
def serving_not_ordered_meals_with_zero_score(meal: Item) -> Tuple[bool, float | int]:
"""Not ordered meals are accepted but do not affect the score."""
return True, 0
def penalty_for_each_item(remove: Item | list[Item]) -> float:
if isinstance(remove, list):
return -len(remove) * 5
return -5