diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9b2bf6ac8d8d6a02e9785ce6c35ef7f348c59298..2e8793959ae1ebe78b898fe22fef6651b3bafabe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,13 +15,13 @@ repos: - id: check-merge-conflict - id: check-ast - repo: https://github.com/psf/black - rev: 23.12.0 + rev: 23.12.1 hooks: - id: black files: app args: [--check] - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.1.8' + rev: 'v0.1.11' hooks: - id: ruff - repo: https://github.com/PyCQA/isort @@ -31,7 +31,7 @@ repos: files: app args: [-c] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.7.1 + rev: v1.8.0 hooks: - id: mypy files: app diff --git a/app/api/endpoints/resource_version.py b/app/api/endpoints/resource_version.py index ce923649aaa8ece90cb5a38bd3c22650dcf764e6..2da37ad792458bcd60e403aaf78389f70696325b 100644 --- a/app/api/endpoints/resource_version.py +++ b/app/api/endpoints/resource_version.py @@ -1,8 +1,9 @@ -from typing import Annotated, Any, Awaitable, Callable, List, Optional +from typing import Annotated, Any, Awaitable, Callable, List, Union from clowmdb.models import ResourceVersion from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, status from opentelemetry import trace +from pydantic.json_schema import SkipJsonSchema from app.api.dependencies import ( AuthorizationDependency, @@ -43,7 +44,7 @@ async def list_resource_versions( resource: CurrentResource, current_user: CurrentUser, version_status: Annotated[ - Optional[List[ResourceVersion.Status]], + Union[List[ResourceVersion.Status], SkipJsonSchema[None]], Query( description=f"Which versions to include in the response. Permission `resource:read_any` required, current user is the maintainer, then only permission `resource:read` required. Default `{ResourceVersion.Status.LATEST.name}` and `{ResourceVersion.Status.SYNCHRONIZED.name}`.", # noqa: E501 @@ -92,7 +93,6 @@ async def request_resource_version( resource_version_in: ResourceVersionIn, current_user: CurrentUser, db: DBSession, - s3: S3Resource, background_tasks: BackgroundTasks, ) -> ResourceVersionOut: """ @@ -131,13 +131,11 @@ async def request_resource_version( resource_version_out = ResourceVersionOut.from_db_resource_version(resource_version) background_tasks.add_task( give_permission_to_s3_resource_version, - s3=s3, resource_version=resource_version_out, maintainer_id=current_user.uid, ) background_tasks.add_task( add_s3_resource_version_info, - s3=s3, s3_resource_version_info=S3ResourceVersionInfo.from_models( resource=resource, resource_version=resource.versions[0], maintainer=current_user ), @@ -233,7 +231,7 @@ async def request_resource_version_sync( detail=f"Can't request sync for resource version with status {resource_version.status.name}", ) resource_version_out = ResourceVersionOut.from_db_resource_version(resource_version) - if await get_s3_resource_version_obj(s3, ResourceVersionOut.from_db_resource_version(resource_version)) is None: + if await get_s3_resource_version_obj(ResourceVersionOut.from_db_resource_version(resource_version), s3=s3) is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Missing resource at S3 path {resource_version_out.s3_path}", @@ -244,7 +242,6 @@ async def request_resource_version_sync( resource_version_out.status = ResourceVersion.Status.SYNC_REQUESTED background_tasks.add_task( remove_permission_to_s3_resource_version, - s3=s3, resource_version=resource_version_out, ) return resource_version_out @@ -296,7 +293,7 @@ async def resource_version_sync( detail=f"Can't sync resource version with status {resource_version.status.name}", ) resource_version_out = ResourceVersionOut.from_db_resource_version(resource_version) - if await get_s3_resource_version_obj(s3, ResourceVersionOut.from_db_resource_version(resource_version)) is None: + if await get_s3_resource_version_obj(ResourceVersionOut.from_db_resource_version(resource_version), s3=s3) is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Missing resource at S3 path {resource_version_out.s3_path}", @@ -306,9 +303,7 @@ async def resource_version_sync( ) resource_version_out.status = ResourceVersion.Status.SYNCHRONIZING - background_tasks.add_task( - synchronize_cluster_resource, db=db, slurm_client=slurm_client, resource_version=resource_version_out - ) + background_tasks.add_task(synchronize_cluster_resource, resource_version=resource_version_out) return resource_version_out @@ -352,9 +347,7 @@ async def resource_version_latest( detail=f"Can't set resource version to {ResourceVersion.Status.LATEST.name} with status {resource_version.status.name}", ) resource_version_out = ResourceVersionOut.from_db_resource_version(resource_version) - background_tasks.add_task( - set_cluster_resource_version_latest, db=db, slurm_client=slurm_client, resource_version=resource_version_out - ) + background_tasks.add_task(set_cluster_resource_version_latest, resource_version=resource_version_out) return resource_version_out @@ -397,9 +390,7 @@ async def delete_resource_version_cluster( ) resource_version.status = ResourceVersion.Status.CLUSTER_DELETED resource_version_out = ResourceVersionOut.from_db_resource_version(resource_version) - background_tasks.add_task( - delete_cluster_resource_version, slurm_client=slurm_client, resource_version=resource_version_out - ) + background_tasks.add_task(delete_cluster_resource_version, resource_version=resource_version_out) return resource_version_out @@ -443,13 +434,11 @@ async def delete_resource_version_s3( resource_version.status = ResourceVersion.Status.S3_DELETED background_tasks.add_task( delete_s3_resource_version, - s3=s3, resource_id=resource.resource_id, resource_version_id=resource_version.resource_version_id, ) background_tasks.add_task( remove_permission_to_s3_resource_version, - s3=s3, resource_version=ResourceVersionOut.from_db_resource_version(resource_version), ) diff --git a/app/api/endpoints/resources.py b/app/api/endpoints/resources.py index 615566e639f644401b0307b12464a6ffb777b05a..02f9f14a38a869b97f4d6d0a938e0ecebb9a338b 100644 --- a/app/api/endpoints/resources.py +++ b/app/api/endpoints/resources.py @@ -1,26 +1,19 @@ -from typing import Annotated, Any, Awaitable, Callable, List, Optional +from typing import Annotated, Any, Awaitable, Callable, List, Union from clowmdb.models import ResourceVersion -from fastapi import APIRouter, BackgroundTasks, Depends, Query, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, status from opentelemetry import trace +from pydantic.json_schema import SkipJsonSchema -from app.api.dependencies import ( - AuthorizationDependency, - CurrentResource, - CurrentUser, - DBSession, - S3Resource, - SlurmClient, -) +from app.api.dependencies import AuthorizationDependency, CurrentResource, CurrentUser, DBSession from app.api.resource_cluster_utils import delete_cluster_resource from app.api.resource_s3_utils import ( add_s3_resource_version_info, - delete_s3_bucket_policy_stmt, + delete_resource_policy_stmt, delete_s3_resource, give_permission_to_s3_resource, give_permission_to_s3_resource_version, ) -from app.core.config import settings from app.crud import CRUDResource from app.schemas.resource import ResourceIn, ResourceOut, S3ResourceVersionInfo from app.utils.otlp import start_as_current_span_async @@ -38,7 +31,7 @@ async def list_resources( current_user: CurrentUser, db: DBSession, maintainer_id: Annotated[ - Optional[str], + Union[str, SkipJsonSchema[None]], Query( max_length=64, description="Filter for resource by maintainer. If current user is the same as maintainer ID, permission `resource:list` required, otherwise `resource:list_filter`.", @@ -46,13 +39,13 @@ async def list_resources( ), ] = None, version_status: Annotated[ - Optional[List[ResourceVersion.Status]], + Union[List[ResourceVersion.Status], SkipJsonSchema[None]], Query( description=f"Which versions of the resource to include in the response. Permission `resource:list_filter` required, unless `maintainer_id` is provided and current user is maintainer, then only permission `resource:list` required. Default `{ResourceVersion.Status.LATEST.name}` and `{ResourceVersion.Status.SYNCHRONIZED.name}`.", # noqa: E501 ), ] = None, - name_substring: Annotated[Optional[str], Query(max_length=32)] = None, + name_substring: Annotated[Union[str, SkipJsonSchema[None]], Query(max_length=32)] = None, ) -> List[ResourceOut]: """ List all resources. @@ -101,12 +94,11 @@ async def list_resources( @router.post("", summary="Request a new resource", status_code=status.HTTP_201_CREATED) @start_as_current_span_async("api_request_resource", tracer=tracer) -async def request_resource( +async def create_resource( authorization: Authorization, current_user: CurrentUser, resource_in: ResourceIn, db: DBSession, - s3: S3Resource, background_tasks: BackgroundTasks, ) -> ResourceOut: """ @@ -133,18 +125,20 @@ async def request_resource( current_span = trace.get_current_span() current_span.set_attribute("resource_in", resource_in.model_dump_json(indent=2)) await authorization("create") + if await CRUDResource.get_by_name(db=db, name=resource_in.name) is not None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=f"Resource with name '{resource_in.name}' already exists" + ) resource = await CRUDResource.create(db, resource_in=resource_in, maintainer_id=current_user.uid) resource_out = ResourceOut.from_db_resource(db_resource=resource) - background_tasks.add_task(give_permission_to_s3_resource, s3=s3, resource=resource_out) + background_tasks.add_task(give_permission_to_s3_resource, resource=resource_out) background_tasks.add_task( give_permission_to_s3_resource_version, - s3=s3, resource_version=resource_out.versions[0], maintainer_id=resource_out.maintainer_id, ) background_tasks.add_task( add_s3_resource_version_info, - s3=s3, s3_resource_version_info=S3ResourceVersionInfo.from_models( resource=resource, resource_version=resource.versions[0], maintainer=current_user ), @@ -159,7 +153,7 @@ async def get_resource( resource: CurrentResource, current_user: CurrentUser, version_status: Annotated[ - Optional[List[ResourceVersion.Status]], + Union[List[ResourceVersion.Status], SkipJsonSchema[None]], Query( description=f"Which versions of the resource to include in the response. Permission `resource:read_any` required, unless the current user is the maintainer, then only permission `resource:read` required. Default `{ResourceVersion.Status.LATEST.name}` and `{ResourceVersion.Status.SYNCHRONIZED.name}`.", # noqa: E501 @@ -202,9 +196,7 @@ async def delete_resource( authorization: Authorization, resource: CurrentResource, db: DBSession, - s3: S3Resource, background_tasks: BackgroundTasks, - slurm_client: SlurmClient, ) -> None: """ Delete a resources. @@ -223,18 +215,10 @@ async def delete_resource( Entrypoint for new BackgroundTasks. Provided by FastAPI. s3 : types_aiobotocore_s3.service_resource import S3ServiceResource S3 Service to perform operations on buckets. Dependency Injection. - slurm_client : app.slurm.rest_client.SlurmClient - Slurm client with an open connection. Dependency Injection """ trace.get_current_span().set_attribute("resource_id", str(resource.resource_id)) await authorization("delete") await CRUDResource.delete(db, resource_id=resource.resource_id) - background_tasks.add_task( - delete_s3_bucket_policy_stmt, - s3=s3, - bucket_name=settings.RESOURCE_BUCKET, - sid=[str(resource.resource_id)] - + [str(resource_version.resource_version_id) for resource_version in resource.versions], - ) - background_tasks.add_task(delete_s3_resource, s3=s3, resource_id=resource.resource_id) - background_tasks.add_task(delete_cluster_resource, slurm_client=slurm_client, resource_id=resource.resource_id) + background_tasks.add_task(delete_resource_policy_stmt, resource=ResourceOut.from_db_resource(resource)) + background_tasks.add_task(delete_s3_resource, resource_id=resource.resource_id) + background_tasks.add_task(delete_cluster_resource, resource_id=resource.resource_id) diff --git a/app/api/resource_cluster_utils.py b/app/api/resource_cluster_utils.py index 6bc92257370584d643bd42539fce0f9f33955939..e3e11b1e4fc191ae0953df9e31fedc44d69cd00b 100644 --- a/app/api/resource_cluster_utils.py +++ b/app/api/resource_cluster_utils.py @@ -1,14 +1,14 @@ from asyncio import sleep as async_sleep from pathlib import Path -from typing import Optional +from typing import Any, AsyncGenerator, Optional from uuid import UUID from clowmdb.models import ResourceVersion -from httpx import HTTPError +from httpx import AsyncClient, HTTPError from mako.template import Template from opentelemetry import trace -from sqlalchemy.ext.asyncio import AsyncSession +from app.api.dependencies import get_db, get_slurm_client from app.core.config import settings from app.crud import CRUDResourceVersion from app.schemas.resource_version import ResourceVersionOut, resource_dir_name, resource_version_dir_name @@ -23,9 +23,17 @@ delete_resource_version_script_template = Template(filename="app/mako_templates/ tracer = trace.get_tracer_provider().get_tracer(__name__) -async def synchronize_cluster_resource( - db: AsyncSession, slurm_client: SlurmClient, resource_version: ResourceVersionOut -) -> None: +async def get_async_http_client() -> AsyncGenerator[AsyncClient, None]: + async with AsyncClient() as client: + yield client + + +async def update_db_resource_wrapper(**kwargs: Any) -> None: + async for db in get_db(): + await CRUDResourceVersion.update_status(db=db, **kwargs) + + +async def synchronize_cluster_resource(resource_version: ResourceVersionOut) -> None: """ Synchronize a resource to the cluster @@ -33,19 +41,15 @@ async def synchronize_cluster_resource( ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - slurm_client : app.slurm.rest_client.SlurmClient - Slurm Rest Client to communicate with Slurm cluster. resource_version : app.schemas.resource_version.ResourceVersionOut Resource version schema to synchronize. """ - synchronization_script = synchronization_script_template.render( s3_url=settings.OBJECT_GATEWAY_URI, resource_version_s3_folder=f"{settings.RESOURCE_BUCKET}/{resource_version_dir_name(resource_version.resource_id, resource_version.resource_version_id)}", ) failed_job = AsyncJob( - CRUDResourceVersion.update_status, - db=db, + update_db_resource_wrapper, resource_version_id=resource_version.resource_version_id, status=ResourceVersion.Status.SYNC_REQUESTED, resource_id=None, @@ -71,26 +75,26 @@ async def synchronize_cluster_resource( }, }, ) - # Try to start the job on the slurm cluster - slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) - await _monitor_proper_job_execution( - slurm_client=slurm_client, - slurm_job_id=slurm_job_id, - success_job=AsyncJob( - CRUDResourceVersion.update_status, - db=db, - resource_version_id=resource_version.resource_version_id, - status=ResourceVersion.Status.SYNCHRONIZED, - resource_id=None, - ), - failed_job=failed_job, - ) + async for client in get_async_http_client(): + slurm_client = get_slurm_client(client) + # Try to start the job on the slurm cluster + slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) + await _monitor_proper_job_execution( + slurm_client=slurm_client, + slurm_job_id=slurm_job_id, + success_job=AsyncJob( + update_db_resource_wrapper, + resource_version_id=resource_version.resource_version_id, + status=ResourceVersion.Status.SYNCHRONIZED, + resource_id=None, + ), + failed_job=failed_job, + ) except (HTTPError, KeyError): # pragma: no cover await failed_job() async def delete_cluster_resource( - slurm_client: SlurmClient, resource_id: UUID, ) -> None: """ @@ -98,8 +102,6 @@ async def delete_cluster_resource( Parameters ---------- - slurm_client : app.slurm.rest_client.SlurmClient - Slurm Rest Client to communicate with Slurm cluster. resource_id : uuid.UUID ID of the resource to delete """ @@ -119,25 +121,21 @@ async def delete_cluster_resource( }, }, ) - # Try to start the job on the slurm cluster - slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) - await _monitor_proper_job_execution(slurm_client=slurm_client, slurm_job_id=slurm_job_id) + async for client in get_async_http_client(): + slurm_client = get_slurm_client(client) + # Try to start the job on the slurm cluster + slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) + await _monitor_proper_job_execution(slurm_client=slurm_client, slurm_job_id=slurm_job_id) except (HTTPError, KeyError): # pragma: no cover pass -async def set_cluster_resource_version_latest( - db: AsyncSession, slurm_client: SlurmClient, resource_version: ResourceVersionOut -) -> None: +async def set_cluster_resource_version_latest(resource_version: ResourceVersionOut) -> None: """ Synchronize a resource to the cluster Parameters ---------- - db : sqlalchemy.ext.asyncio.AsyncSession. - Async database session to perform query on. - slurm_client : app.slurm.rest_client.SlurmClient - Slurm Rest Client to communicate with Slurm cluster. resource_version : app.schemas.resource_version.ResourceVersionOut Resource version schema to synchronize. """ @@ -173,26 +171,26 @@ async def set_cluster_resource_version_latest( }, }, ) - # Try to start the job on the slurm cluster - slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) + async for client in get_async_http_client(): + slurm_client = get_slurm_client(client) + # Try to start the job on the slurm cluster + slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) - await _monitor_proper_job_execution( - slurm_client=slurm_client, - slurm_job_id=slurm_job_id, - success_job=AsyncJob( - CRUDResourceVersion.update_status, - db=db, - status=ResourceVersion.Status.LATEST, - resource_version_id=resource_version.resource_version_id, - resource_id=resource_version.resource_id, - ), - ) + await _monitor_proper_job_execution( + slurm_client=slurm_client, + slurm_job_id=slurm_job_id, + success_job=AsyncJob( + update_db_resource_wrapper, + status=ResourceVersion.Status.LATEST, + resource_version_id=resource_version.resource_version_id, + resource_id=resource_version.resource_id, + ), + ) except (HTTPError, KeyError): # pragma: no cover pass async def delete_cluster_resource_version( - slurm_client: SlurmClient, resource_version: ResourceVersionOut, ) -> None: """ @@ -200,8 +198,6 @@ async def delete_cluster_resource_version( Parameters ---------- - slurm_client : app.slurm.rest_client.SlurmClient - Slurm Rest Client to communicate with Slurm cluster. resource_version : app.schemas.resource_version.ResourceVersionOut Resource version schema to delete. """ @@ -237,9 +233,11 @@ async def delete_cluster_resource_version( }, }, ) - # Try to start the job on the slurm cluster - slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) - await _monitor_proper_job_execution(slurm_client=slurm_client, slurm_job_id=slurm_job_id) + async for client in get_async_http_client(): + slurm_client = get_slurm_client(client) + # Try to start the job on the slurm cluster + slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) + await _monitor_proper_job_execution(slurm_client=slurm_client, slurm_job_id=slurm_job_id) except (HTTPError, KeyError): # pragma: no cover pass diff --git a/app/api/resource_s3_utils.py b/app/api/resource_s3_utils.py index 4e274d2adcee8d3fd3f5e409a69855aab11a7240..c6a0c60509a4d2fecc8ad486653200fe0f9336cd 100644 --- a/app/api/resource_s3_utils.py +++ b/app/api/resource_s3_utils.py @@ -1,5 +1,5 @@ from io import BytesIO -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional from uuid import UUID from opentelemetry import trace @@ -7,6 +7,7 @@ from opentelemetry import trace from app.core.config import settings from app.s3.s3_resource import ( add_s3_bucket_policy_stmt, + boto_session, delete_s3_bucket_policy_stmt, delete_s3_objects, get_s3_object, @@ -30,14 +31,21 @@ else: tracer = trace.get_tracer_provider().get_tracer(__name__) -async def give_permission_to_s3_resource(s3: S3ServiceResource, resource: ResourceOut) -> None: +async def get_s3_resource() -> AsyncGenerator[S3ServiceResource, None]: + async with boto_session.resource( + service_name="s3", + endpoint_url=str(settings.OBJECT_GATEWAY_URI)[:-1], + verify=str(settings.OBJECT_GATEWAY_URI).startswith("https"), + ) as s3_resource: + yield s3_resource + + +async def give_permission_to_s3_resource(resource: ResourceOut) -> None: """ Give the maintainer permissions to list the S3 objects of his resource. Parameters ---------- - s3 : types_aiobotocore_s3.service_resource import S3ServiceResource - S3 Service to perform operations on buckets. resource : app.schemas.resource.ResourceOut Information about the resource. """ @@ -49,19 +57,16 @@ async def give_permission_to_s3_resource(s3: S3ServiceResource, resource: Resour "Action": ["s3:ListBucket"], "Condition": {"StringLike": {"s3:prefix": resource_dir_name(resource.resource_id) + "/*"}}, } - await add_s3_bucket_policy_stmt(s3=s3, bucket_name=settings.RESOURCE_BUCKET, stmt=policy) + async for s3 in get_s3_resource(): + await add_s3_bucket_policy_stmt(s3=s3, bucket_name=settings.RESOURCE_BUCKET, stmt=policy) -async def give_permission_to_s3_resource_version( - s3: S3ServiceResource, resource_version: ResourceVersionOut, maintainer_id: str -) -> None: +async def give_permission_to_s3_resource_version(resource_version: ResourceVersionOut, maintainer_id: str) -> None: """ Give the maintainer permissions to upload the resource to the appropriate S3 key. Parameters ---------- - s3 : types_aiobotocore_s3.service_resource import S3ServiceResource - S3 Service to perform operations on buckets. resource_version : app.schemas.resource_version.ResourceVersionOut Information about the resource version. maintainer_id : str @@ -74,10 +79,11 @@ async def give_permission_to_s3_resource_version( "Resource": f"arn:aws:s3:::{resource_version.s3_path[5:]}", "Action": ["s3:DeleteObject", "s3:PutObject"], } - await add_s3_bucket_policy_stmt(s3=s3, bucket_name=settings.RESOURCE_BUCKET, stmt=policy) + async for s3 in get_s3_resource(): + await add_s3_bucket_policy_stmt(s3=s3, bucket_name=settings.RESOURCE_BUCKET, stmt=policy) -async def add_s3_resource_version_info(s3: S3ServiceResource, s3_resource_version_info: S3ResourceVersionInfo) -> None: +async def add_s3_resource_version_info(s3_resource_version_info: S3ResourceVersionInfo) -> None: """ Upload the resource version information to S3 for documentation @@ -89,85 +95,109 @@ async def add_s3_resource_version_info(s3: S3ServiceResource, s3_resource_versio Resource information object that will be uploaded to S3. """ buf = BytesIO(s3_resource_version_info.model_dump_json(indent=2).encode("utf-8")) - await upload_obj( - s3=s3, - bucket_name=settings.RESOURCE_BUCKET, - key=s3_resource_version_info.s3_path(), - obj_stream=buf, - ExtraArgs={"ContentType": "application/json"}, - ) + async for s3 in get_s3_resource(): + await upload_obj( + s3=s3, + bucket_name=settings.RESOURCE_BUCKET, + key=s3_resource_version_info.s3_path(), + obj_stream=buf, + ExtraArgs={"ContentType": "application/json"}, + ) -async def remove_permission_to_s3_resource_version(s3: S3ServiceResource, resource_version: ResourceVersionOut) -> None: +async def remove_permission_to_s3_resource_version(resource_version: ResourceVersionOut) -> None: """ Remove the permission of the maintainer to upload a resource to S3. Parameters ---------- - s3 : types_aiobotocore_s3.service_resource import S3ServiceResource - S3 Service to perform operations on buckets. resource_version : app.schemas.resource_version.ResourceVersionOut Information about the resource version. """ - await delete_s3_bucket_policy_stmt( - s3, bucket_name=settings.RESOURCE_BUCKET, sid=str(resource_version.resource_version_id) - ) + async for s3 in get_s3_resource(): + await delete_s3_bucket_policy_stmt( + s3, bucket_name=settings.RESOURCE_BUCKET, sid=str(resource_version.resource_version_id) + ) -async def delete_s3_resource(s3: S3ServiceResource, resource_id: UUID) -> None: +async def delete_s3_resource(resource_id: UUID) -> None: """ Delete all objects related to a resource in S3. Parameters ---------- - s3 : types_aiobotocore_s3.service_resource import S3ServiceResource - S3 Service to perform operations on buckets. resource_id : uuid.UUID ID of the resource """ - await delete_s3_objects(s3, bucket_name=settings.RESOURCE_BUCKET, prefix=resource_dir_name(resource_id) + "/") + async for s3 in get_s3_resource(): + await delete_s3_objects(s3, bucket_name=settings.RESOURCE_BUCKET, prefix=resource_dir_name(resource_id) + "/") -async def delete_s3_resource_version(s3: S3ServiceResource, resource_id: UUID, resource_version_id: UUID) -> None: +async def delete_s3_resource_version(resource_id: UUID, resource_version_id: UUID) -> None: """ Delete all objects related to a resource version in S3. Parameters ---------- - s3 : types_aiobotocore_s3.service_resource import S3ServiceResource - S3 Service to perform operations on buckets. resource_id : uuid.UUID ID of the resource resource_version_id : uuid.UUID ID of the resource version """ - await delete_s3_objects( - s3, bucket_name=settings.RESOURCE_BUCKET, prefix=resource_version_dir_name(resource_id, resource_version_id) - ) + async for s3 in get_s3_resource(): + await delete_s3_objects( + s3, bucket_name=settings.RESOURCE_BUCKET, prefix=resource_version_dir_name(resource_id, resource_version_id) + ) async def get_s3_resource_version_obj( - s3: S3ServiceResource, resource_version: ResourceVersionOut + resource_version: ResourceVersionOut, + s3: Optional[S3ServiceResource] = None, ) -> Optional[ObjectSummary]: """ Delete all objects related to a resource version in S3. Parameters ---------- - s3 : types_aiobotocore_s3.service_resource import S3ServiceResource - S3 Service to perform operations on buckets. resource_version : app.schemas.resource_version.ResourceVersionOut Information about the resource version. + s3 : types_aiobotocore_s3.service_resource import S3ServiceResource | None + S3 Service to perform operations on buckets. If None, a new connection will be established. Returns ------- obj : types_aiobotocore_s3.service_resource.ObjectSummary | None The object summary of the resource version if it exists. """ - return await get_s3_object( - s3=s3, - bucket_name=settings.RESOURCE_BUCKET, - key=resource_version_key( - resource_id=resource_version.resource_id, resource_version_id=resource_version.resource_version_id - ), - ) + + async def func(s3_inner: S3ServiceResource) -> Optional[ObjectSummary]: + return await get_s3_object( + s3=s3_inner, + bucket_name=settings.RESOURCE_BUCKET, + key=resource_version_key( + resource_id=resource_version.resource_id, resource_version_id=resource_version.resource_version_id + ), + ) + + if s3 is None: + async for s3_outer in get_s3_resource(): + return await func(s3_outer) + return await func(s3) # type: ignore[arg-type] + + +async def delete_resource_policy_stmt(resource: ResourceOut) -> None: + """ + Delete all bucket policy statements regarding ths resource. + + Parameters + ---------- + resource : app.schemas.resource.ResourceOut + Information about the resource. + """ + async for s3 in get_s3_resource(): + await delete_s3_bucket_policy_stmt( + s3=s3, + bucket_name=settings.RESOURCE_BUCKET, + sid=[str(resource.resource_id)] + + [str(resource_version.resource_version_id) for resource_version in resource.versions], + ) diff --git a/app/crud/crud_resource.py b/app/crud/crud_resource.py index 20a3b235d9da287c77feed67e9e3476f0683e469..ece5f60ec9975d9ceffa51d1a761cedf8d104164 100644 --- a/app/crud/crud_resource.py +++ b/app/crud/crud_resource.py @@ -38,6 +38,28 @@ class CRUDResource: ): return await db.scalar(stmt) + @staticmethod + async def get_by_name(db: AsyncSession, name: str) -> Optional[Resource]: + """ + Get a resource by its ID. + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + name: str + Name of a resource. + + Returns + ------- + resource : clowmdb.models.Resource | None + The resource with the given name if it exists, None otherwise + """ + + stmt = select(Resource).where(Resource.name == name) + with tracer.start_as_current_span("db_get_resource_by_name", attributes={"name": name, "sql_query": str(stmt)}): + return await db.scalar(stmt) + @staticmethod async def list_resources( db: AsyncSession, diff --git a/app/tests/api/test_resource.py b/app/tests/api/test_resource.py index 9559fb13faeeace6f47ee34e8ce3dc03a6aa385f..57d14f32218536c0f19ef6c86f825bb64371781d 100644 --- a/app/tests/api/test_resource.py +++ b/app/tests/api/test_resource.py @@ -82,6 +82,33 @@ class TestResourceRouteCreate(_TestResourceRoutes): resource_db = await db.scalar(select(Resource).where(Resource._resource_id == resource_out.resource_id.bytes)) assert resource_db is not None + @pytest.mark.asyncio + async def test_create_duplicated_resource_route( + self, client: AsyncClient, random_user: UserWithAuthHeader, db: AsyncSession, random_resource: Resource + ) -> None: + """ + Test for creating a duplicated resource. + + Parameters + ---------- + client : httpx.AsyncClient + HTTP Client to perform the request on. + random_user : app.tests.utils.user.UserWithAuthHeader + Random user for testing. + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_resource : clowmdb.models.Resource + Random resource for testing. + """ + resource = ResourceIn( + release=random_lower_string(6), + name=random_resource.name, + description=random_lower_string(16), + source=random_lower_string(8), + ) + response = await client.post(self.base_path, json=resource.model_dump(), headers=random_user.auth_headers) + assert response.status_code == status.HTTP_400_BAD_REQUEST + class TestResourceRouteDelete(_TestResourceRoutes): @pytest.mark.asyncio diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 57daa41f575025d99988a985440c507e305d85b4..9aa6dc593cd6b0a1826f5d9be660aa91d2547e7a 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -3,17 +3,17 @@ import json from functools import partial from io import BytesIO from secrets import token_urlsafe -from typing import AsyncIterator, Dict, Iterator +from typing import AsyncIterator, Iterator import httpx import pytest import pytest_asyncio from clowmdb.db.session import get_async_session from clowmdb.models import Resource, ResourceVersion -from pytrie import SortedStringTrie as Trie from sqlalchemy import delete, update from sqlalchemy.ext.asyncio import AsyncSession +from app.api import dependencies, resource_cluster_utils, resource_s3_utils from app.api.dependencies import get_db, get_decode_jwt_function, get_httpx_client, get_s3_resource from app.core.config import settings from app.main import app @@ -64,6 +64,52 @@ def mock_opa_service() -> Iterator[MockOpaService]: mock_opa.reset() +@pytest_asyncio.fixture(scope="session") +async def mock_client( + mock_opa_service: MockOpaService, + mock_slurm_cluster: MockSlurmCluster, +) -> AsyncIterator[httpx.AsyncClient]: + def mock_request_handler(request: httpx.Request) -> httpx.Response: + url = str(request.url) + handler: MockHTTPService + if url.startswith(str(settings.SLURM_ENDPOINT)): + handler = mock_slurm_cluster + elif url.startswith(str(settings.OPA_URI)): + handler = mock_opa_service + else: + handler = DefaultMockHTTPService() + return handler.handle_request(request) + + async with httpx.AsyncClient(transport=httpx.MockTransport(mock_request_handler)) as http_client: + yield http_client + + +@pytest.fixture(autouse=True) +def monkeypatch_background_tasks( + monkeypatch: pytest.MonkeyPatch, + mock_client: httpx.AsyncClient, + db: AsyncSession, + mock_s3_service: MockS3ServiceResource, +) -> None: + """ + Patch the functions to get resources in background tasks with mock resources. + """ + + async def get_http_client() -> AsyncIterator[httpx.AsyncClient]: + yield mock_client + + async def get_patch_db() -> AsyncIterator[AsyncSession]: + yield db + + async def get_s3() -> AsyncIterator[MockS3ServiceResource]: + yield mock_s3_service + + monkeypatch.setattr(dependencies, "get_db", get_patch_db) + monkeypatch.setattr(resource_s3_utils, "get_s3_resource", get_s3) + monkeypatch.setattr(resource_cluster_utils, "get_async_http_client", get_http_client) + monkeypatch.setenv("OTLP_GRPC_ENDPOINT", "") + + @pytest_asyncio.fixture(scope="module") async def client( mock_s3_service: MockS3ServiceResource, @@ -75,12 +121,6 @@ async def client( Fixture for creating a TestClient and perform HTTP Request on it. Overrides several dependencies. """ - endpoints: Dict[str, MockHTTPService] = { - str(settings.SLURM_ENDPOINT): mock_slurm_cluster, - str(settings.OPA_URI): mock_opa_service, - } - # data structure to easily find the appropriate mock request based on the URL - t = Trie(**endpoints) async def get_mock_httpx_client( raise_opa_error: bool = False, raise_slurm_error: bool = False, raise_error: bool = False @@ -106,7 +146,14 @@ async def client( def mock_request_handler(request: httpx.Request) -> httpx.Response: url = str(request.url) - return t.longest_prefix_value(url, DefaultMockHTTPService()).handle_request(request, **errors) + handler: MockHTTPService + if url.startswith(str(settings.SLURM_ENDPOINT)): + handler = mock_slurm_cluster + elif url.startswith(str(settings.OPA_URI)): + handler = mock_opa_service + else: + handler = DefaultMockHTTPService() + return handler.handle_request(request, **errors) async with httpx.AsyncClient(transport=httpx.MockTransport(mock_request_handler)) as http_client: yield http_client diff --git a/app/tests/crud/test_resource.py b/app/tests/crud/test_resource.py index b5a149891c8fb3be5c7333c8018bd102ce5e84b9..bc6e41c102d1447973385186fc65df30eb7ac968 100644 --- a/app/tests/crud/test_resource.py +++ b/app/tests/crud/test_resource.py @@ -43,6 +43,35 @@ class TestResourceCRUDGet: resource = await CRUDResource.get(db, resource_id=uuid4()) assert resource is None + @pytest.mark.asyncio + async def test_get_resource_by_name(self, db: AsyncSession, random_resource: Resource) -> None: + """ + Test for getting a resource by name from the database + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_resource : clowmdb.models.Resource + Random resource for testing. + """ + resource = await CRUDResource.get_by_name(db, name=random_resource.name) + assert resource is not None + assert resource == random_resource + + @pytest.mark.asyncio + async def test_get_non_existing_resource_by_name(self, db: AsyncSession) -> None: + """ + Test for getting a non-existing resource by name from the database + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + """ + resource = await CRUDResource.get_by_name(db=db, name=random_lower_string()) + assert resource is None + class TestResourceCRUDList: @pytest.mark.asyncio diff --git a/pyproject.toml b/pyproject.toml index e48e191ed42e52535b20dea692f0263c77c27b55..c1825ab66eb937ecd4857c45c036dadee5c9f29c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ line-length = 120 [tool.ruff] line-length = 120 -target-version = "py311" +target-version = "py312" [tool.mypy] plugins = ["pydantic.mypy", "sqlalchemy.ext.mypy.plugin"] diff --git a/requirements-dev.txt b/requirements-dev.txt index 174bb00e549b7d9625c6a149ca30934f701a1477..bb9b40c398cf581c4283c86c02f1b1ec65e57f08 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,15 +2,14 @@ pytest>=7.4.0,<7.5.0 pytest-asyncio>=0.21.0,<0.22.0 pytest-cov>=4.1.0,<4.2.0 -coverage[toml]>=7.3.0,<7.4.0 +coverage[toml]>=7.4.0,<7.5.0 # Linters ruff>=0.1.0,<0.2.0 black>=23.12.0,<24.0.0 isort>=5.13.0,<5.14.0 -mypy>=1.7.0,<1.8.0 +mypy>=1.8.0,<1.9.0 # stubs for mypy -types-aiobotocore-lite[s3]>=2.8.0,<2.9.0 +types-aiobotocore-lite[s3]>=2.9.0,<2.10.0 types-requests # Miscellaneous pre-commit>=3.6.0,<3.7.0 -PyTrie>=0.4.0,<0.5.0 diff --git a/requirements.txt b/requirements.txt index 9b690e87b5fc1394d0da433925b5c64bca20c83a..2b4edb3782e55603019497fd150047c520fbf438 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,24 +2,21 @@ clowmdb>=2.3.0,<2.4.0 # Webserver packages -anyio>=3.7.0,<4.0.0 -fastapi>=0.105.0,<0.106.0 +fastapi>=0.108.0,<0.109.0 pydantic>=2.5.0,<2.6.0 pydantic-settings>=2.1.0,<2.2.0 -uvicorn>=0.24.0,<0.25.0 -python-multipart +uvicorn>=0.25.0,<0.26.0 # Database packages PyMySQL>=1.1.0,<1.2.0 SQLAlchemy>=2.0.0,<2.1.0 aiomysql>=0.2.0,<0.3.0 # Security packages -authlib>=1.2.0,<1.3.0 +authlib>=1.3.0,<1.4.0 # S3 packages aioboto3>=12.0.0,<13.0.0 # Miscellaneous tenacity>=8.2.0,<8.3.0 -httpx>=0.25.0,<0.26.0 -itsdangerous +httpx>=0.26.0,<0.27.0 # template engine mako>=1.3.0,<1.4.0 python-dotenv