Newer
Older
import json
from datetime import timedelta
import time
from pathlib import Path
import cv2
import numpy as np
import yaml
from matplotlib import pyplot as plt
from cooperative_cuisine import ROOT_DIR
from cooperative_cuisine.action import ActionType, Action
from cooperative_cuisine.environment import Environment
from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer, CacheFlags
def run_benchmark(GRID_SIZE, ITERS, OPTIMIZATIONS):
env = Environment(
env_config=environment_config,
layout_config=layout,
item_info=item_info,
as_files=False,
)
env.add_player(PLAYER)
visualizer: Visualizer = Visualizer(config=visualization_config)
visualizer.create_player_colors(n=1)
visualizer.set_grid_size(grid_size=GRID_SIZE)
print("\n")
print(f"Optimizations: {OPTIMIZATIONS}")
durations = []
for i in range(ITERS):
duration = 1 / 60
action = Action(
PLAYER,
ActionType.MOVEMENT,
np.array([np.sin(i / 12), np.cos(i / 12)]),
duration,
)
env.perform_action(action)
env.step(
timedelta(seconds=duration)
)
json_dict = env.get_state(player_id=PLAYER)
time_start = time.perf_counter()
observation = visualizer.get_state_image(state=json_dict, cache_flags=OPTIMIZATIONS["cache_counters"])
img = cv2.resize(observation.transpose((1, 0, 2)), (400, 400), interpolation=cv2.INTER_NEAREST)
cv2.imshow("env", img[:, :, ::-1])
cv2.waitKey(1)
durations.append(time.perf_counter() - time_start)
print(f"Mean duration: %.6fs" % np.mean(durations))
if __name__ == "__main__":
environment_config_path = (
ROOT_DIR / "configs" / "environment_config.yaml"
# layout_path: Path = ROOT_DIR / "configs" / "layouts" / "basic.layout"
layout_path: Path = ROOT_DIR / "reinforcement_learning" / "rl.layout"
item_info_path = ROOT_DIR / "configs" / "item_info.yaml"
with open(item_info_path, "r") as file:
item_info = file.read()
with open(layout_path, "r") as file:
layout = file.read()
with open(environment_config_path, "r") as file:
environment_config = file.read()
with open(ROOT_DIR / "pygame_2d_vis" / "visualization.yaml", "r") as file:
visualization_config = yaml.safe_load(file)
OPTIMIZATIONS = {"use_array3d": True, "cache_counters": CacheFlags.BACKGROUND | CacheFlags.COUNTERS}
run_benchmark(GRID_SIZE, ITERS, OPTIMIZATIONS)