Skip to content
Snippets Groups Projects
Commit 97539d86 authored by Florian Schröder's avatar Florian Schröder
Browse files

Refactor code for clarity and type hinting

The code has been refactored to clarify the functionality of various components and improve type hinting. Documentation for methods has been expanded. The return types of to_dict methods in several classes have been changed to return CounterState. The Counter class now correctly handles cases where an Item or Iterable[Item] is occupied. An unnecessary Boolean attribute, all_players_ready, has been commented out. Manual normalization of orientation vectors has been introduced.
parent b0201cd6
No related branches found
No related tags found
No related merge requests found
Pipeline #48037 passed
"""
The actions a player can perform.
The `Action` class is used to represent the incoming action data. There are three types of actions:
- `movement`: Move the player for a specified duration into a direction
- `pick_up_drop`: Based on the situation, pick up or drop off an item from a counter to the players hand or the other way around.
- `interact`: Interact with a counter to increase the progress, e.g., the CuttingBoard to chop ingredients or the Sink to clean plates.
The `action_data` depends on the `action_type`:
- `movement`: a 2d list/array in which direction to move. (Movement vector: complete 360° are allowed, e.g. [0.435889894, 0.9]).
- `pick_up_drop`: None,
- `interact`: `InterActionData` either "keydown" or "keyup" (start and stop the interaction).
The duration part is only needed for the `movement` action. For real-time interactions/games: 1/fps.
"""
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
......
...@@ -41,7 +41,7 @@ from collections import deque ...@@ -41,7 +41,7 @@ from collections import deque
from collections.abc import Iterable from collections.abc import Iterable
from datetime import datetime, timedelta from datetime import datetime, timedelta
from random import Random from random import Random
from typing import TYPE_CHECKING, Optional, Callable, Set from typing import TYPE_CHECKING, Callable, Set
from cooperative_cuisine.hooks import ( from cooperative_cuisine.hooks import (
Hooks, Hooks,
...@@ -63,6 +63,7 @@ from cooperative_cuisine.hooks import ( ...@@ -63,6 +63,7 @@ from cooperative_cuisine.hooks import (
PLATE_OUT_OF_KITCHEN_TIME, PLATE_OUT_OF_KITCHEN_TIME,
DROP_OFF_ON_COOKING_EQUIPMENT, DROP_OFF_ON_COOKING_EQUIPMENT,
) )
from cooperative_cuisine.state_representation import CounterState
if TYPE_CHECKING: if TYPE_CHECKING:
from cooperative_cuisine.environment import ( from cooperative_cuisine.environment import (
...@@ -85,9 +86,6 @@ from cooperative_cuisine.items import ( ...@@ -85,9 +86,6 @@ from cooperative_cuisine.items import (
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
"""The logger for this module.""" """The logger for this module."""
COUNTER_CATEGORY = "Counter"
"""The string for the `category` value in the json state representation for all counters."""
class Counter: class Counter:
"""Simple class for a counter at a specified position (center of counter). Can hold things on top. """Simple class for a counter at a specified position (center of counter). Can hold things on top.
...@@ -99,7 +97,7 @@ class Counter: ...@@ -99,7 +97,7 @@ class Counter:
self, self,
pos: npt.NDArray[float], pos: npt.NDArray[float],
hook: Hooks, hook: Hooks,
occupied_by: Optional[Item] = None, occupied_by: Item | None = None,
uid: hex = None, uid: hex = None,
**kwargs, **kwargs,
): ):
...@@ -113,11 +111,11 @@ class Counter: ...@@ -113,11 +111,11 @@ class Counter:
"""A unique id for better tracking in GUIs with assets which instance moved or changed.""" """A unique id for better tracking in GUIs with assets which instance moved or changed."""
self.pos: npt.NDArray[float] = pos self.pos: npt.NDArray[float] = pos
"""The position of the counter.""" """The position of the counter."""
self.occupied_by: Optional[Item] = occupied_by self.occupied_by: Item | Iterable[Item] | None = occupied_by
"""What is on top of the counter, e.g., `Item`s.""" """What is on top of the counter, e.g., `Item`s."""
self.active_effects: list[Effect] = [] self.active_effects: list[Effect] = []
"""The effects that currently affect the usage of the counter.""" """The effects that currently affect the usage of the counter."""
self.hook = hook self.hook: Hooks = hook
"""Reference to the hook manager.""" """Reference to the hook manager."""
self.orientation: npt.NDArray[float] = np.array([0, 1], dtype=float) self.orientation: npt.NDArray[float] = np.array([0, 1], dtype=float)
"""In what direction the counter is facing.""" """In what direction the counter is facing."""
...@@ -127,13 +125,25 @@ class Counter: ...@@ -127,13 +125,25 @@ class Counter:
"""Is something on top of the counter.""" """Is something on top of the counter."""
return self.occupied_by is not None return self.occupied_by is not None
def set_orientation(self, orientation: npt.NDArray[float]) -> None: def set_orientation(self, orientation: npt.NDArray[float]):
"""This method sets the orientation of an object to the specified numpy array.
The provided numpy array is normalized if its norm is not equal to 1, ensuring that it represents a valid
orientation. The resulting orientation is then assigned to the 'orientation' attribute of the object. If the
norm of the provided numpy array is already equal to 1, it is assigned directly to the 'orientation'
attribute without normalization.
Args:
orientation: A 2D numpy array representing the orientation of an object.
"""
if not np.isclose(np.linalg.norm(orientation), 1): if not np.isclose(np.linalg.norm(orientation), 1):
self.orientation = orientation / np.linalg.norm(orientation) self.orientation = orientation / np.linalg.norm(orientation)
else: else:
self.orientation = orientation self.orientation = orientation
def pick_up(self, on_hands: bool = True, player: str = "0") -> Item | None: def pick_up(
self, on_hands: bool = True, player: str = "0"
) -> Item | None | Iterable[Item]:
"""Gets called upon a player performing the pickup action. If the counter can give something to """Gets called upon a player performing the pickup action. If the counter can give something to
the player, it does so. In the standard counter this is when an item is on the counter. the player, it does so. In the standard counter this is when an item is on the counter.
...@@ -235,6 +245,22 @@ class Counter: ...@@ -235,6 +245,22 @@ class Counter:
def _do_single_tool_interaction( def _do_single_tool_interaction(
passed_time: timedelta, tool: Item, target: Item | Counter passed_time: timedelta, tool: Item, target: Item | Counter
) -> bool: ) -> bool:
"""This method is used to perform a single tool interaction on a target entity.
It calculates the progress of the interaction based on the amount of time passed and the properties of the
tool and target entity. If the tool has suitable effects for the target entity, the progress percentage is
updated and the method returns True. If the progress exceeds the maximum value, the effect is removed from
the target entity and the method returns True. Otherwise, the method returns False indicating that the tool
interaction was unsuccessful.
Args:
passed_time: A timedelta object representing the amount of time passed.
tool: An Item object representing the tool being used.
target: An Item or Counter object representing the target entity.
Returns:
A boolean value indicating whether the tool interaction was successful.
"""
suitable_effects = [ suitable_effects = [
e for e in target.active_effects if e.name in tool.item_info.needs e for e in target.active_effects if e.name in tool.item_info.needs
] ]
...@@ -249,13 +275,19 @@ class Counter: ...@@ -249,13 +275,19 @@ class Counter:
return False return False
def do_hand_free_interaction(self, passed_time: timedelta, now: datetime): def do_hand_free_interaction(self, passed_time: timedelta, now: datetime):
"""Called by environment step function for time progression.
Args:
passed_time: the time passed since the last progress call
now: the current env time. **Not the same as `datetime.now`**.
"""
... ...
def to_dict(self) -> dict: def to_dict(self) -> CounterState:
"""For the state representation. Only the relevant attributes are put into the dict.""" """For the state representation. Only the relevant attributes are put into the dict."""
return { return {
"id": self.uuid, "id": self.uuid,
"category": COUNTER_CATEGORY, "category": "Counter",
"type": self.__class__.__name__, "type": self.__class__.__name__,
"pos": self.pos.tolist(), "pos": self.pos.tolist(),
"orientation": self.orientation.tolist(), "orientation": self.orientation.tolist(),
...@@ -468,7 +500,7 @@ class Dispenser(Counter): ...@@ -468,7 +500,7 @@ class Dispenser(Counter):
} }
return Item(**kwargs) return Item(**kwargs)
def to_dict(self) -> dict: def to_dict(self) -> CounterState:
d = super().to_dict() d = super().to_dict()
d.update((("type", self.__repr__()),)) d.update((("type", self.__repr__()),))
return d return d
...@@ -694,7 +726,7 @@ class CookingCounter(Counter): ...@@ -694,7 +726,7 @@ class CookingCounter(Counter):
def __repr__(self): def __repr__(self):
return f"{self.name}(pos={self.pos},occupied_by={self.occupied_by})" return f"{self.name}(pos={self.pos},occupied_by={self.occupied_by})"
def to_dict(self) -> dict: def to_dict(self) -> CounterState:
d = super().to_dict() d = super().to_dict()
d.update((("type", self.name),)) d.update((("type", self.name),))
return d return d
...@@ -741,7 +773,6 @@ class Sink(Counter): ...@@ -741,7 +773,6 @@ class Sink(Counter):
return len(self.occupied_by) != 0 return len(self.occupied_by) != 0
def do_hand_free_interaction(self, passed_time: timedelta, now: datetime): def do_hand_free_interaction(self, passed_time: timedelta, now: datetime):
"""Called by environment step function for time progression"""
if ( if (
self.occupied self.occupied
and self.occupied_by[-1].name in self.transition_needs and self.occupied_by[-1].name in self.transition_needs
......
...@@ -130,7 +130,7 @@ class FireEffectManager(EffectManager): ...@@ -130,7 +130,7 @@ class FireEffectManager(EffectManager):
"""A boolean indicating whether the fire burns ingredients and meals.""" """A boolean indicating whether the fire burns ingredients and meals."""
self.effect_to_timer: dict[str:datetime] = {} self.effect_to_timer: dict[str:datetime] = {}
"""A dictionary mapping effect uuids to their corresponding timers.""" """A dictionary mapping effect uuids to their corresponding timers."""
self.next_finished_timer = datetime.max self.next_finished_timer: datetime = datetime.max
"""A datetime representing the time when the next effect will finish.""" """A datetime representing the time when the next effect will finish."""
self.active_effects: list[Tuple[Effect, Item | Counter]] = [] self.active_effects: list[Tuple[Effect, Item | Counter]] = []
"""A list of tuples representing the active effects and their target items or counters.""" """A list of tuples representing the active effects and their target items or counters."""
......
...@@ -182,7 +182,7 @@ class StateRepresentation(BaseModel): ...@@ -182,7 +182,7 @@ class StateRepresentation(BaseModel):
info_msg: list[tuple[str, str]] info_msg: list[tuple[str, str]]
"""Info messages for the players to be displayed.""" """Info messages for the players to be displayed."""
# is added: # is added:
all_players_ready: bool # all_players_ready: bool
"""Added by the game server, indicate if all players are ready and actions are passed to the environment.""" """Added by the game server, indicate if all players are ready and actions are passed to the environment."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment