From 9c5d4d7aade58ab4f6219d7be3abf36600335fcc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20G=C3=B6bel?= <dgoebel@techfak.uni-bielefeld.de>
Date: Fri, 6 Oct 2023 14:50:51 +0200
Subject: [PATCH] Fix race condition during job monitoring

#56
---
 app/api/utils.py               | 24 ++++++++++++++----------
 app/slurm/slurm_rest_client.py |  2 +-
 2 files changed, 15 insertions(+), 11 deletions(-)

diff --git a/app/api/utils.py b/app/api/utils.py
index ef183fb..37c531c 100644
--- a/app/api/utils.py
+++ b/app/api/utils.py
@@ -212,7 +212,9 @@ async def start_workflow_execution(
         await CRUDWorkflowExecution.update_slurm_job_id(
             db, slurm_job_id=slurm_job_id, execution_id=execution.execution_id
         )
-        await _monitor_proper_job_execution(db=db, slurm_client=slurm_client, execution_id=execution.execution_id)
+        await _monitor_proper_job_execution(
+            db=db, slurm_client=slurm_client, execution_id=execution.execution_id, slurm_job_id=slurm_job_id
+        )
     except (ConnectError, ConnectTimeout):  # pragma: no cover
         # Mark job as aborted when there is an error
         await CRUDWorkflowExecution.cancel(
@@ -221,7 +223,7 @@ async def start_workflow_execution(
 
 
 async def _monitor_proper_job_execution(
-    db: AsyncSession, slurm_client: SlurmClient, execution_id: UUID
+    db: AsyncSession, slurm_client: SlurmClient, execution_id: UUID, slurm_job_id: int
 ) -> None:  # pragma: no cover
     """
     Checks every settings.SLURM_JOB_STATUS_CHECK_INTERVAL seconds if the slurm job is still running as long as
@@ -235,17 +237,19 @@ async def _monitor_proper_job_execution(
         Slurm Rest Client to communicate with Slurm cluster.
     execution_id : uuid.UUID
         ID of the workflow execution
+    slurm_job_id : int
+        ID of the slurm job to monitor
     """
     while True:
         await async_sleep(settings.SLURM_JOB_STATUS_CHECK_INTERVAL)
-        execution = await CRUDWorkflowExecution.get(db, execution_id=execution_id)
-        if execution is None or execution.end_time is not None:
-            break
-        if await slurm_client.is_job_finished(execution.slurm_job_id):
-            # Mark job as finished with an error when the slurm job is finished
-            await CRUDWorkflowExecution.cancel(
-                db, execution_id=execution_id, status=WorkflowExecution.WorkflowExecutionStatus.ERROR
-            )
+        if await slurm_client.is_job_finished(slurm_job_id):
+            execution = await CRUDWorkflowExecution.get(db, execution_id=execution_id)
+            # Check if the execution is marked as finished in the database
+            if execution is not None and execution.end_time is None:
+                # Mark job as finished with an error
+                await CRUDWorkflowExecution.cancel(
+                    db, execution_id=execution_id, status=WorkflowExecution.WorkflowExecutionStatus.ERROR
+                )
             break
 
 
diff --git a/app/slurm/slurm_rest_client.py b/app/slurm/slurm_rest_client.py
index 4e7f8f8..9c658e4 100644
--- a/app/slurm/slurm_rest_client.py
+++ b/app/slurm/slurm_rest_client.py
@@ -98,6 +98,6 @@ class SlurmClient:
             return True
         try:
             job_state = response.json()["jobs"][0]["job_state"]
-            return job_state == "COMPLETED" or job_state == "FAILED"
+            return job_state == "COMPLETED" or job_state == "FAILED" or job_state == "CANCELLED"
         except (KeyError, IndexError):
             return True
-- 
GitLab