diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 51be2846e3cf0c92f204586eca2e1ae2aec9202e..9b2bf6ac8d8d6a02e9785ce6c35ef7f348c59298 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,11 +21,11 @@ repos: files: app args: [--check] - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.1.7' + rev: 'v0.1.8' hooks: - id: ruff - repo: https://github.com/PyCQA/isort - rev: 5.13.1 + rev: 5.13.2 hooks: - id: isort files: app diff --git a/app/api/dependencies.py b/app/api/dependencies.py index 0b491edef86419832e372889f9f1db10300e6ba7..c177dd83393eb15df007edd0063f8c7ff1f9666c 100644 --- a/app/api/dependencies.py +++ b/app/api/dependencies.py @@ -15,7 +15,7 @@ from app.core.config import settings from app.core.security import decode_token, request_authorization from app.crud import CRUDResource, CRUDUser from app.schemas.security import JWT, AuthzRequest, AuthzResponse -from app.slurm.rest_client import SlurmClient +from app.slurm.rest_client import SlurmClient as _SlurmClient from app.utils.otlp import start_as_current_span_async if TYPE_CHECKING: @@ -29,6 +29,19 @@ tracer = trace.get_tracer_provider().get_tracer(__name__) async def get_s3_resource(request: Request) -> S3ServiceResource: # pragma: no cover + """ + Get a async S3 service with an open connection. + + Parameters + ---------- + request : fastapi.Request + Request object from FastAPI where the client is attached to. + + Returns + ------- + client : types_aiobotocore_s3.service_resource.S3ServiceResource + Async S3 resource with open connection + """ return request.app.s3_resource @@ -56,17 +69,45 @@ DBSession = Annotated[AsyncSession, Depends(get_db)] async def get_httpx_client(request: Request) -> AsyncClient: # pragma: no cover - # Fetch open http client from the app + """ + Get a async http client with an open connection. + + Parameters + ---------- + request : fastapi.Request + Request object from FastAPI where the client is attached to. + + Returns + ------- + client : httpx.AsyncClient + Async http client with open connection + """ return request.app.requests_client HTTPClient = Annotated[AsyncClient, Depends(get_httpx_client)] -def get_slurm_client(client: HTTPClient) -> SlurmClient: +def get_slurm_client(client: HTTPClient) -> _SlurmClient: + """ + Get a async slurm client. + + Parameters + ---------- + client : httpx.AsyncClient + Async http client toi inject into the slurm client + + Returns + ------- + slurm_client : app.slurm.rest_client.SlurmClient + Slurm client with an open connection. + """ return SlurmClient(client=client) +SlurmClient = Annotated[_SlurmClient, Depends(get_slurm_client)] + + def get_decode_jwt_function() -> Callable[[str], Dict[str, str]]: # pragma: no cover """ Get function to decode and verify the JWT. diff --git a/app/api/endpoints/resource_version.py b/app/api/endpoints/resource_version.py index 3db5b21614020acd13e38ed9488f6194aace81ce..ce923649aaa8ece90cb5a38bd3c22650dcf764e6 100644 --- a/app/api/endpoints/resource_version.py +++ b/app/api/endpoints/resource_version.py @@ -1,7 +1,7 @@ from typing import Annotated, Any, Awaitable, Callable, List, Optional from clowmdb.models import ResourceVersion -from fastapi import APIRouter, Depends, Query, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, status from opentelemetry import trace from app.api.dependencies import ( @@ -10,10 +10,25 @@ from app.api.dependencies import ( CurrentResourceVersion, CurrentUser, DBSession, + S3Resource, + SlurmClient, +) +from app.api.resource_cluster_utils import ( + delete_cluster_resource_version, + set_cluster_resource_version_latest, + synchronize_cluster_resource, +) +from app.api.resource_s3_utils import ( + add_s3_resource_version_info, + delete_s3_resource_version, + get_s3_resource_version_obj, + give_permission_to_s3_resource_version, + remove_permission_to_s3_resource_version, ) from app.crud import CRUDResourceVersion +from app.schemas.resource import S3ResourceVersionInfo from app.schemas.resource_version import ResourceVersionIn, ResourceVersionOut -from app.utils.otlp import start_as_current_span_async +from app.utils import start_as_current_span_async router = APIRouter(prefix="/resources/{rid}/versions", tags=["ResourceVersion"]) resource_authorization = AuthorizationDependency(resource="resource") @@ -77,6 +92,8 @@ async def request_resource_version( resource_version_in: ResourceVersionIn, current_user: CurrentUser, db: DBSession, + s3: S3Resource, + background_tasks: BackgroundTasks, ) -> ResourceVersionOut: """ Request a new resource version. @@ -95,6 +112,10 @@ async def request_resource_version( Current user. Dependency injection. db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. Dependency Injection. + background_tasks : fastapi.BackgroundTasks + Entrypoint for new BackgroundTasks. Provided by FastAPI. + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. Dependency Injection. """ current_span = trace.get_current_span() current_span.set_attributes( @@ -107,7 +128,21 @@ async def request_resource_version( resource_version = await CRUDResourceVersion.create( db, resource_id=resource.resource_id, release=resource_version_in.release ) - return ResourceVersionOut.from_db_resource_version(resource_version) + resource_version_out = ResourceVersionOut.from_db_resource_version(resource_version) + background_tasks.add_task( + give_permission_to_s3_resource_version, + s3=s3, + resource_version=resource_version_out, + maintainer_id=current_user.uid, + ) + background_tasks.add_task( + add_s3_resource_version_info, + s3=s3, + s3_resource_version_info=S3ResourceVersionInfo.from_models( + resource=resource, resource_version=resource.versions[0], maintainer=current_user + ), + ) + return resource_version_out @router.get("/{rvid}", summary="Get version of a resource") @@ -157,6 +192,8 @@ async def request_resource_version_sync( resource_version: CurrentResourceVersion, current_user: CurrentUser, db: DBSession, + s3: S3Resource, + background_tasks: BackgroundTasks, ) -> ResourceVersionOut: """ Request the synchronization of a resource version to the cluster. @@ -175,6 +212,10 @@ async def request_resource_version_sync( Resource Version associated with the ID in the path. Dependency Injection. db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. Dependency Injection. + background_tasks : fastapi.BackgroundTasks + Entrypoint for new BackgroundTasks. Provided by FastAPI. + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. Dependency Injection. """ trace.get_current_span().set_attributes( {"resource_id": str(resource.resource_id), "resource_version_id": str(resource_version.resource_version_id)} @@ -183,11 +224,30 @@ async def request_resource_version_sync( if current_user.uid != resource.maintainer_id: rbac_operation = "request_sync_any" await authorization(rbac_operation) + if resource_version.status not in [ + ResourceVersion.Status.RESOURCE_REQUESTED, + ResourceVersion.Status.CLUSTER_DELETED, + ]: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Can't request sync for resource version with status {resource_version.status.name}", + ) + resource_version_out = ResourceVersionOut.from_db_resource_version(resource_version) + if await get_s3_resource_version_obj(s3, ResourceVersionOut.from_db_resource_version(resource_version)) is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Missing resource at S3 path {resource_version_out.s3_path}", + ) await CRUDResourceVersion.update_status( db, resource_version_id=resource_version.resource_version_id, status=ResourceVersion.Status.SYNC_REQUESTED ) - resource_version.status = ResourceVersion.Status.SYNC_REQUESTED - return ResourceVersionOut.from_db_resource_version(resource_version) + resource_version_out.status = ResourceVersion.Status.SYNC_REQUESTED + background_tasks.add_task( + remove_permission_to_s3_resource_version, + s3=s3, + resource_version=resource_version_out, + ) + return resource_version_out @router.put("/{rvid}/sync", summary="Synchronize resource version with cluster") @@ -197,6 +257,9 @@ async def resource_version_sync( resource: CurrentResource, resource_version: CurrentResourceVersion, db: DBSession, + s3: S3Resource, + background_tasks: BackgroundTasks, + slurm_client: SlurmClient, ) -> ResourceVersionOut: """ Synchronize the resource version to the cluster. @@ -213,22 +276,51 @@ async def resource_version_sync( Resource Version associated with the ID in the path. Dependency Injection. db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. Dependency Injection. + slurm_client : app.slurm.rest_client.SlurmClient + Slurm client with an open connection. Dependency Injection + background_tasks : fastapi.BackgroundTasks + Entrypoint for new BackgroundTasks. Provided by FastAPI. + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. Dependency Injection. """ trace.get_current_span().set_attributes( {"resource_id": str(resource.resource_id), "resource_version_id": str(resource_version.resource_version_id)} ) await authorization("sync") + if resource_version.status not in [ + ResourceVersion.Status.SYNC_REQUESTED, + ResourceVersion.Status.CLUSTER_DELETED, + ]: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Can't sync resource version with status {resource_version.status.name}", + ) + resource_version_out = ResourceVersionOut.from_db_resource_version(resource_version) + if await get_s3_resource_version_obj(s3, ResourceVersionOut.from_db_resource_version(resource_version)) is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Missing resource at S3 path {resource_version_out.s3_path}", + ) await CRUDResourceVersion.update_status( db, resource_version_id=resource_version.resource_version_id, status=ResourceVersion.Status.SYNCHRONIZING ) - resource_version.status = ResourceVersion.Status.SYNCHRONIZING - return ResourceVersionOut.from_db_resource_version(resource_version) + resource_version_out.status = ResourceVersion.Status.SYNCHRONIZING + + background_tasks.add_task( + synchronize_cluster_resource, db=db, slurm_client=slurm_client, resource_version=resource_version_out + ) + return resource_version_out @router.put("/{rvid}/latest", summary="Set resource version to latest") @start_as_current_span_async("api_resource_version_set_latest", tracer=tracer) async def resource_version_latest( - authorization: Authorization, resource: CurrentResource, resource_version: CurrentResourceVersion + authorization: Authorization, + db: DBSession, + resource: CurrentResource, + resource_version: CurrentResourceVersion, + background_tasks: BackgroundTasks, + slurm_client: SlurmClient, ) -> ResourceVersionOut: """ Set the resource version as the latest version. @@ -239,16 +331,31 @@ async def resource_version_latest( ---------- authorization : Callable[[str], Awaitable[Any]] Async function to ask the auth service for authorization. Dependency Injection. + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. Dependency Injection. resource : clowmdb.models.Resource Resource associated with the ID in the path. Dependency Injection. resource_version : clowmdb.models.ResourceVersion Resource Version associated with the ID in the path. Dependency Injection. + slurm_client : app.slurm.rest_client.SlurmClient + Slurm client with an open connection. Dependency Injection + background_tasks : fastapi.BackgroundTasks + Entrypoint for new BackgroundTasks. Provided by FastAPI. """ trace.get_current_span().set_attributes( {"resource_id": str(resource.resource_id), "resource_version_id": str(resource_version.resource_version_id)} ) await authorization("set_latest") - return resource_version + if resource_version.status != ResourceVersion.Status.SYNCHRONIZED: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Can't set resource version to {ResourceVersion.Status.LATEST.name} with status {resource_version.status.name}", + ) + resource_version_out = ResourceVersionOut.from_db_resource_version(resource_version) + background_tasks.add_task( + set_cluster_resource_version_latest, db=db, slurm_client=slurm_client, resource_version=resource_version_out + ) + return resource_version_out @router.delete("/{rvid}/cluster", summary="Delete resource version on cluster") @@ -258,6 +365,8 @@ async def delete_resource_version_cluster( resource: CurrentResource, resource_version: CurrentResourceVersion, db: DBSession, + background_tasks: BackgroundTasks, + slurm_client: SlurmClient, ) -> ResourceVersionOut: """ Delete the resource version on the cluster. @@ -274,6 +383,10 @@ async def delete_resource_version_cluster( Resource Version associated with the ID in the path. Dependency Injection. db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. Dependency Injection. + slurm_client : app.slurm.rest_client.SlurmClient + Slurm client with an open connection. Dependency Injection + background_tasks : fastapi.BackgroundTasks + Entrypoint for new BackgroundTasks. Provided by FastAPI. """ trace.get_current_span().set_attributes( {"resource_id": str(resource.resource_id), "resource_version_id": str(resource_version.resource_version_id)} @@ -283,7 +396,11 @@ async def delete_resource_version_cluster( db, resource_version_id=resource_version.resource_version_id, status=ResourceVersion.Status.CLUSTER_DELETED ) resource_version.status = ResourceVersion.Status.CLUSTER_DELETED - return ResourceVersionOut.from_db_resource_version(resource_version) + resource_version_out = ResourceVersionOut.from_db_resource_version(resource_version) + background_tasks.add_task( + delete_cluster_resource_version, slurm_client=slurm_client, resource_version=resource_version_out + ) + return resource_version_out @router.delete("/{rvid}/s3", summary="Delete resource version in S3") @@ -293,6 +410,8 @@ async def delete_resource_version_s3( resource: CurrentResource, resource_version: CurrentResourceVersion, db: DBSession, + s3: S3Resource, + background_tasks: BackgroundTasks, ) -> ResourceVersionOut: """ Delete the resource version in the S3 bucket. @@ -309,6 +428,10 @@ async def delete_resource_version_s3( Resource Version associated with the ID in the path. Dependency Injection. db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. Dependency Injection. + background_tasks : fastapi.BackgroundTasks + Entrypoint for new BackgroundTasks. Provided by FastAPI. + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. Dependency Injection. """ trace.get_current_span().set_attributes( {"resource_id": str(resource.resource_id), "resource_version_id": str(resource_version.resource_version_id)} @@ -318,4 +441,16 @@ async def delete_resource_version_s3( ) await authorization("delete_s3") resource_version.status = ResourceVersion.Status.S3_DELETED + background_tasks.add_task( + delete_s3_resource_version, + s3=s3, + resource_id=resource.resource_id, + resource_version_id=resource_version.resource_version_id, + ) + background_tasks.add_task( + remove_permission_to_s3_resource_version, + s3=s3, + resource_version=ResourceVersionOut.from_db_resource_version(resource_version), + ) + return ResourceVersionOut.from_db_resource_version(resource_version) diff --git a/app/api/endpoints/resources.py b/app/api/endpoints/resources.py index c37dbc381f806d161e5625c2dad7977ba9cc4ba6..615566e639f644401b0307b12464a6ffb777b05a 100644 --- a/app/api/endpoints/resources.py +++ b/app/api/endpoints/resources.py @@ -1,12 +1,28 @@ from typing import Annotated, Any, Awaitable, Callable, List, Optional from clowmdb.models import ResourceVersion -from fastapi import APIRouter, Depends, Query, status +from fastapi import APIRouter, BackgroundTasks, Depends, Query, status from opentelemetry import trace -from app.api.dependencies import AuthorizationDependency, CurrentResource, CurrentUser, DBSession +from app.api.dependencies import ( + AuthorizationDependency, + CurrentResource, + CurrentUser, + DBSession, + S3Resource, + SlurmClient, +) +from app.api.resource_cluster_utils import delete_cluster_resource +from app.api.resource_s3_utils import ( + add_s3_resource_version_info, + delete_s3_bucket_policy_stmt, + delete_s3_resource, + give_permission_to_s3_resource, + give_permission_to_s3_resource_version, +) +from app.core.config import settings from app.crud import CRUDResource -from app.schemas.resource import ResourceIn, ResourceOut +from app.schemas.resource import ResourceIn, ResourceOut, S3ResourceVersionInfo from app.utils.otlp import start_as_current_span_async router = APIRouter(prefix="/resources", tags=["Resource"]) @@ -88,8 +104,10 @@ async def list_resources( async def request_resource( authorization: Authorization, current_user: CurrentUser, - resource: ResourceIn, + resource_in: ResourceIn, db: DBSession, + s3: S3Resource, + background_tasks: BackgroundTasks, ) -> ResourceOut: """ Request a new resources. @@ -102,16 +120,36 @@ async def request_resource( Async function to ask the auth service for authorization. Dependency Injection. current_user : clowmdb.models.User Current user. Dependency injection. - resource : app.schemas.resource.ResourceIn + resource_in : app.schemas.resource.ResourceIn Data about the new resource. HTTP Body. db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. Dependency Injection. + background_tasks : fastapi.BackgroundTasks + Entrypoint for new BackgroundTasks. Provided by FastAPI. + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. Dependency Injection. + """ current_span = trace.get_current_span() - current_span.set_attribute("resource_in", resource.model_dump_json(indent=2)) + current_span.set_attribute("resource_in", resource_in.model_dump_json(indent=2)) await authorization("create") - resource = await CRUDResource.create(db, resource_in=resource, maintainer_id=current_user.uid) - return ResourceOut.from_db_resource(db_resource=resource) + resource = await CRUDResource.create(db, resource_in=resource_in, maintainer_id=current_user.uid) + resource_out = ResourceOut.from_db_resource(db_resource=resource) + background_tasks.add_task(give_permission_to_s3_resource, s3=s3, resource=resource_out) + background_tasks.add_task( + give_permission_to_s3_resource_version, + s3=s3, + resource_version=resource_out.versions[0], + maintainer_id=resource_out.maintainer_id, + ) + background_tasks.add_task( + add_s3_resource_version_info, + s3=s3, + s3_resource_version_info=S3ResourceVersionInfo.from_models( + resource=resource, resource_version=resource.versions[0], maintainer=current_user + ), + ) + return resource_out @router.get("/{rid}", summary="Get a resource") @@ -164,6 +202,9 @@ async def delete_resource( authorization: Authorization, resource: CurrentResource, db: DBSession, + s3: S3Resource, + background_tasks: BackgroundTasks, + slurm_client: SlurmClient, ) -> None: """ Delete a resources. @@ -178,7 +219,22 @@ async def delete_resource( Resource associated with the ID in the path. Dependency Injection. db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. Dependency Injection. + background_tasks : fastapi.BackgroundTasks + Entrypoint for new BackgroundTasks. Provided by FastAPI. + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. Dependency Injection. + slurm_client : app.slurm.rest_client.SlurmClient + Slurm client with an open connection. Dependency Injection """ trace.get_current_span().set_attribute("resource_id", str(resource.resource_id)) await authorization("delete") await CRUDResource.delete(db, resource_id=resource.resource_id) + background_tasks.add_task( + delete_s3_bucket_policy_stmt, + s3=s3, + bucket_name=settings.RESOURCE_BUCKET, + sid=[str(resource.resource_id)] + + [str(resource_version.resource_version_id) for resource_version in resource.versions], + ) + background_tasks.add_task(delete_s3_resource, s3=s3, resource_id=resource.resource_id) + background_tasks.add_task(delete_cluster_resource, slurm_client=slurm_client, resource_id=resource.resource_id) diff --git a/app/api/resource_cluster_utils.py b/app/api/resource_cluster_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc92257370584d643bd42539fce0f9f33955939 --- /dev/null +++ b/app/api/resource_cluster_utils.py @@ -0,0 +1,289 @@ +from asyncio import sleep as async_sleep +from pathlib import Path +from typing import Optional +from uuid import UUID + +from clowmdb.models import ResourceVersion +from httpx import HTTPError +from mako.template import Template +from opentelemetry import trace +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import settings +from app.crud import CRUDResourceVersion +from app.schemas.resource_version import ResourceVersionOut, resource_dir_name, resource_version_dir_name +from app.slurm import SlurmClient, SlurmJobSubmission +from app.utils import AsyncJob, Job +from app.utils.backoff_strategy import ExponentialBackoff + +synchronization_script_template = Template(filename="app/mako_templates/synchronize_resource_version.sh.tmpl") +set_latest_script_template = Template(filename="app/mako_templates/set_latest_resource_version.sh.tmpl") +delete_resource_script_template = Template(filename="app/mako_templates/delete_resource.sh.tmpl") +delete_resource_version_script_template = Template(filename="app/mako_templates/delete_resource_version.sh.tmpl") +tracer = trace.get_tracer_provider().get_tracer(__name__) + + +async def synchronize_cluster_resource( + db: AsyncSession, slurm_client: SlurmClient, resource_version: ResourceVersionOut +) -> None: + """ + Synchronize a resource to the cluster + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + slurm_client : app.slurm.rest_client.SlurmClient + Slurm Rest Client to communicate with Slurm cluster. + resource_version : app.schemas.resource_version.ResourceVersionOut + Resource version schema to synchronize. + """ + + synchronization_script = synchronization_script_template.render( + s3_url=settings.OBJECT_GATEWAY_URI, + resource_version_s3_folder=f"{settings.RESOURCE_BUCKET}/{resource_version_dir_name(resource_version.resource_id, resource_version.resource_version_id)}", + ) + failed_job = AsyncJob( + CRUDResourceVersion.update_status, + db=db, + resource_version_id=resource_version.resource_version_id, + status=ResourceVersion.Status.SYNC_REQUESTED, + resource_id=None, + ) + try: + job_submission = SlurmJobSubmission( + script=synchronization_script.strip(), + job={ + "current_working_directory": settings.SLURM_WORKING_DIRECTORY, + "name": f"Synchronize {str(resource_version.resource_version_id)}", + "requeue": False, + "standard_output": str( + Path(settings.SLURM_WORKING_DIRECTORY) + / f"slurm_synchronize-{str(resource_version.resource_version_id)}.out" + ), + "environment": { + "AWS_ACCESS_KEY_ID": settings.BUCKET_CEPH_ACCESS_KEY, + "AWS_SECRET_ACCESS_KEY": settings.BUCKET_CEPH_SECRET_KEY, + "RESOURCE_VERSION_PATH": str( + Path(settings.RESOURCE_CLUSTER_PATH) + / resource_version_dir_name(resource_version.resource_id, resource_version.resource_version_id) + ), + }, + }, + ) + # Try to start the job on the slurm cluster + slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) + await _monitor_proper_job_execution( + slurm_client=slurm_client, + slurm_job_id=slurm_job_id, + success_job=AsyncJob( + CRUDResourceVersion.update_status, + db=db, + resource_version_id=resource_version.resource_version_id, + status=ResourceVersion.Status.SYNCHRONIZED, + resource_id=None, + ), + failed_job=failed_job, + ) + except (HTTPError, KeyError): # pragma: no cover + await failed_job() + + +async def delete_cluster_resource( + slurm_client: SlurmClient, + resource_id: UUID, +) -> None: + """ + Synchronize a resource to the cluster + + Parameters + ---------- + slurm_client : app.slurm.rest_client.SlurmClient + Slurm Rest Client to communicate with Slurm cluster. + resource_id : uuid.UUID + ID of the resource to delete + """ + + delete_script = delete_resource_script_template.render() + + try: + job_submission = SlurmJobSubmission( + script=delete_script.strip(), + job={ + "current_working_directory": settings.SLURM_WORKING_DIRECTORY, + "name": f"Delete {str(resource_id)}", + "requeue": False, + "standard_output": str(Path(settings.SLURM_WORKING_DIRECTORY) / f"slurm_delete-{str(resource_id)}.out"), + "environment": { + "RESOURCE_PATH": str(Path(settings.RESOURCE_CLUSTER_PATH) / resource_dir_name(resource_id)), + }, + }, + ) + # Try to start the job on the slurm cluster + slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) + await _monitor_proper_job_execution(slurm_client=slurm_client, slurm_job_id=slurm_job_id) + except (HTTPError, KeyError): # pragma: no cover + pass + + +async def set_cluster_resource_version_latest( + db: AsyncSession, slurm_client: SlurmClient, resource_version: ResourceVersionOut +) -> None: + """ + Synchronize a resource to the cluster + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + slurm_client : app.slurm.rest_client.SlurmClient + Slurm Rest Client to communicate with Slurm cluster. + resource_version : app.schemas.resource_version.ResourceVersionOut + Resource version schema to synchronize. + """ + + set_latest_script = set_latest_script_template.render() + + try: + job_submission = SlurmJobSubmission( + script=set_latest_script.strip(), + job={ + "current_working_directory": settings.SLURM_WORKING_DIRECTORY, + "name": f"SET Latest {str(resource_version.resource_version_id)}", + "requeue": False, + "standard_output": str( + Path(settings.SLURM_WORKING_DIRECTORY) + / f"slurm_set-latest-{str(resource_version.resource_version_id)}.out" + ), + "environment": { + "RESOURCE_VERSION_PATH": str( + Path(settings.RESOURCE_CLUSTER_PATH) + / resource_version_dir_name( + resource_id=resource_version.resource_id, + resource_version_id=resource_version.resource_version_id, + ) + ), + "LATEST_VERSION_PATH": str( + Path(settings.RESOURCE_CLUSTER_PATH) + / resource_dir_name( + resource_id=resource_version.resource_id, + ) + / "latest" + ), + }, + }, + ) + # Try to start the job on the slurm cluster + slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) + + await _monitor_proper_job_execution( + slurm_client=slurm_client, + slurm_job_id=slurm_job_id, + success_job=AsyncJob( + CRUDResourceVersion.update_status, + db=db, + status=ResourceVersion.Status.LATEST, + resource_version_id=resource_version.resource_version_id, + resource_id=resource_version.resource_id, + ), + ) + except (HTTPError, KeyError): # pragma: no cover + pass + + +async def delete_cluster_resource_version( + slurm_client: SlurmClient, + resource_version: ResourceVersionOut, +) -> None: + """ + Synchronize a resource to the cluster + + Parameters + ---------- + slurm_client : app.slurm.rest_client.SlurmClient + Slurm Rest Client to communicate with Slurm cluster. + resource_version : app.schemas.resource_version.ResourceVersionOut + Resource version schema to delete. + """ + + delete_script = delete_resource_version_script_template.render() + + try: + job_submission = SlurmJobSubmission( + script=delete_script.strip(), + job={ + "current_working_directory": settings.SLURM_WORKING_DIRECTORY, + "name": f"Delete {str(resource_version.resource_version_id)}", + "requeue": False, + "standard_output": str( + Path(settings.SLURM_WORKING_DIRECTORY) + / f"slurm_delete-version-{str(resource_version.resource_version_id)}.out" + ), + "environment": { + "RESOURCE_VERSION_PATH": str( + Path(settings.RESOURCE_CLUSTER_PATH) + / resource_version_dir_name( + resource_id=resource_version.resource_id, + resource_version_id=resource_version.resource_version_id, + ) + ), + "LATEST_VERSION_PATH": str( + Path(settings.RESOURCE_CLUSTER_PATH) + / resource_dir_name( + resource_id=resource_version.resource_id, + ) + / "latest" + ), + }, + }, + ) + # Try to start the job on the slurm cluster + slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) + await _monitor_proper_job_execution(slurm_client=slurm_client, slurm_job_id=slurm_job_id) + except (HTTPError, KeyError): # pragma: no cover + pass + + +async def _monitor_proper_job_execution( + slurm_client: SlurmClient, slurm_job_id: int, success_job: Optional[Job] = None, failed_job: Optional[Job] = None +) -> None: + """ + Check in an interval based on a backoff strategy if the slurm job is still running + the workflow execution in the database is not marked as finished. + + Parameters + ---------- + slurm_client : app.slurm.rest_client.SlurmClient + Slurm Rest Client to communicate with Slurm cluster. + slurm_job_id : int + ID of the slurm job to monitor + success_job : app.utils.Job | None + Function to execute after the slurm job was successful + failed_job : app.utils.Job | None + Function to execute after the slurm job was unsuccessful + """ + # exponential backoff to 5 minutes + sleep_generator = ExponentialBackoff(max_value=300) + + # Closure around monitor_job code + async def monitor_job() -> None: + with tracer.start_span("monitor_job") as span: + span.set_attributes({"slurm_job_id": slurm_job_id}) + job_state = await slurm_client.job_state(slurm_job_id) + if job_state != SlurmClient.JobState.RUNNING: + if job_state == SlurmClient.JobState.SUCCESS and success_job is not None: + if success_job.is_async: + await success_job() + else: # pragma: no cover + success_job() + elif job_state == SlurmClient.JobState.ERROR and failed_job is not None: + if failed_job.is_async: + await failed_job() + else: # pragma: no cover + failed_job() + sleep_generator.close() + + await monitor_job() + for sleep_time in sleep_generator: # pragma: no cover + await async_sleep(sleep_time) + await monitor_job() diff --git a/app/api/resource_s3_utils.py b/app/api/resource_s3_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4e274d2adcee8d3fd3f5e409a69855aab11a7240 --- /dev/null +++ b/app/api/resource_s3_utils.py @@ -0,0 +1,173 @@ +from io import BytesIO +from typing import TYPE_CHECKING, Any, Dict, Optional +from uuid import UUID + +from opentelemetry import trace + +from app.core.config import settings +from app.s3.s3_resource import ( + add_s3_bucket_policy_stmt, + delete_s3_bucket_policy_stmt, + delete_s3_objects, + get_s3_object, + upload_obj, +) +from app.schemas.resource import ResourceOut, S3ResourceVersionInfo +from app.schemas.resource_version import ( + ResourceVersionOut, + resource_dir_name, + resource_version_dir_name, + resource_version_key, +) + +if TYPE_CHECKING: + from types_aiobotocore_s3.service_resource import ObjectSummary, S3ServiceResource +else: + S3ServiceResource = object + ObjectSummary = object + + +tracer = trace.get_tracer_provider().get_tracer(__name__) + + +async def give_permission_to_s3_resource(s3: S3ServiceResource, resource: ResourceOut) -> None: + """ + Give the maintainer permissions to list the S3 objects of his resource. + + Parameters + ---------- + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. + resource : app.schemas.resource.ResourceOut + Information about the resource. + """ + policy: Dict[str, Any] = { + "Sid": str(resource.resource_id), + "Effect": "Allow", + "Principal": {"AWS": f"arn:aws:iam:::user/{resource.maintainer_id}"}, + "Resource": f"arn:aws:s3:::{settings.RESOURCE_BUCKET}", + "Action": ["s3:ListBucket"], + "Condition": {"StringLike": {"s3:prefix": resource_dir_name(resource.resource_id) + "/*"}}, + } + await add_s3_bucket_policy_stmt(s3=s3, bucket_name=settings.RESOURCE_BUCKET, stmt=policy) + + +async def give_permission_to_s3_resource_version( + s3: S3ServiceResource, resource_version: ResourceVersionOut, maintainer_id: str +) -> None: + """ + Give the maintainer permissions to upload the resource to the appropriate S3 key. + + Parameters + ---------- + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. + resource_version : app.schemas.resource_version.ResourceVersionOut + Information about the resource version. + maintainer_id : str + ID of the maintainer + """ + policy: Dict[str, Any] = { + "Sid": str(resource_version.resource_version_id), + "Effect": "Allow", + "Principal": {"AWS": f"arn:aws:iam:::user/{maintainer_id}"}, + "Resource": f"arn:aws:s3:::{resource_version.s3_path[5:]}", + "Action": ["s3:DeleteObject", "s3:PutObject"], + } + await add_s3_bucket_policy_stmt(s3=s3, bucket_name=settings.RESOURCE_BUCKET, stmt=policy) + + +async def add_s3_resource_version_info(s3: S3ServiceResource, s3_resource_version_info: S3ResourceVersionInfo) -> None: + """ + Upload the resource version information to S3 for documentation + + Parameters + ---------- + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. + s3_resource_version_info : app.schemas.resource.S3ResourceVersionInfo + Resource information object that will be uploaded to S3. + """ + buf = BytesIO(s3_resource_version_info.model_dump_json(indent=2).encode("utf-8")) + await upload_obj( + s3=s3, + bucket_name=settings.RESOURCE_BUCKET, + key=s3_resource_version_info.s3_path(), + obj_stream=buf, + ExtraArgs={"ContentType": "application/json"}, + ) + + +async def remove_permission_to_s3_resource_version(s3: S3ServiceResource, resource_version: ResourceVersionOut) -> None: + """ + Remove the permission of the maintainer to upload a resource to S3. + + Parameters + ---------- + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. + resource_version : app.schemas.resource_version.ResourceVersionOut + Information about the resource version. + """ + await delete_s3_bucket_policy_stmt( + s3, bucket_name=settings.RESOURCE_BUCKET, sid=str(resource_version.resource_version_id) + ) + + +async def delete_s3_resource(s3: S3ServiceResource, resource_id: UUID) -> None: + """ + Delete all objects related to a resource in S3. + + Parameters + ---------- + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. + resource_id : uuid.UUID + ID of the resource + """ + await delete_s3_objects(s3, bucket_name=settings.RESOURCE_BUCKET, prefix=resource_dir_name(resource_id) + "/") + + +async def delete_s3_resource_version(s3: S3ServiceResource, resource_id: UUID, resource_version_id: UUID) -> None: + """ + Delete all objects related to a resource version in S3. + + Parameters + ---------- + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. + resource_id : uuid.UUID + ID of the resource + resource_version_id : uuid.UUID + ID of the resource version + """ + await delete_s3_objects( + s3, bucket_name=settings.RESOURCE_BUCKET, prefix=resource_version_dir_name(resource_id, resource_version_id) + ) + + +async def get_s3_resource_version_obj( + s3: S3ServiceResource, resource_version: ResourceVersionOut +) -> Optional[ObjectSummary]: + """ + Delete all objects related to a resource version in S3. + + Parameters + ---------- + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. + resource_version : app.schemas.resource_version.ResourceVersionOut + Information about the resource version. + + Returns + ------- + obj : types_aiobotocore_s3.service_resource.ObjectSummary | None + The object summary of the resource version if it exists. + """ + return await get_s3_object( + s3=s3, + bucket_name=settings.RESOURCE_BUCKET, + key=resource_version_key( + resource_id=resource_version.resource_id, resource_version_id=resource_version.resource_version_id + ), + ) diff --git a/app/api/utils.py b/app/api/utils.py deleted file mode 100644 index f5ce50aca23b53cb83a0737365d05493f6541666..0000000000000000000000000000000000000000 --- a/app/api/utils.py +++ /dev/null @@ -1,75 +0,0 @@ -from asyncio import sleep as async_sleep - -from httpx import HTTPError -from mako.template import Template -from opentelemetry import trace -from sqlalchemy.ext.asyncio import AsyncSession - -from app.core.config import settings -from app.slurm import SlurmClient, SlurmJobSubmission -from app.utils.backoff_strategy import ExponentialBackoff - -synchronization_script_template = Template(filename="app/mako_templates/synchronize_resource_version.sh.tmpl") - -tracer = trace.get_tracer_provider().get_tracer(__name__) - - -async def synchronize_resource( - db: AsyncSession, - slurm_client: SlurmClient, -) -> None: - """ - Synchronize a resource to the cluster - - Parameters - ---------- - db : sqlalchemy.ext.asyncio.AsyncSession. - Async database session to perform query on. - slurm_client : app.slurm.rest_client.SlurmClient - Slurm Rest Client to communicate with Slurm cluster. - """ - - synchronization_script = synchronization_script_template.render() - - try: - job_submission = SlurmJobSubmission( - script=synchronization_script.strip(), - job={ - "current_working_directory": settings.SLURM_WORKING_DIRECTORY, - "name": "somename", - "requeue": False, - }, - ) - # Try to start the job on the slurm cluster - slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) - await _monitor_proper_job_execution(db=db, slurm_client=slurm_client, slurm_job_id=slurm_job_id) - except (HTTPError, KeyError): - # Mark job as aborted when there is an error - pass - - -async def _monitor_proper_job_execution( - db: AsyncSession, slurm_client: SlurmClient, slurm_job_id: int -) -> None: # pragma: no cover - """ - Check in an interval based on a backoff strategy if the slurm job is still running - the workflow execution in the database is not marked as finished. - - Parameters - ---------- - db : sqlalchemy.ext.asyncio.AsyncSession. - Async database session to perform query on. - slurm_client : app.slurm.rest_client.SlurmClient - Slurm Rest Client to communicate with Slurm cluster. - slurm_job_id : int - ID of the slurm job to monitor - """ - # exponential to 5 minutes - sleep_generator = ExponentialBackoff(initial_delay=30, max_value=300) - for sleep_time in sleep_generator: - await async_sleep(sleep_time) - with tracer.start_span("monitor_job") as span: - span.set_attributes({"slurm_job_id": slurm_job_id}) - if await slurm_client.is_job_finished(slurm_job_id): - await db.close() # Reset connection - sleep_generator.close() diff --git a/app/crud/crud_resource_version.py b/app/crud/crud_resource_version.py index a7c3b02c49d2d550df02c39910e4920110c27244..f19c901b037937df6429e0a63f800579a867e408 100644 --- a/app/crud/crud_resource_version.py +++ b/app/crud/crud_resource_version.py @@ -38,7 +38,12 @@ class CRUDResourceVersion: return resource_version_db @staticmethod - async def update_status(db: AsyncSession, resource_version_id: UUID, status: ResourceVersion.Status) -> None: + async def update_status( + db: AsyncSession, + status: ResourceVersion.Status, + resource_version_id: UUID, + resource_id: Optional[UUID] = None, + ) -> None: """ Update the status of a resource version. @@ -47,20 +52,40 @@ class CRUDResourceVersion: db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. resource_version_id : uuid.UUID - Git commit git_commit_hash of the version. + ID of a resource version. + resource_id : uuid.UUID | None + ID of a resource. Must be set if `status` is LATEST. status : clowmdb.models.ResourceVersion.Status New status of the resource version """ - stmt = ( + if status == ResourceVersion.Status.LATEST: + if resource_id is None: + raise ValueError("Parameter 'resource_version_id' or 'resource_id' must not be None") + stmt1 = ( + update(ResourceVersion) + .values(status=ResourceVersion.Status.SYNCHRONIZED.name) + .where(ResourceVersion._resource_id == resource_id.bytes) + ) + with tracer.start_as_current_span( + "db_remove_latest_tag", + attributes={"status": status.name, "resource_id": str(resource_id), "sql_query": str(stmt1)}, + ): + await db.execute(stmt1) + + stmt2 = ( update(ResourceVersion) - .where(ResourceVersion._resource_version_id == resource_version_id.bytes) .values(status=status.name) + .where(ResourceVersion._resource_version_id == resource_version_id.bytes) ) with tracer.start_as_current_span( "db_update_resource_version_status", - attributes={"status": status.name, "resource_version_id": str(resource_version_id), "sql_query": str(stmt)}, + attributes={ + "status": status.name, + "resource_version_id": str(resource_version_id), + "sql_query": str(stmt2), + }, ): - await db.execute(stmt) + await db.execute(stmt2) await db.commit() @staticmethod @@ -77,7 +102,7 @@ class CRUDResourceVersion: resource_version_id : uuid.UUID ID of a resource version. resource_id : uuid.UUID | None - + ID of a resource Returns ------- diff --git a/app/main.py b/app/main.py index 92df34d900352c3e99545fbb35b4b0268b2062c5..538783032d5e4f1dc032be59d1d2963b6d78d1c1 100644 --- a/app/main.py +++ b/app/main.py @@ -21,7 +21,7 @@ from app.api.api import api_router from app.api.middleware.etagmiddleware import HashJSONResponse from app.api.miscellaneous_endpoints import miscellaneous_router from app.core.config import settings -from app.s3.s3_resource import botosession +from app.s3.s3_resource import boto_session description = """ This is the resource service from the CloWM Service. @@ -35,15 +35,14 @@ def custom_generate_unique_id(route: APIRoute) -> str: @asynccontextmanager async def lifespan(fastapi_app: FastAPI) -> AsyncGenerator[None, None]: # pragma: no cover # Create a http client once instead for every request and attach it to the app - async with AsyncClient() as client: + async with AsyncClient() as client, boto_session.resource( + service_name="s3", + endpoint_url=str(settings.OBJECT_GATEWAY_URI)[:-1], + verify=str(settings.OBJECT_GATEWAY_URI).startswith("https"), + ) as s3_resource: fastapi_app.requests_client = client # type: ignore[attr-defined] - async with botosession.resource( - service_name="s3", - endpoint_url=str(settings.OBJECT_GATEWAY_URI)[:-1], - verify=str(settings.OBJECT_GATEWAY_URI).startswith("https"), - ) as s3_resource: - fastapi_app.s3_resource = s3_resource # type: ignore[attr-defined] - yield + fastapi_app.s3_resource = s3_resource # type: ignore[attr-defined] + yield app = FastAPI( diff --git a/app/mako_templates/delete_resource.sh.tmpl b/app/mako_templates/delete_resource.sh.tmpl index a9bf588e2f88457fdf73ac7361ef1d596fb81453..d27fb0892cddb3170d5e20b23823bbd969616370 100644 --- a/app/mako_templates/delete_resource.sh.tmpl +++ b/app/mako_templates/delete_resource.sh.tmpl @@ -1 +1,9 @@ #!/bin/bash + +# Env variables +# RESOURCE_PATH: Path to the resource folder on the cluster + +if [ -d "$RESOURCE_PATH" ]; then + echo "Remove $RESOURCE_PATH" + rm -r "$RESOURCE_PATH" +fi diff --git a/app/mako_templates/delete_resource_version.sh.tmpl b/app/mako_templates/delete_resource_version.sh.tmpl index a9bf588e2f88457fdf73ac7361ef1d596fb81453..6dfbc4630e2ed1c3a2c99c9e023e866225ec0133 100644 --- a/app/mako_templates/delete_resource_version.sh.tmpl +++ b/app/mako_templates/delete_resource_version.sh.tmpl @@ -1 +1,18 @@ #!/bin/bash + +# Env variables +# RESOURCE_VERSION_PATH: Path to the resource folder on the cluster +# LATEST_VERSION_PATH: Path to the latest version symlink of the resource + +if [ -d "$RESOURCE_VERSION_PATH" ]; then + echo "Remove $RESOURCE_VERSION_PATH" + rm -r "$RESOURCE_VERSION_PATH" +fi + +# -L: True if file exists and is a symbolic link +# -e: True if file exists +# if symlink exists and links to a non-existing file +if [ -L "$LATEST_VERSION_PATH" ] && [ ! -e "$LATEST_VERSION_PATH" ]; then + echo "Remove latest link $LATEST_VERSION_PATH" + rm "$LATEST_VERSION_PATH" +fi diff --git a/app/mako_templates/set_latest_resource_version.sh.tmpl b/app/mako_templates/set_latest_resource_version.sh.tmpl index a9bf588e2f88457fdf73ac7361ef1d596fb81453..1b02e2f0121a635abe18b7c47be82d4d13e49b01 100644 --- a/app/mako_templates/set_latest_resource_version.sh.tmpl +++ b/app/mako_templates/set_latest_resource_version.sh.tmpl @@ -1 +1,14 @@ #!/bin/bash + +# Env variables +# RESOURCE_VERSION_PATH: Path to the resource folder on the cluster +# LATEST_VERSION_PATH: Path to the latest version symlink of the resource + +if [ -d "$RESOURCE_VERSION_PATH" ]; then + if [ -L "$LATEST_VERSION_PATH" ]; then + echo "Remove latest link $LATEST_VERSION_PATH" + rm "$LATEST_VERSION_PATH" + fi + echo "Set link from $RESOURCE_VERSION_PATH to $LATEST_VERSION_PATH" + ln -s "$RESOURCE_VERSION_PATH" "$LATEST_VERSION_PATH" +fi diff --git a/app/mako_templates/synchronize_resource_version.sh.tmpl b/app/mako_templates/synchronize_resource_version.sh.tmpl index a9bf588e2f88457fdf73ac7361ef1d596fb81453..1ea476d894fa0b377bd4320b595868dec615e3a1 100644 --- a/app/mako_templates/synchronize_resource_version.sh.tmpl +++ b/app/mako_templates/synchronize_resource_version.sh.tmpl @@ -1 +1,30 @@ -#!/bin/bash +#!/bin/bash -eu + +# Env variables +# RESOURCE_VERSION_PATH: Path to the resource folder on the cluster +# AWS_ACCESS_KEY_ID: S3 access key +# AWS_SECRET_ACCESS_KEY: S3 secret key +RESOURCE_PATH=$(dirname "$RESOURCE_VERSION_PATH") + +if [ ! -d "$RESOURCE_PATH" ]; then + echo "Create $RESOURCE_PATH" + mkdir -p "$RESOURCE_PATH" +fi + +if [ ! -d "$RESOURCE_VERSION_PATH" ]; then + docker run --rm \ + -u "$(id -u):$(id -g)" \ + -e S3_ENDPOINT_URL="${s3_url}" \ + -e AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" \ + -e AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" \ + -v "$RESOURCE_PATH":"$RESOURCE_PATH" \ + peakcom/s5cmd:v2.2.2 \ + cp --if-source-newer "s3://${resource_version_s3_folder}/*" "$RESOURCE_VERSION_PATH/" + + cd "$RESOURCE_VERSION_PATH" + gzip -t resource.tar.gz + echo "Extract" + tar -xvf resource.tar.gz + tar --compare --file=resource.tar.gz | awk '!/Mode/ && !/Uid/ && !/Gid/ && !/time/' + rm resource.tar.gz +fi diff --git a/app/s3/s3_resource.py b/app/s3/s3_resource.py index 5f02f913e12f0ff0b96d5803a3d0c6c18a729527..76328bae70cbcb9ae7b2386f4aac74fd9126d43e 100644 --- a/app/s3/s3_resource.py +++ b/app/s3/s3_resource.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Optional +import json +from io import BytesIO +from typing import TYPE_CHECKING, Any, Dict, List, Optional import aioboto3 from botocore.exceptions import ClientError @@ -16,37 +18,105 @@ else: tracer = trace.get_tracer_provider().get_tracer(__name__) -botosession = aioboto3.Session( +boto_session = aioboto3.Session( aws_access_key_id=settings.BUCKET_CEPH_ACCESS_KEY, aws_secret_access_key=settings.BUCKET_CEPH_SECRET_KEY, ) -async def get_s3_bucket_policy(s3: S3ServiceResource, bucket_name: str) -> Optional[str]: - with tracer.start_as_current_span("s3_get_bucket_policy", attributes={"bucket_name": bucket_name}) as span: - s3_policy = await s3.BucketPolicy(bucket_name=bucket_name) +async def load_s3_bucket_policy(s3_policy: BucketPolicy) -> Dict[str, Any]: + """ + Loads hte bucket policy form the BucketPolicy object. If there is no policy, an empty one will be generated. + + Parameters + ---------- + s3_policy : types_aiobotocore_s3.service_resource.BucketPolicy + BucketPolivy object to load the policy from. + + Returns + ------- + policy : dict[str, Any] + The serialized bucket policy. + """ + with tracer.start_as_current_span("s3_load_bucket_policy", attributes={"bucket_name": s3_policy.bucket_name}): try: - policy = await s3_policy.policy - except ClientError: - return None - span.set_attribute("policy", policy) - return policy + return json.loads(await s3_policy.policy) + except ClientError: # pragma: no cover + return {"Version": "2012-10-17", "Statement": []} + +async def add_s3_bucket_policy_stmt(s3: S3ServiceResource, bucket_name: str, stmt: Dict[str, Any]) -> None: + """ + Add a Statement to a bucket policy. If it doesn't exist, then create a policy. -async def put_s3_bucket_policy(s3: S3ServiceResource, bucket_name: str, policy: str) -> None: + Parameters + ---------- + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. + bucket_name : str + Name of the bucket. + stmt : dict[str, Any] + Statement to add to the policy. + """ with tracer.start_as_current_span( - "s3_put_bucket_policy", attributes={"bucket_name": bucket_name, "policy": policy} + "s3_add_bucket_policy_statement", attributes={"bucket_name": bucket_name, "policy_stmt": json.dumps(stmt)} ): - bucket_policy = await s3.BucketPolicy(bucket_name=bucket_name) - await bucket_policy.put(Policy=policy) + s3_policy = await s3.BucketPolicy(bucket_name=bucket_name) + policy = await load_s3_bucket_policy(s3_policy) + policy["Statement"].append(stmt) + await s3_policy.put(Policy=json.dumps(policy)) -async def get_s3_bucket_object(s3: S3ServiceResource, bucket_name: str, key: str) -> Optional[ObjectSummary]: + +async def delete_s3_bucket_policy_stmt(s3: S3ServiceResource, bucket_name: str, sid: str | List[str]) -> None: + """ + Delete one or multiple Statement based on the Sid from a bucket policy. + + Parameters + ---------- + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. + bucket_name : str + Name of the bucket. + sid : str | list[str] + ID or IDs of the statement(s). + """ + with tracer.start_as_current_span( + "s3_delete_bucket_policy_statement", attributes={"bucket_name": bucket_name, "sid": sid} + ): + s3_policy = await s3.BucketPolicy(bucket_name=bucket_name) + policy = await load_s3_bucket_policy(s3_policy) + + if isinstance(sid, list): + policy["Statement"] = [stmt for stmt in policy["Statement"] if stmt["Sid"] not in sid] + else: + policy["Statement"] = [stmt for stmt in policy["Statement"] if stmt["Sid"] != sid] + await s3_policy.put(Policy=json.dumps(policy)) + + +async def get_s3_object(s3: S3ServiceResource, bucket_name: str, key: str) -> Optional[ObjectSummary]: + """ + Get the object summary from S3 if the object exists. + + Parameters + ---------- + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. + bucket_name : str + Name of the bucket. + key : str + Key of the objects. + + Returns + ------- + obj : types_aiobotocore_s3.service_resource.ObjectSummary | None + The object summary if the object exists. + """ with tracer.start_as_current_span( "s3_get_object_meta_data", attributes={"bucket_name": bucket_name, "key": key} ) as span: - obj = await s3.ObjectSummary(bucket_name=bucket_name, key=key) try: + obj = await s3.ObjectSummary(bucket_name=bucket_name, key=key) await obj.load() except ClientError: return None @@ -54,3 +124,48 @@ async def get_s3_bucket_object(s3: S3ServiceResource, bucket_name: str, key: str {"size": await obj.size, "last_modified": (await obj.last_modified).isoformat(), "etag": await obj.e_tag} ) return obj + + +async def delete_s3_objects(s3: S3ServiceResource, bucket_name: str, prefix: str) -> None: + """ + Delete multiple S3 objects based on a prefix. + + Parameters + ---------- + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. + bucket_name : str + Name of the bucket. + prefix : str + Prefix of the keys that should get deleted. + """ + with tracer.start_as_current_span( + "s3_delete_objects", attributes={"bucket_name": bucket_name, "prefix": prefix} + ) as span: + bucket = await s3.Bucket(bucket_name) + deleted_objs = await bucket.objects.filter(Prefix=prefix).delete() + deleted_keys = [obj["Key"] for obj in deleted_objs[0]["Deleted"]] + span.set_attribute("keys", deleted_keys) + + +async def upload_obj(s3: S3ServiceResource, obj_stream: BytesIO, bucket_name: str, key: str, **kwargs: Any) -> None: + """ + Upload a object in a buffer to + Parameters + ---------- + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource + S3 Service to perform operations on buckets. + obj_stream : io.BytesIO + Stream from which the object content is read and uploaded. + bucket_name : str + Name of the bucket. + key : str + Key of the objects. + kwargs : Any + Additional arguments to the upload_fileobj function + + """ + with tracer.start_as_current_span("s3_upload_object", attributes={"bucket_name": bucket_name, "key": key}): + bucket = await s3.Bucket(bucket_name) + with obj_stream as f: + await bucket.upload_fileobj(f, Key=key, **kwargs) diff --git a/app/schemas/resource.py b/app/schemas/resource.py index 679c9e9ef9a25878a0995823979f222e13d90511..48f10d66e57ad36dd734f789f9239ab52aed21f1 100644 --- a/app/schemas/resource.py +++ b/app/schemas/resource.py @@ -1,10 +1,10 @@ from typing import List, Optional, Sequence from uuid import UUID -from clowmdb.models import Resource, ResourceVersion +from clowmdb.models import Resource, ResourceVersion, User from pydantic import BaseModel, Field -from app.schemas.resource_version import ResourceVersionIn, ResourceVersionOut +from app.schemas.resource_version import ResourceVersionIn, ResourceVersionOut, resource_version_dir_name class BaseResource(BaseModel): @@ -63,3 +63,33 @@ class ResourceOut(BaseResource): maintainer_id=db_resource.maintainer_id, versions=[ResourceVersionOut.from_db_resource_version(v) for v in temp_versions], ) + + +class S3ResourceVersionInfo(BaseResource): + release: str = Field( + ..., + description="Short tag describing the version of the resource", + examples=["01-2023"], + ) + resource_id: UUID = Field(..., description="ID of the resource", examples=["4c072e39-2bd9-4fa3-b564-4d890e240ccd"]) + resource_version_id: UUID = Field( + ..., description="ID of the resource version", examples=["fb4cee12-1e91-49f3-905f-808845c7c1f4"] + ) + maintainer_id: str = Field(..., description="ID of the maintainer", examples=["28c5353b8bb34984a8bd4169ba94c606"]) + maintainer: str = Field(..., description="Name of the maintainer", examples=["Bilbo Baggins"]) + + def s3_path(self) -> str: + return resource_version_dir_name(self.resource_id, self.resource_version_id) + "/clowm_resinfo.json" + + @staticmethod + def from_models(resource: Resource, resource_version: ResourceVersion, maintainer: User) -> "S3ResourceVersionInfo": + return S3ResourceVersionInfo( + name=resource.name, + description=resource.short_description, + source=resource.source, + resource_id=resource.resource_id, + resource_version_id=resource_version.resource_version_id, + maintainer=maintainer.display_name, + maintainer_id=maintainer.uid, + release=resource_version.release, + ) diff --git a/app/schemas/resource_version.py b/app/schemas/resource_version.py index ac5b77aac8a7ffbff9cb374792c07aefb137e897..faca74360d1007da6f7a186a5e17269a3bb865ab 100644 --- a/app/schemas/resource_version.py +++ b/app/schemas/resource_version.py @@ -53,7 +53,7 @@ class ResourceVersionOut(BaseResourceVersion): ) @property def s3_path(self) -> str: - return f"s3://{settings.RESOURCE_BUCKET}/CLDB-{self.resource_id.hex[:8]}/{self.resource_version_id.hex}/resource.tar.gz" + return f"s3://{settings.RESOURCE_BUCKET}/{resource_version_key(self.resource_id, self.resource_version_id)}" @staticmethod def from_db_resource_version(db_resource_version: ResourceVersion) -> "ResourceVersionOut": @@ -64,3 +64,15 @@ class ResourceVersionOut(BaseResourceVersion): resource_id=db_resource_version.resource_id, created_at=db_resource_version.created_at, ) + + +def resource_dir_name(resource_id: UUID) -> str: + return f"CLDB-{resource_id.hex[:8]}" + + +def resource_version_dir_name(resource_id: UUID, resource_version_id: UUID) -> str: + return resource_dir_name(resource_id) + f"/{resource_version_id.hex}" + + +def resource_version_key(resource_id: UUID, resource_version_id: UUID) -> str: + return resource_version_dir_name(resource_id, resource_version_id) + "/resource.tar.gz" diff --git a/app/slurm/rest_client.py b/app/slurm/rest_client.py index b1eff52198b84a31becc8716d2340b4c59421b64..7a38021fac72b41b128f5bedfda2f60d42378452 100644 --- a/app/slurm/rest_client.py +++ b/app/slurm/rest_client.py @@ -1,3 +1,5 @@ +from enum import Enum, unique + from fastapi import status from httpx import AsyncClient from opentelemetry import trace @@ -9,6 +11,16 @@ tracer = trace.get_tracer_provider().get_tracer(__name__) class SlurmClient: + @unique + class JobState(str, Enum): + """ + Enumeration for the possible states of a slurm job. + """ + + RUNNING: str = "RUNNING" + SUCCESS: str = "SUCCESS" + ERROR: str = "ERROR" + def __init__(self, client: AsyncClient, version: str = "v0.0.38"): """ Initialize the client to communicate with a Slurm cluster. @@ -41,8 +53,8 @@ class SlurmClient: with tracer.start_as_current_span( "slurm_submit_job", attributes={ - "parameters": job_submission.job.model_dump_json(exclude_none=True), - "job_script": "\n".join(job_submission.script.split("\n")[5:]).strip(), + "parameters": job_submission.job.model_dump_filtered_json(exclude_none=True), + "job_script": job_submission.script, }, ) as span: response = await self._client.post( @@ -57,23 +69,9 @@ class SlurmClient: span.set_attribute("job_id", job_id) return job_id - async def cancel_job(self, job_id: int) -> None: + async def job_state(self, job_id: int) -> JobState: """ - Cancel a Slurm job on the cluster. - - Parameters - ---------- - job_id : int - ID of the job to cancel. - """ - with tracer.start_as_current_span("slurm_cancel_job"): - await self._client.delete( - f"{settings.SLURM_ENDPOINT}slurm/{self.version}/job/{job_id}", headers=self._headers - ) - - async def is_job_finished(self, job_id: int) -> bool: # pragma: no cover - """ - Check if the job with the given is completed + Check the current state of a slurm job Parameters ---------- @@ -82,20 +80,28 @@ class SlurmClient: Returns ------- - finished : bool - Flag if the job is finished + finished : SlurmClient.JobState + The state if the specified job + + Notes + ----- + If the job is not found, then the function returns SUCCESS """ with tracer.start_as_current_span("slurm_check_job_status") as span: response = await self._client.get( f"{settings.SLURM_ENDPOINT}slurm/{self.version}/job/{job_id}", headers=self._headers ) span.set_attribute("slurm.job-status.request.code", response.status_code) - if response.status_code != status.HTTP_200_OK: - return True + if response.status_code != status.HTTP_200_OK: # pragma: no cover + return SlurmClient.JobState.SUCCESS try: job_state = response.json()["jobs"][0]["job_state"] span.set_attribute("slurm.job-status.state", job_state) - return job_state == "COMPLETED" or job_state == "FAILED" or job_state == "CANCELLED" - except (KeyError, IndexError) as ex: + if job_state == "COMPLETED": + return SlurmClient.JobState.SUCCESS + elif job_state in ["FAILED", "CANCELLED"]: + return SlurmClient.JobState.ERROR + return SlurmClient.JobState.RUNNING + except (KeyError, IndexError) as ex: # pragma: no cover span.record_exception(ex) - return True + return SlurmClient.JobState.ERROR diff --git a/app/slurm/schemas.py b/app/slurm/schemas.py index 3f3730779f78a73afe545c6cd018b4fbcbef5a8c..7b44eda6d63dada32e77313bd90c990b520c1f33 100644 --- a/app/slurm/schemas.py +++ b/app/slurm/schemas.py @@ -16,6 +16,11 @@ class SlurmJobProperties(BaseModel): ) argv: Optional[List[str]] = Field(None, description="Arguments to the script.") + def model_dump_filtered_json(self, *args: Any, **kwargs: Any) -> str: + filtered_environment = {k: v for k, v in self.environment.items() if "KEY" not in k} + filtered_model = self.model_copy(update={"environment": filtered_environment}) + return filtered_model.model_dump_json(*args, **kwargs) + class SlurmJobSubmission(BaseModel): script: str = Field(..., description="Executable script (full contents) to run in batch step") diff --git a/app/tests/api/test_resource.py b/app/tests/api/test_resource.py index 204c6d06791c8cadcb32a62472d27fb9fd05baa9..9559fb13faeeace6f47ee34e8ce3dc03a6aa385f 100644 --- a/app/tests/api/test_resource.py +++ b/app/tests/api/test_resource.py @@ -1,13 +1,18 @@ +import json from uuid import uuid4 import pytest +from botocore.exceptions import ClientError from clowmdb.models import Resource, ResourceVersion from fastapi import status from httpx import AsyncClient from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession +from app.core.config import settings from app.schemas.resource import ResourceIn, ResourceOut +from app.schemas.resource_version import resource_version_key +from app.tests.mocks.mock_s3_resource import MockS3ServiceResource from app.tests.utils.user import UserWithAuthHeader from app.tests.utils.utils import CleanupList, random_lower_string @@ -19,7 +24,12 @@ class _TestResourceRoutes: class TestResourceRouteCreate(_TestResourceRoutes): @pytest.mark.asyncio async def test_create_resource_route( - self, client: AsyncClient, random_user: UserWithAuthHeader, db: AsyncSession, cleanup: CleanupList + self, + client: AsyncClient, + random_user: UserWithAuthHeader, + db: AsyncSession, + cleanup: CleanupList, + mock_s3_service: MockS3ServiceResource, ) -> None: """ Test for creating a new resource. @@ -32,8 +42,10 @@ class TestResourceRouteCreate(_TestResourceRoutes): Random user for testing. db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - cleanup : list[sqlalchemy.sql.dml.Delete] - List were to append sql deletes that gets executed after the test + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. + mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource + Mock S3 Service to manipulate objects. """ resource = ResourceIn( release=random_lower_string(6), @@ -45,7 +57,13 @@ class TestResourceRouteCreate(_TestResourceRoutes): assert response.status_code == status.HTTP_201_CREATED resource_out = ResourceOut.model_validate_json(response.content) - cleanup.append(delete(Resource).where(Resource._resource_id == resource_out.resource_id.bytes)) + async def delete_resource() -> None: + await db.execute(delete(Resource).where(Resource._resource_id == resource_out.resource_id.bytes)) + await db.commit() + + cleanup.add_task(delete_resource) + s3_policy = await mock_s3_service.BucketPolicy(settings.RESOURCE_BUCKET) + cleanup.add_task(s3_policy.delete) assert resource_out.name == resource.name assert len(resource_out.versions) == 1 @@ -54,6 +72,13 @@ class TestResourceRouteCreate(_TestResourceRoutes): assert resource_version.status == ResourceVersion.Status.RESOURCE_REQUESTED assert resource_version.release == resource.release + # test if the maintainer get permission to upload to the resource bucket + s3_policy = await mock_s3_service.BucketPolicy(settings.RESOURCE_BUCKET) + policy = json.loads(await s3_policy.policy) + assert sum(1 for stmt in policy["Statement"] if stmt["Sid"] == str(resource_out.resource_id)) == 1 + assert sum(1 for stmt in policy["Statement"] if stmt["Sid"] == str(resource_version.resource_version_id)) == 1 + + # test that the resource got actually created resource_db = await db.scalar(select(Resource).where(Resource._resource_id == resource_out.resource_id.bytes)) assert resource_db is not None @@ -61,7 +86,12 @@ class TestResourceRouteCreate(_TestResourceRoutes): class TestResourceRouteDelete(_TestResourceRoutes): @pytest.mark.asyncio async def test_delete_resource_route( - self, client: AsyncClient, random_resource: Resource, random_user: UserWithAuthHeader + self, + client: AsyncClient, + random_resource: Resource, + random_user: UserWithAuthHeader, + mock_s3_service: MockS3ServiceResource, + cleanup: CleanupList, ) -> None: """ Test for deleting a resource. @@ -74,12 +104,35 @@ class TestResourceRouteDelete(_TestResourceRoutes): Random resource for testing. random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. + mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource + Mock S3 Service to manipulate objects. """ response = await client.delete( f"{self.base_path}/{str(random_resource.resource_id)}", headers=random_user.auth_headers ) assert response.status_code == status.HTTP_204_NO_CONTENT + # test if all permission for the maintainer got deleted in S3 + s3_policy = await mock_s3_service.BucketPolicy(settings.RESOURCE_BUCKET) + policy = json.loads(await s3_policy.policy) + assert ( + sum( + 1 + for stmt in policy["Statement"] + if stmt["Sid"] + in [str(random_resource.resource_id)] + [str(rv.resource_version_id) for rv in random_resource.versions] + ) + == 0 + ) + + # test if the resource got deleted in S3 + s3_key = resource_version_key(random_resource.resource_id, random_resource.versions[0].resource_version_id) + with pytest.raises(ClientError): + obj = await mock_s3_service.ObjectSummary(settings.RESOURCE_BUCKET, s3_key) + await obj.load() + class TestResourceRouteGet(_TestResourceRoutes): @pytest.mark.asyncio diff --git a/app/tests/api/test_resource_version.py b/app/tests/api/test_resource_version.py index 4b3fa5cd4d51732bb5b589a7b80de8af97c57725..b6303e3defacb65175030220cb639b45b2197e6b 100644 --- a/app/tests/api/test_resource_version.py +++ b/app/tests/api/test_resource_version.py @@ -1,13 +1,19 @@ +import json +from io import BytesIO from uuid import uuid4 import pytest +from botocore.exceptions import ClientError from clowmdb.models import Resource, ResourceVersion from fastapi import status from httpx import AsyncClient -from app.schemas.resource_version import ResourceVersionIn +from app.core.config import settings +from app.schemas.resource_version import ResourceVersionIn, ResourceVersionOut, resource_version_key +from app.tests.mocks.mock_s3_resource import MockS3ServiceResource +from app.tests.mocks.mock_slurm_cluster import MockSlurmCluster from app.tests.utils.user import UserWithAuthHeader -from app.tests.utils.utils import random_lower_string +from app.tests.utils.utils import CleanupList, random_lower_string class _TestResourceVersionRoutes: @@ -17,7 +23,12 @@ class _TestResourceVersionRoutes: class TestResourceVersionRouteCreate(_TestResourceVersionRoutes): @pytest.mark.asyncio async def test_create_resource_route( - self, client: AsyncClient, random_user: UserWithAuthHeader, random_resource: Resource + self, + client: AsyncClient, + random_user: UserWithAuthHeader, + random_resource: Resource, + cleanup: CleanupList, + mock_s3_service: MockS3ServiceResource, ) -> None: """ Test for creating a new resource version. @@ -28,6 +39,12 @@ class TestResourceVersionRouteCreate(_TestResourceVersionRoutes): HTTP Client to perform the request on. random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. + random_resource : clowmdb.models.Resource + Random resource for testing. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. + mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource + Mock S3 Service to manipulate objects. """ resource_version = ResourceVersionIn(release=random_lower_string(6)) response = await client.post( @@ -36,6 +53,25 @@ class TestResourceVersionRouteCreate(_TestResourceVersionRoutes): headers=random_user.auth_headers, ) assert response.status_code == status.HTTP_201_CREATED + resource_version_out = ResourceVersionOut.model_validate_json(response.content) + + # Load bucket policy + s3_policy = await mock_s3_service.BucketPolicy(settings.RESOURCE_BUCKET) + policy = json.loads(await s3_policy.policy) + + # delete create bucket policy statement after the test + async def delete_policy_stmt() -> None: + policy["Statement"] = [ + stmt for stmt in policy["Statement"] if stmt["Sid"] != str(resource_version_out.resource_version_id) + ] + await s3_policy.put(json.dumps(policy)) + + cleanup.add_task(delete_policy_stmt) + + # check if a bucket policy statement was created + assert ( + sum(1 for stmt in policy["Statement"] if stmt["Sid"] == str(resource_version_out.resource_version_id)) == 1 + ) class TestResourceVersionRouteGet(_TestResourceVersionRoutes): @@ -114,6 +150,8 @@ class TestResourceVersionRouteGet(_TestResourceVersionRoutes): HTTP Client to perform the request on. random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. + random_resource : clowmdb.models.Resource + Random resource for testing. """ response = await client.get( f"{self.base_path}/{str(random_resource.resource_id)}/versions/{str(uuid4())}", @@ -128,7 +166,7 @@ class TestResourceVersionRouteDelete(_TestResourceVersionRoutes): self, client: AsyncClient, random_resource: Resource, - random_resource_version: ResourceVersion, + random_resource_version_states: ResourceVersion, random_user: UserWithAuthHeader, ) -> None: """ @@ -140,8 +178,8 @@ class TestResourceVersionRouteDelete(_TestResourceVersionRoutes): HTTP Client to perform the request on. random_resource : clowmdb.models.Resource Random resource for testing. - random_resource_version : clowmdb.models.Resource - Random resource version for testing. + random_resource_version_states : clowmdb.models.Resource + Random resource version with all possible states for testing. random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. """ @@ -151,7 +189,7 @@ class TestResourceVersionRouteDelete(_TestResourceVersionRoutes): self.base_path, str(random_resource.resource_id), "versions", - str(random_resource_version.resource_version_id), + str(random_resource_version_states.resource_version_id), "cluster", ] ), @@ -164,8 +202,9 @@ class TestResourceVersionRouteDelete(_TestResourceVersionRoutes): self, client: AsyncClient, random_resource: Resource, - random_resource_version: ResourceVersion, + random_resource_version_states: ResourceVersion, random_user: UserWithAuthHeader, + mock_s3_service: MockS3ServiceResource, ) -> None: """ Test for deleting a resource version in the S3 Bucket @@ -176,10 +215,12 @@ class TestResourceVersionRouteDelete(_TestResourceVersionRoutes): HTTP Client to perform the request on. random_resource : clowmdb.models.Resource Random resource for testing. - random_resource_version : clowmdb.models.Resource - Random resource version for testing. + random_resource_version_states : clowmdb.models.Resource + Random resource version with all possible states for testing. random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. + mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource + Mock S3 Service to manipulate objects. """ response = await client.delete( "/".join( @@ -187,7 +228,7 @@ class TestResourceVersionRouteDelete(_TestResourceVersionRoutes): self.base_path, str(random_resource.resource_id), "versions", - str(random_resource_version.resource_version_id), + str(random_resource_version_states.resource_version_id), "s3", ] ), @@ -195,15 +236,38 @@ class TestResourceVersionRouteDelete(_TestResourceVersionRoutes): ) assert response.status_code == status.HTTP_200_OK + # test that the resource is deleted in S3 + with pytest.raises(ClientError): + obj = await mock_s3_service.ObjectSummary( + bucket_name=settings.RESOURCE_BUCKET, + key=resource_version_key( + random_resource.resource_id, random_resource_version_states.resource_version_id + ), + ) + await obj.load() + + # test that the maintainer has no permission to upload a new resource to S3 + s3_policy = await mock_s3_service.BucketPolicy(settings.RESOURCE_BUCKET) + policy = json.loads(await s3_policy.policy) + assert ( + sum( + 1 + for stmt in policy["Statement"] + if stmt["Sid"] == str(random_resource_version_states.resource_version_id) + ) + == 0 + ) + class TestResourceVersionRoutePut(_TestResourceVersionRoutes): @pytest.mark.asyncio - async def test_request_sync_resource_version_route( + async def test_request_sync_resource_version_route_without_uploaded_resource( self, client: AsyncClient, random_resource: Resource, - random_resource_version: ResourceVersion, + random_resource_version_states: ResourceVersion, random_user: UserWithAuthHeader, + mock_s3_service: MockS3ServiceResource, ) -> None: """ Test for requesting a synchronization of the resource version to the cluster. @@ -214,31 +278,151 @@ class TestResourceVersionRoutePut(_TestResourceVersionRoutes): HTTP Client to perform the request on. random_resource : clowmdb.models.Resource Random resource for testing. - random_resource_version : clowmdb.models.Resource - Random resource version for testing. + random_resource_version_states : clowmdb.models.Resource + Random resource version with all possible states for testing. random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. + mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource + Mock S3 Service to manipulate objects. """ + s3_bucket = await mock_s3_service.Bucket(settings.RESOURCE_BUCKET) + s3_key = resource_version_key(random_resource.resource_id, random_resource_version_states.resource_version_id) + await s3_bucket.delete_objects({"Objects": [{"Key": s3_key}]}) + response = await client.put( "/".join( [ self.base_path, str(random_resource.resource_id), "versions", - str(random_resource_version.resource_version_id), + str(random_resource_version_states.resource_version_id), "request_sync", ] ), headers=random_user.auth_headers, ) - assert response.status_code == status.HTTP_200_OK + assert response.status_code == status.HTTP_400_BAD_REQUEST + + @pytest.mark.asyncio + async def test_request_sync_resource_version_route_with_uploaded_resource( + self, + client: AsyncClient, + random_resource: Resource, + random_resource_version_states: ResourceVersion, + random_user: UserWithAuthHeader, + mock_s3_service: MockS3ServiceResource, + cleanup: CleanupList, + ) -> None: + """ + Test for requesting a synchronization of the resource version to the cluster. + + Parameters + ---------- + client : httpx.AsyncClient + HTTP Client to perform the request on. + random_resource : clowmdb.models.Resource + Random resource for testing. + random_resource_version_states : clowmdb.models.Resource + Random resource version with all possible states for testing. + random_user : app.tests.utils.user.UserWithAuthHeader + Random user for testing. + mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource + Mock S3 Service to manipulate objects. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. + """ + previous_state = random_resource_version_states.status + if previous_state == ResourceVersion.Status.RESOURCE_REQUESTED: + s3_bucket = await mock_s3_service.Bucket(settings.RESOURCE_BUCKET) + s3_key = resource_version_key( + random_resource.resource_id, random_resource_version_states.resource_version_id + ) + await s3_bucket.upload_fileobj( + BytesIO(b"content"), + Key=s3_key, + ) + cleanup.add_task( + s3_bucket.delete_objects, + Delete={"Objects": [{"Key": s3_key}]}, + ) + + response = await client.put( + "/".join( + [ + self.base_path, + str(random_resource.resource_id), + "versions", + str(random_resource_version_states.resource_version_id), + "request_sync", + ] + ), + headers=random_user.auth_headers, + ) + + if previous_state in [ResourceVersion.Status.RESOURCE_REQUESTED, ResourceVersion.Status.CLUSTER_DELETED]: + assert response.status_code == status.HTTP_200_OK + # test if the maintainer gets a permission to upload to the bucket + s3_policy = await mock_s3_service.BucketPolicy(settings.RESOURCE_BUCKET) + policy = json.loads(await s3_policy.policy) + assert ( + sum( + 1 + for stmt in policy["Statement"] + if stmt["Sid"] == str(random_resource_version_states.resource_version_id) + ) + == 0 + ) + else: + assert response.status_code == status.HTTP_400_BAD_REQUEST + + @pytest.mark.asyncio + async def test_sync_resource_version_route_without_uploaded_resource( + self, + client: AsyncClient, + random_resource: Resource, + random_resource_version_states: ResourceVersion, + random_user: UserWithAuthHeader, + mock_s3_service: MockS3ServiceResource, + ) -> None: + """ + Test for synchronizing a resource version to the cluster but missing the resource in S3. + + Parameters + ---------- + client : httpx.AsyncClient + HTTP Client to perform the request on. + random_resource : clowmdb.models.Resource + Random resource for testing. + random_resource_version_states : clowmdb.models.Resource + Random resource version with all possible states for testing. + random_user : app.tests.utils.user.UserWithAuthHeader + Random user for testing. + mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource + Mock S3 Service to manipulate objects. + """ + s3_bucket = await mock_s3_service.Bucket(settings.RESOURCE_BUCKET) + s3_key = resource_version_key(random_resource.resource_id, random_resource_version_states.resource_version_id) + await s3_bucket.delete_objects({"Objects": [{"Key": s3_key}]}) + response = await client.put( + "/".join( + [ + self.base_path, + str(random_resource.resource_id), + "versions", + str(random_resource_version_states.resource_version_id), + "sync", + ] + ), + headers=random_user.auth_headers, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio async def test_sync_resource_version_route( self, client: AsyncClient, random_resource: Resource, - random_resource_version: ResourceVersion, + random_resource_version_states: ResourceVersion, random_user: UserWithAuthHeader, ) -> None: """ @@ -250,31 +434,89 @@ class TestResourceVersionRoutePut(_TestResourceVersionRoutes): HTTP Client to perform the request on. random_resource : clowmdb.models.Resource Random resource for testing. - random_resource_version : clowmdb.models.Resource - Random resource version for testing. + random_resource_version_states : clowmdb.models.Resource + Random resource version with all possible states for testing. random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. """ + previous_status = random_resource_version_states.status response = await client.put( "/".join( [ self.base_path, str(random_resource.resource_id), "versions", - str(random_resource_version.resource_version_id), + str(random_resource_version_states.resource_version_id), "sync", ] ), headers=random_user.auth_headers, ) - assert response.status_code == status.HTTP_200_OK + if previous_status in [ + ResourceVersion.Status.SYNC_REQUESTED, + ResourceVersion.Status.CLUSTER_DELETED, + ]: + assert response.status_code == status.HTTP_200_OK + else: + assert response.status_code == status.HTTP_400_BAD_REQUEST + + @pytest.mark.asyncio + async def test_fail_sync_resource_version_route( + self, + client: AsyncClient, + random_resource: Resource, + random_resource_version_states: ResourceVersion, + random_user: UserWithAuthHeader, + mock_slurm_cluster: MockSlurmCluster, + cleanup: CleanupList, + ) -> None: + """ + Test for synchronizing a resource version to the cluster. + + Parameters + ---------- + client : httpx.AsyncClient + HTTP Client to perform the request on. + random_resource : clowmdb.models.Resource + Random resource for testing. + random_resource_version_states : clowmdb.models.Resource + Random resource version with all possible states for testing. + random_user : app.tests.utils.user.UserWithAuthHeader + Random user for testing. + """ + previous_status = random_resource_version_states.status + mock_slurm_cluster.fail_jobs = True + + def reset_slurm_config() -> None: + mock_slurm_cluster.fail_jobs = False + + cleanup.add_task(reset_slurm_config) + response = await client.put( + "/".join( + [ + self.base_path, + str(random_resource.resource_id), + "versions", + str(random_resource_version_states.resource_version_id), + "sync", + ] + ), + headers=random_user.auth_headers, + ) + if previous_status in [ + ResourceVersion.Status.SYNC_REQUESTED, + ResourceVersion.Status.CLUSTER_DELETED, + ]: + assert response.status_code == status.HTTP_200_OK + else: + assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio async def test_set_latest_resource_version_route( self, client: AsyncClient, random_resource: Resource, - random_resource_version: ResourceVersion, + random_resource_version_states: ResourceVersion, random_user: UserWithAuthHeader, ) -> None: """ @@ -286,21 +528,25 @@ class TestResourceVersionRoutePut(_TestResourceVersionRoutes): HTTP Client to perform the request on. random_resource : clowmdb.models.Resource Random resource for testing. - random_resource_version : clowmdb.models.Resource - Random resource version for testing. + random_resource_version_states : clowmdb.models.Resource + Random resource version with all possible states for testing. random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. """ + previous_state = random_resource_version_states.status response = await client.put( "/".join( [ self.base_path, str(random_resource.resource_id), "versions", - str(random_resource_version.resource_version_id), + str(random_resource_version_states.resource_version_id), "latest", ] ), headers=random_user.auth_headers, ) - assert response.status_code == status.HTTP_200_OK + if previous_state == ResourceVersion.Status.SYNCHRONIZED: + assert response.status_code == status.HTTP_200_OK + else: + assert response.status_code == status.HTTP_400_BAD_REQUEST diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 4a3200156f416e73a5b07fa2c24bc05895728a78..57daa41f575025d99988a985440c507e305d85b4 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -1,7 +1,9 @@ import asyncio +import json from functools import partial +from io import BytesIO from secrets import token_urlsafe -from typing import AsyncGenerator, AsyncIterator, Dict, Iterator +from typing import AsyncIterator, Dict, Iterator import httpx import pytest @@ -9,12 +11,13 @@ import pytest_asyncio from clowmdb.db.session import get_async_session from clowmdb.models import Resource, ResourceVersion from pytrie import SortedStringTrie as Trie -from sqlalchemy import delete +from sqlalchemy import delete, update from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies import get_db, get_decode_jwt_function, get_httpx_client, get_s3_resource from app.core.config import settings from app.main import app +from app.schemas.resource_version import resource_version_key from app.tests.mocks import DefaultMockHTTPService, MockHTTPService from app.tests.mocks.mock_opa_service import MockOpaService from app.tests.mocks.mock_s3_resource import MockS3ServiceResource @@ -35,8 +38,8 @@ def event_loop() -> Iterator: loop.close() -@pytest.fixture(scope="session") -async def mock_s3_service() -> AsyncGenerator[MockS3ServiceResource, None]: +@pytest_asyncio.fixture(scope="session") +async def mock_s3_service() -> AsyncIterator[MockS3ServiceResource]: """ Fixture for creating a mock object for the rgwadmin package. """ @@ -168,7 +171,12 @@ async def random_third_user(db: AsyncSession, mock_opa_service: MockOpaService) @pytest_asyncio.fixture(scope="function") -async def random_resource(db: AsyncSession, random_user: UserWithAuthHeader) -> AsyncIterator[Resource]: +async def random_resource( + db: AsyncSession, + random_user: UserWithAuthHeader, + mock_s3_service: MockS3ServiceResource, + cleanup: CleanupList, +) -> AsyncIterator[Resource]: """ Create a random resource and deletes it afterward. """ @@ -186,6 +194,13 @@ async def random_resource(db: AsyncSession, random_user: UserWithAuthHeader) -> db.add(resource_version_db) await db.commit() await db.refresh(resource_db, attribute_names=["versions"]) + + s3_policy = await mock_s3_service.BucketPolicy(settings.RESOURCE_BUCKET) + policy = json.loads(await s3_policy.policy) + policy["Statement"].append({"Sid": str(resource_db.resource_id)}) + await s3_policy.put(Policy=json.dumps(policy)) + cleanup.add_task(s3_policy.delete) + yield resource_db await db.execute(delete(Resource).where(Resource._resource_id == resource_db.resource_id.bytes)) await db.commit() @@ -200,13 +215,73 @@ async def random_resource_version(db: AsyncSession, random_resource: Resource) - return resource_version +@pytest_asyncio.fixture(scope="function", params=[status for status in ResourceVersion.Status]) +async def random_resource_version_states( + db: AsyncSession, + random_resource: Resource, + request: pytest.FixtureRequest, + mock_s3_service: MockS3ServiceResource, + cleanup: CleanupList, +) -> ResourceVersion: + """ + Create a random resource version with all possible resource version status. + """ + resource_version: ResourceVersion = random_resource.versions[0] + stmt = ( + update(ResourceVersion) + .where(ResourceVersion._resource_version_id == resource_version.resource_version_id.bytes) + .values(status=request.param) + ) + await db.execute(stmt) + await db.commit() + await db.refresh(resource_version, attribute_names=["status"]) + + if request.param == ResourceVersion.Status.RESOURCE_REQUESTED: + # create a permission for the maintainer to upload the resource to S3 + s3_policy = await mock_s3_service.BucketPolicy(settings.RESOURCE_BUCKET) + policy = json.loads(await s3_policy.policy) + policy["Statement"].append({"Sid": str(resource_version.resource_version_id)}) + await s3_policy.put(Policy=json.dumps(policy)) + + # delete policy statement after the test + async def delete_policy_stmt() -> None: + policy["Statement"] = [ + s3_stmt + for s3_stmt in policy["Statement"] + if s3_stmt["Sid"] != str(resource_version.resource_version_id) + ] + await s3_policy.put(Policy=json.dumps(policy)) + + cleanup.add_task(delete_policy_stmt) + + if request.param in [ + ResourceVersion.Status.SYNC_REQUESTED, + ResourceVersion.Status.SYNCHRONIZING, + ResourceVersion.Status.SYNCHRONIZED, + ResourceVersion.Status.LATEST, + ResourceVersion.Status.DENIED, + ResourceVersion.Status.CLUSTER_DELETED, + ]: + # create the resource object in S3 for appropriate resource version states + s3_bucket = await mock_s3_service.Bucket(settings.RESOURCE_BUCKET) + s3_key = resource_version_key(random_resource.resource_id, random_resource.versions[0].resource_version_id) + await s3_bucket.upload_fileobj( + BytesIO(b"content"), + Key=s3_key, + ) + cleanup.add_task( + s3_bucket.delete_objects, + Delete={"Objects": [{"Key": s3_key}]}, + ) + + return resource_version + + @pytest_asyncio.fixture(scope="function") async def cleanup(db: AsyncSession) -> AsyncIterator[CleanupList]: """ - Yield a list with sql delete statements that gets executed after a (failed) test + Yields a Cleanup object where (async) functions can be registered which get executed after a (failed) test """ - to_delete: CleanupList = [] - yield to_delete - for stmt in to_delete: - await db.execute(stmt) - await db.commit() + cleanup_list = CleanupList() + yield cleanup_list + await cleanup_list.empty_queue() diff --git a/app/tests/crud/test_resource.py b/app/tests/crud/test_resource.py index 1f1c50eb49ca2df99e02e209e3dcf3e88e6844bd..b5a149891c8fb3be5c7333c8018bd102ce5e84b9 100644 --- a/app/tests/crud/test_resource.py +++ b/app/tests/crud/test_resource.py @@ -176,8 +176,8 @@ class TestResourceCRUDCreate: Async database session to perform query on. random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. - cleanup : list[sqlalchemy.sql.dml.Delete] - List were to append sql deletes that gets executed after the test + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ resource_in = ResourceIn( name=random_lower_string(8), @@ -188,7 +188,12 @@ class TestResourceCRUDCreate: resource = await CRUDResource.create(db, resource_in=resource_in, maintainer_id=random_user.user.uid) assert resource is not None - cleanup.append(delete(Resource).where(Resource._resource_id == resource.resource_id.bytes)) + + async def delete_resource() -> None: + await db.execute(delete(Resource).where(Resource._resource_id == resource.resource_id.bytes)) + await db.commit() + + cleanup.add_task(delete_resource) created_resource = await db.scalar( select(Resource) diff --git a/app/tests/crud/test_resource_version.py b/app/tests/crud/test_resource_version.py index 04254c68dee69bbdbbdacce671752a1f05df6e2a..abc6a966cf0f5d2c0c66bda2eda537ce4d66b238 100644 --- a/app/tests/crud/test_resource_version.py +++ b/app/tests/crud/test_resource_version.py @@ -111,6 +111,54 @@ class TestResourceCRUDUpdate: assert updated_resource_version.status == ResourceVersion.Status.S3_DELETED + @pytest.mark.asyncio + async def test_update_resource_version_to_latest_without_resource_id( + self, db: AsyncSession, random_resource_version: ResourceVersion + ) -> None: + """ + Test for updating the resource version status from the CRUD Repository. + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_resource_version : clowmdb.models.ResourceVersion + Random resource for testing. + """ + with pytest.raises(ValueError): + await CRUDResourceVersion.update_status( + db, + resource_version_id=random_resource_version.resource_version_id, + status=ResourceVersion.Status.LATEST, + ) + + @pytest.mark.asyncio + async def test_update_resource_version_to_latest_with_resource_id( + self, db: AsyncSession, random_resource_version: ResourceVersion + ) -> None: + """ + Test for updating the resource version status from the CRUD Repository. + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_resource_version : clowmdb.models.ResourceVersion + Random resource for testing. + """ + await CRUDResourceVersion.update_status( + db, + resource_version_id=random_resource_version.resource_version_id, + status=ResourceVersion.Status.LATEST, + resource_id=random_resource_version.resource_id, + ) + updated_resource_version = await db.scalars( + select(ResourceVersion) + .where(ResourceVersion._resource_id == random_resource_version.resource_id.bytes) + .where(ResourceVersion.status == ResourceVersion.Status.LATEST.name) + ) + assert sum(1 for _ in updated_resource_version) == 1 + @pytest.mark.asyncio async def test_update_non_existing_resource_version( self, db: AsyncSession, random_resource_version: ResourceVersion @@ -156,8 +204,8 @@ class TestResourceCRUDCreate: Async database session to perform query on. random_resource : clowmdb.models.Resource Random resource for testing. - cleanup : list[sqlalchemy.sql.dml.Delete] - List were to append sql deletes that gets executed after the test + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ release = random_lower_string(8) @@ -167,11 +215,15 @@ class TestResourceCRUDCreate: assert resource_version is not None - cleanup.append( - delete(ResourceVersion).where( - ResourceVersion._resource_version_id == resource_version.resource_version_id.bytes + async def delete_resource_version() -> None: + await db.execute( + delete(ResourceVersion).where( + ResourceVersion._resource_version_id == resource_version.resource_version_id.bytes + ) ) - ) + await db.commit() + + cleanup.add_task(delete_resource_version) created_resource_version = await db.scalar( select(ResourceVersion).where( diff --git a/app/tests/mocks/mock_s3_resource.py b/app/tests/mocks/mock_s3_resource.py index 80a5d1cd6fc9f477479018dc051beac475a51c50..97897eb276cbd5c9019a7bc6b1a60d6df048f39c 100644 --- a/app/tests/mocks/mock_s3_resource.py +++ b/app/tests/mocks/mock_s3_resource.py @@ -1,5 +1,6 @@ from datetime import datetime -from typing import TYPE_CHECKING, Dict, List, Optional +from io import BytesIO +from typing import TYPE_CHECKING, Any, Dict, List, Optional from botocore.exceptions import ClientError @@ -23,7 +24,7 @@ class MockS3Object: Name of the corresponding bucket. """ - def __init__(self, bucket_name: str, key: str) -> None: + def __init__(self, bucket_name: str, key: str, content: str = "") -> None: """ Initialize a MockS3Object. @@ -37,6 +38,7 @@ class MockS3Object: self.key = key self.bucket_name = bucket_name self.content_type = "text/plain" + self.content = content def __repr__(self) -> str: return f"MockS3Object(key={self.key}, bucket={self.bucket_name})" @@ -65,7 +67,7 @@ class MockS3ObjectSummary: Hash of the object content """ - def __init__(self, bucket_name: str, key: str) -> None: + def __init__(self, bucket_name: str, key: str, content: str = "") -> None: """ Initialize a MockS3ObjectSummary. @@ -78,9 +80,22 @@ class MockS3ObjectSummary: """ self.key = key self.bucket_name = bucket_name - self.size = 100 - self.last_modified = datetime.now() - self.etag = random_hex_string(32) + self._size = 100 + self._last_modified = datetime.now() + self._etag = random_hex_string(32) + self.obj = MockS3Object(bucket_name, key, content) + + @property + async def last_modified(self) -> datetime: + return self._last_modified + + @property + async def e_tag(self) -> str: + return self._etag + + @property + async def size(self) -> int: + return self._size def __repr__(self) -> str: return f"MockS3ObjectSummary(key={self.key}, bucket={self.bucket_name})" @@ -94,7 +109,10 @@ class MockS3ObjectSummary: sObject : app.tests.mocks.mock_s3_resource.MockS3Object The corresponding S3Object. """ - return MockS3Object(self.bucket_name, self.key) + return self.obj + + async def load(self) -> None: + pass class MockS3BucketPolicy: @@ -116,7 +134,7 @@ class MockS3BucketPolicy: def __init__(self, bucket_name: str): self.bucket_name = bucket_name - self.policy: str = "" + self._policy: str = '{"Version": "2012-10-17", "Statement": []}' async def put(self, Policy: str) -> None: """ @@ -127,11 +145,18 @@ class MockS3BucketPolicy: Policy : str The new policy as str. """ - self.policy = Policy + self._policy = Policy async def load(self) -> None: pass + async def delete(self) -> None: + self._policy = '{"Version": "2012-10-17", "Statement": []}' + + @property + async def policy(self) -> str: + return self._policy + class MockS3CorsRule: """ @@ -227,8 +252,9 @@ class MockS3Bucket: Delete a MockS3ObjectSummary from the list """ - def __init__(self, obj_list: Optional[List[MockS3ObjectSummary]] = None) -> None: + def __init__(self, parent_bucket: "MockS3Bucket", obj_list: Optional[List[MockS3ObjectSummary]] = None) -> None: self._objs: List[MockS3ObjectSummary] = [] if obj_list is None else obj_list + self._bucket = parent_bucket def all(self) -> List[MockS3ObjectSummary]: """ @@ -252,7 +278,7 @@ class MockS3Bucket: """ self._objs.append(obj) - def delete(self, key: str) -> None: + async def delete(self, key: Optional[str] = None) -> list[dict[str, list[dict[str, str]]]]: """ Delete a MockS3ObjectSummary from the list. @@ -261,7 +287,12 @@ class MockS3Bucket: key : str Key of the object to delete """ - self._objs = [obj for obj in self._objs if obj.key != key] + if key is not None: + self._objs = [obj for obj in self._objs if obj.key != key] + return [{"Deleted": [{"Key": key}]}] + pre_deleted = self._objs.copy() + await self._bucket.delete_objects({"Objects": [{"Key": obj.key} for obj in self._objs]}) + return [{"Deleted": [{"Key": obj.key} for obj in pre_deleted]}] def filter(self, Prefix: str) -> "MockS3Bucket.MockS3ObjectList": """ @@ -277,7 +308,9 @@ class MockS3Bucket: obj_list : app.tests.mocks.mock_s3_resource.MockS3Bucket.MockS3ObjectList The filtered list. """ - return MockS3Bucket.MockS3ObjectList(obj_list=list(filter(lambda x: x.key.startswith(Prefix), self._objs))) + return MockS3Bucket.MockS3ObjectList( + obj_list=[obj for obj in self._objs if obj.key.startswith(Prefix)], parent_bucket=self._bucket + ) def __init__(self, name: str, parent_service: "MockS3ServiceResource"): """ @@ -292,7 +325,7 @@ class MockS3Bucket: """ self.name = name self.creation_date: datetime = datetime.now() - self.objects = MockS3Bucket.MockS3ObjectList() + self.objects = MockS3Bucket.MockS3ObjectList(parent_bucket=self) self._parent_service: MockS3ServiceResource = parent_service self.policy = MockS3BucketPolicy(name) self.cors = MockS3CorsRule() @@ -342,7 +375,7 @@ class MockS3Bucket: } """ for key_object in Delete["Objects"]: - self.objects.delete(key=key_object["Key"]) + await self.objects.delete(key=key_object["Key"]) def get_objects(self) -> List[MockS3ObjectSummary]: """ @@ -368,6 +401,9 @@ class MockS3Bucket: """ self.objects.add(obj) + async def upload_fileobj(self, Fileobj: BytesIO, Key: str, **kwargs: Any) -> None: + self.add_object(MockS3ObjectSummary(bucket_name=self.name, key=Key, content=Fileobj.read().decode("utf-8"))) + def __repr__(self) -> str: return f"MockS3Bucket(name={self.name}, objects={[obj.key for obj in self.get_objects()]})" diff --git a/app/tests/mocks/mock_slurm_cluster.py b/app/tests/mocks/mock_slurm_cluster.py index 4e042ca621d5a000d39f680b9044a020f140a45e..db77a1c9bf6990c982bf5533af41cde15190ca43 100644 --- a/app/tests/mocks/mock_slurm_cluster.py +++ b/app/tests/mocks/mock_slurm_cluster.py @@ -25,6 +25,8 @@ class MockSlurmCluster(MockHTTPService): self._job_states: List[bool] = [] self.base_path = f"slurm/{version}" self._job_path_regex = re.compile(f"^/slurm/{re.escape(version)}/job/[\d]*$") + self.fail_jobs = False + self.run_jobs = False def handle_request(self, request: Request, **kwargs: bool) -> Response: """ @@ -58,19 +60,22 @@ class MockSlurmCluster(MockHTTPService): return self._method_not_allowed_response # If a job should be canceled elif self._job_path_regex.match(request.url.path) is not None: - # Route supports DELETE Method - if request.method == "DELETE": + # Route supports GET Method + if request.method == "GET": # Parse job ID from url path job_id = int(request.url.path.split("/")[-1]) - # Check if job exists and send appropriate response - if job_id < len(self._request_bodies): - self._job_states[job_id] = False - return Response(status_code=status.HTTP_200_OK, text="") - else: + if job_id >= len(self._request_bodies): return Response(status_code=status.HTTP_404_NOT_FOUND, text=f"Job with ID {job_id} not found") - elif request.method == "GET": # Get status of job for slurm job monitoring - return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, text="") + state = "COMPLETED" + if self.run_jobs: + state = "RUNNING" + elif self.fail_jobs: + state = "FAILED" + return Response( + status_code=status.HTTP_200_OK, + json={"jobs": [{"job_state": state}]}, + ) else: return self._method_not_allowed_response # Requested route is not mocked @@ -100,6 +105,7 @@ class MockSlurmCluster(MockHTTPService): Resets the mock service to its initial state. """ self._request_bodies = [] + self._job_states = [] def add_workflow_execution(self, job: SlurmRequestBody) -> int: """ @@ -108,7 +114,7 @@ class MockSlurmCluster(MockHTTPService): Parameters ---------- job : Dict[str, Any] - The requests body for submiting a slurm job. + The requests body for submitting a slurm job. Returns ------- diff --git a/app/tests/unit/test_backoff_strategy.py b/app/tests/unit/test_backoff_strategy.py index a9b188b47fee37f88eb39626df4cef791ea26d92..0dd8f93d5b49b4df6d101e69960c258644cb2e07 100644 --- a/app/tests/unit/test_backoff_strategy.py +++ b/app/tests/unit/test_backoff_strategy.py @@ -1,16 +1,16 @@ -from math import ceil, floor, log2 +from math import floor, log2 from app.utils.backoff_strategy import ExponentialBackoff, LinearBackoff, NoBackoff class TestExponentialBackoffStrategy: - def test_exponential_without_initial_delay(self) -> None: + def test_exponential(self) -> None: """ Test generating a bounded exponential backoff strategy series in seconds without an initial delay. """ for max_val in range(1023, 1026): # numbers around a power of 2 index_at_maximum = floor(log2(max_val)) - sleep_generator = ExponentialBackoff(initial_delay=0, max_value=max_val) + sleep_generator = ExponentialBackoff(max_value=max_val) for i, sleep in enumerate(sleep_generator): if i < index_at_maximum: @@ -23,11 +23,11 @@ class TestExponentialBackoffStrategy: else: assert False, "Iteration should have stopped" - def test_unbounded_exponential_without_initial_delay(self) -> None: + def test_unbounded_exponential(self) -> None: """ Test generating an unbounded exponential backoff strategy series in seconds without an initial delay """ - sleep_generator = ExponentialBackoff(initial_delay=0, max_value=-1) + sleep_generator = ExponentialBackoff(max_value=-1) for i, sleep in enumerate(sleep_generator): assert sleep == 2 ** (i + 1) @@ -36,30 +36,9 @@ class TestExponentialBackoffStrategy: elif i > 20: assert False, "Iteration should have stopped" - def test_exponential_with_initial_delay(self) -> None: - """ - Test generating a bounded exponential backoff strategy series in seconds with an initial delay - """ - for max_val in range(1023, 1026): # numbers around a power of 2 - index_at_maximum = ceil(log2(max_val)) - sleep_generator = ExponentialBackoff(initial_delay=30, max_value=max_val) - - for i, sleep in enumerate(sleep_generator): - if i == 0: - assert sleep == 30 - elif i < index_at_maximum: - assert sleep == 2**i - elif i == index_at_maximum: - assert sleep == max_val - elif i == index_at_maximum + 1: - assert sleep == max_val - sleep_generator.close() - else: - assert False, "Iteration should have stopped" - class TestLinearBackoffStrategy: - def test_linear_without_initial_delay(self) -> None: + def test_linear(self) -> None: """ Test generating a bounded linear backoff strategy series in seconds without an initial delay """ @@ -67,7 +46,7 @@ class TestLinearBackoffStrategy: repetition_factor = 5 for max_val in range((linear_backoff * repetition_factor) - 1, (linear_backoff * repetition_factor) + 2): index_at_maximum = max_val // linear_backoff - sleep_generator = LinearBackoff(initial_delay=0, backoff=linear_backoff, max_value=max_val) + sleep_generator = LinearBackoff(backoff=linear_backoff, max_value=max_val) for i, sleep in enumerate(sleep_generator): if i < index_at_maximum: @@ -80,11 +59,11 @@ class TestLinearBackoffStrategy: else: assert False, "Iteration should have stopped" - def test_unbounded_linear_without_initial_delay(self) -> None: + def test_unbounded_linear(self) -> None: """ Test generating an unbounded linear backoff strategy series in seconds without an initial delay """ - sleep_generator = LinearBackoff(initial_delay=0, backoff=6, max_value=-1) + sleep_generator = LinearBackoff(backoff=6, max_value=-1) for i, sleep in enumerate(sleep_generator): assert sleep == 6 * (i + 1) @@ -93,36 +72,13 @@ class TestLinearBackoffStrategy: elif i > 200: assert False, "Iteration should have stopped" - def test_linear_with_initial_delay(self) -> None: - """ - Test generating a bounded linear backoff strategy series in seconds with an initial delay - """ - linear_backoff = 5 - repetition_factor = 5 - for max_val in range((linear_backoff * repetition_factor) - 1, (linear_backoff * repetition_factor) + 2): - index_at_maximum = (max_val // linear_backoff) + 1 - sleep_generator = LinearBackoff(initial_delay=30, backoff=linear_backoff, max_value=max_val) - - for i, sleep in enumerate(sleep_generator): - if i == 0: - assert sleep == 30 - elif i < index_at_maximum: - assert sleep == linear_backoff * i - elif i == index_at_maximum: - assert sleep == max_val - elif i == index_at_maximum + 1: - assert sleep == max_val - sleep_generator.close() - else: - assert False, "Iteration should have stopped" - class TestNoBackoffStrategy: - def test_no_backoff_without_initial_delay(self) -> None: + def test_no_backoff(self) -> None: """ Test generating no backoff strategy series in seconds without an initial delay """ - sleep_generator = NoBackoff(initial_delay=0, constant_value=40) + sleep_generator = NoBackoff(constant_value=40) for i, sleep in enumerate(sleep_generator): assert sleep == 40 @@ -130,19 +86,3 @@ class TestNoBackoffStrategy: sleep_generator.close() elif i > 20: assert False, "Iteration should have stopped" - - def test_no_backoff_with_initial_delay(self) -> None: - """ - Test generating no backoff strategy series in seconds with an initial delay - """ - sleep_generator = NoBackoff(initial_delay=20, constant_value=40) - - for i, sleep in enumerate(sleep_generator): - if i == 0: - assert sleep == 20 - else: - assert sleep == 40 - if i == 20: - sleep_generator.close() - elif i > 20: - assert False, "Iteration should have stopped" diff --git a/app/tests/utils/utils.py b/app/tests/utils/utils.py index 18687fb64e90134d5fb7306b6ae55e6211cce761..0eb0bb7bc6be2e4f914ec94bb54f7820aa633b35 100644 --- a/app/tests/utils/utils.py +++ b/app/tests/utils/utils.py @@ -1,10 +1,45 @@ import random import string -from typing import List +from typing import Any, Callable, List, ParamSpec -from sqlalchemy.sql.dml import Delete +from app.utils import Job -CleanupList = List[Delete] +P = ParamSpec("P") + + +class CleanupList: + """ + Helper object to hold a queue of functions that can be executed later + """ + + def __init__(self) -> None: + self.queue: List[Job] = [] + + def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: + """ + Add a (async) function to the queue. + + Parameters + ---------- + func : Callable[P, Any] + Function to register. + args : P.args + Arguments to the function. + kwargs : P.kwargs + Keyword arguments to the function. + """ + self.queue.append(Job(func, *args, **kwargs)) + + async def empty_queue(self) -> None: + """ + Empty the queue by executing the registered functions. + """ + while len(self.queue) > 0: + func = self.queue.pop() + if func.is_async: + await func() + else: + func() def random_lower_string(length: int = 32) -> str: diff --git a/app/utils/__init__.py b/app/utils/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..4698bb6d780afbd46ce1090c1d1dfac8670fdece 100644 --- a/app/utils/__init__.py +++ b/app/utils/__init__.py @@ -0,0 +1,3 @@ +from .backoff_strategy import ExponentialBackoff # noqa: F401 +from .job import AsyncJob, Job # noqa: F401 +from .otlp import start_as_current_span_async # noqa: F401 diff --git a/app/utils/backoff_strategy.py b/app/utils/backoff_strategy.py index 3e12ddc7ad50f4a72dc7d7a863bfd5ef0493a103..2dcdc631b88034459b948390259f187d7dbd49d0 100644 --- a/app/utils/backoff_strategy.py +++ b/app/utils/backoff_strategy.py @@ -5,18 +5,14 @@ from typing import Any, Type class BackoffStrategy(ABC, Generator): - def __init__(self, initial_delay: int = 0): + def __init__(self) -> None: """ Initialize the class BackoffStrategy Parameters ---------- - initial_delay : int, default 0 - The initial delay in seconds that should be emitted first. if smaller than 1 then don't emit this value. """ self._current_val = 0 - self._delay = initial_delay - self._delay_first_iteration = initial_delay > 0 self._iteration = 0 self._stop_next = False @@ -66,10 +62,6 @@ class BackoffStrategy(ABC, Generator): if self._stop_next: raise StopIteration self._iteration += 1 - if self._delay_first_iteration: - self._delay_first_iteration = False - self._iteration -= 1 - return self._delay self._current_val = self._compute_next_value(self._iteration) return self._current_val @@ -85,18 +77,16 @@ class ExponentialBackoff(BackoffStrategy): An exponential Backoff strategy based on the power of two. The generated values should be put into a sleep function. """ - def __init__(self, initial_delay: int = 0, max_value: int = 300): + def __init__(self, max_value: int = 300): """ Initialize the exponential BackoffStrategy class Parameters ---------- - initial_delay : int, default 0 - The initial delay in seconds that should be emitted first. if smaller than 1 then don't emit this value. max_value : int, default 300 The maximum this generator can emit. If smaller than 1 then this series is unbounded. """ - super().__init__(initial_delay=initial_delay) + super().__init__() self.max_value = max_value self._reached_max = False @@ -115,18 +105,16 @@ class NoBackoff(BackoffStrategy): No Backoff strategy. It always emits a constant value. The generated values should be put into a sleep function. """ - def __init__(self, initial_delay: int = 0, constant_value: int = 30): + def __init__(self, constant_value: int = 30): """ Initialize the no BackoffStrategy class Parameters ---------- - initial_delay : int, default 0 - The initial delay in seconds that should be emitted first. if smaller than 1 then don't emit this value. constant_value : int, default 30 The constant value this generator should emit. """ - super().__init__(initial_delay=initial_delay) + super().__init__() self._val = constant_value def _compute_next_value(self, iteration: int) -> int: @@ -138,20 +126,18 @@ class LinearBackoff(BackoffStrategy): A linear Backoff strategy. The generated values should be put into a sleep function. """ - def __init__(self, initial_delay: int = 0, backoff: int = 5, max_value: int = 300): + def __init__(self, backoff: int = 5, max_value: int = 300): """ Initialize the linear BackoffStrategy class Parameters ---------- - initial_delay : int, default 0 - The initial delay in seconds that should be emitted first. if smaller than 1 then don't emit this value. backoff : int, default 5 The linear factor that is added each iteration. max_value : int, default 300 The maximum this generator can emit. If smaller than 1 then this series is unbounded. """ - super().__init__(initial_delay=initial_delay) + super().__init__() self.max_value = max_value self._backoff = backoff self._reached_max = False diff --git a/app/utils/job.py b/app/utils/job.py new file mode 100644 index 0000000000000000000000000000000000000000..d563d1e804005d49d9b19fe19f7a2ed285148ffd --- /dev/null +++ b/app/utils/job.py @@ -0,0 +1,28 @@ +from inspect import iscoroutinefunction +from typing import Callable, Generic, ParamSpec, TypeVar + +P = ParamSpec("P") +T = TypeVar("T") + + +class Job(Generic[P, T]): + def __init__(self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: + self.func = func + self.args = args + self.kwargs = kwargs + + @property + def is_async(self) -> bool: + return iscoroutinefunction(self.func) + + def __call__(self) -> T: + return self.func(*self.args, **self.kwargs) + + +class AsyncJob(Job): + def __init__(self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: + super().__init__(func, *args, **kwargs) + assert iscoroutinefunction(self.func) + + async def __call__(self) -> T: + return await super().__call__()