From 815a67fceb784be5aa8070b25d71af27ac5ffce2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20G=C3=B6bel?= <dgoebel@techfak.uni-bielefeld.de> Date: Mon, 4 Sep 2023 14:37:26 +0200 Subject: [PATCH] Add Mock services for Slurm and OPA to improve tests #45 --- app/api/dependencies.py | 6 +- app/api/endpoints/workflow_execution.py | 4 +- app/slurm/slurm_rest_client.py | 4 +- app/tests/api/test_workflow.py | 2 - app/tests/api/test_workflow_execution.py | 182 ++++++++++++++++++++++- app/tests/conftest.py | 86 ++++++++--- app/tests/crud/test_workflow_version.py | 160 ++++++++++++++------ app/tests/mocks/__init__.py | 3 + app/tests/mocks/authorization_service.py | 46 ------ app/tests/mocks/mock_opa_service.py | 92 ++++++++++++ app/tests/mocks/mock_slurm_cluster.py | 164 ++++++++++++++++++++ app/tests/mocks/slurm_cluster.py | 22 --- app/tests/utils/utils.py | 41 ----- mako_templates/nextflow_command.template | 2 +- 14 files changed, 625 insertions(+), 189 deletions(-) delete mode 100644 app/tests/mocks/authorization_service.py create mode 100644 app/tests/mocks/mock_opa_service.py create mode 100644 app/tests/mocks/mock_slurm_cluster.py delete mode 100644 app/tests/mocks/slurm_cluster.py diff --git a/app/api/dependencies.py b/app/api/dependencies.py index a16210b..b94a488 100644 --- a/app/api/dependencies.py +++ b/app/api/dependencies.py @@ -25,14 +25,14 @@ else: bearer_token = HTTPBearer(description="JWT Header") -def get_s3_resource() -> S3ServiceResource: - return s3_resource # pragma: no cover +def get_s3_resource() -> S3ServiceResource: # pragma: no cover + return s3_resource S3Service = Annotated[S3ServiceResource, Depends(get_s3_resource)] -async def get_db() -> AsyncIterator[AsyncSession]: +async def get_db() -> AsyncIterator[AsyncSession]: # pragma: no cover """ Get a Session with the database. diff --git a/app/api/endpoints/workflow_execution.py b/app/api/endpoints/workflow_execution.py index 5b01aca..ca3c30b 100644 --- a/app/api/endpoints/workflow_execution.py +++ b/app/api/endpoints/workflow_execution.py @@ -102,7 +102,7 @@ async def start_workflow( if workflow_version.status not in [WorkflowVersion.Status.PUBLISHED, WorkflowVersion.Status.CREATED]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=f"Workflow version with status {workflow_version.status.name} can't be started.", + detail=f"Workflow version with status {workflow_version.status} can't be started.", ) # Check active workflow execution limit await check_active_workflow_execution_limit(db, current_user.uid) @@ -490,6 +490,6 @@ async def cancel_workflow_execution( raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot cancel workflow execution that is finished." ) - if workflow_execution.slurm_job_id > 0: + if workflow_execution.slurm_job_id >= 0: background_tasks.add_task(slurm_client.cancel_job, job_id=workflow_execution.slurm_job_id) await CRUDWorkflowExecution.cancel(db, workflow_execution.execution_id) diff --git a/app/slurm/slurm_rest_client.py b/app/slurm/slurm_rest_client.py index 0f1cd00..c525e86 100644 --- a/app/slurm/slurm_rest_client.py +++ b/app/slurm/slurm_rest_client.py @@ -37,6 +37,8 @@ class SlurmClient: Script to execute on the slurm cluster. execution_id : uuid.UUID ID of the workflow execution. + scm_file_name : str | None, default None + Name of a SCM file if the workflow is in a private repository Returns ------- @@ -72,4 +74,4 @@ class SlurmClient: job_id : int ID of the job to cancel. """ - await self._client.delete(f"{settings.SLURM_ENDPOINT}/slurm/{self.version}/job/{job_id}", headers=self._headers) + await self._client.delete(f"{settings.SLURM_ENDPOINT}slurm/{self.version}/job/{job_id}", headers=self._headers) diff --git a/app/tests/api/test_workflow.py b/app/tests/api/test_workflow.py index 0b5112d..cc5b51c 100644 --- a/app/tests/api/test_workflow.py +++ b/app/tests/api/test_workflow.py @@ -439,9 +439,7 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): repository_url="https://github.de/example-user/example", modes=[workflow_mode, workflow_mode], ) - print(workflow) response = await client.post(self.base_path, json=workflow.model_dump(), headers=random_user.auth_headers) - print(response.json()) assert response.status_code == status.HTTP_201_CREATED created_workflow = response.json() assert len(created_workflow["versions"][0]["modes"]) == 2 diff --git a/app/tests/api/test_workflow_execution.py b/app/tests/api/test_workflow_execution.py index c6f3407..c66ebc7 100644 --- a/app/tests/api/test_workflow_execution.py +++ b/app/tests/api/test_workflow_execution.py @@ -12,7 +12,7 @@ from app.core.config import settings from app.schemas.workflow import WorkflowOut from app.schemas.workflow_execution import DevWorkflowExecutionIn, WorkflowExecutionIn from app.scm import SCM -from app.tests.mocks.mock_s3_resource import MockS3ServiceResource +from app.tests.mocks import MockS3ServiceResource, MockSlurmCluster from app.tests.utils.bucket import add_permission_for_bucket from app.tests.utils.user import UserWithAuthHeader from app.tests.utils.utils import random_hex_string, random_lower_string @@ -29,7 +29,9 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): client: AsyncClient, random_user: UserWithAuthHeader, random_workflow_version: WorkflowVersion, + random_workflow: Workflow, mock_s3_service: MockS3ServiceResource, + mock_slurm_cluster: MockSlurmCluster, ) -> None: """ Test for starting a workflow execution. @@ -44,6 +46,8 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Random workflow version for testing. mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster + Mock Slurm cluster to inspect submitted jobs. """ execution_in = WorkflowExecutionIn(workflow_version_id=random_workflow_version.git_commit_hash, parameters={}) response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) @@ -54,10 +58,26 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): assert execution_response["status"] == WorkflowExecution.WorkflowExecutionStatus.PENDING assert execution_response["workflow_version_id"] == random_workflow_version.git_commit_hash + execution_id = UUID(execution_response["execution_id"]) + assert ( f"params-{UUID(hex=execution_response['execution_id']).hex }.json" in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() ) + job = mock_slurm_cluster.get_job_by_name(str(execution_id)) + assert job is not None + assert job["job"]["name"] == str(execution_id) + assert job["job"]["current_working_directory"] == settings.SLURM_WORKING_DIRECTORY + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert "NXF_SCM_FILE" not in job["job"]["environment"].keys() + + nextflow_script = job["script"] + assert "-hub github" in nextflow_script + assert "-entry" not in nextflow_script + assert f"-revision {random_workflow_version.git_commit_hash}" in nextflow_script + assert f"run {random_workflow.repository_url}" in nextflow_script + assert "-with-report" not in nextflow_script + assert "-with-timeline" not in nextflow_script @pytest.mark.asyncio async def test_start_private_github_workflow_execution( @@ -67,6 +87,7 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): random_workflow_version: WorkflowVersion, random_private_workflow: WorkflowOut, mock_s3_service: MockS3ServiceResource, + mock_slurm_cluster: MockSlurmCluster, ) -> None: """ Test for starting a workflow execution from a private GitHub repository. @@ -83,6 +104,8 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Random private workflow for testing. mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster + Mock Slurm cluster to inspect submitted jobs. """ execution_in = WorkflowExecutionIn(workflow_version_id=random_workflow_version.git_commit_hash, parameters={}) @@ -99,6 +122,21 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() ) + execution_id = UUID(execution_response["execution_id"]) + job = mock_slurm_cluster.get_job_by_name(str(execution_id)) + assert job is not None + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert "NXF_SCM_FILE" in job["job"]["environment"].keys() + assert ( + job["job"]["environment"]["NXF_SCM_FILE"] + == f"{settings.PARAMS_BUCKET_MOUNT_PATH}/{random_workflow_version.workflow_id.hex}.scm" + ) + + nextflow_script = job["script"] + assert "-hub github" in nextflow_script + assert "-entry" not in nextflow_script + assert f"-revision {random_workflow_version.git_commit_hash}" in nextflow_script + @pytest.mark.asyncio async def test_start_gitlab_workflow_execution( self, @@ -107,6 +145,7 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): random_user: UserWithAuthHeader, random_workflow_version: WorkflowVersion, mock_s3_service: MockS3ServiceResource, + mock_slurm_cluster: MockSlurmCluster, ) -> None: """ Test for starting a workflow execution. @@ -122,6 +161,8 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Random workflow version for testing. mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster + Mock Slurm cluster to inspect submitted jobs. """ stmt = ( update(Workflow) @@ -144,6 +185,17 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() ) + execution_id = UUID(execution_response["execution_id"]) + job = mock_slurm_cluster.get_job_by_name(str(execution_id)) + assert job is not None + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert "NXF_SCM_FILE" not in job["job"]["environment"].keys() + + nextflow_script = job["script"] + assert "-hub gitlab" in nextflow_script + assert "-entry" not in nextflow_script + assert f"-revision {random_workflow_version.git_commit_hash}" in nextflow_script + @pytest.mark.asyncio async def test_start_too_many_workflow_executions( self, @@ -364,6 +416,7 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): random_user: UserWithAuthHeader, random_bucket: Bucket, random_workflow_version: WorkflowVersion, + mock_slurm_cluster: MockSlurmCluster, ) -> None: """ Test for starting a workflow execution where the report output bucket is an existing bucket. @@ -378,6 +431,8 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Random bucket for testing. random_workflow_version : clowmdb.models.WorkflowVersion Random workflow version for testing. + mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster + Mock Slurm cluster to inspect submitted jobs. """ execution_in = WorkflowExecutionIn( workflow_version_id=random_workflow_version.git_commit_hash, @@ -387,6 +442,16 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) assert response.status_code == status.HTTP_201_CREATED + execution_id = UUID(response.json()["execution_id"]) + job = mock_slurm_cluster.get_job_by_name(str(execution_id)) + assert job is not None + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + + nextflow_script = job["script"] + assert f"-with-report s3://{random_bucket.name}/report-" in nextflow_script + assert f"-with-timeline s3://{random_bucket.name}/timeline-" in nextflow_script + assert f"-revision {random_workflow_version.git_commit_hash}" in nextflow_script + @pytest.mark.asyncio async def test_start_workflow_execution_with_bad_report_bucket( self, @@ -421,6 +486,7 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): random_user: UserWithAuthHeader, random_workflow_version: WorkflowVersion, random_workflow_mode: WorkflowMode, + mock_slurm_cluster: MockSlurmCluster, ) -> None: """ Test for starting a workflow execution where the report output bucket is a non-existing bucket. @@ -433,18 +499,34 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Random user for testing. random_workflow_version : clowmdb.models.WorkflowVersion Random workflow version for testing. + mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster + Mock Slurm cluster to inspect submitted jobs. """ execution_in = WorkflowExecutionIn( workflow_version_id=random_workflow_version.git_commit_hash, parameters={}, mode=random_workflow_mode.mode_id, ).model_dump_json() - response = await client.post(self.base_path, headers=random_user.auth_headers, data=execution_in) + response = await client.post( + self.base_path, headers=random_user.auth_headers, data=execution_in # type: ignore[arg-type] + ) assert response.status_code == status.HTTP_201_CREATED + execution_id = UUID(response.json()["execution_id"]) + job = mock_slurm_cluster.get_job_by_name(str(execution_id)) + assert job is not None + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + + nextflow_script = job["script"] + assert f"-entry {random_workflow_mode.entrypoint}" in nextflow_script + assert f"-revision {random_workflow_version.git_commit_hash}" in nextflow_script @pytest.mark.asyncio async def test_start_workflow_execution_with_non_existing_workflow_mode( - self, client: AsyncClient, random_user: UserWithAuthHeader, random_workflow_version: WorkflowVersion + self, + client: AsyncClient, + random_user: UserWithAuthHeader, + random_workflow_version: WorkflowVersion, + mock_slurm_cluster: MockSlurmCluster, ) -> None: """ Test for starting a workflow execution where the report output bucket is a non-existing bucket. @@ -457,11 +539,15 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Random user for testing. random_workflow_version : clowmdb.models.WorkflowVersion Random workflow version for testing. + mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster + Mock Slurm cluster to inspect submitted jobs. """ execution_in = WorkflowExecutionIn( workflow_version_id=random_workflow_version.git_commit_hash, parameters={}, mode=uuid4() ).model_dump_json() - response = await client.post(self.base_path, headers=random_user.auth_headers, data=execution_in) + response = await client.post( + self.base_path, headers=random_user.auth_headers, data=execution_in # type: ignore[arg-type] + ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -472,6 +558,7 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): client: AsyncClient, random_user: UserWithAuthHeader, mock_s3_service: MockS3ServiceResource, + mock_slurm_cluster: MockSlurmCluster, ) -> None: """ Test for starting a workflow execution with an arbitrary GitHub repository. @@ -484,6 +571,8 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Random user for testing. mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster + Mock Slurm cluster to inspect submitted jobs. """ execution_in = DevWorkflowExecutionIn( git_commit_hash=random_hex_string(), repository_url="https://github.com/example-user/example", parameters={} @@ -501,12 +590,25 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() ) + execution_id = UUID(execution_response["execution_id"]) + job = mock_slurm_cluster.get_job_by_name(str(execution_id)) + assert job is not None + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert "NXF_SCM_FILE" not in job["job"]["environment"].keys() + + nextflow_script = job["script"] + assert "-hub github" in nextflow_script + assert "-entry" not in nextflow_script + assert f"-revision {execution_in.git_commit_hash}" in nextflow_script + assert f"run {execution_in.repository_url}" in nextflow_script + @pytest.mark.asyncio async def test_start_dev_workflow_execution_from_gitlab( self, client: AsyncClient, random_user: UserWithAuthHeader, mock_s3_service: MockS3ServiceResource, + mock_slurm_cluster: MockSlurmCluster, ) -> None: """ Test for starting a workflow execution with an arbitrary Gitlab repository. @@ -519,6 +621,8 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Random user for testing. mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster + Mock Slurm cluster to inspect submitted jobs. """ execution_in = DevWorkflowExecutionIn( git_commit_hash=random_hex_string(), repository_url="https://gitlab.com/example-user/example", parameters={} @@ -536,12 +640,26 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() ) + execution_id = UUID(execution_response["execution_id"]) + + job = mock_slurm_cluster.get_job_by_name(str(execution_id)) + assert job is not None + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert "NXF_SCM_FILE" not in job["job"]["environment"].keys() + + nextflow_script = job["script"] + assert "-hub gitlab" in nextflow_script + assert "-entry" not in nextflow_script + assert f"-revision {execution_in.git_commit_hash}" in nextflow_script + assert f"run {execution_in.repository_url}" in nextflow_script + @pytest.mark.asyncio async def test_start_dev_workflow_execution_from_private_gitlab( self, client: AsyncClient, random_user: UserWithAuthHeader, mock_s3_service: MockS3ServiceResource, + mock_slurm_cluster: MockSlurmCluster, ) -> None: """ Test for starting a workflow execution with an arbitrary private Gitlab repository. @@ -554,6 +672,8 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Random user for testing. mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster + Mock Slurm cluster to inspect submitted jobs. """ token = random_lower_string(15) execution_in = DevWorkflowExecutionIn( @@ -573,7 +693,7 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): execution_id = UUID(hex=execution_response["execution_id"]) # Check if params file is created - params_file_name = f"params-{execution_id.hex }.json" + params_file_name = f"params-{execution_id.hex}.json" assert params_file_name in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() # Check if SCM file is created @@ -592,6 +712,19 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): assert provider.name == f"repo{execution_id.hex}" assert provider.platform == "gitlab" assert provider.server == "https://gitlab.com" + + job = mock_slurm_cluster.get_job_by_name(str(execution_id)) + assert job is not None + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert "NXF_SCM_FILE" in job["job"]["environment"].keys() + assert job["job"]["environment"]["NXF_SCM_FILE"] == f"{settings.PARAMS_BUCKET_MOUNT_PATH}/{scm_file_name}" + + nextflow_script = job["script"] + assert f"-hub repo{execution_id.hex}" in nextflow_script + assert "-entry" not in nextflow_script + assert f"-revision {execution_in.git_commit_hash}" in nextflow_script + assert f"run {execution_in.repository_url}" in nextflow_script + # Clean up after test mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(params_file_name).delete() scm_obj.delete() @@ -602,6 +735,7 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): client: AsyncClient, random_user: UserWithAuthHeader, mock_s3_service: MockS3ServiceResource, + mock_slurm_cluster: MockSlurmCluster, ) -> None: """ Test for starting a workflow execution with an arbitrary private GitHub repository without a provided username. @@ -614,6 +748,8 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Random user for testing. mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster + Mock Slurm cluster to inspect submitted jobs. """ token = random_lower_string(15) execution_in = DevWorkflowExecutionIn( @@ -633,7 +769,7 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): execution_id = UUID(hex=execution_response["execution_id"]) # Check if params file is created - params_file_name = f"params-{execution_id.hex }.json" + params_file_name = f"params-{execution_id.hex}.json" assert params_file_name in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() # Check if SCM file is created @@ -651,6 +787,19 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): assert provider.password == token assert provider.name == "github" assert provider.user == "example-user" + + job = mock_slurm_cluster.get_job_by_name(str(execution_id)) + assert job is not None + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert "NXF_SCM_FILE" in job["job"]["environment"].keys() + assert job["job"]["environment"]["NXF_SCM_FILE"] == f"{settings.PARAMS_BUCKET_MOUNT_PATH}/{scm_file_name}" + + nextflow_script = job["script"] + assert "-hub github" in nextflow_script + assert "-entry" not in nextflow_script + assert f"-revision {execution_in.git_commit_hash}" in nextflow_script + assert f"run {execution_in.repository_url}" in nextflow_script + # Clean up after test mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(params_file_name).delete() scm_obj.delete() @@ -661,6 +810,7 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): client: AsyncClient, random_user: UserWithAuthHeader, mock_s3_service: MockS3ServiceResource, + mock_slurm_cluster: MockSlurmCluster, ) -> None: """ Test for starting a workflow execution with an arbitrary private GitHub repository with a provided username. @@ -673,6 +823,8 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Random user for testing. mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster + Mock Slurm cluster to inspect submitted jobs. """ token = random_lower_string(15) execution_in = DevWorkflowExecutionIn( @@ -711,6 +863,19 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): assert provider.password == token assert provider.name == "github" assert provider.user == "bilbobaggins" + + job = mock_slurm_cluster.get_job_by_name(str(execution_id)) + assert job is not None + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert "NXF_SCM_FILE" in job["job"]["environment"].keys() + assert job["job"]["environment"]["NXF_SCM_FILE"] == f"{settings.PARAMS_BUCKET_MOUNT_PATH}/{scm_file_name}" + + nextflow_script = job["script"] + assert "-hub github" in nextflow_script + assert "-entry" not in nextflow_script + assert f"-revision {execution_in.git_commit_hash}" in nextflow_script + assert f"run {execution_in.repository_url}" in nextflow_script + # Clean up after test mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(params_file_name).delete() scm_obj.delete() @@ -920,6 +1085,7 @@ class TestWorkflowExecutionRoutesCancel(_TestWorkflowExecutionRoutes): client: AsyncClient, random_user: UserWithAuthHeader, random_workflow_execution: WorkflowExecution, + mock_slurm_cluster: MockSlurmCluster, ) -> None: """ Test for canceling a workflow execution. @@ -932,12 +1098,16 @@ class TestWorkflowExecutionRoutesCancel(_TestWorkflowExecutionRoutes): Random user for testing. random_workflow_execution : clowmdb.models.WorkflowExecution Random workflow execution for testing. + mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster + Mock Slurm cluster to inspect submitted jobs. """ response = await client.post( "/".join([self.base_path, str(random_workflow_execution.execution_id), "cancel"]), headers=random_user.auth_headers, ) assert response.status_code == status.HTTP_204_NO_CONTENT + job_active = mock_slurm_cluster.job_active(random_workflow_execution.slurm_job_id) + assert not job_active @pytest.mark.asyncio async def test_cancel_finished_workflow_execution( diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 0b1758f..934b11a 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -2,7 +2,7 @@ import asyncio from functools import partial from io import BytesIO from secrets import token_urlsafe -from typing import AsyncIterator, Callable, Dict, Iterator +from typing import AsyncIterator, Iterator import httpx import pytest @@ -16,6 +16,7 @@ from clowmdb.models import ( WorkflowVersion, workflow_mode_association_table, ) +from fastapi import status from sqlalchemy import insert, select, update from sqlalchemy.ext.asyncio import AsyncSession @@ -25,10 +26,10 @@ from app.git_repository import build_repository from app.main import app from app.schemas.workflow import WorkflowOut from app.scm import SCM, Provider -from app.tests.mocks.mock_s3_resource import MockS3ServiceResource +from app.tests.mocks import MockOpaService, MockS3ServiceResource, MockSlurmCluster from app.tests.utils.bucket import create_random_bucket from app.tests.utils.user import UserWithAuthHeader, create_random_user, decode_mock_token, get_authorization_headers -from app.tests.utils.utils import handle_http_request, random_hex_string, random_lower_string +from app.tests.utils.utils import random_hex_string, random_lower_string jwt_secret = token_urlsafe(32) @@ -58,31 +59,56 @@ def mock_s3_service() -> Iterator[MockS3ServiceResource]: mock_s3.delete_bucket(name=settings.WORKFLOW_BUCKET, force_delete=True) +@pytest.fixture(scope="session") +def mock_slurm_cluster() -> Iterator[MockSlurmCluster]: + mock_slurm = MockSlurmCluster() + yield mock_slurm + mock_slurm.reset() + + +@pytest.fixture(scope="session") +def mock_opa_service() -> Iterator[MockOpaService]: + mock_opa = MockOpaService() + yield mock_opa + mock_opa.reset() + + @pytest_asyncio.fixture(scope="module") -async def client(mock_s3_service: MockS3ServiceResource, db: AsyncSession) -> AsyncIterator[httpx.AsyncClient]: +async def client( + mock_s3_service: MockS3ServiceResource, + db: AsyncSession, + mock_opa_service: MockOpaService, + mock_slurm_cluster: MockSlurmCluster, +) -> AsyncIterator[httpx.AsyncClient]: """ Fixture for creating a TestClient and perform HTTP Request on it. - Overrides the dependency for the RGW admin operations. + Overrides serveral dependencies. """ - def get_mock_s3() -> MockS3ServiceResource: - return mock_s3_service - - def get_decode_token_function() -> Callable[[str], Dict[str, str]]: - # Override the decode_jwt function with mock function for tests and inject random shared secret - return partial(decode_mock_token, secret=jwt_secret) - async def get_mock_httpx_client(raise_error: bool = False) -> AsyncIterator[httpx.AsyncClient]: # raises an 404 error if the query parameter 'raise_error' is true def mock_request_handler(request: httpx.Request) -> httpx.Response: - return handle_http_request(request, raise_error) + url = str(request.url) + if url.startswith(str(settings.OPA_URI)): + return mock_opa_service.handle_request(request) + elif url.startswith(str(settings.SLURM_ENDPOINT)): + return mock_slurm_cluster.handle_request(request) + elif raise_error: + return httpx.Response(status_code=status.HTTP_404_NOT_FOUND, json={}) + return httpx.Response( + status_code=status.HTTP_200_OK, + json={ + # When checking if a file exists in a git repository, the GitHub API expects this in a response + "download_url": "https://example.com" + }, + ) async with httpx.AsyncClient(transport=httpx.MockTransport(mock_request_handler)) as http_client: yield http_client app.dependency_overrides[get_httpx_client] = get_mock_httpx_client - app.dependency_overrides[get_s3_resource] = get_mock_s3 - app.dependency_overrides[get_decode_jwt_function] = get_decode_token_function + app.dependency_overrides[get_s3_resource] = lambda: mock_s3_service + app.dependency_overrides[get_decode_jwt_function] = lambda: partial(decode_mock_token, secret=jwt_secret) app.dependency_overrides[get_db] = lambda: db async with httpx.AsyncClient(app=app, base_url="http://localhost") as ac: yield ac @@ -101,34 +127,40 @@ async def db() -> AsyncIterator[AsyncSession]: @pytest_asyncio.fixture(scope="function") -async def random_user(db: AsyncSession) -> AsyncIterator[UserWithAuthHeader]: +async def random_user(db: AsyncSession, mock_opa_service: MockOpaService) -> AsyncIterator[UserWithAuthHeader]: """ - Create a random user and deletes him afterwards. + Create a random user and deletes him afterward. """ user = await create_random_user(db) + mock_opa_service.add_user(user.uid, privileged=True) yield UserWithAuthHeader(user=user, auth_headers=get_authorization_headers(uid=user.uid, secret=jwt_secret)) + mock_opa_service.delete_user(user.uid) await db.delete(user) await db.commit() @pytest_asyncio.fixture(scope="module") -async def random_second_user(db: AsyncSession) -> AsyncIterator[UserWithAuthHeader]: +async def random_second_user(db: AsyncSession, mock_opa_service: MockOpaService) -> AsyncIterator[UserWithAuthHeader]: """ - Create a random second user and deletes him afterwards. + Create a random second user and deletes him afterward. """ user = await create_random_user(db) + mock_opa_service.add_user(user.uid) yield UserWithAuthHeader(user=user, auth_headers=get_authorization_headers(uid=user.uid, secret=jwt_secret)) + mock_opa_service.delete_user(user.uid) await db.delete(user) await db.commit() @pytest_asyncio.fixture(scope="module") -async def random_third_user(db: AsyncSession) -> AsyncIterator[UserWithAuthHeader]: +async def random_third_user(db: AsyncSession, mock_opa_service: MockOpaService) -> AsyncIterator[UserWithAuthHeader]: """ - Create a random third user and deletes him afterwards. + Create a random third user and deletes him afterward. """ user = await create_random_user(db) + mock_opa_service.add_user(user.uid) yield UserWithAuthHeader(user=user, auth_headers=get_authorization_headers(uid=user.uid, secret=jwt_secret)) + mock_opa_service.delete_user(user.uid) await db.delete(user) await db.commit() @@ -237,16 +269,24 @@ async def random_workflow_execution( random_workflow_version: WorkflowVersion, random_user: UserWithAuthHeader, mock_s3_service: MockS3ServiceResource, + mock_slurm_cluster: MockSlurmCluster, ) -> AsyncIterator[WorkflowExecution]: """ Create a random workflow execution. Will be deleted, when the user is deleted. """ execution = WorkflowExecution( - user_id=random_user.user.uid, workflow_version_id=random_workflow_version.git_commit_hash, slurm_job_id=1 + user_id=random_user.user.uid, + workflow_version_id=random_workflow_version.git_commit_hash, + slurm_job_id=-1, ) db.add(execution) await db.commit() - await db.refresh(execution) + slurm_job_id = mock_slurm_cluster.add_workflow_execution({"job": {"name": str(execution.execution_id)}}) + await db.execute( + update(WorkflowExecution) + .where(WorkflowExecution._execution_id == execution.execution_id.bytes) + .values(slurm_job_id=slurm_job_id) + ) mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(f"params-{execution.execution_id.hex}.json").upload_fileobj( BytesIO(b"{}") ) diff --git a/app/tests/crud/test_workflow_version.py b/app/tests/crud/test_workflow_version.py index 001593e..c0d6734 100644 --- a/app/tests/crud/test_workflow_version.py +++ b/app/tests/crud/test_workflow_version.py @@ -10,7 +10,9 @@ from app.tests.utils.utils import random_hex_string class TestWorkflowVersionCRUDGet: @pytest.mark.asyncio - async def test_get_specific_workflow_version(self, db: AsyncSession, random_workflow: WorkflowOut) -> None: + async def test_get_specific_workflow_version( + self, db: AsyncSession, random_workflow_version: WorkflowVersion + ) -> None: """ Test for getting a workflow version by its id from CRUD Repository. @@ -18,16 +20,17 @@ class TestWorkflowVersionCRUDGet: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - random_workflow : app.schemas.workflow.WorkflowOut - Random bucket for testing. + random_workflow_version : clowmdb.model.WorkflowVersion + Random workflow version for testing. """ - workflow = await CRUDWorkflowVersion.get(db, random_workflow.versions[0].git_commit_hash) - assert workflow is not None - assert workflow.workflow_id == random_workflow.workflow_id + version = await CRUDWorkflowVersion.get(db, random_workflow_version.git_commit_hash) + assert version is not None + assert version.workflow_id == random_workflow_version.workflow_id + assert version.git_commit_hash == random_workflow_version.git_commit_hash @pytest.mark.asyncio async def test_get_specific_workflow_version_with_populated_workflow( - self, db: AsyncSession, random_workflow: WorkflowOut + self, db: AsyncSession, random_workflow_version: WorkflowVersion ) -> None: """ Test for getting a workflow version by its id with populated workflow from CRUD Repository. @@ -36,19 +39,18 @@ class TestWorkflowVersionCRUDGet: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - random_workflow : app.schemas.workflow.WorkflowOut - Random bucket for testing. + random_workflow_version : clowmdb.model.WorkflowVersion + Random workflow version for testing. """ - workflow = await CRUDWorkflowVersion.get( - db, random_workflow.versions[0].git_commit_hash, populate_workflow=True - ) - assert workflow is not None - assert workflow.workflow_id == random_workflow.workflow_id - assert workflow.workflow.workflow_id == random_workflow.workflow_id + version = await CRUDWorkflowVersion.get(db, random_workflow_version.git_commit_hash, populate_workflow=True) + assert version is not None + assert version.workflow_id == random_workflow_version.workflow_id + assert version.git_commit_hash == random_workflow_version.git_commit_hash + assert version.workflow.workflow_id == random_workflow_version.workflow_id @pytest.mark.asyncio async def test_get_latest_unpublished_workflow_version( - self, db: AsyncSession, random_workflow: WorkflowOut + self, db: AsyncSession, random_workflow_version: WorkflowVersion ) -> None: """ Test for getting the latest workflow version from a workflow which is not published from CRUD Repository. @@ -57,16 +59,20 @@ class TestWorkflowVersionCRUDGet: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - random_workflow : app.schemas.workflow.WorkflowOut - Random bucket for testing. + random_workflow_version : clowmdb.model.WorkflowVersion + Random workflow version for testing. """ - workflow_version = await CRUDWorkflowVersion.get_latest(db, random_workflow.workflow_id, published=False) + workflow_version = await CRUDWorkflowVersion.get_latest( + db, random_workflow_version.workflow_id, published=False + ) assert workflow_version is not None - assert workflow_version.workflow_id == random_workflow.workflow_id - assert workflow_version.git_commit_hash == random_workflow.versions[0].git_commit_hash + assert workflow_version.workflow_id == random_workflow_version.workflow_id + assert workflow_version.git_commit_hash == random_workflow_version.git_commit_hash @pytest.mark.asyncio - async def test_get_latest_published_workflow_version(self, db: AsyncSession, random_workflow: WorkflowOut) -> None: + async def test_get_latest_published_workflow_version( + self, db: AsyncSession, random_workflow_version: WorkflowVersion + ) -> None: """ Test for getting the latest workflow version from a workflow which is published from CRUD Repository. @@ -74,14 +80,14 @@ class TestWorkflowVersionCRUDGet: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - random_workflow : app.schemas.workflow.WorkflowOut - Random bucket for testing. + random_workflow_version : clowmdb.model.WorkflowVersion + Random workflow version for testing. """ - workflow_version = await CRUDWorkflowVersion.get_latest(db, random_workflow.workflow_id) + workflow_version = await CRUDWorkflowVersion.get_latest(db, random_workflow_version.workflow_id) assert workflow_version is None @pytest.mark.asyncio - async def test_get_all_workflow_versions(self, db: AsyncSession, random_workflow: WorkflowOut) -> None: + async def test_get_all_workflow_versions(self, db: AsyncSession, random_workflow_version: WorkflowVersion) -> None: """ Test for getting all versions from a workflow from the CRUD Repository. @@ -89,16 +95,16 @@ class TestWorkflowVersionCRUDGet: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - random_workflow : app.schemas.workflow.WorkflowOut - Random bucket for testing. + random_workflow_version : clowmdb.model.WorkflowVersion + Random workflow version for testing. """ - workflow_versions = await CRUDWorkflowVersion.list(db, random_workflow.workflow_id) + workflow_versions = await CRUDWorkflowVersion.list(db, random_workflow_version.workflow_id) assert len(workflow_versions) == 1 - assert workflow_versions[0].git_commit_hash == random_workflow.versions[0].git_commit_hash + assert workflow_versions[0].git_commit_hash == random_workflow_version.git_commit_hash @pytest.mark.asyncio async def test_get_all_workflow_version_with_specific_status( - self, db: AsyncSession, random_workflow: WorkflowOut + self, db: AsyncSession, random_workflow_version: WorkflowVersion ) -> None: """ Test for getting all versions with a specific status from a workflow from the CRUD Repository. @@ -107,14 +113,14 @@ class TestWorkflowVersionCRUDGet: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - random_workflow : app.schemas.workflow.WorkflowOut - Random bucket for testing. + random_workflow_version : clowmdb.model.WorkflowVersion + Random workflow version for testing. """ workflow_versions = await CRUDWorkflowVersion.list( - db, random_workflow.workflow_id, version_status=[random_workflow.versions[0].status] + db, random_workflow_version.workflow_id, version_status=[random_workflow_version.status] ) assert len(workflow_versions) == 1 - assert workflow_versions[0].git_commit_hash == random_workflow.versions[0].git_commit_hash + assert workflow_versions[0].git_commit_hash == random_workflow_version.git_commit_hash class TestWorkflowVersionCRUDCreate: @@ -147,7 +153,9 @@ class TestWorkflowVersionCRUDCreate: class TestWorkflowVersionCRUDUpdate: @pytest.mark.asyncio - async def test_update_workflow_version_status(self, db: AsyncSession, random_workflow: WorkflowOut) -> None: + async def test_update_workflow_version_status( + self, db: AsyncSession, random_workflow_version: WorkflowVersion + ) -> None: """ Test for creating a workflow version in CRUD Repository. @@ -155,16 +163,84 @@ class TestWorkflowVersionCRUDUpdate: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - random_workflow : app.schemas.workflow.WorkflowOut - Random bucket for testing. + random_workflow_version : clowmdb.model.WorkflowVersion + Random workflow version for testing. """ await CRUDWorkflowVersion.update_status( - db, git_commit_hash=random_workflow.versions[0].git_commit_hash, status=WorkflowVersion.Status.PUBLISHED + db, git_commit_hash=random_workflow_version.git_commit_hash, status=WorkflowVersion.Status.PUBLISHED ) - stmt = select(WorkflowVersion).where( - WorkflowVersion.git_commit_hash == random_workflow.versions[0].git_commit_hash - ) + stmt = select(WorkflowVersion).where(WorkflowVersion.git_commit_hash == random_workflow_version.git_commit_hash) version = await db.scalar(stmt) assert version assert version.status == WorkflowVersion.Status.PUBLISHED + + @pytest.mark.asyncio + async def test_update_workflow_version_icon( + self, db: AsyncSession, random_workflow_version: WorkflowVersion + ) -> None: + """ + Test for updating the worklfow version icon + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_workflow_version : clowmdb.model.WorkflowVersion + Random workflow version for testing. + """ + new_slug = random_hex_string() + await CRUDWorkflowVersion.update_icon( + db, git_commit_hash=random_workflow_version.git_commit_hash, icon_slug=new_slug + ) + + stmt = select(WorkflowVersion).where(WorkflowVersion.git_commit_hash == random_workflow_version.git_commit_hash) + version = await db.scalar(stmt) + assert version + assert version.git_commit_hash == random_workflow_version.git_commit_hash + assert version.icon_slug == new_slug + + @pytest.mark.asyncio + async def test_remove_workflow_version_icon( + self, db: AsyncSession, random_workflow_version: WorkflowVersion + ) -> None: + """ + Test for removeing the workflow version icon + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_workflow_version : clowmdb.model.WorkflowVersion + Random workflow version for testing. + """ + await CRUDWorkflowVersion.update_icon( + db, git_commit_hash=random_workflow_version.git_commit_hash, icon_slug=None + ) + + stmt = select(WorkflowVersion).where(WorkflowVersion.git_commit_hash == random_workflow_version.git_commit_hash) + version = await db.scalar(stmt) + assert version + assert version.git_commit_hash == random_workflow_version.git_commit_hash + assert version.icon_slug is None + + +class TestWorkflowVersionCRUDCheck: + @pytest.mark.asyncio + async def test_check_icon_dependency(self, db: AsyncSession, random_workflow_version: WorkflowVersion) -> None: + """ + Test for checking if a workflow version is dependent on an icon + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_workflow_version : clowmdb.model.WorkflowVersion + Random workflow version for testing. + """ + dependent1 = await CRUDWorkflowVersion.icon_exists(db, random_workflow_version.icon_slug) + + assert dependent1 + + dependent2 = await CRUDWorkflowVersion.icon_exists(db, random_hex_string()) + assert not dependent2 diff --git a/app/tests/mocks/__init__.py b/app/tests/mocks/__init__.py index e69de29..9fb9a28 100644 --- a/app/tests/mocks/__init__.py +++ b/app/tests/mocks/__init__.py @@ -0,0 +1,3 @@ +from .mock_opa_service import MockOpaService # noqa: F401 +from .mock_s3_resource import MockS3ServiceResource # noqa: F401 +from .mock_slurm_cluster import MockSlurmCluster # noqa: F401 diff --git a/app/tests/mocks/authorization_service.py b/app/tests/mocks/authorization_service.py deleted file mode 100644 index 3793001..0000000 --- a/app/tests/mocks/authorization_service.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Dict -from uuid import uuid4 - -from fastapi import status -from httpx import Response - -from app.schemas.security import AuthzResponse - - -def handle_request(body: Dict[str, str]) -> Response: - """ - Handle a request to the authorization service during testing. - - Parameters - ---------- - body : Dict[str, str] - Body of the request. - - Returns - ------- - response : httpx.Response - Mock response. - """ - response_body = AuthzResponse(result=not request_admin_permission(body), decision_id=str(uuid4())).model_dump() - return Response(status_code=status.HTTP_200_OK, json=response_body) - - -def request_admin_permission(body: Dict[str, str]) -> bool: - """ - Helper function to determine if the authorization request needs the 'administrator' role. - - Parameters - ---------- - body : Dict[str, str] - Body of the request. - - Returns - ------- - decision : bool - Flag if the request needs the 'administrator' role - """ - operation = body["operation"] - checks = "any" in operation - if "bucket_permission" in body["resource"]: - checks = checks or "all" in operation - return checks diff --git a/app/tests/mocks/mock_opa_service.py b/app/tests/mocks/mock_opa_service.py new file mode 100644 index 0000000..639e0a0 --- /dev/null +++ b/app/tests/mocks/mock_opa_service.py @@ -0,0 +1,92 @@ +import json +from typing import Dict +from uuid import uuid4 + +from fastapi import status +from httpx import Request, Response + +from app.schemas.security import AuthzRequest, AuthzResponse + + +class MockOpaService: + """ + Class to mock the Open Policy Agent service. + Has a simplified role management. A user can be either "Admin" or "Normal User". + """ + + def __init__(self) -> None: + self._users: Dict[str, bool] = {} + + def add_user(self, uid: str, privileged: bool = False) -> None: + """ + Add a user to the mock service. + + Parameters + ---------- + uid : str + ID of a user. + privileged : bool, default False + Flag if the user is an Admin or not. + """ + self._users[uid] = privileged + + def delete_user(self, uid: str) -> None: + """ + Delete a user in the mock service. + + Parameters + ---------- + uid : str + ID of the user to delete. + """ + if uid in self._users.keys(): + del self._users[uid] + + def reset(self) -> None: + """ + Reset the mock service to its initial state. + """ + self._users = {} + + def handle_request(self, request: Request) -> Response: + """ + Handle the raw request that is sent to the mock service. + + Parameters + ---------- + request: httpx.Request + Raw HTTP request object. + + Returns + ------- + response : httpx.Response + Appropriate response to the received request. + """ + authz_request = AuthzRequest(**json.loads(request.read().decode("utf-8"))["input"]) + if authz_request.uid not in self._users: + result = False + else: + result = not MockOpaService.request_admin_permission(authz_request) or self._users[authz_request.uid] + return Response( + status_code=status.HTTP_200_OK, json=AuthzResponse(result=result, decision_id=str(uuid4())).model_dump() + ) + + @staticmethod + def request_admin_permission(authz_request: AuthzRequest) -> bool: + """ + Helper function to determine if the authorization request needs the 'administrator' role. + + Parameters + ---------- + authz_request : app.schemas.security.AuthzRequest + Body of the request. + + Returns + ------- + decision : bool + Flag if the request needs the 'administrator' role + """ + checks = "any" in authz_request.operation + if "bucket_permission" in authz_request.resource: + checks = checks or "all" in authz_request.operation + return checks diff --git a/app/tests/mocks/mock_slurm_cluster.py b/app/tests/mocks/mock_slurm_cluster.py new file mode 100644 index 0000000..3e2b1fa --- /dev/null +++ b/app/tests/mocks/mock_slurm_cluster.py @@ -0,0 +1,164 @@ +import json +import re +from typing import Any, Dict, List, Optional + +from fastapi import status +from httpx import Headers, Request, Response + +SlurmRequestBody = Dict[str, Any] + + +class MockSlurmCluster: + """ + Class to mock the Rest API of a Slurm cluster + """ + + _method_not_allowed_response = Response( + status_code=status.HTTP_405_METHOD_NOT_ALLOWED, text="Requested REST method is not defined at URL." + ) + + def __init__(self, version: str = "v0.0.38") -> None: + self._request_bodies: List[SlurmRequestBody] = [] + self._job_states: List[bool] = [] + self.base_path = f"slurm/{version}" + self._job_path_regex = re.compile(f"^/slurm/{re.escape(version)}/job/[\d]*$") + + def handle_request(self, request: Request) -> Response: + """ + Handle the raw request that is sent to the API. + + Parameters + ---------- + request: httpx.Request + Raw HTTP request object. + + Returns + ------- + response : httpx.Response + Appropriate response to the request + """ + # Authorize request + error_response = MockSlurmCluster.authorize_request(request.headers) + if error_response is not None: + return error_response + # If a job should be submitted + if request.url.path == f"/{self.base_path}/job/submit": + # Route supports POST Method + if request.method == "POST": + # Parse request and save it + job_id = self.add_workflow_execution(json.loads(request.read().decode("utf-8"))) + # Return index into list with saved requests + return Response(status_code=status.HTTP_200_OK, json={"job_id": job_id}) + else: + return self._method_not_allowed_response + # If a job should be canceled + elif self._job_path_regex.match(request.url.path) is not None: + # Route supports DELETE Method + if request.method == "DELETE": + # Parse job ID from url path + job_id = int(request.url.path.split("/")[-1]) + # Check if job exists and send appropriate response + if job_id < len(self._request_bodies): + self._job_states[job_id] = False + return Response(status_code=status.HTTP_200_OK, text="") + else: + return Response(status_code=status.HTTP_404_NOT_FOUND, text=f"Job with ID {job_id} not found") + else: + return self._method_not_allowed_response + # Requested route is not mocked + return Response(status_code=status.HTTP_404_NOT_FOUND, text="Unable find requested URL.") + + @staticmethod + def authorize_request(headers: Headers) -> Optional[Response]: + """ + Check if the authorization headers are present in a request + + Parameters + ---------- + headers : httpx.Headers + Request headers + + Returns + ------- + response : httpx.Response | None, default None + Response with an error message if the request is not authorized, None otherwise. + """ + if headers.get("X-SLURM-USER-TOKEN") is None or headers.get("X-SLURM-USER-NAME") is None: + return Response(status_code=status.HTTP_401_UNAUTHORIZED, text="Authentication failure") + return None + + def reset(self) -> None: + """ + Resets the mock service to its initial state. + """ + self._request_bodies = [] + + def add_workflow_execution(self, job: SlurmRequestBody) -> int: + """ + Add a workflow execution to the list. + + Parameters + ---------- + job : Dict[str, Any] + The requests body for submiting a slurm job. + + Returns + ------- + job_id : int + The assigned job id. + """ + self._request_bodies.append(job) + self._job_states.append(True) + return len(self._request_bodies) - 1 + + def get_job(self, job_id: int) -> SlurmRequestBody: + """ + Get a job by its ID. + + Parameters + ---------- + job_id : int + The ID of a job. + + Returns + ------- + job : Dict[str, Any] + The request body of the job with the given ID. + """ + if job_id < 0: + raise IndexError("Index must be larger than 0") + return self._request_bodies[job_id] + + def job_active(self, job_id: int) -> bool: + """ + Check if a job is active or got canceled. + + Parameters + ---------- + job_id : int + The ID of a job. + + Returns + ------- + job : bool + Flag if the job is still active. + """ + if job_id < 0: + raise IndexError("Index must be larger than 0") + return self._job_states[job_id] + + def get_job_by_name(self, job_name: str) -> Optional[SlurmRequestBody]: + """ + Get a job by its name. + + Parameters + ---------- + job_name : str + Name of a job. + + Returns + ------- + job : Dict[str, Any] | None + The request body of the job with the given name if it exists. + """ + return next((job for job in self._request_bodies if job["job"]["name"] == job_name), None) diff --git a/app/tests/mocks/slurm_cluster.py b/app/tests/mocks/slurm_cluster.py deleted file mode 100644 index 2049a57..0000000 --- a/app/tests/mocks/slurm_cluster.py +++ /dev/null @@ -1,22 +0,0 @@ -from fastapi import status -from httpx import Response - - -def handle_request(http_method: str) -> Response: - """ - Handle a request to a Slurm cluster during testing. - - Parameters - ---------- - http_method : str - Used HTTP method of request. - - Returns - ------- - response : httpx.Response - Mock response. - """ - if http_method == "POST": - # If the method is POST, then a job should be submitted and a job request is expected in the response - return Response(status_code=status.HTTP_200_OK, json={"job_id": 1}) - return Response(status_code=status.HTTP_200_OK, json={}) diff --git a/app/tests/utils/utils.py b/app/tests/utils/utils.py index 84d2219..daba17c 100644 --- a/app/tests/utils/utils.py +++ b/app/tests/utils/utils.py @@ -1,13 +1,5 @@ import random import string -from typing import Dict - -import httpx -from fastapi import status - -from app.core.config import settings -from app.tests.mocks.authorization_service import handle_request as auth_handle_request -from app.tests.mocks.slurm_cluster import handle_request as slurm_handle_request def random_lower_string(length: int = 32) -> str: @@ -54,36 +46,3 @@ def random_ipv4_string() -> str: Random IPv4 address. """ return ".".join(str(random.randint(0, 255)) for _ in range(4)) - - -def handle_http_request(request: httpx.Request, raise_error: bool = False) -> httpx.Response: - """ - Handler for a mock HTTP request. Forwards it to the appropriate handler based on the requested URL. - - Parameters - ---------- - request : httpx.Request - Mock request. - raise_error : bool, default False - Flag if to return an 404 status code response. - - Returns - ------- - response : httpx.Response - Generated mock request. - """ - url = str(request.url) - if url.startswith(str(settings.OPA_URI)): - request_body: Dict[str, str] = eval(request.content.decode("utf-8"))["input"] - return auth_handle_request(body=request_body) - elif url.startswith(str(settings.SLURM_ENDPOINT)): - return slurm_handle_request(request.method) - elif raise_error: - return httpx.Response(status_code=status.HTTP_404_NOT_FOUND, json={}) - return httpx.Response( - status_code=status.HTTP_200_OK, - json={ - # When checking if a file exists in a git repository, the GitHub API expects this in a response - "download_url": "https://example.com" - }, - ) diff --git a/mako_templates/nextflow_command.template b/mako_templates/nextflow_command.template index 9c07e2c..7804b17 100644 --- a/mako_templates/nextflow_command.template +++ b/mako_templates/nextflow_command.template @@ -11,7 +11,7 @@ ${nx_bin} run ${repo.repo_url} \ -with-timeline s3://${report_output_bucket}/timeline-${execution_id.hex}.html \ % endif % if workflow_entrypoint is not None: --entry workflow_entrypoint \ +-entry ${workflow_entrypoint} \ % endif % if configuration is not None: -c ${configuration} \ -- GitLab