diff --git a/overcooked_simulator/game_content/agents/arch_config.yml b/overcooked_simulator/game_content/agents/arch_config.yml new file mode 100644 index 0000000000000000000000000000000000000000..e7108df0e6c1b618eae57b948a6ed139e4fae445 --- /dev/null +++ b/overcooked_simulator/game_content/agents/arch_config.yml @@ -0,0 +1,24 @@ +concurrency: MultiProcessing + +communication: + communication_prefs: + - !name:ipaacar_com_service.communications.ipaacar_com.IPAACARInfo + +modules: + connection: + module_info: !name:cocosy_agent.modules.connection_module.ConnectionModule + mean_frequency_step: 2 # 2: every 0.5 seconds + working_memory: + module_info: !name:cocosy_agent.modules.working_memory_module.WorkingMemoryModule + subtask_selection: + module_info: !name:cocosy_agent.modules.random_subtask_module.RandomSubtaskModule + action_execution: + module_info: !name:cocosy_agent.modules.action_execution_module.ActionExecutionModule + mean_frequency_step: 10 # 2: every 0.5 seconds + # gui: + # module_info: !name:aaambos.std.guis.pysimplegui.pysimplegui_window.PySimpleGUIWindowModule + # window_title: Counting GUI + # topics_to_show: [["SubtaskDecision", "cocosy_agent.conventions.communication.SubtaskDecision", ["task_type"]], ["ActionControl", "cocosy_agent.conventions.communication.ActionControl", ["action_type"]]] + status_manager: + module_info: !name:aaambos.std.modules.module_status_manager.ModuleStatusManager + gui: false \ No newline at end of file diff --git a/overcooked_simulator/game_content/agents/random_agent.py b/overcooked_simulator/game_content/agents/random_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..9b666bb8c2e4e06d8b250bf8d743eb2965b65e26 --- /dev/null +++ b/overcooked_simulator/game_content/agents/random_agent.py @@ -0,0 +1,221 @@ +import argparse +import asyncio +import dataclasses +import json +import random +import time +from collections import defaultdict +from datetime import datetime, timedelta + +import numpy as np +from websockets import connect + +from overcooked_simulator.overcooked_environment import ( + ActionType, + Action, + InterActionData, +) +from overcooked_simulator.utils import custom_asdict_factory + + +async def agent(): + parser = argparse.ArgumentParser("Random agent") + parser.add_argument("--uri", type=str) + parser.add_argument("--player_id", type=str) + parser.add_argument("--player_hash", type=str) + parser.add_argument("--step_time", type=float, default=0.5) + + args = parser.parse_args() + + async with connect(args.uri) as websocket: + await websocket.send( + json.dumps({"type": "ready", "player_hash": args.player_hash}) + ) + await websocket.recv() + + ended = False + + counters = None + + player_info = {} + current_agent_pos = None + interaction_counter = None + + last_interacting = False + last_interact_progress = None + + threshold = datetime.max + + task_type = None + task_args = None + + started_interaction = False + still_interacting = False + current_nearest_counter_id = None + + while not ended: + time.sleep(args.step_time) + await websocket.send( + json.dumps({"type": "get_state", "player_hash": args.player_hash}) + ) + state = json.loads(await websocket.recv()) + + if counters is None: + counters = defaultdict(list) + for counter in state["counters"]: + counters[counter["type"]].append(counter) + + for player in state["players"]: + if player["id"] == args.player_id: + player_info = player + current_agent_pos = player["pos"] + if player["current_nearest_counter_id"]: + if ( + current_nearest_counter_id + != player["current_nearest_counter_id"] + ): + for counter in state["counters"]: + if ( + counter["id"] + == player["current_nearest_counter_id"] + ): + interaction_counter = counter + current_nearest_counter_id = player[ + "current_nearest_counter_id" + ] + break + if last_interacting: + if ( + not interaction_counter + or not interaction_counter["occupied_by"] + or isinstance(interaction_counter["occupied_by"], list) + or ( + interaction_counter["occupied_by"][ + "progress_percentage" + ] + == 1.0 + ) + ): + last_interacting = False + last_interact_progress = None + else: + if ( + interaction_counter + and interaction_counter["occupied_by"] + and not isinstance(interaction_counter["occupied_by"], list) + ): + if ( + last_interact_progress + != interaction_counter["occupied_by"][ + "progress_percentage" + ] + ): + last_interact_progress = interaction_counter[ + "occupied_by" + ]["progress_percentage"] + last_interacting = True + + break + + if task_type: + if threshold < datetime.now(): + print( + args.player_hash, args.player_id, "---Threshold---Too long---" + ) + task_type = None + match task_type: + case "GOTO": + diff = np.array(task_args) - np.array(current_agent_pos) + dist = np.linalg.norm(diff) + if dist > 1.2: + if dist != 0: + await websocket.send( + json.dumps( + { + "type": "action", + "action": dataclasses.asdict( + Action( + args.player_id, + ActionType.MOVEMENT, + (diff / dist).tolist(), + args.step_time + 0.01, + ), + dict_factory=custom_asdict_factory, + ), + "player_hash": args.player_hash, + } + ) + ) + await websocket.recv() + else: + task_type = None + task_args = None + case "INTERACT": + if not started_interaction or ( + still_interacting and interaction_counter + ): + if not started_interaction: + started_interaction = True + + still_interacting = True + await websocket.send( + json.dumps( + { + "type": "action", + "action": dataclasses.asdict( + Action( + args.player_id, + ActionType.INTERACT, + InterActionData.START, + ), + dict_factory=custom_asdict_factory, + ), + "player_hash": args.player_hash, + } + ) + ) + await websocket.recv() + else: + still_interacting = False + started_interaction = False + task_type = None + task_args = None + case "PUT": + await websocket.send( + json.dumps( + { + "type": "action", + "action": dataclasses.asdict( + Action( + args.player_id, + ActionType.PUT, + "pickup", + ), + dict_factory=custom_asdict_factory, + ), + "player_hash": args.player_hash, + } + ) + ) + await websocket.recv() + task_type = None + task_args = None + case None: + ... + + if not task_type: + task_type = random.choice(["GOTO", "PUT", "INTERACT"]) + threshold = datetime.now() + timedelta(seconds=15.0) + if task_type == "GOTO": + counter_type = random.choice(list(counters.keys())) + task_args = random.choice(counters[counter_type])["pos"] + print(args.player_hash, args.player_id, task_type, counter_type) + else: + print(args.player_hash, args.player_id, task_type) + task_args = None + + ended = state["ended"] + + +if __name__ == "__main__": + asyncio.run(agent()) diff --git a/overcooked_simulator/game_content/agents/run_config.yml b/overcooked_simulator/game_content/agents/run_config.yml new file mode 100644 index 0000000000000000000000000000000000000000..f9c6cdf64e2133ae910dd13d90a8ec734368cd00 --- /dev/null +++ b/overcooked_simulator/game_content/agents/run_config.yml @@ -0,0 +1,15 @@ +general: + agent_name: cocosy_agent + instance: _dev + local_agent_directories: ~/aaambos_agents + plus: + agent_websocket: ws://localhost:8000:/ws/player/MY_CLIENT_ID + player_hash: abcdefghijklmnopqrstuvwxyz + agent_id: 1 + +logging: + log_level_command_line: INFO + +supervisor: + run_time_manager_class: !name:aaambos.std.supervision.instruction_run_time_manager.instruction_run_time_manager.InstructionRunTimeManager + diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml index 5dd5a82db0a9201d5d9dbb2ccbf1b7310625192d..6e58cec45fa0d7c81701afe42a427e8475102447 100644 --- a/overcooked_simulator/game_content/environment_config.yaml +++ b/overcooked_simulator/game_content/environment_config.yaml @@ -85,7 +85,7 @@ orders: player_config: radius: 0.4 - player_speed_units_per_seconds: 8 + player_speed_units_per_seconds: 6 interaction_range: 1.6 diff --git a/overcooked_simulator/game_content/layouts/empty.layout b/overcooked_simulator/game_content/layouts/empty.layout index 2fa1dd8c29c076e9d94e9a305843ff87bebb29f1..1160842744d30ff014b5e02ac7b5bea7d2421e3d 100644 --- a/overcooked_simulator/game_content/layouts/empty.layout +++ b/overcooked_simulator/game_content/layouts/empty.layout @@ -1,7 +1,8 @@ -______ -______ -______ -______ -______ -______ -_____P \ No newline at end of file +_______ +_______ +_______ +_______ +__A____ +_______ +_______ +______P \ No newline at end of file diff --git a/overcooked_simulator/game_content/layouts/large.layout b/overcooked_simulator/game_content/layouts/large.layout new file mode 100644 index 0000000000000000000000000000000000000000..6933567897246f90838b6a7efd8f15e48ca5b9bf --- /dev/null +++ b/overcooked_simulator/game_content/layouts/large.layout @@ -0,0 +1,23 @@ +#QU#F###O#T#################N###L###B# +#____________________________________# +#____________________________________M +#____________________________________# +#____________________________________# +#____________________________________K +W____________________________________I +#____________________________________# +#____________________________________# +#__A_____A___________________________D +#____________________________________# +#____________________________________# +#____________________________________# +#____________________________________# +#____________________________________# +C____________________________________E +#____________________________________# +#____________________________________# +#____________________________________# +#____________________________________# +C____________________________________G +#____________________________________# +#P#####S+####X#####S+################# \ No newline at end of file diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py index 1323b350ede52892918ed48c2de1409baba6b798..b6c2675bb1c74c8a750d12c65cb7717f5f70f8d6 100644 --- a/overcooked_simulator/game_server.py +++ b/overcooked_simulator/game_server.py @@ -220,7 +220,9 @@ class EnvironmentHandler: self.envs[env_id].last_step_time = time.time_ns() self.envs[env_id].environment.reset_env_time() - def get_state(self, player_hash: str) -> str: # -> StateRepresentation as json + def get_state( + self, player_hash: str + ) -> str | int: # -> StateRepresentation as json """Get the current state representation of the environment for a player. Args: @@ -237,6 +239,10 @@ class EnvironmentHandler: return self.envs[ self.player_data[player_hash].env_id ].environment.get_json_state() + if player_hash not in self.player_data: + return 1 + if self.player_data[player_hash].env_id not in self.envs: + return 2 def pause_env(self, manager_id: str, env_id: str, reason: str): """Pause the specified environment. @@ -599,7 +605,17 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul } case PlayerRequestType.GET_STATE: - return environment_handler.get_state(message_dict["player_hash"]) + state = environment_handler.get_state(message_dict["player_hash"]) + if isinstance(state, int): + return { + "request_type": message_dict["type"], + "status": 400, + "msg": "env id of player not in running envs" + if state == 2 + else "player hash unknown", + "player_hash": None, + } + return state case PlayerRequestType.ACTION: assert ( diff --git a/overcooked_simulator/gui_2d_vis/drawing.py b/overcooked_simulator/gui_2d_vis/drawing.py index 60a751c9783df1e58ef2c4741691f1c2c6964b49..4c6bd198404743e6d3a2179d803015e57a582b94 100644 --- a/overcooked_simulator/gui_2d_vis/drawing.py +++ b/overcooked_simulator/gui_2d_vis/drawing.py @@ -108,6 +108,7 @@ class Visualizer: screen: pygame.Surface, state: dict, grid_size: int, + controlled_player_idxs: list[int], ): """Draws the game state on the given surface. @@ -131,6 +132,14 @@ class Visualizer: grid_size, ) + for idx, col in zip(controlled_player_idxs, [colors["blue"], colors["red"]]): + pygame.draw.circle( + screen, + col, + np.array(state["players"][idx]["pos"]) * grid_size + (grid_size // 2), + (grid_size / 2), + ) + self.draw_players( screen, state["players"], @@ -148,7 +157,6 @@ class Visualizer: height: The kitchen height. grid_size: The gridsize to base the background shapes on. """ - block_size = grid_size // 2 # Set the size of the grid block surface.fill(colors[self.config["Kitchen"]["ground_tiles_color"]]) for x in range(0, width, block_size): @@ -230,7 +238,7 @@ class Visualizer: if USE_PLAYER_COOK_SPRITES: pygame.draw.circle( screen, - self.player_colors[p_idx], + colors[self.player_colors[p_idx]], pos - facing * grid_size * 0.25, grid_size * 0.2, ) @@ -278,7 +286,7 @@ class Visualizer: ) if player_dict["holding"] is not None: - holding_item_pos = pos + (20 * facing) + holding_item_pos = pos + (grid_size * 0.5 * facing) self.draw_item( pos=holding_item_pos, grid_size=grid_size, diff --git a/overcooked_simulator/gui_2d_vis/gui_theme.json b/overcooked_simulator/gui_2d_vis/gui_theme.json index cabbe0368805a697727c738673b1783ac1e56e7e..8e7e819a989131890f69df5df701b99f810ecb50 100644 --- a/overcooked_simulator/gui_2d_vis/gui_theme.json +++ b/overcooked_simulator/gui_2d_vis/gui_theme.json @@ -6,7 +6,7 @@ "disabled_bg": "#25292e", "selected_bg": "#193754", "dark_bg": "#15191e", - "normal_text": "#c5cbd8", + "normal_text": "#000000", "hovered_text": "#FFFFFF", "selected_text": "#FFFFFF", "disabled_text": "#6d736f", @@ -92,5 +92,70 @@ "normal_border": "#000000", "normal_text": "#000000" } + }, + "#players": { + "colours": { + "dark_bg": "#fffacd", + "normal_border": "#fffacd" + } + }, + "#players_players": { + "colours": { + "dark_bg": "#fffacd" + } + }, + "#players_bots": { + "colours": { + "dark_bg": "#fffacd" + } + }, + "#number_players_label": { + "colours": { + "dark_bg": "#fffacd", + "normal_text": "#000000" + }, + "font": { + "size": 14, + "bold": 1 + } + }, + "#number_bots_label": { + "colours": { + "dark_bg": "#fffacd", + "normal_text": "#000000" + }, + "font": { + "size": 14, + "bold": 1, + "colour": "#000000" + } + }, + "#multiple_keysets_button": { + "font": { + "size": 12, + "bold": 1, + "colour": "#000000" + } + }, + "#split_players_button": { + "font": { + "size": 12, + "bold": 1, + "colour": "#000000" + } + }, + "#controller_button": { + "font": { + "size": 12, + "bold": 1, + "colour": "#000000" + } + }, + "#quantity_button": { + "font": { + "size": 24, + "bold": 1, + "colour": "#000000" + } } } \ No newline at end of file diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py index 7de54adc03572259ef3c8ed5c850715c60f1c5b9..9f66332533000c02de2a640d8e55e4884e684588 100644 --- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py +++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py @@ -5,6 +5,7 @@ import logging import random import sys from enum import Enum +from subprocess import Popen import numpy as np import pygame @@ -45,7 +46,14 @@ class PlayerKeySet: First four keys are for movement. Order: Down, Up, Left, Right. 5th key is for interacting with counters. 6th key ist for picking up things or dropping them. """ - def __init__(self, player_name: str | int, keys: list[pygame.key]): + def __init__( + self, + move_keys: list[pygame.key], + interact_key: pygame.key, + pickup_key: pygame.key, + switch_key: pygame.key, + players: list[int], + ): """Creates a player key set which contains information about which keyboard keys control the player. Movement keys in the following order: Down, Up, Left, Right @@ -54,14 +62,32 @@ class PlayerKeySet: player_name: The name of the player to control. keys: The keys which control this player in the following order: Down, Up, Left, Right, Interact, Pickup. """ - self.name = player_name - self.player_keys = keys - self.move_vectors = [[-1, 0], [1, 0], [0, -1], [0, 1]] - self.key_to_movement = { - key: vec for (key, vec) in zip(self.player_keys[:-2], self.move_vectors) + self.move_vectors: list[list[int]] = [[-1, 0], [1, 0], [0, -1], [0, 1]] + self.key_to_movement: dict[pygame.key, list[int]] = { + key: vec for (key, vec) in zip(move_keys, self.move_vectors) } - self.interact_key = self.player_keys[-2] - self.pickup_key = self.player_keys[-1] + self.move_keys: list[pygame.key] = move_keys + self.interact_key: pygame.key = interact_key + self.pickup_key: pygame.key = pickup_key + self.switch_key: pygame.key = switch_key + self.controlled_players: list[int] = players + self.current_player: int = players[0] if players else 0 + self.current_idx = 0 + self.other_keyset: list[PlayerKeySet] = [] + + def set_controlled_players(self, controlled_players: list[int]) -> None: + self.controlled_players = controlled_players + self.current_player = self.controlled_players[0] + self.current_idx = 0 + + def next_player(self) -> None: + self.current_idx = (self.current_idx + 1) % len(self.controlled_players) + if self.other_keyset: + for ok in self.other_keyset: + if ok.current_idx == self.current_idx: + self.next_player() + return + self.current_player = self.controlled_players[self.current_idx] class PyGameGUI: @@ -69,8 +95,6 @@ class PyGameGUI: def __init__( self, - player_names: list[str | int], - player_keys: list[pygame.key], url: str, port: int, manager_ids: list[str], @@ -79,15 +103,8 @@ class PyGameGUI: self.FPS = 60 self.running = True - self.player_names = player_names - self.player_keys = player_keys - - self.player_key_sets: list[PlayerKeySet] = [ - PlayerKeySet(player_name, keys) - for player_name, keys in zip( - self.player_names, self.player_keys[: len(self.player_names)] - ) - ] + self.reset_gui_values() + self.key_sets: list[PlayerKeySet] = [] self.websocket_url = f"ws://{url}:{port}/ws/player/" self.websockets = {} @@ -124,7 +141,8 @@ class PyGameGUI: self.manager: pygame_gui.UIManager self.vis = Visualizer(self.visualization_config) - self.vis.create_player_colors(len(self.player_names)) + + self.sub_processes = [] def get_window_sizes(self, state: dict): kitchen_width = state["kitchen"]["width"] @@ -167,23 +185,60 @@ class PyGameGUI: grid_size, ) + def setup_player_keys(self, n=1, disjunct=False): + if n: + players = list(range(self.number_humans_to_be_added)) + key_set1 = PlayerKeySet( + move_keys=[pygame.K_a, pygame.K_d, pygame.K_w, pygame.K_s], + interact_key=pygame.K_f, + pickup_key=pygame.K_e, + switch_key=pygame.K_SPACE, + players=players, + ) + key_set2 = PlayerKeySet( + move_keys=[pygame.K_LEFT, pygame.K_RIGHT, pygame.K_UP, pygame.K_DOWN], + interact_key=pygame.K_i, + pickup_key=pygame.K_o, + switch_key=pygame.K_p, + players=players, + ) + key_sets = [key_set1, key_set2] + + if disjunct: + key_set1.set_controlled_players(players[::2]) + key_set2.set_controlled_players(players[1::2]) + elif n > 1: + key_set1.set_controlled_players(players) + key_set2.set_controlled_players(players) + key_set1.other_keyset = [key_set2] + key_set2.other_keyset = [key_set1] + key_set2.next_player() + return key_sets[:n] + else: + return [] + def handle_keys(self): """Handles keyboard inputs. Sends action for the respective players. When a key is held down, every frame an action is sent in this function. """ + keys = pygame.key.get_pressed() - for player_idx, key_set in enumerate(self.player_key_sets): - relevant_keys = [keys[k] for k in key_set.player_keys] - if any(relevant_keys[:-2]): + for key_set in self.key_sets: + current_player_name = str(key_set.current_player) + relevant_keys = [keys[k] for k in key_set.move_keys] + if any(relevant_keys): move_vec = np.zeros(2) - for idx, pressed in enumerate(relevant_keys[:-2]): + for idx, pressed in enumerate(relevant_keys): if pressed: move_vec += key_set.move_vectors[idx] if np.linalg.norm(move_vec) != 0: move_vec = move_vec / np.linalg.norm(move_vec) action = Action( - key_set.name, ActionType.MOVEMENT, move_vec, duration=1 / self.FPS + current_player_name, + ActionType.MOVEMENT, + move_vec, + duration=1 / self.FPS, ) self.send_action(action) @@ -195,22 +250,27 @@ class PyGameGUI: Args: event: Pygame event for extracting the key action. """ - for key_set in self.player_key_sets: + + for key_set in self.key_sets: + current_player_name = str(key_set.current_player) if event.key == key_set.pickup_key and event.type == pygame.KEYDOWN: - action = Action(key_set.name, ActionType.PUT, "pickup") + action = Action(current_player_name, ActionType.PUT, "pickup") self.send_action(action) if event.key == key_set.interact_key: if event.type == pygame.KEYDOWN: action = Action( - key_set.name, ActionType.INTERACT, InterActionData.START + current_player_name, ActionType.INTERACT, InterActionData.START ) self.send_action(action) elif event.type == pygame.KEYUP: action = Action( - key_set.name, ActionType.INTERACT, InterActionData.STOP + current_player_name, ActionType.INTERACT, InterActionData.STOP ) self.send_action(action) + if event.key == key_set.switch_key: + if event.type == pygame.KEYDOWN: + key_set.next_player() def init_ui_elements(self): self.manager = pygame_gui.UIManager((self.window_width, self.window_height)) @@ -218,28 +278,28 @@ class PyGameGUI: self.start_button = pygame_gui.elements.UIButton( relative_rect=pygame.Rect( - ( - (self.window_width // 2) - self.buttons_width // 2, - (self.window_height / 2) - self.buttons_height // 2, - ), - (self.buttons_width, self.buttons_height), + (0, 0), (self.buttons_width, self.buttons_height) ), text="Start Game", manager=self.manager, + anchors={"center": "center"}, ) self.start_button.can_hover() - self.quit_button = pygame_gui.elements.UIButton( - relative_rect=pygame.Rect( - ( - (self.window_width - self.buttons_width), - 0, - ), - (self.buttons_width, self.buttons_height), + quit_rect = pygame.Rect( + ( + 0, + 0, ), + (self.buttons_width, self.buttons_height), + ) + quit_rect.topright = (0, 0) + self.quit_button = pygame_gui.elements.UIButton( + relative_rect=quit_rect, text="Quit Game", manager=self.manager, object_id="#quit_button", + anchors={"right": "right", "top": "top"}, ) self.quit_button.can_hover() @@ -340,6 +400,171 @@ class PyGameGUI: object_id="#score_label", ) + ####################### + + player_selection_rect = pygame.Rect( + (0, 0), + ( + self.window_width * 0.9, + (self.window_height // 3), + ), + ) + player_selection_rect.bottom = -10 + self.player_selection_container = pygame_gui.elements.UIPanel( + player_selection_rect, + manager=self.manager, + object_id="#players", + anchors={"bottom": "bottom", "centerx": "centerx"}, + ) + + multiple_keysets_button_rect = pygame.Rect((0, 0), (190, 50)) + self.multiple_keysets_button = pygame_gui.elements.UIButton( + relative_rect=multiple_keysets_button_rect, + manager=self.manager, + container=self.player_selection_container, + text="not set", + anchors={"left": "left", "centery": "centery"}, + object_id="#multiple_keysets_button", + ) + + split_players_button_rect = pygame.Rect((0, 0), (190, 50)) + self.split_players_button = pygame_gui.elements.UIButton( + relative_rect=split_players_button_rect, + manager=self.manager, + container=self.player_selection_container, + text="not set", + anchors={"centerx": "centerx", "centery": "centery"}, + object_id="#split_players_button", + ) + if self.multiple_keysets: + self.split_players_button.show() + else: + self.split_players_button.hide() + + xbox_controller_button_rect = pygame.Rect((0, 0), (190, 50)) + xbox_controller_button_rect.right = 0 + self.xbox_controller_button = pygame_gui.elements.UIButton( + relative_rect=xbox_controller_button_rect, + manager=self.manager, + container=self.player_selection_container, + text="Controller?", + anchors={"right": "right", "centery": "centery"}, + object_id="#controller_button", + ) + + ######## + # + # panel = pygame_gui.elements.UIPanel( + # pygame.Rect((50, 50), (700, 500)), + # manager=manager, + # anchors={ + # "left": "left", + # "right": "right", + # "top": "top", + # "bottom": "bottom", + # }, + # ) + + players_container_rect = pygame.Rect( + (0, 0), + ( + self.window_width * 0.6, + self.player_selection_container.get_abs_rect().height // 3, + ), + ) + self.player_number_container = pygame_gui.elements.UIPanel( + relative_rect=players_container_rect, + manager=self.manager, + object_id="#players_players", + container=self.player_selection_container, + anchors={"top": "top", "centerx": "centerx"}, + ) + + bot_container_rect = pygame.Rect( + (0, 0), + ( + self.window_width * 0.6, + self.player_selection_container.get_abs_rect().height // 3, + ), + ) + bot_container_rect.bottom = 0 + self.bot_number_container = pygame_gui.elements.UIPanel( + relative_rect=bot_container_rect, + manager=self.manager, + object_id="#players_bots", + container=self.player_selection_container, + anchors={"bottom": "bottom", "centerx": "centerx"}, + ) + + number_players_rect = pygame.Rect((0, 0), (200, 200)) + self.added_players_label = pygame_gui.elements.UILabel( + number_players_rect, + manager=self.manager, + object_id="#number_players_label", + container=self.player_number_container, + text=f"Humans to be added: {self.number_humans_to_be_added}", + anchors={"center": "center"}, + ) + + number_bots_rect = pygame.Rect((0, 0), (200, 200)) + self.added_bots_label = pygame_gui.elements.UILabel( + number_bots_rect, + manager=self.manager, + object_id="#number_bots_label", + container=self.bot_number_container, + text=f"Bots to be added: {self.number_bots_to_be_added}", + anchors={"center": "center"}, + ) + + size = 50 + add_player_button_rect = pygame.Rect((0, 0), (size, size)) + add_player_button_rect.right = 0 + self.add_human_player_button = pygame_gui.elements.UIButton( + relative_rect=add_player_button_rect, + text="+", + manager=self.manager, + object_id="#quantity_button", + container=self.player_number_container, + anchors={"right": "right", "centery": "centery"}, + ) + self.add_human_player_button.can_hover() + + remove_player_button_rect = pygame.Rect((0, 0), (size, size)) + remove_player_button_rect.left = 0 + self.remove_human_button = pygame_gui.elements.UIButton( + relative_rect=remove_player_button_rect, + text="-", + manager=self.manager, + object_id="#quantity_button", + container=self.player_number_container, + anchors={"left": "left", "centery": "centery"}, + ) + self.remove_human_button.can_hover() + + add_bot_button_rect = pygame.Rect((0, 0), (size, size)) + add_bot_button_rect.right = 0 + self.add_bot_button = pygame_gui.elements.UIButton( + relative_rect=add_bot_button_rect, + text="+", + manager=self.manager, + object_id="#quantity_button", + container=self.bot_number_container, + anchors={"right": "right", "centery": "centery"}, + ) + self.add_bot_button.can_hover() + + remove_bot_button_rect = pygame.Rect((0, 0), (size, size)) + remove_bot_button_rect.left = 0 + self.remove_bot_button = pygame_gui.elements.UIButton( + relative_rect=remove_bot_button_rect, + text="-", + manager=self.manager, + object_id="#quantity_button", + container=self.bot_number_container, + anchors={"left": "left", "centery": "centery"}, + ) + self.remove_bot_button.can_hover() + def draw(self, state): """Main visualization function. @@ -348,6 +573,7 @@ class PyGameGUI: self.game_screen, state, self.grid_size, + [k.current_player for k in self.key_sets], ) # self.manager.draw_ui(self.main_window) @@ -414,6 +640,8 @@ class PyGameGUI: self.timer_label.hide() self.orders_label.hide() self.conclusion_label.hide() + + self.player_selection_container.show() case MenuStates.Game: self.start_button.hide() self.back_button.hide() @@ -425,6 +653,9 @@ class PyGameGUI: self.timer_label.show() self.orders_label.show() self.conclusion_label.hide() + + self.player_selection_container.hide() + case MenuStates.End: self.start_button.hide() self.back_button.show() @@ -436,6 +667,8 @@ class PyGameGUI: self.orders_label.hide() self.conclusion_label.show() + self.player_selection_container.hide() + def update_score_label(self, state): score = state["score"] self.score_label.set_text(f"Score {score}") @@ -464,7 +697,7 @@ class PyGameGUI: seed = 161616161616 creation_json = CreateEnvironmentConfig( manager_id=self.manager_id, - number_players=2, + number_players=self.number_players, environment_settings={"all_player_can_pause_game": False}, item_info_config=item_info, environment_config=environment_config, @@ -483,18 +716,86 @@ class PyGameGUI: assert isinstance(env_info, dict), "Env info must be a dictionary" self.current_env_id = env_info["env_id"] self.player_info = env_info["player_info"] - for player_id, player_info in env_info["player_info"].items(): + + state = None + for p, (player_id, player_info) in enumerate(env_info["player_info"].items()): + if p < self.number_humans_to_be_added: + websocket = connect(self.websocket_url + player_info["client_id"]) + websocket.send( + json.dumps( + {"type": "ready", "player_hash": player_info["player_hash"]} + ) + ) + assert ( + json.loads(websocket.recv())["status"] == 200 + ), "not accepted player" + self.websockets[player_id] = websocket + else: + player_hash = player_info["player_hash"] + print( + f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"' + ) + # sub = Popen( + # " ".join( + # [ + # "exec", + # "aaambos", + # "run", + # "--arch_config", + # str( + # ROOT_DIR / "game_content" / "agents" / "arch_config.yml" + # ), + # "--run_config", + # str( + # ROOT_DIR / "game_content" / "agents" / "run_config.yml" + # ), + # f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"', + # f"--instance={player_hash}", + # ] + # ), + # shell=True, + # ) + sub = Popen( + " ".join( + [ + "python", + str( + ROOT_DIR / "game_content" / "agents" / "random_agent.py" + ), + f'--uri {self.websocket_url + player_info["client_id"]}', + f"--player_hash {player_hash}", + f"--player_id {player_id}", + ] + ), + shell=True, + ) + self.sub_processes.append(sub) + + if p + 1 == self.number_humans_to_be_added: + self.state_player_id = player_id + websocket.send( + json.dumps( + {"type": "get_state", "player_hash": player_info["player_hash"]} + ) + ) + state = json.loads(websocket.recv()) + + if not self.number_humans_to_be_added: + player_id = "0" + player_info = env_info["player_info"][player_id] websocket = connect(self.websocket_url + player_info["client_id"]) websocket.send( json.dumps({"type": "ready", "player_hash": player_info["player_hash"]}) ) assert json.loads(websocket.recv())["status"] == 200, "not accepted player" self.websockets[player_id] = websocket - self.state_player_id = player_id - websocket.send( - json.dumps({"type": "get_state", "player_hash": player_info["player_hash"]}) - ) - state = json.loads(websocket.recv()) + self.state_player_id = player_id + websocket.send( + json.dumps( + {"type": "get_state", "player_hash": player_info["player_hash"]} + ) + ) + state = json.loads(websocket.recv()) ( self.window_width, @@ -507,6 +808,20 @@ class PyGameGUI: def start_button_press(self): self.menu_state = MenuStates.Game + self.number_players = ( + self.number_humans_to_be_added + self.number_bots_to_be_added + ) + self.vis.create_player_colors(self.number_players) + + if self.split_players: + assert ( + self.number_humans_to_be_added > 1 + ), "Not enough players for key configuration." + num_key_set = 2 if self.multiple_keysets else 1 + self.key_sets = self.setup_player_keys( + min(self.number_humans_to_be_added, num_key_set), self.split_players + ) + self.setup_environment() self.set_window_size() @@ -519,6 +834,9 @@ class PyGameGUI: def back_button_press(self): self.menu_state = MenuStates.Start self.reset_window_size() + + self.update_selection_elements() + log.debug("Pressed back button") def quit_button_press(self): @@ -527,6 +845,8 @@ class PyGameGUI: log.debug("Pressed quit button") def reset_button_press(self): + # self.reset_gui_values() + requests.post( f"{self.request_url}/manage/stop_env", json={ @@ -552,6 +872,41 @@ class PyGameGUI: self.reset_window_size() log.debug("Pressed finished button") + def reset_gui_values(self): + self.currently_controlled_player_idx = 0 + self.number_humans_to_be_added = 1 + self.number_bots_to_be_added = 0 + self.split_players = False + self.multiple_keysets = False + self.player_minimum = 1 + + def update_selection_elements(self): + if self.number_humans_to_be_added <= self.player_minimum: + self.remove_human_button.disable() + self.number_humans_to_be_added = self.player_minimum + else: + self.remove_human_button.enable() + self.number_humans_to_be_added = max( + self.player_minimum, self.number_humans_to_be_added + ) + + text = "WASD+ARROW" if self.multiple_keysets else "WASD" + self.multiple_keysets_button.set_text(text) + # self.split_players_button + self.added_players_label.set_text( + f"Humans to be added: {self.number_humans_to_be_added}" + ) + self.added_bots_label.set_text( + f"Bots to be added: {self.number_bots_to_be_added}" + ) + text = "Yes" if self.split_players else "No" + self.split_players_button.set_text(f"Split players: {text}") + + if self.multiple_keysets: + self.split_players_button.show() + else: + self.split_players_button.hide() + def send_action(self, action: Action): """Sends an action to the game environment. @@ -587,12 +942,22 @@ class PyGameGUI: } ) ) - # self.websocket.send(json.dumps("get_state")) - # state_dict = json.loads(self.websocket.recv()) state = json.loads(self.websockets[self.state_player_id].recv()) return state def disconnect_websockets(self): + for sub in self.sub_processes: + try: + sub.kill() + # pgrp = os.getpgid(sub.pid) + # os.killpg(pgrp, signal.SIGINT) + # subprocess.run( + # "kill $(ps aux | grep 'aaambos' | awk '{print $2}')", shell=True + # ) + except ProcessLookupError: + pass + + self.sub_processes = [] for websocket in self.websockets.values(): websocket.close() @@ -610,6 +975,8 @@ class PyGameGUI: self.init_ui_elements() self.manage_button_visibility() + self.update_selection_elements() + # Game loop self.running = True while self.running: @@ -624,6 +991,11 @@ class PyGameGUI: if event.type == pygame_gui.UI_BUTTON_PRESSED: match event.ui_element: case self.start_button: + if not ( + self.number_humans_to_be_added + + self.number_bots_to_be_added + ): + continue self.start_button_press() case self.back_button: self.back_button_press() @@ -640,6 +1012,33 @@ class PyGameGUI: self.disconnect_websockets() self.start_button_press() + case self.add_human_player_button: + self.number_humans_to_be_added += 1 + case self.remove_human_button: + self.number_humans_to_be_added = max( + 0, self.number_humans_to_be_added - 1 + ) + case self.add_bot_button: + self.number_bots_to_be_added += 1 + case self.remove_bot_button: + self.number_bots_to_be_added = max( + 0, self.number_bots_to_be_added - 1 + ) + case self.multiple_keysets_button: + self.multiple_keysets = not self.multiple_keysets + self.split_players = False + case self.split_players_button: + self.split_players = not self.split_players + if self.split_players: + self.player_minimum = 2 + else: + self.player_minimum = 1 + + case self.xbox_controller_button: + print("xbox_controller_button pressed.") + + self.update_selection_elements() + self.manage_button_visibility() if ( @@ -696,26 +1095,8 @@ class PyGameGUI: sys.exit() -def main( - url: str, port: int, manager_ids: list[str], enable_websocket_logging: bool = False -): - # TODO maybe read the player names and keyboard keys from config file? - setup_logging(enable_websocket_logging) - - keys1 = [ - pygame.K_LEFT, - pygame.K_RIGHT, - pygame.K_UP, - pygame.K_DOWN, - pygame.K_SPACE, - pygame.K_i, - ] - keys2 = [pygame.K_a, pygame.K_d, pygame.K_w, pygame.K_s, pygame.K_f, pygame.K_e] - - number_players = 2 +def main(url: str, port: int, manager_ids: list[str]): gui = PyGameGUI( - list(map(str, range(number_players))), - [keys1, keys2], url=url, port=port, manager_ids=manager_ids, diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index 67640bce5ad7ee51115e620e03774d5ea7c482e8..45a1f7c3bd177773500c4a04a628a04b4e6b745b 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -14,6 +14,7 @@ from typing import Literal, TypedDict, Callable, Tuple import numpy as np import numpy.typing as npt import yaml +from scipy.spatial import distance_matrix from overcooked_simulator.counter_factory import CounterFactory from overcooked_simulator.counters import ( @@ -55,6 +56,9 @@ from overcooked_simulator.utils import create_init_env_time, get_closest log = logging.getLogger(__name__) +PREVENT_SQUEEZING_INTO_OTHER_PLAYERS = True + + class ActionType(Enum): """The 3 different types of valid actions. They can be extended via the `Action.action_data` attribute.""" @@ -222,8 +226,17 @@ class Environment: ) = self.parse_layout_file() self.hook(LAYOUT_FILE_PARSED) - self.world_borders_x = [-0.5, self.kitchen_width - 0.5] - self.world_borders_y = [-0.5, self.kitchen_height - 0.5] + self.counter_positions = np.array([c.pos for c in self.counters]) + + self.world_borders = np.array( + [[-0.5, self.kitchen_width - 0.5], [-0.5, self.kitchen_height - 0.5]], + dtype=float, + ) + + self.player_movement_speed = self.environment_config["player_config"][ + "player_speed_units_per_seconds" + ] + self.player_radius = self.environment_config["player_config"]["radius"] progress_counter_classes = list( filter( @@ -261,7 +274,7 @@ class Environment: environment_config=env_config, layout_config=self.layout_config, seed=seed, - env_start_time_worldtime=datetime.now() + env_start_time_worldtime=datetime.now(), ) @property @@ -269,6 +282,15 @@ class Environment: """Whether the game is over or not based on the calculated `Environment.env_time_end`""" return self.env_time >= self.env_time_end + def set_collision_arrays(self): + number_players = len(self.players) + self.world_borders_lower = self.world_borders[np.newaxis, :, 0].repeat( + number_players, axis=0 + ) + self.world_borders_upper = self.world_borders[np.newaxis, :, 1].repeat( + number_players, axis=0 + ) + def get_env_time(self): """the internal time of the environment. An environment starts always with the time from `create_init_env_time`. @@ -514,7 +536,7 @@ class Environment: facing_counter = get_closest(player.facing_point, self.counters) return facing_counter - def perform_movement(self, player: Player, duration: timedelta): + def perform_movement(self, duration: timedelta): """Moves a player in the direction specified in the action.action. If the player collides with a counter or other player through this movement, then they are not moved. (The extended code with the two ifs is for sliding movement at the counters, which feels a bit smoother. @@ -526,145 +548,112 @@ class Environment: Detects collisions with other players and pushes them out of the way. Args: - player: The player to move. duration: The duration for how long the movement to perform. """ - old_pos = player.pos.copy() - - move_vector = player.current_movement - d_time = duration.total_seconds() - step = move_vector * (player.player_speed_units_per_seconds * d_time) - - player.move(step) - if self.detect_collision(player): - collided_players = self.get_collided_players(player) - for collided_player in collided_players: - pushing_vector = collided_player.pos - player.pos - if np.linalg.norm(pushing_vector) != 0: - pushing_vector = pushing_vector / np.linalg.norm(pushing_vector) - - old_pos_other = collided_player.pos.copy() - collided_player.current_movement = pushing_vector - self.perform_movement(collided_player, duration) - if self.detect_collision_counters( - collided_player - ) or self.detect_collision_world_bounds(collided_player): - collided_player.move_abs(old_pos_other) - player.move_abs(old_pos) - - old_pos = player.pos.copy() - - step_sliding = step.copy() - step_sliding[0] = 0 - player.move(step_sliding * 0.5) - player.turn(step) - - if self.detect_collision(player): - player.move_abs(old_pos) - - old_pos = player.pos.copy() - - step_sliding = step.copy() - step_sliding[1] = 0 - player.move(step_sliding * 0.5) - player.turn(step) - - if self.detect_collision(player): - player.move_abs(old_pos) - - if self.counters: - closest_counter = self.get_facing_counter(player) - player.current_nearest_counter = ( - closest_counter if player.can_reach(closest_counter) else None - ) - def detect_collision(self, player: Player): - """Detect collisions between the player and other players or counters. - - Args: - player: The player for which to check collisions. - - Returns: True if the player is intersecting with any object in the environment. - """ - return ( - len(self.get_collided_players(player)) != 0 - or self.detect_collision_counters(player) - or self.detect_collision_world_bounds(player) + player_positions = np.array([p.pos for p in self.players.values()], dtype=float) + player_movement_vectors = np.array( + [ + p.current_movement if self.env_time <= p.movement_until else [0, 0] + for p in self.players.values() + ], + dtype=float, ) + number_players = len(player_positions) - def get_collided_players(self, player: Player) -> list[Player]: - """Detects collisions between the queried player and other players. Returns the list of the collided players. - A player is modelled as a circle. Collision is detected if the distance between the players is smaller - than the sum of the radius's. - - Args: - player: The player to check collisions with other players for. - - Returns: The list of other players the player collides with. - - """ - other_players = filter(lambda p: p.name != player.name, self.players.values()) - - def collide(p): - return np.linalg.norm(player.pos - p.pos) <= player.radius + p.radius - - return list(filter(collide, other_players)) - - def detect_player_collision(self, player: Player): - """Detects collisions between the queried player and other players. - A player is modelled as a circle. Collision is detected if the distance between the players is smaller - than the sum of the radius's. - - Args: - player: The player to check collisions with other players for. - - Returns: True if the player collides with other players, False if not. + targeted_positions = player_positions + ( + player_movement_vectors * (self.player_movement_speed * d_time) + ) - """ - other_players = filter(lambda p: p.name != player.name, self.players.values()) + # Collisions player between player + distances_players_after_scipy = distance_matrix( + targeted_positions, targeted_positions + ) - def collide(p): - return np.linalg.norm(player.pos - p.pos) <= (player.radius + p.radius) + player_diff_vecs = -( + player_positions[:, np.newaxis, :] - player_positions[np.newaxis, :, :] + ) + collision_idxs = distances_players_after_scipy < (2 * self.player_radius) + eye_idxs = np.eye(number_players, number_players, dtype=bool) + collision_idxs[eye_idxs] = False - return any(map(collide, other_players)) + # Player push players around + player_diff_vecs[collision_idxs == False] = 0 + push_vectors = np.sum(player_diff_vecs, axis=0) - def detect_collision_counters(self, player: Player): - """Checks for collisions of the queried player with each counter. + updated_movement = push_vectors + player_movement_vectors + new_positions = player_positions + ( + updated_movement * (self.player_movement_speed * d_time) + ) - Args: - player: The player to check collisions with counters for. + # Collisions players counters + counter_diff_vecs = ( + new_positions[:, np.newaxis, :] - self.counter_positions[np.newaxis, :, :] + ) + counter_distances = np.max((np.abs(counter_diff_vecs)), axis=2) + # counter_distances = np.linalg.norm(counter_diff_vecs, axis=2) + closest_counter_positions = self.counter_positions[ + np.argmin(counter_distances, axis=1) + ] + + nearest_counter_to_player = closest_counter_positions - new_positions + + collided = np.min(counter_distances, axis=1) < self.player_radius + 0.5 + relevant_axes = np.abs(nearest_counter_to_player).argmax(axis=1) + + for idx, player in enumerate(player_positions): + axis = relevant_axes[idx] + + if collided[idx]: + # collide with counter left or top + if nearest_counter_to_player[idx][axis] < 0: + updated_movement[idx, axis] = max(updated_movement[idx, axis], 0) + # collide with counter right or bottom + if nearest_counter_to_player[idx][axis] > 0: + updated_movement[idx, axis] = min(updated_movement[idx, axis], 0) + + new_positions = player_positions + ( + updated_movement * (self.player_movement_speed * d_time) + ) - Returns: True if the player collides with any counter, False if not. + # Check if pushed players collide with counters or second closest is to close + counter_diff_vecs = ( + new_positions[:, np.newaxis, :] - self.counter_positions[np.newaxis, :, :] + ) + counter_distances = np.max((np.abs(counter_diff_vecs)), axis=2) + collided2 = np.min(counter_distances, axis=1) < self.player_radius + 0.5 + # player do not move if they collide after pushing/sliding + new_positions[collided2] = player_positions[collided2] + # Players that pushed the player that can not be pushed do also no movement + # in the future these players could slide around the player? + for idx, collides in enumerate(collided2): + if collides: + new_positions[collision_idxs[idx]] = player_positions[ + collision_idxs[idx] + ] - """ - return any( - map( - lambda counter: self.detect_collision_player_counter(player, counter), - self.counters, + # Check if two moving players collide into each other: No movement (Future: slide?) + if PREVENT_SQUEEZING_INTO_OTHER_PLAYERS: + distances_players_after_scipy = distance_matrix( + new_positions, new_positions ) + collision_idxs = distances_players_after_scipy < (2 * self.player_radius) + collision_idxs[eye_idxs] = False + collision_idxs = np.any(collision_idxs, axis=1) + new_positions[collision_idxs] = player_positions[collision_idxs] + + # Collisions player world borders + new_positions = np.clip( + new_positions, + self.world_borders_lower + self.player_radius, + self.world_borders_upper - self.player_radius, ) - @staticmethod - def detect_collision_player_counter(player: Player, counter: Counter): - """Checks if the player and counter collide (overlap). - A counter is modelled as a rectangle (square actually), a player is modelled as a circle. - The distance of the player position (circle center) and the counter rectangle is calculated, if it is - smaller than the player radius, a collision is detected. - - Args: - player: The player to check the collision for. - counter: The counter to check the collision for. - - Returns: True if player and counter overlap, False if not. - - """ - cx, cy = player.pos - dx = max(np.abs(cx - counter.pos[0]) - 1 / 2, 0) - dy = max(np.abs(cy - counter.pos[1]) - 1 / 2, 0) - distance = np.linalg.norm([dx, dy]) - # TODO: Efficiency improvement by checking only nearest counters? Quadtree...? - return distance < player.radius + for idx, p in enumerate(self.players.values()): + if not (new_positions[idx] == player_positions[idx]).all(): + p.turn(player_movement_vectors[idx]) + p.move_abs(new_positions[idx]) def add_player(self, player_name: str, pos: npt.NDArray = None): """Add a player to the environment. @@ -702,6 +691,7 @@ class Environment: log.debug("No free positions left in kitchens") player.update_facing_point() + self.set_collision_arrays() self.hook(PLAYER_ADDED, player_name=player_name, pos=pos) def detect_collision_world_bounds(self, player: Player): @@ -734,8 +724,8 @@ class Environment: else: for player in self.players.values(): player.progress(passed_time, self.env_time) - if self.env_time <= player.movement_until: - self.perform_movement(player, passed_time) + + self.perform_movement(passed_time) for counter in self.progressing_counters: counter.progress(passed_time=passed_time, now=self.env_time) diff --git a/overcooked_simulator/player.py b/overcooked_simulator/player.py index 036f83b21842465f49e785efa8be2e35dfedb944..5633a491248ae8b9448d457607b86b8e74121c08 100644 --- a/overcooked_simulator/player.py +++ b/overcooked_simulator/player.py @@ -57,15 +57,9 @@ class Player: self.holding: Optional[Item] = None """What item the player is holding.""" + self.player_config = player_config + """See `PlayerConfig`.""" - self.radius: float = player_config.radius - """See `PlayerConfig.radius`.""" - self.player_speed_units_per_seconds: float | int = ( - player_config.player_speed_units_per_seconds - ) - """See `PlayerConfig.move_dist`.""" - self.interaction_range: float = player_config.interaction_range - """See `PlayerConfig.interaction_range`.""" self.facing_direction: npt.NDArray[float] = np.array([0, 1]) """Current direction the player looks.""" self.last_interacted_counter: Optional[ @@ -127,7 +121,9 @@ class Player: def update_facing_point(self): """Update facing point on the player border circle based on the radius.""" - self.facing_point = self.pos + (self.facing_direction * self.radius * 0.5) + self.facing_point = self.pos + ( + self.facing_direction * self.player_config.radius * 0.5 + ) def can_reach(self, counter: Counter): """Checks whether the player can reach the counter in question. Simple check if the distance is not larger @@ -140,7 +136,10 @@ class Player: True if the counter is in range of the player, False if not. """ - return np.linalg.norm(counter.pos - self.facing_point) <= self.interaction_range + return ( + np.linalg.norm(counter.pos - self.facing_point) + <= self.player_config.interaction_range + ) def put_action(self, counter: Counter): """Performs the pickup-action with the counter. Handles the logic of what the player is currently holding, diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py index b5376c240afe6c19083f3bacc97e8ecef9ea1277..b78d44af869d9a53ec17f5f2cd8eda191e0b4e17 100644 --- a/overcooked_simulator/utils.py +++ b/overcooked_simulator/utils.py @@ -22,6 +22,7 @@ from overcooked_simulator import ROOT_DIR if TYPE_CHECKING: from overcooked_simulator.counters import Counter +from overcooked_simulator.player import Player def create_init_env_time(): @@ -46,6 +47,18 @@ def get_closest(point: npt.NDArray[float], counters: list[Counter]): ] +def get_collided_players( + player_idx, players: list[Player], player_radius: float +) -> list[Player]: + player_positions = np.array([p.pos for p in players], dtype=float) + distances = distance_matrix(player_positions, player_positions)[player_idx] + player_radiuses = np.array([player_radius for p in players], dtype=float) + collisions = distances <= player_radiuses + player_radius + collisions[player_idx] = False + + return [players[idx] for idx, val in enumerate(collisions) if val] + + def get_touching_counters(target: Counter, counters: list[Counter]) -> list[Counter]: return list( filter( diff --git a/tests/test_start.py b/tests/test_start.py index 86b0d10395229378b7bd4a31a6707ca2bda6e1bc..8efe8da9ecf1c67974f88a18b1bcdb0b1cd3786a 100644 --- a/tests/test_start.py +++ b/tests/test_start.py @@ -46,6 +46,7 @@ def layout_config(): with open(layout_path, "r") as file: layout = file.read() return layout + env.add_player("0") @pytest.fixture @@ -80,7 +81,7 @@ def test_movement(env_config, layout_empty_config, item_info): player_name = "1" start_pos = np.array([3, 4]) env.add_player(player_name, start_pos) - env.players[player_name].player_speed_units_per_seconds = 1 + env.player_movement_speed = 1 move_direction = np.array([1, 0]) move_action = Action(player_name, ActionType.MOVEMENT, move_direction, duration=0.1) do_moves_number = 3 @@ -89,22 +90,19 @@ def test_movement(env_config, layout_empty_config, item_info): env.step(timedelta(seconds=0.1)) expected = start_pos + do_moves_number * ( - move_direction - * env.players[player_name].player_speed_units_per_seconds - * move_action.duration + move_direction * env.player_movement_speed * move_action.duration ) - assert np.isclose( np.linalg.norm(expected - env.players[player_name].pos), 0 ), "Performed movement do not move the player as expected." -def test_player_speed_units_per_seconds(env_config, layout_empty_config, item_info): +def test_player_movement_speed(env_config, layout_empty_config, item_info): env = Environment(env_config, layout_empty_config, item_info, as_files=False) player_name = "1" start_pos = np.array([3, 4]) env.add_player(player_name, start_pos) - env.players[player_name].player_speed_units_per_seconds = 2 + env.player_movement_speed = 2 move_direction = np.array([1, 0]) move_action = Action(player_name, ActionType.MOVEMENT, move_direction, duration=0.1) do_moves_number = 3 @@ -113,9 +111,7 @@ def test_player_speed_units_per_seconds(env_config, layout_empty_config, item_in env.step(timedelta(seconds=0.1)) expected = start_pos + do_moves_number * ( - move_direction - * env.players[player_name].player_speed_units_per_seconds - * move_action.duration + move_direction * env.player_movement_speed * move_action.duration ) assert np.isclose( @@ -123,36 +119,6 @@ def test_player_speed_units_per_seconds(env_config, layout_empty_config, item_in ), "Performed movement do not move the player as expected." -def test_collision_detection(env_config, layout_config, item_info): - env = Environment(env_config, layout_config, item_info, as_files=False) - - counter_pos = np.array([1, 2]) - counter = Counter(pos=counter_pos, hook=Hooks(env)) - env.counters = [counter] - env.add_player("1", np.array([1, 1])) - env.add_player("2", np.array([1, 4])) - - player1 = env.players["1"] - player2 = env.players["2"] - - assert not env.detect_collision_counters(player1), "Should not collide" - assert not env.detect_player_collision(player1), "Should not collide yet." - - assert not env.detect_collision(player1), "Does not collide yet." - - player1.move_abs(counter_pos) - assert env.detect_collision_counters( - player1 - ), "Player and counter at same pos. Not detected." - player2.move_abs(counter_pos) - assert env.detect_player_collision(player1), "Players at same pos. Not detected." - - player1.move_abs(np.array([-1, -1])) - assert env.detect_collision_world_bounds( - player1 - ), "Player collides with world bounds." - - def test_player_reach(env_config, layout_empty_config, item_info): env = Environment(env_config, layout_empty_config, item_info, as_files=False) @@ -160,7 +126,7 @@ def test_player_reach(env_config, layout_empty_config, item_info): counter = Counter(pos=counter_pos, hook=Hooks(env)) env.counters = [counter] env.add_player("1", np.array([2, 4])) - env.players["1"].player_speed_units_per_seconds = 1 + env.player_movement_speed = 1 player = env.players["1"] assert not player.can_reach(counter), "Player is too far away." @@ -182,7 +148,7 @@ def test_pickup(env_config, layout_config, item_info): env.add_player("1", np.array([2, 3])) player = env.players["1"] - player.player_speed_units_per_seconds = 1 + env.player_movement_speed = 1 move_down = Action("1", ActionType.MOVEMENT, np.array([0, -1]), duration=1) move_up = Action("1", ActionType.MOVEMENT, np.array([0, 1]), duration=1) @@ -244,7 +210,7 @@ def test_processing(env_config, layout_config, item_info): tomato = Item(name="Tomato", item_info=None) env.add_player("1", np.array([2, 3])) player = env.players["1"] - player.player_speed_units_per_seconds = 1 + env.player_movement_speed = 1 player.holding = tomato move = Action("1", ActionType.MOVEMENT, np.array([0, -1]), duration=1) @@ -276,6 +242,7 @@ def test_time_passed(): layouts_folder / "empty.layout", ROOT_DIR / "game_content" / "item_info.yaml", ) + env.add_player("0") env.reset_env_time() passed_time = timedelta(seconds=10) env.step(passed_time) @@ -297,6 +264,8 @@ def test_time_limit(): layouts_folder / "empty.layout", ROOT_DIR / "game_content" / "item_info.yaml", ) + env.add_player("0") + env.reset_env_time() assert not env.game_ended, "Game has not ended yet"