From 49231a767d17962b32fad15c5b58237f9a32883f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20Schr=C3=B6der?= <fschroeder@techfak.uni-bielefeld.de> Date: Thu, 29 Feb 2024 17:55:33 +0100 Subject: [PATCH] Refactor Action classes into separate file Action, ActionType, and InterActionData classes have been moved from the environment module into a separate file named action.py. All imports have been adjusted to reflect this change. This provides a clearer structure and improves modularity in the code, as all action related classes and enums are now organized in a single module. --- cooperative_cuisine/__init__.py | 1 + cooperative_cuisine/action.py | 51 ++++++++++++++ .../configs/agents/random_agent.py | 6 +- cooperative_cuisine/environment.py | 68 +------------------ cooperative_cuisine/game_server.py | 3 +- cooperative_cuisine/pygame_2d_vis/gui.py | 6 +- .../pygame_2d_vis/video_replay.py | 3 +- .../reinforcement_learning/gym_env.py | 4 +- tests/test_start.py | 4 +- 9 files changed, 63 insertions(+), 83 deletions(-) create mode 100644 cooperative_cuisine/action.py diff --git a/cooperative_cuisine/__init__.py b/cooperative_cuisine/__init__.py index 23cdcc74..d1eb6431 100644 --- a/cooperative_cuisine/__init__.py +++ b/cooperative_cuisine/__init__.py @@ -356,6 +356,7 @@ num_bots: 0 # Structure of the Documentation The API documentation follows the file and content structure in the repo. On the left you can find the navigation panel that brings you to the implementation of +- the **action**s the player can perform, - the **counter factory** converts the characters in the layout file to counter instances, - the **counters**, including the kitchen utility objects like dispenser, cooking counter (stove, deep fryer, oven), sink, etc., diff --git a/cooperative_cuisine/action.py b/cooperative_cuisine/action.py new file mode 100644 index 00000000..a080070b --- /dev/null +++ b/cooperative_cuisine/action.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import dataclasses +from enum import Enum +from typing import Literal + +from numpy import typing as npt + + +class ActionType(Enum): + """The 3 different types of valid actions. They can be extended via the `Action.action_data` attribute.""" + + MOVEMENT = "movement" + """move the agent.""" + PUT = "pickup" + """interaction type 1, e.g., for pickup or drop off.""" + # TODO change value to put + INTERACT = "interact" + """interaction type 2, e.g., for progressing. Start and stop interaction via `keydown` and `keyup` actions.""" + + +class InterActionData(Enum): + """The data for the interaction action: `ActionType.MOVEMENT`.""" + + START = "keydown" + "start an interaction." + STOP = "keyup" + "stop an interaction without moving away." + + +@dataclasses.dataclass +class Action: + """Action class, specifies player, action type and action itself.""" + + player: str + """Id of the player.""" + action_type: ActionType + """Type of the action to perform. Defines what action data is valid.""" + action_data: npt.NDArray[float] | InterActionData | Literal["pickup"] + """Data for the action, e.g., movement vector or start and stop interaction.""" + duration: float | int = 0 + """Duration of the action (relevant for movement)""" + + def __repr__(self): + return f"Action({self.player},{self.action_type.value},{self.action_data},{self.duration})" + + def __post_init__(self): + if isinstance(self.action_type, str): + self.action_type = ActionType(self.action_type) + if isinstance(self.action_data, str) and self.action_data != "pickup": + self.action_data = InterActionData(self.action_data) diff --git a/cooperative_cuisine/configs/agents/random_agent.py b/cooperative_cuisine/configs/agents/random_agent.py index db05c7a9..f2208817 100644 --- a/cooperative_cuisine/configs/agents/random_agent.py +++ b/cooperative_cuisine/configs/agents/random_agent.py @@ -10,11 +10,7 @@ from datetime import datetime, timedelta import numpy as np from websockets import connect -from cooperative_cuisine.environment import ( - ActionType, - Action, - InterActionData, -) +from cooperative_cuisine.action import ActionType, InterActionData, Action from cooperative_cuisine.utils import custom_asdict_factory TIME_TO_STOP_ACTION = 3.0 diff --git a/cooperative_cuisine/environment.py b/cooperative_cuisine/environment.py index c0eb6804..958c1b91 100644 --- a/cooperative_cuisine/environment.py +++ b/cooperative_cuisine/environment.py @@ -1,13 +1,11 @@ from __future__ import annotations -import dataclasses import inspect import json import logging import sys from collections import defaultdict from datetime import timedelta, datetime -from enum import Enum from pathlib import Path from random import Random from typing import Literal, TypedDict, Callable @@ -16,6 +14,7 @@ import numpy as np import numpy.typing as npt import yaml +from cooperative_cuisine.action import ActionType, InterActionData, Action from cooperative_cuisine.counter_factory import ( CounterFactory, ) @@ -66,53 +65,6 @@ log = logging.getLogger(__name__) PREVENT_SQUEEZING_INTO_OTHER_PLAYERS = True -class ActionType(Enum): - """The 3 different types of valid actions. They can be extended via the `Action.action_data` attribute.""" - - MOVEMENT = "movement" - """move the agent.""" - PUT = "pickup" - """interaction type 1, e.g., for pickup or drop off. Maybe other words: transplace?""" - # TODO change value to put - INTERACT = "interact" - """interaction type 2, e.g., for progressing. Start and stop interaction via `keydown` and `keyup` actions.""" - - -class InterActionData(Enum): - """The data for the interaction action: `ActionType.MOVEMENT`.""" - - START = "keydown" - "start an interaction." - STOP = "keyup" - "stop an interaction without moving away." - - -@dataclasses.dataclass -class Action: - """Action class, specifies player, action type and action itself.""" - - player: str - """Id of the player.""" - action_type: ActionType - """Type of the action to perform. Defines what action data is valid.""" - action_data: npt.NDArray[float] | InterActionData | Literal["pickup"] - """Data for the action, e.g., movement vector or start and stop interaction.""" - duration: float | int = 0 - """Duration of the action (relevant for movement)""" - - def __repr__(self): - return f"Action({self.player},{self.action_type.value},{self.action_data},{self.duration})" - - def __post_init__(self): - if isinstance(self.action_type, str): - self.action_type = ActionType(self.action_type) - if isinstance(self.action_data, str) and self.action_data != "pickup": - self.action_data = InterActionData(self.action_data) - - -# TODO Abstract base class for different environments - - class EnvironmentConfig(TypedDict): plates: PlateConfig game: dict[ @@ -266,23 +218,9 @@ class Environment: ), ) - progress_counter_classes = list( - filter( - lambda cl: hasattr(cl, "progress"), - dict( - inspect.getmembers( - sys.modules["cooperative_cuisine.counters"], inspect.isclass - ) - ).values(), - ) - ) - self.progressing_counters = list( - filter( - lambda c: c.__class__ in progress_counter_classes, - self.counters, - ) - ) + self.progressing_counters = [] """Counters that needs to be called in the step function via the `progress` method.""" + self.overwrite_counters(self.counters) self.order_manager.create_init_orders(self.env_time) self.start_time = self.env_time diff --git a/cooperative_cuisine/game_server.py b/cooperative_cuisine/game_server.py index 56ae02dc..f28761b0 100644 --- a/cooperative_cuisine/game_server.py +++ b/cooperative_cuisine/game_server.py @@ -29,7 +29,8 @@ from pydantic import BaseModel from starlette.websockets import WebSocketDisconnect from typing_extensions import TypedDict -from cooperative_cuisine.environment import Action, Environment +from cooperative_cuisine.action import Action +from cooperative_cuisine.environment import Environment from cooperative_cuisine.server_results import ( CreateEnvResult, PlayerInfo, diff --git a/cooperative_cuisine/pygame_2d_vis/gui.py b/cooperative_cuisine/pygame_2d_vis/gui.py index a4ce3b32..dad44975 100644 --- a/cooperative_cuisine/pygame_2d_vis/gui.py +++ b/cooperative_cuisine/pygame_2d_vis/gui.py @@ -20,11 +20,7 @@ from pygame import mixer from websockets.sync.client import connect from cooperative_cuisine import ROOT_DIR -from cooperative_cuisine.environment import ( - Action, - ActionType, - InterActionData, -) +from cooperative_cuisine.action import ActionType, InterActionData, Action from cooperative_cuisine.game_server import CreateEnvironmentConfig from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer from cooperative_cuisine.pygame_2d_vis.game_colors import colors diff --git a/cooperative_cuisine/pygame_2d_vis/video_replay.py b/cooperative_cuisine/pygame_2d_vis/video_replay.py index 1c6c348a..acb08cba 100644 --- a/cooperative_cuisine/pygame_2d_vis/video_replay.py +++ b/cooperative_cuisine/pygame_2d_vis/video_replay.py @@ -43,7 +43,8 @@ from PIL import Image from tqdm import tqdm from cooperative_cuisine import ROOT_DIR -from cooperative_cuisine.environment import Environment, Action +from cooperative_cuisine.action import Action +from cooperative_cuisine.environment import Environment from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer from cooperative_cuisine.recording import FileRecorder diff --git a/cooperative_cuisine/reinforcement_learning/gym_env.py b/cooperative_cuisine/reinforcement_learning/gym_env.py index d5f79214..6b138cd9 100644 --- a/cooperative_cuisine/reinforcement_learning/gym_env.py +++ b/cooperative_cuisine/reinforcement_learning/gym_env.py @@ -21,12 +21,10 @@ from stable_baselines3.common.vec_env import VecVideoRecorder from wandb.integration.sb3 import WandbCallback from cooperative_cuisine import ROOT_DIR +from cooperative_cuisine.action import ActionType, InterActionData, Action from cooperative_cuisine.counters import Counter, CookingCounter, Dispenser from cooperative_cuisine.environment import ( Environment, - Action, - ActionType, - InterActionData, ) from cooperative_cuisine.game_items import CookingEquipment from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer diff --git a/tests/test_start.py b/tests/test_start.py index 18e8de29..d52e19db 100644 --- a/tests/test_start.py +++ b/tests/test_start.py @@ -4,12 +4,10 @@ import numpy as np import pytest from cooperative_cuisine import ROOT_DIR +from cooperative_cuisine.action import ActionType, InterActionData, Action from cooperative_cuisine.counters import Counter, CuttingBoard from cooperative_cuisine.environment import ( - Action, Environment, - ActionType, - InterActionData, ) from cooperative_cuisine.game_items import Item, ItemInfo, ItemType from cooperative_cuisine.game_server import PlayerRequestType -- GitLab