Skip to content
Snippets Groups Projects
random_agent.py 14 KiB
Newer Older
  • Learn to ignore specific revisions
  • import dataclasses
    import json
    import random
    import time
    from collections import defaultdict
    from datetime import datetime, timedelta
    
    
    import numpy.typing as npt
    
    from cooperative_cuisine.action import ActionType, InterActionData, Action
    
    from cooperative_cuisine.state_representation import (
        create_movement_graph,
        astar_heuristic,
        restrict_movement_graph,
    )
    
    from cooperative_cuisine.utils import custom_asdict_factory
    
    Fabian Heinrich's avatar
    Fabian Heinrich committed
    TIME_TO_STOP_ACTION = 3.0
    
    
    DIAGONAL_MOVEMENTS = True
    
    def get_free_neighbours(
        state: dict, counter_pos: list[float] | tuple[float, float] | npt.NDArray
    ) -> list[tuple[float, float]]:
        width, height = state["kitchen"]["width"], state["kitchen"]["height"]
        free_space = np.ones((width, height), dtype=bool)
        for counter in state["counters"]:
            grid_idx = np.array(counter["pos"]).astype(int)
    
    Fabian Heinrich's avatar
    Fabian Heinrich committed
            free_space[grid_idx[0], grid_idx[1]] = False
    
        i, j = np.array(counter_pos).astype(int)
        free = []
    
        for x, y in [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]:
            if 0 < x < width and 0 < y < height and free_space[x, y]:
                free.append((x, y))
        return free
    
    
    
    async def agent():
        parser = argparse.ArgumentParser("Random agent")
        parser.add_argument("--uri", type=str)
        parser.add_argument("--player_id", type=str)
        parser.add_argument("--player_hash", type=str)
    
        parser.add_argument("--step_time", type=float, default=0.1)
    
        async with (connect(args.uri) as websocket):
    
            await websocket.send(
                json.dumps({"type": "ready", "player_hash": args.player_hash})
            )
            await websocket.recv()
    
            all_counters = None
    
            player_info = {}
            current_agent_pos = None
            interaction_counter = None
    
            last_interacting = False
            last_interact_progress = None
    
            started_interaction = False
            still_interacting = False
            current_nearest_counter_id = None
    
            while not ended:
                time.sleep(args.step_time)
                await websocket.send(
                    json.dumps({"type": "get_state", "player_hash": args.player_hash})
                )
                state = json.loads(await websocket.recv())
    
                if not state["all_players_ready"]:
                    continue
    
                if movement_graph is None:
    
                    movement_graph = create_movement_graph(
                        state, diagonal=DIAGONAL_MOVEMENTS
                    )
    
                if counters is None:
                    counters = defaultdict(list)
                    for counter in state["counters"]:
                        counters[counter["type"]].append(counter)
    
                    all_counters = state["counters"]
    
    
                for player in state["players"]:
                    if player["id"] == args.player_id:
                        player_info = player
                        current_agent_pos = player["pos"]
                        if player["current_nearest_counter_id"]:
                            if (
                                current_nearest_counter_id
                                != player["current_nearest_counter_id"]
                            ):
                                for counter in state["counters"]:
                                    if (
                                        counter["id"]
                                        == player["current_nearest_counter_id"]
                                    ):
                                        interaction_counter = counter
                                        current_nearest_counter_id = player[
                                            "current_nearest_counter_id"
                                        ]
                                        break
                        if last_interacting:
                            if (
                                not interaction_counter
                                or not interaction_counter["occupied_by"]
                                or isinstance(interaction_counter["occupied_by"], list)
                                or (
                                    interaction_counter["occupied_by"][
                                        "progress_percentage"
                                    ]
                                    == 1.0
                                )
                            ):
                                last_interacting = False
                                last_interact_progress = None
                        else:
    
                                interaction_counter
                                and interaction_counter["occupied_by"]
                                and not isinstance(interaction_counter["occupied_by"], list)
    
                                if (
                                    last_interact_progress
                                    != interaction_counter["occupied_by"][
                                        "progress_percentage"
                                    ]
                                ):
                                    last_interact_progress = interaction_counter[
                                        "occupied_by"
                                    ]["progress_percentage"]
                                    last_interacting = True
    
                        break
    
                if task_type:
                    if threshold < datetime.now():
                        print(
                            args.player_hash, args.player_id, "---Threshold---Too long---"
                        )
                        task_type = None
                    match task_type:
                        case "GOTO":
    
                            target_diff = np.array(task_args) - np.array(current_agent_pos)
                            target_dist = np.linalg.norm(target_diff)
    
                            source = tuple(
                                np.round(np.array(current_agent_pos)).astype(int)
                            )
                            target = tuple(np.array(task_args).astype(int))
                            target_free_spaces = get_free_neighbours(state, target)
                            paths = []
                            for free in target_free_spaces:
                                try:
                                    path = networkx.astar_path(
    
                                            graph=movement_graph,
    
                                            player_positions=[
                                                p["pos"]
                                                for p in state["players"]
                                                if p["id"] != args.player_id
                                            ],
                                        )
                                        if AVOID_OTHER_PLAYERS
                                        else movement_graph,
    
                                        source,
                                        free,
                                        heuristic=astar_heuristic,
                                    )
                                    paths.append(path)
                                except networkx.exception.NetworkXNoPath:
                                    pass
                                except networkx.exception.NodeNotFound:
                                    pass
    
                            if paths:
                                shortest_path = paths[np.argmin([len(p) for p in paths])]
                                if len(shortest_path) > 1:
                                    node_diff = shortest_path[1] - np.array(
                                        current_agent_pos
                                    )
                                    node_dist = np.linalg.norm(node_diff)
                                    movement = node_diff / node_dist
                                else:
                                    movement = target_diff / target_dist
    
                                do_movement = True
    
                                # no paths available
    
                                # task_type = None
                                # task_args = None
                                do_movement = False
    
                            if target_dist > 1.2 and do_movement:
    
                                if target_dist != 0:
    
                                    if ADD_RANDOM_MOVEMENTS:
                                        random_small_rotation_angle = (
                                            np.random.random() * np.pi * 0.1
                                        )
                                        rotation_matrix = np.array(
                                            [
                                                [
                                                    np.cos(random_small_rotation_angle),
                                                    -np.sin(random_small_rotation_angle),
                                                ],
                                                [
                                                    np.sin(random_small_rotation_angle),
                                                    np.cos(random_small_rotation_angle),
                                                ],
                                            ]
                                        )
                                        movement = rotation_matrix @ movement
    
                                    await websocket.send(
                                        json.dumps(
                                            {
                                                "type": "action",
                                                "action": dataclasses.asdict(
                                                    Action(
                                                        args.player_id,
                                                        ActionType.MOVEMENT,
    
                                                        args.step_time + 0.01,
                                                    ),
                                                    dict_factory=custom_asdict_factory,
                                                ),
                                                "player_hash": args.player_hash,
                                            }
                                        )
                                    )
                                    await websocket.recv()
                            else:
    
                                # Target reached here.
    
                                print("TARGET REACHED")
    
                                task_type = None
                                task_args = None
                        case "INTERACT":
                            if not started_interaction or (
                                still_interacting and interaction_counter
                            ):
                                if not started_interaction:
                                    started_interaction = True
    
                                still_interacting = True
                                await websocket.send(
    
                                    json.dumps(
                                        {
                                            "type": "action",
                                            "action": dataclasses.asdict(
                                                Action(
                                                    args.player_id,
    
                                                    ActionType.INTERACT,
                                                    InterActionData.START,
    
                                                ),
                                                dict_factory=custom_asdict_factory,
                                            ),
                                            "player_hash": args.player_hash,
                                        }
                                    )
                                )
    
                                await websocket.recv()
                            else:
                                still_interacting = False
                                started_interaction = False
                                task_type = None
                                task_args = None
                        case "PUT":
                            await websocket.send(
    
                                json.dumps(
                                    {
                                        "type": "action",
                                        "action": dataclasses.asdict(
                                            Action(
                                                args.player_id,
    
                                                ActionType.PICK_UP_DROP,
                                                None,
    
                                            ),
                                            dict_factory=custom_asdict_factory,
                                        ),
                                        "player_hash": args.player_hash,
                                    }
                                )
                            )
    
                    # task_type = random.choice(["GOTO", "PUT", "INTERACT"])
                    task_type = random.choice(["GOTO"])
    
    Fabian Heinrich's avatar
    Fabian Heinrich committed
                    threshold = datetime.now() + timedelta(seconds=TIME_TO_STOP_ACTION)
    
                        # counter_type = random.choice(list(counters.keys()))
                        # task_args = random.choice(counters[counter_type])["pos"]
    
                        random_counter = random.choice(all_counters)
                        counter_type = random_counter["type"]
                        task_args = random_counter["pos"]
    
                        print(
                            args.player_hash,
                            args.player_id,
                            task_type,
                            counter_type,
                            task_args,
                        )
    
                    else:
                        print(args.player_hash, args.player_id, task_type)
    
    
                ended = state["ended"]
    
    
    if __name__ == "__main__":
        asyncio.run(agent())