From 2299740ec5cd017b498a39b21383ef86d3932ab8 Mon Sep 17 00:00:00 2001
From: fheinrich <fheinrich@techfak.uni-bielefeld.de>
Date: Fri, 19 Jan 2024 15:25:22 +0100
Subject: [PATCH] Added reset button and reset functionality in api

---
 overcooked_simulator/api_call.py              |  14 ++
 overcooked_simulator/fastapi_game_server.py   | 147 ++++++++++++++++++
 .../game_content/environment_config.yaml      |   2 +-
 .../gui_2d_vis/overcooked_gui.py              |  56 +++++--
 .../overcooked_environment.py                 |  12 +-
 5 files changed, 213 insertions(+), 18 deletions(-)
 create mode 100644 overcooked_simulator/api_call.py
 create mode 100644 overcooked_simulator/fastapi_game_server.py

diff --git a/overcooked_simulator/api_call.py b/overcooked_simulator/api_call.py
new file mode 100644
index 00000000..ea733f53
--- /dev/null
+++ b/overcooked_simulator/api_call.py
@@ -0,0 +1,14 @@
+# websocket_client.py
+import asyncio
+
+import websockets
+
+
+async def send_message():
+    uri = "ws://127.0.0.1:8000/ws"
+    async with websockets.connect(uri) as websocket:
+        await websocket.send("Hello, server!")
+        response = await websocket.recv()
+        print(response)
+
+asyncio.run(send_message())
diff --git a/overcooked_simulator/fastapi_game_server.py b/overcooked_simulator/fastapi_game_server.py
new file mode 100644
index 00000000..2a9fa234
--- /dev/null
+++ b/overcooked_simulator/fastapi_game_server.py
@@ -0,0 +1,147 @@
+import json
+import logging
+import threading
+from contextlib import asynccontextmanager
+
+import numpy as np
+import uvicorn
+from fastapi import FastAPI
+from fastapi import WebSocket
+from starlette.websockets import WebSocketDisconnect
+
+from overcooked_simulator import ROOT_DIR
+from overcooked_simulator.game_server import setup_logging
+from overcooked_simulator.overcooked_environment import Action
+from overcooked_simulator.simulation_runner import Simulator
+
+log = logging.getLogger(__name__)
+setup_logging()
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    setup_logging()
+    yield
+    for thread in threading.enumerate():
+        if isinstance(thread, Simulator):
+            thread.stop()
+            thread.join()
+
+
+app = FastAPI(lifespan=lifespan)
+
+WEBSOCKET_URL = "localhost"
+WEBSOCKET_PORT = 8000
+
+
+class OvercookedAPI:
+    simulator: Simulator
+
+    def __init__(self):
+        self.setup_game()
+
+    def setup_game(self):
+        self.simulator = Simulator(
+            ROOT_DIR / "game_content" / "environment_config.yaml",
+            ROOT_DIR / "game_content" / "layouts" / "basic.layout",
+            600,
+        )
+        number_player = 2
+        for i in range(number_player):
+            player_name = f"p{i}"
+            self.simulator.register_player(player_name)
+        self.simulator.start()
+
+    def get_state(self):
+        return self.simulator.get_state_simple_json()
+
+    def reset_game(self):
+        self.simulator.stop()
+        self.setup_game()
+
+
+class ConnectionManager:
+    def __init__(self):
+        self.active_connections: list[WebSocket] = []
+
+    async def connect(self, websocket: WebSocket):
+        await websocket.accept()
+        self.active_connections.append(websocket)
+
+    def disconnect(self, websocket: WebSocket):
+        self.active_connections.remove(websocket)
+
+    async def send_personal_message(self, message: str, websocket: WebSocket):
+        await websocket.send_text(message)
+
+    async def broadcast(self, message: str):
+        for connection in self.active_connections:
+            await connection.send_text(message)
+
+
+manager = ConnectionManager()
+oc_api: OvercookedAPI = OvercookedAPI()
+
+
+def parse_action(message: str) -> Action:
+    if message.replace('"', "") != "get_state":
+        message_dict = json.loads(message)
+        if message_dict["act_type"] == "movement":
+            if isinstance(message_dict["value"], list):
+                x, y = message_dict["value"]
+            elif isinstance(message_dict["value"], str):
+                x, y = (
+                    message_dict["value"]
+                    .replace(" ", "")
+                    .replace("[", "")
+                    .replace("]", "")
+                    .split(",")
+                )
+            else:
+                x, y = 0, 0
+            value = np.array([x, y], dtype=float)
+        else:
+            value = None
+        action = Action(message_dict["player_name"], message_dict["act_type"], value)
+        return action
+
+
+def manage_message(message: str):
+    answer = None
+    print("MESSAGE:", message)
+
+    if "get_state" in message:
+        return oc_api.get_state()
+
+    if "reset_game" in message:
+        oc_api.reset_game()
+        return "Reset game."
+
+    action = parse_action(message)
+    oc_api.simulator.enter_action(action)
+    return oc_api.get_state()
+
+
+@app.get("/")
+def read_root():
+    return {"OVER": "COOKED"}
+
+
+@app.websocket("/ws/{client_id}")
+async def websocket_endpoint(websocket: WebSocket, client_id: int):
+    await manager.connect(websocket)
+    log.debug(f"Client #{client_id} connected")
+    try:
+        while True:
+            message = await websocket.receive_text()
+            answer = manage_message(message)
+            print("ANSWER:", answer)
+            await manager.send_personal_message(answer, websocket)
+
+    except WebSocketDisconnect:
+        manager.disconnect(websocket)
+        log.debug(f"Client #{client_id} disconnected")
+
+
+if __name__ == "__main__":
+    uvicorn.run(app, host=WEBSOCKET_URL, port=WEBSOCKET_PORT)
diff --git a/overcooked_simulator/game_content/environment_config.yaml b/overcooked_simulator/game_content/environment_config.yaml
index b6b7c579..a88b5d10 100644
--- a/overcooked_simulator/game_content/environment_config.yaml
+++ b/overcooked_simulator/game_content/environment_config.yaml
@@ -5,7 +5,7 @@ plates:
   # seconds until the dirty plate arrives.
 
 game:
-  time_limit_seconds: 180
+  time_limit_seconds: 20
 
 meals:
   all: false
diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
index fc0dfcf5..b6e6cfdf 100644
--- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py
+++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
@@ -106,7 +106,8 @@ class PyGameGUI:
             )
         ]
 
-        self.websocket_url = "ws://localhost:8765"
+        # self.websocket_url = "ws://localhost:8765"
+        self.websocket_url = "ws://localhost:8000/ws/29"
 
         # TODO cache loaded images?
         with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file:
@@ -666,6 +667,20 @@ class PyGameGUI:
         )
         self.quit_button.can_hover()
 
+        self.reset_button = pygame_gui.elements.UIButton(
+            relative_rect=pygame.Rect(
+                (
+                    (self.screen_margin + self.game_width),
+                    self.screen_margin,
+                ),
+                (self.screen_margin, 100),
+            ),
+            text="RESET",
+            manager=self.manager,
+            object_id="#quit_button",
+        )
+        self.reset_button.can_hover()
+
         self.finished_button = pygame_gui.elements.UIButton(
             relative_rect=pygame.Rect(
                 (
@@ -805,7 +820,7 @@ class PyGameGUI:
         self.menu_state = MenuStates.Game
 
         with connect(self.websocket_url) as websocket:
-            state = self.request_state(websocket)
+            state = self.request_state()
 
         (
             self.window_width,
@@ -831,6 +846,13 @@ class PyGameGUI:
         self.running = False
         log.debug("Pressed quit button")
 
+    def reset_button_press(self):
+        _ = self.websocket_communicate("reset_game")
+
+        # self.websocket.send(json.dumps("reset_game"))
+        # answer = self.websocket.recv()
+        log.debug("Pressed reset button")
+
     def finished_button_press(self):
         self.menu_state = MenuStates.End
         self.reset_window_size()
@@ -851,15 +873,21 @@ class PyGameGUI:
             "act_type": action.act_type,
             "value": value,
         }
-        _ = self.websocket_communicate(message_dict, websocket)
-
-    def websocket_communicate(self, message_dict: dict | str, websocket):
-        websocket.send(json.dumps(message_dict))
-        answer = websocket.recv()
-        return json.loads(answer)
-
-    def request_state(self, websocket):
-        state_dict = self.websocket_communicate("get_state", websocket)
+        _ = self.websocket_communicate(message_dict)
+
+    def websocket_communicate(self, message_dict: dict | str):
+        self.websocket.send(json.dumps(message_dict))
+        answer = self.websocket.recv()
+        try:
+            answer = json.loads(answer)
+        except json.decoder.JSONDecodeError:
+            answer = None
+        return answer
+
+    def request_state(self):
+        state_dict = self.websocket_communicate("get_state")
+        # self.websocket.send(json.dumps("get_state"))
+        # state_dict = json.loads(self.websocket.recv())
         return state_dict
 
     def start_pygame(self):
@@ -877,6 +905,7 @@ class PyGameGUI:
         self.manage_button_visibility()
 
         with connect(self.websocket_url) as websocket:
+            self.websocket = websocket
             # Game loop
             self.running = True
             while self.running:
@@ -898,6 +927,9 @@ class PyGameGUI:
                                     self.finished_button_press()
                                 case self.quit_button:
                                     self.quit_button_press()
+                                case self.reset_button:
+                                    self.reset_button_press()
+                                    self.start_button_press()
 
                             self.manage_button_visibility()
 
@@ -926,7 +958,7 @@ class PyGameGUI:
                             pass
 
                         case MenuStates.Game:
-                            state = self.request_state(websocket)
+                            state = self.request_state()
 
                             self.draw_background()
 
diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py
index b8259ee2..28e35742 100644
--- a/overcooked_simulator/overcooked_environment.py
+++ b/overcooked_simulator/overcooked_environment.py
@@ -571,10 +571,13 @@ class Environment:
         self.env_time += passed_time
 
         with self.lock:
-            for counter in self.counters:
-                if isinstance(counter, (CuttingBoard, Stove, Sink, PlateDispenser)):
-                    counter.progress(passed_time=passed_time, now=self.env_time)
-            self.order_and_score.progress(passed_time=passed_time, now=self.env_time)
+            if not self.game_ended:
+                for counter in self.counters:
+                    if isinstance(counter, (CuttingBoard, Stove, Sink, PlateDispenser)):
+                        counter.progress(passed_time=passed_time, now=self.env_time)
+                self.order_and_score.progress(
+                    passed_time=passed_time, now=self.env_time
+                )
 
     def get_state(self):
         """Get the current state of the game environment. The state here is accessible by the current python objects.
@@ -629,7 +632,6 @@ class Environment:
             }
             counters.append(counter_dict)
 
-        print(self.game_ended)
         gamestate_dict = {
             "players": players,
             "counters": counters,
-- 
GitLab