diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index eccf65b4903be975909710869d7e80e8e8890ae6..a3d623f0252638a3c0c684a32866b635734ed935 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,4 @@ -image: ${CI_DEPENDENCY_PROXY_DIRECT_GROUP_IMAGE_PREFIX}/python:3.11-slim +image: ${CI_DEPENDENCY_PROXY_DIRECT_GROUP_IMAGE_PREFIX}/python:3.12-slim variables: PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" @@ -49,7 +49,7 @@ integration-test-job: # Runs integration tests with the database MYSQL_DATABASE: "$DB_DATABASE" MYSQL_USER: "$DB_USER" MYSQL_PASSWORD: "$DB_PASSWORD" - - name: $CI_REGISTRY/cmg/clowm/clowm-database:v2.3 + - name: $CI_REGISTRY/cmg/clowm/clowm-database:v3.0 alias: upgrade-db script: - python app/check_database_connection.py @@ -77,7 +77,7 @@ e2e-test-job: # Runs e2e tests on the API endpoints MYSQL_DATABASE: "$DB_DATABASE" MYSQL_USER: "$DB_USER" MYSQL_PASSWORD: "$DB_PASSWORD" - - name: $CI_REGISTRY/cmg/clowm/clowm-database:v2.3 + - name: $CI_REGISTRY/cmg/clowm/clowm-database:v3.0 alias: upgrade-db script: - python app/check_database_connection.py @@ -135,30 +135,30 @@ lint-test-job: # Runs linters checks on code publish-dev-docker-container-job: stage: deploy image: - name: gcr.io/kaniko-project/executor:v1.17.0-debug + name: gcr.io/kaniko-project/executor:v1.20.0-debug entrypoint: [""] dependencies: [] only: refs: - - development + - main before_script: - echo "{\"auths\":{\"${CI_REGISTRY}\":{\"auth\":\"$(printf "%s:%s" "${CI_REGISTRY_USER}" "${CI_REGISTRY_PASSWORD}" | base64 | tr -d '\n')\"},\"$CI_DEPENDENCY_PROXY_SERVER\":{\"auth\":\"$(printf "%s:%s" ${CI_DEPENDENCY_PROXY_USER} "${CI_DEPENDENCY_PROXY_PASSWORD}" | base64 | tr -d '\n')\"}}}" > /kaniko/.docker/config.json script: - /kaniko/executor --context "${CI_PROJECT_DIR}" --dockerfile "${CI_PROJECT_DIR}/Dockerfile" - --destination "${CI_REGISTRY_IMAGE}:dev-${CI_COMMIT_SHA}" - --destination "${CI_REGISTRY_IMAGE}:dev-latest" + --destination "${CI_REGISTRY_IMAGE}:main-${CI_COMMIT_SHA}" + --destination "${CI_REGISTRY_IMAGE}:main-latest" - /kaniko/executor --context "${CI_PROJECT_DIR}" --dockerfile "${CI_PROJECT_DIR}/Dockerfile-Gunicorn" - --destination "${CI_REGISTRY_IMAGE}:dev-${CI_COMMIT_SHA}-gunicorn" - --destination "${CI_REGISTRY_IMAGE}:dev-latest-gunicorn" + --destination "${CI_REGISTRY_IMAGE}:main-${CI_COMMIT_SHA}-gunicorn" + --destination "${CI_REGISTRY_IMAGE}:main-latest-gunicorn" publish-docker-container-job: stage: deploy image: - name: gcr.io/kaniko-project/executor:v1.17.0-debug + name: gcr.io/kaniko-project/executor:v1.20.0-debug entrypoint: [""] dependencies: [] only: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3c51ad992d1c0b6446e52061aad32294fa355a8b..a6a07950c9b673cc93736742574512507c0346ef 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,23 +15,23 @@ repos: - id: check-merge-conflict - id: check-ast - repo: https://github.com/psf/black - rev: 23.11.0 + rev: 23.12.1 hooks: - id: black files: app args: [--check] - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.1.7' + rev: 'v0.1.14' hooks: - id: ruff - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort files: app args: [-c] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.7.1 + rev: v1.8.0 hooks: - id: mypy files: app diff --git a/Dockerfile b/Dockerfile index d364a9b2efbb0abeb5800317b3544c8c059f56f7..4169882763ee9c166ccf56afd63d7bb21cc7dcbf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,11 +1,16 @@ -FROM python:3.11-slim -EXPOSE 8000 +FROM python:3.12-slim +ENV PORT=8000 +EXPOSE $PORT + # dumb-init forwards the kill signal to the python process -RUN apt-get update && apt-get -y install dumb-init curl +RUN apt-get update && apt-get -y install dumb-init +RUN apt-get clean ENTRYPOINT ["/usr/bin/dumb-init", "--"] +STOPSIGNAL SIGINT +RUN pip install --no-cache-dir httpx[cli] "uvicorn<0.28.0" -HEALTHCHECK --interval=30s --timeout=4s CMD curl -f http://localhost:8000/health || exit 1 +HEALTHCHECK --interval=30s --timeout=2s CMD httpx http://localhost:$PORT/health || exit 1 RUN useradd -m worker USER worker @@ -13,10 +18,13 @@ WORKDIR /home/worker/code ENV PYTHONPATH=/home/worker/code ENV PATH="/home/worker/.local/bin:${PATH}" +COPY ./start_service_uvicorn.sh /home/worker/code/start.sh +COPY ./scripts/prestart.sh /home/worker/code/prestart.sh + COPY --chown=worker:worker requirements.txt ./requirements.txt RUN pip install --user --no-cache-dir --upgrade -r requirements.txt -COPY --chown=worker:worker . . +COPY --chown=worker:worker ./app /home/worker/code/app -CMD ["./start_service.sh"] +CMD ["./start.sh"] diff --git a/Dockerfile-Gunicorn b/Dockerfile-Gunicorn index 9c460ba123efe26687258bc06f8f76c5e36fa43d..7c25aa03d7e03c79d9366d5caae07c9f72c35eb2 100644 --- a/Dockerfile-Gunicorn +++ b/Dockerfile-Gunicorn @@ -1,15 +1,20 @@ -FROM tiangolo/uvicorn-gunicorn-fastapi:python3.11-slim -EXPOSE 8000 +FROM python:3.12-slim ENV PORT=8000 +EXPOSE $PORT +WORKDIR /app/ +ENV PYTHONPATH=/app -RUN pip install --no-cache-dir httpx[cli] +RUN pip install --no-cache-dir httpx[cli] "gunicorn<21.3.0" "uvicorn<0.28.0" +COPY ./gunicorn_conf.py /app/gunicorn_conf.py +COPY ./start_service_gunicorn.sh /app/start.sh -HEALTHCHECK --interval=30s --timeout=4s CMD httpx http://localhost:$PORT/health || exit 1 +HEALTHCHECK --interval=30s --timeout=2s CMD httpx http://localhost:$PORT/health || exit 1 COPY ./scripts/prestart.sh /app/prestart.sh COPY ./requirements.txt /app/requirements.txt RUN pip install --no-cache-dir --upgrade -r requirements.txt -COPY ./mako_templates /app/mako_templates COPY ./app /app/app + +CMD ["./start.sh"] diff --git a/README.md b/README.md index b95748ea4c5b859426475bdfc9e1f67e7ad6358d..b6acee7bcf6acb5d00cea8bbc5a20424e0c127bf 100644 --- a/README.md +++ b/README.md @@ -51,3 +51,6 @@ prefix `NXF_` are also support when submitting a job to the slurm cluster, e.g. ``` NXF_VER=23.04.0 ``` +## License + +The API is licensed under the [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) license. See the [License](LICENSE) file for more information. diff --git a/app/api/background_tasks/__init__.py b/app/api/background_tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02089a789627e41a0f313848ce2d9301c5f63a4d --- /dev/null +++ b/app/api/background_tasks/__init__.py @@ -0,0 +1,8 @@ +from .background import ( # noqa:F401 + delete_remote_icon, + delete_s3_obj, + download_file_to_bucket, + process_and_upload_icon, + upload_scm_file, +) +from .cluster_utils import cancel_slurm_job, start_workflow_execution # noqa:F401 diff --git a/app/api/background_tasks/background.py b/app/api/background_tasks/background.py new file mode 100644 index 0000000000000000000000000000000000000000..433c18456a68d1d44f5ca5ebaf915ce5973caf64 --- /dev/null +++ b/app/api/background_tasks/background.py @@ -0,0 +1,92 @@ +from io import BytesIO +from pathlib import Path +from uuid import UUID + +from opentelemetry import trace +from PIL import Image + +from app.api import dependencies +from app.core.config import settings +from app.crud import CRUDWorkflowVersion +from app.git_repository import GitRepository +from app.scm import SCM + +tracer = trace.get_tracer_provider().get_tracer(__name__) + + +async def process_and_upload_icon(icon_slug: str, icon_buffer_file: Path) -> None: + """ + Process the icon and upload it to the S3 Icon Bucket + + Parameters + ---------- + icon_slug : str + Slug of the icon + icon_buffer_file : pathlib.Path + Path to the file containing the icon + """ + try: + im = Image.open(icon_buffer_file) + im.thumbnail((64, 64)) # Crop to 64x64 image + thumbnail_buffer = BytesIO() + im.save(thumbnail_buffer, "PNG") # save in buffer as PNG image + thumbnail_buffer.seek(0) + with tracer.start_as_current_span("s3_upload_workflow_version_icon", attributes={"icon": icon_slug}): + # Upload to bucket + s3 = dependencies.get_s3_resource() + s3.Bucket(name=settings.ICON_BUCKET).Object(key=icon_slug).upload_fileobj( + Fileobj=thumbnail_buffer, ExtraArgs={"ContentType": "image/png"} + ) + finally: + icon_buffer_file.unlink() + + +def upload_scm_file(scm: SCM, scm_file_id: UUID) -> None: + """ + Upload the SCM file of a private workflow into the PARAMS_BUCKET with the workflow id as key + + Parameters + ---------- + scm : app.scm.SCM + Python object representing the SCM file + scm_file_id : str + ID of the scm file for the name of the file in the bucket + """ + with BytesIO() as handle: + scm.serialize(handle) + handle.seek(0) + with tracer.start_as_current_span("s3_upload_workflow_credentials"): + s3 = dependencies.get_s3_resource() + s3.Bucket(settings.PARAMS_BUCKET).Object(SCM.generate_filename(scm_file_id)).upload_fileobj(handle) + + +def delete_s3_obj(bucket_name: str, key: str) -> None: + with tracer.start_as_current_span("s3_delete_object", attributes={"bucket_name": bucket_name, "key": key}): + s3 = dependencies.get_s3_resource() + s3.Bucket(bucket_name).Object(key).delete() + + +async def delete_remote_icon(icon_slug: str) -> None: + """ + Delete icon in S3 Bucket if there are no other workflow versions that depend on it + + Parameters + ---------- + icon_slug : str + Name of the icon file. + """ + # If there are no more Workflow versions that have this icon, delete it in the S3 ICON_BUCKET + async for db in dependencies.get_db(): + check = await CRUDWorkflowVersion.icon_exists(db, icon_slug) + if not check: + delete_s3_obj(bucket_name=settings.ICON_BUCKET, key=icon_slug) + + +async def download_file_to_bucket(repo: GitRepository, *, filepath: str, bucket_name: str, key: str) -> None: + s3 = dependencies.get_s3_resource() + async with dependencies.get_background_http_client() as client: + await repo.copy_file_to_bucket( + filepath=filepath, + obj=s3.Bucket(name=bucket_name).Object(key=key), + client=client, + ) diff --git a/app/api/background_tasks/cluster_utils.py b/app/api/background_tasks/cluster_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9b840ae0fd79f61d8f66dcd2ee20a2f20295b703 --- /dev/null +++ b/app/api/background_tasks/cluster_utils.py @@ -0,0 +1,195 @@ +import json +import re +import shlex +from asyncio import sleep as async_sleep +from contextlib import asynccontextmanager +from os import environ +from pathlib import Path +from tempfile import SpooledTemporaryFile +from typing import Any, AsyncIterator, Dict, Optional, Union +from uuid import UUID + +import botocore.client +import dotenv +from clowmdb.models import WorkflowExecution +from httpx import HTTPError +from mako.template import Template +from opentelemetry import trace + +from app.api import dependencies +from app.core.config import MonitorJobBackoffStrategy, settings +from app.crud import CRUDWorkflowExecution +from app.git_repository.abstract_repository import GitRepository +from app.scm import SCM +from app.slurm import SlurmClient, SlurmJobSubmission +from app.utils.backoff_strategy import BackoffStrategy, ExponentialBackoff, LinearBackoff, NoBackoff + +nextflow_command_template = Template(filename="app/mako_templates/nextflow_command.tmpl") +# regex to find S3 files in parameters of workflow execution +s3_file_regex = re.compile( + r"s3://(?!(((2(5[0-5]|[0-4]\d)|[01]?\d{1,2})\.){3}(2(5[0-5]|[0-4]\d)|[01]?\d{1,2})$))[a-z\d][a-z\d.-]{1,61}[a-z\d][^\"]*" +) + +tracer = trace.get_tracer_provider().get_tracer(__name__) + +dotenv.load_dotenv() +execution_env: Dict[str, Union[str, int, bool]] = {key: val for key, val in environ.items() if key.startswith("NXF_")} +execution_env["SCM_DIR"] = str(Path(settings.SLURM_WORKING_DIRECTORY) / "scm") + + +@asynccontextmanager +async def get_async_slurm_client() -> AsyncIterator[SlurmClient]: # pragma: no cover + async with dependencies.get_background_http_client() as client: + yield SlurmClient(client) + + +async def cancel_slurm_job(job_id: int) -> None: + async with get_async_slurm_client() as slurm_client: + await slurm_client.cancel_job(job_id=job_id) + + +async def start_workflow_execution( + execution: WorkflowExecution, + parameters: Dict[str, Any], + git_repo: GitRepository, + scm_file_id: Optional[UUID] = None, + workflow_entrypoint: Optional[str] = None, +) -> None: + """ + Start a workflow on the Slurm cluster. + + Parameters + ---------- + execution : clowmdb.models.WorkflowExecution + Workflow execution to execute. + parameters : Dict[str, Any] + Parameters for the workflow. + git_repo : app.git_repository.abstract_repository.GitRepository + Git repository of the workflow version. + scm_file_id : UUID | None + ID of the SCM file for private git repositories. + workflow_entrypoint : str | None + Entrypoint for the workflow by specifying the `-entry` parameter + """ + s3 = dependencies.get_s3_resource() + # Upload parameters to + params_file_name = f"params-{execution.execution_id.hex}.json" + with SpooledTemporaryFile(max_size=512000) as f: + f.write(json.dumps(parameters).encode("utf-8")) + f.seek(0) + with tracer.start_as_current_span("s3_upload_workflow_execution_parameters") as span: + span.set_attribute("workflow_execution_id", str(execution.execution_id)) + s3.Bucket(name=settings.PARAMS_BUCKET).Object(key=params_file_name).upload_fileobj(f) + for key in parameters.keys(): + if isinstance(parameters[key], str): + # Escape string parameters for bash shell + parameters[key] = shlex.quote(parameters[key]).replace("$", "\$") + + # Check if the there is an SCM file for the workflow + if scm_file_id is not None: + try: + with tracer.start_as_current_span("s3_check_workflow_scm_file"): + s3.Bucket(settings.PARAMS_BUCKET).Object(SCM.generate_filename(scm_file_id)).load() + except botocore.client.ClientError: + scm_file_id = None + + nextflow_script = nextflow_command_template.render( + repo=git_repo, + parameters=parameters, + execution_id=execution.execution_id, + settings=settings, + scm_file_id=scm_file_id, + debug_s3_path=execution.debug_path, + logs_s3_path=execution.logs_path, + provenance_s3_path=execution.provenance_path, + workflow_entrypoint=workflow_entrypoint, + ) + + # Setup env for the workflow execution + work_directory = str(Path(settings.SLURM_WORKING_DIRECTORY) / f"run-{execution.execution_id.hex}") + env = execution_env.copy() + env["TOWER_WORKSPACE_ID"] = execution.execution_id.hex[:16] + env["NXF_WORK"] = str(Path(work_directory) / "work") + env["NXF_ANSI_LOG"] = False + env["NXF_ASSETS"] = str( + Path(env.get("NXF_ASSETS", "$HOME/.nextflow/assets")) / f"{git_repo.name}_{git_repo.commit}" # type: ignore[arg-type] + ) + + try: + job_submission = SlurmJobSubmission( + script=nextflow_script.strip(), + job={ + "current_working_directory": settings.SLURM_WORKING_DIRECTORY, + "environment": env, + "name": execution.execution_id.hex, + "requeue": False, + "standard_output": str( + Path(settings.SLURM_WORKING_DIRECTORY) / f"slurm-{execution.execution_id.hex}.out" + ), + }, + ) + async with get_async_slurm_client() as slurm_client: + # Try to start the job on the slurm cluster + slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) + async for db in dependencies.get_db(): + await CRUDWorkflowExecution.update_slurm_job_id( + db, slurm_job_id=slurm_job_id, execution_id=execution.execution_id + ) + if not settings.SLURM_JOB_MONITORING == MonitorJobBackoffStrategy.NOMONITORING: # pragma: no cover + await _monitor_proper_job_execution( + slurm_client=slurm_client, execution_id=execution.execution_id, slurm_job_id=slurm_job_id + ) + except (HTTPError, KeyError): + # Mark job as aborted when there is an error + async for db in dependencies.get_db(): + await CRUDWorkflowExecution.set_error( + db, execution_id=execution.execution_id, status=WorkflowExecution.WorkflowExecutionStatus.ERROR + ) + + +async def _monitor_proper_job_execution( + slurm_client: SlurmClient, execution_id: UUID, slurm_job_id: int +) -> None: # pragma: no cover + """ + Check in an interval based on a backoff strategy if the slurm job is still running + the workflow execution in the database is not marked as finished. + + Parameters + ---------- + slurm_client : app.slurm.rest_client.SlurmClient + Slurm Rest Client to communicate with Slurm cluster. + execution_id : uuid.UUID + ID of the workflow execution + slurm_job_id : int + ID of the slurm job to monitor + """ + if settings.SLURM_JOB_MONITORING == MonitorJobBackoffStrategy.EXPONENTIAL: + # exponential to 50 minutes + sleep_generator: BackoffStrategy = ExponentialBackoff(initial_delay=30, max_value=300) # type: ignore[no-redef] + elif settings.SLURM_JOB_MONITORING == MonitorJobBackoffStrategy.LINEAR: + # 5 seconds increase to 5 minutes + sleep_generator: BackoffStrategy = LinearBackoff( # type: ignore[no-redef] + initial_delay=30, backoff=5, max_value=300 + ) + elif settings.SLURM_JOB_MONITORING == MonitorJobBackoffStrategy.CONSTANT: + # constant 30 seconds polling + sleep_generator: BackoffStrategy = NoBackoff(initial_delay=30, constant_value=30) # type: ignore[no-redef] + else: + return + for sleep_time in sleep_generator: + await async_sleep(sleep_time) + with tracer.start_span( + "monitor_job", attributes={"execution_id": str(execution_id), "slurm_job_id": slurm_job_id} + ) as span: + if await slurm_client.is_job_finished(slurm_job_id): + async for db in dependencies.get_db(): + execution = await CRUDWorkflowExecution.get(db, execution_id=execution_id) + # Check if the execution is marked as finished in the database + if execution is not None: + span.set_attribute("workflow_execution_status", str(execution.status)) + if execution.end_time is None: + # Mark job as finished with an error + await CRUDWorkflowExecution.set_error( + db, execution_id=execution_id, status=WorkflowExecution.WorkflowExecutionStatus.ERROR + ) + sleep_generator.close() diff --git a/app/api/dependencies.py b/app/api/dependencies.py index 203df0a58626f7c581983439b8ada1c0ba6fb3d2..731a421f71ab89ad86fdc9bbcda9b77421bc5905 100644 --- a/app/api/dependencies.py +++ b/app/api/dependencies.py @@ -1,3 +1,4 @@ +from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Annotated, AsyncIterator, Awaitable, Callable, Dict from uuid import UUID @@ -57,6 +58,12 @@ async def get_db() -> AsyncIterator[AsyncSession]: # pragma: no cover DBSession = Annotated[AsyncSession, Depends(get_db)] +@asynccontextmanager +async def get_background_http_client() -> AsyncIterator[AsyncClient]: # pragma: no cover + async with AsyncClient() as client: + yield client + + async def get_httpx_client(request: Request) -> AsyncClient: # pragma: no cover # Fetch open http client from the app return request.app.requests_client @@ -65,10 +72,13 @@ async def get_httpx_client(request: Request) -> AsyncClient: # pragma: no cover HTTPClient = Annotated[AsyncClient, Depends(get_httpx_client)] -def get_slurm_client(client: AsyncClient = Depends(get_httpx_client)) -> SlurmClient: +def get_slurm_client(client: HTTPClient) -> SlurmClient: # pragma: no cover return SlurmClient(client=client) +HTTPSlurmClient = Annotated[SlurmClient, Depends(get_slurm_client)] + + def get_decode_jwt_function() -> Callable[[str], Dict[str, str]]: # pragma: no cover """ Get function to decode and verify the JWT. @@ -86,9 +96,9 @@ def get_decode_jwt_function() -> Callable[[str], Dict[str, str]]: # pragma: no @start_as_current_span_async("decode_jwt", tracer=tracer) async def decode_bearer_token( - token: HTTPAuthorizationCredentials = Depends(bearer_token), - decode: Callable[[str], Dict[str, str]] = Depends(get_decode_jwt_function), - db: AsyncSession = Depends(get_db), + token: Annotated[HTTPAuthorizationCredentials, Depends(bearer_token)], + decode: Annotated[Callable[[str], Dict[str, str]], Depends(get_decode_jwt_function)], + db: DBSession, ) -> JWT: """ Get the decoded JWT or reject request if it is not valid or the user doesn't exist. @@ -120,6 +130,37 @@ async def decode_bearer_token( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Malformed JWT") +async def get_current_user(token: JWT = Depends(decode_bearer_token), db: AsyncSession = Depends(get_db)) -> User: + """ + Get the current user from the database based on the JWT. + + FastAPI Dependency Injection Function. + + Parameters + ---------- + token : app.schemas.security.JWT + The verified and decoded JWT. + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. Dependency Injection. + + Returns + ------- + user : clowmdb.models.User + User associated with the JWT sent with the HTTP request. + """ + try: + uid = UUID(token.sub) + except ValueError: # pragma: no cover + raise DecodeError() + user = await CRUDUser.get(db, uid) + if user: + return user + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + + +CurrentUser = Annotated[User, Depends(get_current_user)] + + class AuthorizationDependency: """ Class to parameterize the authorization request with the resource to perform an operation on. @@ -136,16 +177,17 @@ class AuthorizationDependency: def __call__( self, - token: JWT = Depends(decode_bearer_token), - client: AsyncClient = Depends(get_httpx_client), + user: CurrentUser, + client: HTTPClient, ) -> Callable[[str], Awaitable[AuthzResponse]]: """ Get the function to request the authorization service with the resource, JWT and HTTP Client already injected. Parameters ---------- - token : app.schemas.security.JWT - The verified and decoded JWT. Dependency Injection. + user : clowmdb.models.User + The current user based on the JWT. Dependency Injection. + client : httpx.AsyncClient HTTP Client with an open connection. Dependency Injection. @@ -156,42 +198,15 @@ class AuthorizationDependency: """ async def authorization_wrapper(operation: str) -> AuthzResponse: - params = AuthzRequest(operation=operation, resource=self.resource, uid=token.sub) + params = AuthzRequest(operation=operation, resource=self.resource, uid=user.lifescience_id) return await request_authorization(request_params=params, client=client) return authorization_wrapper -async def get_current_user(token: JWT = Depends(decode_bearer_token), db: AsyncSession = Depends(get_db)) -> User: - """ - Get the current user from the database based on the JWT. - - FastAPI Dependency Injection Function. - - Parameters - ---------- - token : app.schemas.security.JWT - The verified and decoded JWT. - db : sqlalchemy.ext.asyncio.AsyncSession. - Async database session to perform query on. Dependency Injection. - - Returns - ------- - user : clowmdb.models.User - User associated with the JWT sent with the HTTP request. - """ - user = await CRUDUser.get(db, token.sub) - if user: - return user - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") - - -CurrentUser = Annotated[User, Depends(get_current_user)] - - async def get_current_workflow( - wid: UUID = Path(..., description="ID of a workflow", examples=["0cc78936-381b-4bdd-999d-736c40591078"]), - db: AsyncSession = Depends(get_db), + wid: Annotated[UUID, Path(description="ID of a workflow", examples=["0cc78936-381b-4bdd-999d-736c40591078"])], + db: DBSession, ) -> Workflow: """ Get the workflow from the database with the ID given in the path. @@ -221,8 +236,10 @@ CurrentWorkflow = Annotated[Workflow, Depends(get_current_workflow)] async def get_current_workflow_execution( - eid: UUID = Path(..., description="ID of a workflow execution.", examples=["0cc78936-381b-4bdd-999d-736c40591078"]), - db: AsyncSession = Depends(get_db), + eid: Annotated[ + UUID, Path(description="ID of a workflow execution.", examples=["0cc78936-381b-4bdd-999d-736c40591078"]) + ], + db: DBSession, ) -> WorkflowExecution: """ Get the workflow execution from the database with the ID given in the path. @@ -251,12 +268,14 @@ async def get_current_workflow_execution( async def get_current_workflow_version( workflow: CurrentWorkflow, db: DBSession, - git_commit_hash: str = Path( - ..., - description="Git commit git_commit_hash of specific version.", - pattern=r"^([0-9a-f]{40}|latest)$", - examples=["ba8bcd9294c2c96aedefa1763a84a18077c50c0f"], - ), + git_commit_hash: Annotated[ + str, + Path( + description="Git commit git_commit_hash of specific version.", + pattern=r"^([0-9a-f]{40}|latest)$", + examples=["ba8bcd9294c2c96aedefa1763a84a18077c50c0f"], + ), + ], ) -> WorkflowVersion: """ Get the workflow version from the database with the ID given in the path. diff --git a/app/api/endpoints/workflow.py b/app/api/endpoints/workflow.py index 00a322c7af01ddd4ce5f0ccc0227af674ff83231..a64c96f8ead482a2486ab3cf6ee044cc9af2034c 100644 --- a/app/api/endpoints/workflow.py +++ b/app/api/endpoints/workflow.py @@ -1,13 +1,15 @@ from datetime import date -from typing import Annotated, Any, Awaitable, Callable, List, Optional, Set +from typing import Annotated, Any, Awaitable, Callable, List, Set, Union from uuid import UUID from clowmdb.models import Workflow, WorkflowMode, WorkflowVersion from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Response, status from opentelemetry import trace +from pydantic.json_schema import SkipJsonSchema -from app.api.dependencies import AuthorizationDependency, CurrentUser, CurrentWorkflow, DBSession, HTTPClient, S3Service -from app.api.utils import check_repo, upload_scm_file +from app.api.background_tasks import delete_s3_obj, download_file_to_bucket, upload_scm_file +from app.api.dependencies import AuthorizationDependency, CurrentUser, CurrentWorkflow, DBSession, HTTPClient +from app.api.utils import check_repo from app.core.config import settings from app.crud import CRUDWorkflow, CRUDWorkflowVersion from app.crud.crud_workflow_mode import CRUDWorkflowMode @@ -26,31 +28,37 @@ Authorization = Annotated[Callable[[str], Awaitable[Any]], Depends(workflow_auth tracer = trace.get_tracer_provider().get_tracer(__name__) -@router.get("", status_code=status.HTTP_200_OK, summary="List workflows") +@router.get("", status_code=status.HTTP_200_OK, summary="List workflows", response_model_exclude_none=True) @start_as_current_span_async("api_workflow_list", tracer=tracer) async def list_workflows( db: DBSession, authorization: Authorization, current_user: CurrentUser, - name_substring: Optional[str] = Query( - None, - min_length=3, - max_length=30, - description="Filter workflows by a substring in their name.", - ), - version_status: Optional[List[WorkflowVersion.Status]] = Query( - None, - description=f"Which versions of the workflow to include in the response. Permission 'workflow:list_filter required', unless 'developer_id' is provided and current user is developer, then only permission 'workflow:list' required. Default {WorkflowVersion.Status.PUBLISHED.name} and {WorkflowVersion.Status.DEPRECATED.name}.", # noqa: E501 - ), - developer_id: Optional[str] = Query( - None, - description="Filter for workflow by developer. If current user is the same as developer ID, permission 'workflow:list' required, otherwise 'workflow:list_filter'.", # noqa: E501 - examples=["28c5353b8bb34984a8bd4169ba94c606"], - ), + name_substring: Annotated[ + Union[str, SkipJsonSchema[None]], + Query( + min_length=3, + max_length=30, + description="Filter workflows by a substring in their name.", + ), + ] = None, + version_status: Annotated[ + Union[List[WorkflowVersion.Status], SkipJsonSchema[None]], + Query( + description=f"Which versions of the workflow to include in the response. Permission `workflow:list_filter` required, unless `developer_id` is provided and current user is developer, then only permission `workflow:list` required. Default `{WorkflowVersion.Status.PUBLISHED.name}` and `{WorkflowVersion.Status.DEPRECATED.name}`.", # noqa: E501 + ), + ] = None, + developer_id: Annotated[ + Union[UUID, SkipJsonSchema[None]], + Query( + description="Filter for workflow by developer. If current user is the developer, permission `workflow:list` required, otherwise `workflow:list_filter`.", # noqa: E501 + examples=["1d3387f3-95c0-4813-8767-2cad87faeebf"], + ), + ] = None, ) -> List[WorkflowOut]: """ List all workflows.\n - Permission "workflow:list" required. + Permission `workflow:list` required. \f Parameters ---------- @@ -74,7 +82,7 @@ async def list_workflows( """ current_span = trace.get_current_span() if developer_id is not None: # pragma: no cover - current_span.set_attribute("developer_id", developer_id) + current_span.set_attribute("developer_id", str(developer_id)) if name_substring is not None: # pragma: no cover current_span.set_attribute("name_substring", name_substring) if version_status is not None and len(version_status) > 0: # pragma: no cover @@ -105,12 +113,11 @@ async def create_workflow( current_user: CurrentUser, authorization: Authorization, client: HTTPClient, - s3: S3Service, workflow: WorkflowIn, ) -> WorkflowOut: """ Create a new workflow.\n - Permission "workflow:create" required.\n + Permission `workflow:create` required.\n For private Gitlab repositories, a Project Access Token with the role Reporter and scope `read_api` is needed.\n For private GitHub repositories, a Personal Access Token (classic) with scope `repo` is needed. \f @@ -124,8 +131,6 @@ async def create_workflow( Async database session to perform query on. Dependency Injection. current_user : clowmdb.models.User Current user. Dependency Injection. - s3 : boto3_type_annotations.s3.ServiceResource - S3 Service to perform operations on buckets in Ceph. Dependency Injection. authorization : Callable[[str], Awaitable[Any]] Async function to ask the auth service for authorization. Dependency Injection. client : httpx.AsyncClient @@ -164,53 +169,64 @@ async def create_workflow( # If it is a private repository, create an SCM file and upload it to the params bucket scm_provider = SCMProvider.from_repo(repo, name=SCMProvider.generate_name(workflow_db.workflow_id)) if repo.token is not None or not isinstance(repo, GitHubRepository): - background_tasks.add_task(upload_scm_file, s3=s3, scm=SCM([scm_provider]), scm_file_id=workflow_db.workflow_id) + background_tasks.add_task(upload_scm_file, scm=SCM([scm_provider]), scm_file_id=workflow_db.workflow_id) # If there are workflow modes with alternative parameter schemas, cache them in the WORKFLOW Bucket if len(workflow.modes) > 0: for mode_db in workflow_db.versions[0].workflow_modes: background_tasks.add_task( - repo.copy_file_to_bucket, + download_file_to_bucket, + repo=repo, filepath=mode_db.schema_path, - obj=s3.Bucket(name=settings.WORKFLOW_BUCKET).Object( - key=f"{workflow.git_commit_hash}-{mode_db.mode_id.hex}.json" - ), - client=client, + bucket_name=settings.WORKFLOW_BUCKET, + key=f"{workflow.git_commit_hash}-{mode_db.mode_id.hex}.json", ) else: # Cache the parameter schema in the WORKFLOW Bucket background_tasks.add_task( - repo.copy_file_to_bucket, + download_file_to_bucket, + repo=repo, filepath="nextflow_schema.json", - obj=s3.Bucket(name=settings.WORKFLOW_BUCKET).Object(key=f"{workflow.git_commit_hash}.json"), - client=client, + bucket_name=settings.WORKFLOW_BUCKET, + key=f"{workflow.git_commit_hash}.json", ) trace.get_current_span().set_attribute("workflow_id", str(workflow_db.workflow_id)) return WorkflowOut.from_db_workflow(await CRUDWorkflow.get(db, workflow_db.workflow_id)) -@router.get("/developer_statistics", status_code=status.HTTP_200_OK, summary="Get anonymized workflow execution") +@router.get( + "/developer_statistics", + status_code=status.HTTP_200_OK, + summary="Get anonymized workflow execution", + response_model_exclude_none=True, +) @start_as_current_span_async("api_workflow_get_developer_statistics", tracer=tracer) async def get_developer_workflow_statistics( db: DBSession, authorization: Authorization, response: Response, current_user: CurrentUser, - developer_id: Optional[str] = Query( - None, - description="Filter by the developer of the workflows", - examples=["28c5353b8bb34984a8bd4169ba94c606"], - min_length=3, - max_length=64, - ), - workflow_ids: Optional[List[UUID]] = Query(None, description="Filter by workflow IDs", alias="workflow_id"), - start: Optional[date] = Query(None, description="Filter by workflow executions after this date"), - end: Optional[date] = Query(None, description="Filter by workflow executions before this date"), + developer_id: Annotated[ + Union[UUID, SkipJsonSchema[None]], + Query( + description="Filter by the developer of the workflows", + examples=["1d3387f3-95c0-4813-8767-2cad87faeebf"], + ), + ] = None, + workflow_ids: Annotated[ + Union[List[UUID], SkipJsonSchema[None]], Query(description="Filter by workflow IDs", alias="workflow_id") + ] = None, + start: Annotated[ + Union[date, SkipJsonSchema[None]], Query(description="Filter by workflow executions after this date") + ] = None, + end: Annotated[ + Union[date, SkipJsonSchema[None]], Query(description="Filter by workflow executions before this date") + ] = None, ) -> List[AnonymizedWorkflowExecution]: """ Get the workflow executions with meta information and anonymized user IDs.\n - Permission "workflow:read_statistics" required if the `developer_id` is the same as the uid of the current user, - other "workflow:read_statistics_any". + Permission `workflow:read_statistics` required if the `developer_id` is the same as the uid of the current user, + other `workflow:read_statistics_any`. \f Parameters ---------- @@ -238,7 +254,7 @@ async def get_developer_workflow_statistics( """ span = trace.get_current_span() if developer_id: # pragma: no cover - span.set_attribute("developer_id", developer_id) + span.set_attribute("developer_id", str(developer_id)) if workflow_ids: # pragma: no cover span.set_attribute("workflow_ids", [str(wid) for wid in workflow_ids]) if start: # pragma: no cover @@ -253,21 +269,23 @@ async def get_developer_workflow_statistics( ) -@router.get("/{wid}", status_code=status.HTTP_200_OK, summary="Get a workflow") +@router.get("/{wid}", status_code=status.HTTP_200_OK, summary="Get a workflow", response_model_exclude_none=True) @start_as_current_span_async("api_workflow_get", tracer=tracer) async def get_workflow( workflow: CurrentWorkflow, db: DBSession, current_user: CurrentUser, authorization: Authorization, - version_status: Optional[List[WorkflowVersion.Status]] = Query( - None, - description=f"Which versions of the workflow to include in the response. Permission 'workflow:read_any' required if you are not the developer of this workflow. Default {WorkflowVersion.Status.PUBLISHED.name} and {WorkflowVersion.Status.DEPRECATED.name}", # noqa: E501 - ), + version_status: Annotated[ + Union[List[WorkflowVersion.Status], SkipJsonSchema[None]], + Query( + description=f"Which versions of the workflow to include in the response. Permission `workflow:read_any` required if you are not the developer of this workflow. Default `{WorkflowVersion.Status.PUBLISHED.name}` and `{WorkflowVersion.Status.DEPRECATED.name}`", # noqa: E501 + ), + ] = None, ) -> WorkflowOut: """ Get a specific workflow.\n - Permission "workflow:read" required. + Permission `workflow:read` required. \f Parameters ---------- @@ -309,7 +327,7 @@ async def get_workflow_statistics( ) -> List[WorkflowStatistic]: """ Get the number of started workflow per day.\n - Permission "workflow:read" required. + Permission `workflow:read` required. \f Parameters ---------- @@ -341,12 +359,11 @@ async def delete_workflow( workflow: CurrentWorkflow, db: DBSession, authorization: Authorization, - s3: S3Service, current_user: CurrentUser, ) -> None: """ Delete a workflow.\n - Permission "workflow:delete" required. + Permission `workflow:delete` required. \f Parameters ---------- @@ -358,8 +375,6 @@ async def delete_workflow( Async database session to perform query on. Dependency Injection. authorization : Callable[[str], Awaitable[Any]] Async function to ask the auth service for authorization. Dependency Injection. - s3 : boto3_type_annotations.s3.ServiceResource - S3 Service to perform operations on buckets in Ceph. Dependency Injection. current_user : clowmdb.models.User Current user. Dependency Injection. """ @@ -369,7 +384,7 @@ async def delete_workflow( versions = await CRUDWorkflowVersion.list(db, workflow.workflow_id) # Delete SCM file for private repositories background_tasks.add_task( - s3.Bucket(name=settings.PARAMS_BUCKET).Object(key=SCM.generate_filename(workflow.workflow_id)).delete + delete_s3_obj, bucket_name=settings.PARAMS_BUCKET, key=SCM.generate_filename(workflow.workflow_id) ) # Delete files in buckets mode_ids: Set[UUID] = set() @@ -379,24 +394,26 @@ async def delete_workflow( for mode in version.workflow_modes: mode_ids.add(mode.mode_id) background_tasks.add_task( - s3.Bucket(name=settings.WORKFLOW_BUCKET) - .Object(key=f"{version.git_commit_hash}-{mode.mode_id.hex}.json") - .delete + delete_s3_obj, + bucket_name=settings.WORKFLOW_BUCKET, + key=f"{version.git_commit_hash}-{mode.mode_id.hex}.json", ) else: # Delete standard parameter schema of workflow background_tasks.add_task( - s3.Bucket(name=settings.WORKFLOW_BUCKET).Object(key=f"{version.git_commit_hash}.json").delete + delete_s3_obj, bucket_name=settings.WORKFLOW_BUCKET, key=f"{version.git_commit_hash}.json" ) # Delete icons of workflow version if version.icon_slug is not None: - background_tasks.add_task(s3.Bucket(name=settings.ICON_BUCKET).Object(key=version.icon_slug).delete) + background_tasks.add_task(delete_s3_obj, bucket_name=settings.ICON_BUCKET, key=version.icon_slug) await CRUDWorkflow.delete(db, workflow.workflow_id) if len(mode_ids) > 0: await CRUDWorkflowMode.delete(db, mode_ids) -@router.post("/{wid}/update", status_code=status.HTTP_201_CREATED, summary="Update a workflow") +@router.post( + "/{wid}/update", status_code=status.HTTP_201_CREATED, summary="Update a workflow", response_model_exclude_none=True +) @start_as_current_span_async("api_workflow_update", tracer=tracer) async def update_workflow( background_tasks: BackgroundTasks, @@ -404,13 +421,12 @@ async def update_workflow( client: HTTPClient, db: DBSession, current_user: CurrentUser, - s3: S3Service, authorization: Authorization, version_update: WorkflowUpdate, ) -> WorkflowVersionSchema: """ Create a new workflow version.\n - Permission "workflow:update" required. + Permission `workflow:update` required. \f Parameters ---------- @@ -422,8 +438,6 @@ async def update_workflow( Async database session to perform query on. Dependency Injection. current_user : clowmdb.models.User Current user. Dependency Injection. - s3 : boto3_type_annotations.s3.ServiceResource - S3 Service to perform operations on buckets in Ceph. Dependency Injection. authorization : Callable[[str], Awaitable[Any]] Async function to ask the auth service for authorization. Dependency Injection. client : httpx.AsyncClient @@ -489,19 +503,19 @@ async def update_workflow( if len(db_modes) > 0: for mode in db_modes: background_tasks.add_task( - repo.copy_file_to_bucket, + download_file_to_bucket, + repo=repo, filepath=mode.schema_path, - obj=s3.Bucket(name=settings.WORKFLOW_BUCKET).Object( - key=f"{version_update.git_commit_hash}-{mode.mode_id.hex}.json" - ), - client=client, + bucket_name=settings.WORKFLOW_BUCKET, + key=f"{version_update.git_commit_hash}-{mode.mode_id.hex}.json", ) else: background_tasks.add_task( - repo.copy_file_to_bucket, + download_file_to_bucket, filepath="nextflow_schema.json", - obj=s3.Bucket(name=settings.WORKFLOW_BUCKET).Object(key=f"{version_update.git_commit_hash}.json"), - client=client, + repo=repo, + bucket_name=settings.WORKFLOW_BUCKET, + key=f"{version_update.git_commit_hash}.json", ) # Create list with mode ids that are connected to the new workflow version @@ -510,7 +524,7 @@ async def update_workflow( db, git_commit_hash=version_update.git_commit_hash, version=version_update.version, - wid=workflow.workflow_id, + workflow_id=workflow.workflow_id, icon_slug=previous_version.icon_slug if previous_version else None, previous_version=previous_version.git_commit_hash if previous_version else None, modes=mode_ids, diff --git a/app/api/endpoints/workflow_credentials.py b/app/api/endpoints/workflow_credentials.py index 206c2f65f79917df9625c1ea4bacbbc01bc53a2d..6cfc4026cd481d79035b0b90826f030dbbc9bbc0 100644 --- a/app/api/endpoints/workflow_credentials.py +++ b/app/api/endpoints/workflow_credentials.py @@ -3,8 +3,9 @@ from typing import Annotated, Any, Awaitable, Callable from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status from opentelemetry import trace +from app.api.background_tasks.background import upload_scm_file from app.api.dependencies import AuthorizationDependency, CurrentUser, CurrentWorkflow, DBSession, HTTPClient, S3Service -from app.api.utils import check_repo, upload_scm_file +from app.api.utils import check_repo from app.core.config import settings from app.crud.crud_workflow import CRUDWorkflow from app.crud.crud_workflow_version import CRUDWorkflowVersion @@ -28,7 +29,7 @@ async def get_workflow_credentials( ) -> WorkflowCredentialsOut: """ Get the credentials for the repository of a workflow. Only the developer of a workflow can do this.\n - Permission "workflow:update" required. + Permission `workflow:update` required. \f Parameters ---------- @@ -61,13 +62,12 @@ async def update_workflow_credentials( db: DBSession, current_user: CurrentUser, authorization: Authorization, - s3: S3Service, background_tasks: BackgroundTasks, client: HTTPClient, ) -> None: """ Update the credentials for the repository of a workflow.\n - Permission "workflow:update" required. + Permission `workflow:update` required. \f Parameters ---------- @@ -81,8 +81,6 @@ async def update_workflow_credentials( Current user. Dependency Injection. authorization : Callable[[str], Awaitable[Any]] Async function to ask the auth service for authorization. Dependency Injection. - s3 : boto3_type_annotations.s3.ServiceResource - S3 Service to perform operations on buckets in Ceph. Dependency Injection. client : httpx.AsyncClient Http client with an open connection. Dependency Injection. background_tasks : fastapi.BackgroundTasks @@ -109,7 +107,6 @@ async def update_workflow_credentials( scm_provider = SCMProvider.from_repo(repo=repo, name=SCMProvider.generate_name(workflow.workflow_id)) background_tasks.add_task( upload_scm_file, - s3=s3, scm=SCM(providers=[scm_provider]), scm_file_id=workflow.workflow_id, ) @@ -128,7 +125,7 @@ async def delete_workflow_credentials( ) -> None: """ Delete the credentials for the repository of a workflow.\n - Permission "workflow:delete" required. + Permission `workflow:delete` required. \f Parameters ---------- @@ -161,7 +158,6 @@ async def delete_workflow_credentials( scm_provider = SCMProvider.from_repo(repo=repo, name=SCMProvider.generate_name(workflow.workflow_id)) background_tasks.add_task( upload_scm_file, - s3=s3, scm=SCM(providers=[scm_provider]), scm_file_id=workflow.workflow_id, ) diff --git a/app/api/endpoints/workflow_execution.py b/app/api/endpoints/workflow_execution.py index 8eedd237e7e7abad939a4bc6e7f8a2d1b8fcdca9..4893cb6ba2abd1eafbb53ae1427fa782205367c1 100644 --- a/app/api/endpoints/workflow_execution.py +++ b/app/api/endpoints/workflow_execution.py @@ -1,12 +1,15 @@ import json from tempfile import SpooledTemporaryFile -from typing import Annotated, Any, Awaitable, Callable, Dict, List, Optional +from typing import Annotated, Any, Awaitable, Callable, Dict, List, Union +from uuid import UUID import jsonschema from clowmdb.models import WorkflowExecution, WorkflowVersion from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, status from opentelemetry import trace +from pydantic.json_schema import SkipJsonSchema +from app.api.background_tasks import cancel_slurm_job, delete_s3_obj, start_workflow_execution, upload_scm_file from app.api.dependencies import ( AuthorizationDependency, CurrentUser, @@ -14,21 +17,13 @@ from app.api.dependencies import ( HTTPClient, S3Service, get_current_workflow_execution, - get_slurm_client, -) -from app.api.utils import ( - check_active_workflow_execution_limit, - check_buckets_access, - check_repo, - start_workflow_execution, - upload_scm_file, ) +from app.api.utils import check_active_workflow_execution_limit, check_buckets_access, check_repo from app.core.config import settings from app.crud import CRUDWorkflowExecution, CRUDWorkflowVersion from app.git_repository import GitHubRepository, build_repository from app.schemas.workflow_execution import DevWorkflowExecutionIn, WorkflowExecutionIn, WorkflowExecutionOut from app.scm import SCM, SCMProvider -from app.slurm.rest_client import SlurmClient from app.utils.otlp import start_as_current_span_async router = APIRouter(prefix="/workflow_executions", tags=["Workflow Execution"]) @@ -40,7 +35,9 @@ CurrentWorkflowExecution = Annotated[WorkflowExecution, Depends(get_current_work tracer = trace.get_tracer_provider().get_tracer(__name__) -@router.post("", status_code=status.HTTP_201_CREATED, summary="Start a new workflow execution") +@router.post( + "", status_code=status.HTTP_201_CREATED, summary="Start a new workflow execution", response_model_exclude_none=True +) @start_as_current_span_async("api_workflow_execution_start", tracer=tracer) async def start_workflow( background_tasks: BackgroundTasks, @@ -49,12 +46,11 @@ async def start_workflow( current_user: CurrentUser, authorization: Authorization, s3: S3Service, - slurm_client: SlurmClient = Depends(get_slurm_client), ) -> WorkflowExecutionOut: """ Start a new workflow execution. Workflow versions wit status `DEPRECATED` or `DENIED` can't be started.\n - Permission "workflow_execution:start" required if workflow versions status is `PUBLISHED`, - otherwise "workflow_execution:start_unpublished" required. + Permission `workflow_execution:start` required if workflow versions status is `PUBLISHED`, + otherwise `workflow_execution:start_unpublished` required. \f Parameters ---------- @@ -70,8 +66,6 @@ async def start_workflow( Async function to ask the auth service for authorization. Dependency Injection. s3 : boto3_type_annotations.s3.ServiceResource S3 Service to perform operations on buckets in Ceph. Dependency Injection. - slurm_client : app.slurm.rest_client.SlurmClient - Slurm Rest Client to communicate with Slurm cluster. Dependency Injection. Returns ------- @@ -94,23 +88,23 @@ async def start_workflow( rbac_operation = "start" if workflow_version.status == WorkflowVersion.Status.PUBLISHED else "start_unpublished" await authorization(rbac_operation) - if len(workflow_version.workflow_modes) > 0 and workflow_execution_in.mode is None: + if len(workflow_version.workflow_modes) > 0 and workflow_execution_in.mode_id is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"A workflow mode needs to be specified for thee workflow version '{workflow_execution_in.workflow_version_id}'", # noqa: E501 + detail=f"A workflow mode needs to be specified for the workflow version '{workflow_execution_in.workflow_version_id}'", # noqa: E501 ) # If a workflow mode is specified, check that the mode is associated with the workflow version workflow_mode = None - if workflow_execution_in.mode is not None: - current_span.set_attribute("workflow_mode_id", str(workflow_execution_in.mode)) + if workflow_execution_in.mode_id is not None: + current_span.set_attribute("workflow_mode_id", str(workflow_execution_in.mode_id)) workflow_mode = next( - (mode for mode in workflow_version.workflow_modes if mode.mode_id == workflow_execution_in.mode), None + (mode for mode in workflow_version.workflow_modes if mode.mode_id == workflow_execution_in.mode_id), None ) if workflow_mode is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Workflow mode '{workflow_execution_in.mode}' does not exist on version '{workflow_execution_in.workflow_version_id}'", # noqa: E501 + detail=f"Workflow mode '{workflow_execution_in.mode_id}' does not exist on version '{workflow_execution_in.workflow_version_id}'", # noqa: E501 ) # Check that the version can be used for execution @@ -125,8 +119,8 @@ async def start_workflow( # Validate schema with saved schema in bucket schema_name = ( f"{workflow_execution_in.workflow_version_id}.json" - if workflow_execution_in.mode is None - else f"{workflow_execution_in.workflow_version_id}-{workflow_execution_in.mode.hex}.json" + if workflow_execution_in.mode_id is None + else f"{workflow_execution_in.workflow_version_id}-{workflow_execution_in.mode_id.hex}.json" ) with SpooledTemporaryFile(max_size=512000) as f: with tracer.start_as_current_span("s3_download_workflow_parameter_schema"): @@ -153,18 +147,15 @@ async def start_workflow( ) # Create execution in database - execution = await CRUDWorkflowExecution.create(db, execution=workflow_execution_in, owner_id=current_user.uid) + execution = await CRUDWorkflowExecution.create(db, execution=workflow_execution_in, executor_id=current_user.uid) # Start workflow execution background_tasks.add_task( start_workflow_execution, - s3=s3, - db=db, execution=execution, parameters=workflow_execution_in.parameters, git_repo=build_repository( url=workflow_version.workflow.repository_url, git_commit_hash=workflow_version.git_commit_hash ), - slurm_client=slurm_client, scm_file_id=workflow_version.workflow_id, workflow_entrypoint=workflow_mode.entrypoint if workflow_mode is not None else None, ) @@ -178,6 +169,7 @@ async def start_workflow( status_code=status.HTTP_201_CREATED, summary="Start a workflow execution with arbitrary git repository", include_in_schema=settings.DEV_SYSTEM, + response_model_exclude_none=True, ) @start_as_current_span_async("api_workflow_execution_start_arbitrary", tracer=tracer) async def start_arbitrary_workflow( @@ -186,13 +178,11 @@ async def start_arbitrary_workflow( db: DBSession, current_user: CurrentUser, client: HTTPClient, - s3: S3Service, - authorization: Callable[[str], Awaitable[Any]] = Depends(AuthorizationDependency(resource="workflow")), - slurm_client: SlurmClient = Depends(get_slurm_client), + authorization: Annotated[Callable[[str], Awaitable[Any]], Depends(AuthorizationDependency(resource="workflow"))], ) -> WorkflowExecutionOut: """ Start a new workflow execution from an arbitrary git repository.\n - Permission "workflow:create" required.\n + Permission `workflow:create` required.\n For private Gitlab repositories, a Project Access Token with the role Reporter and scope `read_api` is needed.\n For private GitHub repositories, a Personal Access Token (classic) with scope `repo` is needed. \f @@ -208,10 +198,6 @@ async def start_arbitrary_workflow( Current user who will be the owner of the newly created bucket. Dependency Injection. authorization : Callable[[str], Awaitable[Any]] Async function to ask the auth service for authorization. Dependency Injection. - s3 : boto3_type_annotations.s3.ServiceResource - S3 Service to perform operations on buckets in Ceph. Dependency Injection. - slurm_client : app.slurm.rest_client.SlurmClient - Slurm Rest Client to communicate with Slurm cluster. Dependency Injection. client : httpx.AsyncClient Http client with an open connection. Dependency Injection. @@ -285,22 +271,19 @@ async def start_arbitrary_workflow( execution = await CRUDWorkflowExecution.create( db, execution=workflow_execution_in, - owner_id=current_user.uid, + executor_id=current_user.uid, notes=execution_note, ) scm_provider = SCMProvider.from_repo(repo=repo, name=SCMProvider.generate_name(execution.execution_id)) if repo.token is not None or not isinstance(repo, GitHubRepository): - background_tasks.add_task(upload_scm_file, s3=s3, scm=SCM([scm_provider]), scm_file_id=execution.execution_id) + background_tasks.add_task(upload_scm_file, scm=SCM([scm_provider]), scm_file_id=execution.execution_id) background_tasks.add_task( start_workflow_execution, - s3=s3, - db=db, execution=execution, parameters=workflow_execution_in.parameters, git_repo=repo, - slurm_client=slurm_client, scm_file_id=execution.execution_id, workflow_entrypoint=workflow_execution_in.mode.entrypoint if workflow_execution_in.mode is not None else None, ) @@ -308,35 +291,40 @@ async def start_arbitrary_workflow( return WorkflowExecutionOut.from_db_model(execution) -@router.get("", status_code=status.HTTP_200_OK, summary="Get all workflow executions") +@router.get("", status_code=status.HTTP_200_OK, summary="Get all workflow executions", response_model_exclude_none=True) @start_as_current_span_async("api_workflow_execution_list", tracer=tracer) async def list_workflow_executions( db: DBSession, current_user: CurrentUser, authorization: Authorization, - user_id: Optional[str] = Query( - None, - description="Filter for workflow executions by a user. If none, Permission 'workflow_execution:read_any' required.", # noqa: E501 - examples=["28c5353b8bb34984a8bd4169ba94c606"], - ), - execution_status: Optional[List[WorkflowExecution.WorkflowExecutionStatus]] = Query( - None, description="Filter for status of workflow execution" - ), - workflow_version_id: Optional[str] = Query( - None, - description="Filter for workflow version", - examples=["ba8bcd9294c2c96aedefa1763a84a18077c50c0f"], - pattern=r"^[0-9a-f]{40}$", - ), + executor_id: Annotated[ + Union[UUID, SkipJsonSchema[None]], + Query( + description="Filter for workflow executions by a user. If none, Permission `workflow_execution:read_any` required.", # noqa: E501 + examples=["1d3387f3-95c0-4813-8767-2cad87faeebf"], + ), + ] = None, + execution_status: Annotated[ + Union[List[WorkflowExecution.WorkflowExecutionStatus], SkipJsonSchema[None]], + Query(description="Filter for status of workflow execution"), + ] = None, + workflow_version_id: Annotated[ + Union[str, SkipJsonSchema[None]], + Query( + description="Filter for workflow version", + examples=["ba8bcd9294c2c96aedefa1763a84a18077c50c0f"], + pattern=r"^[0-9a-f]{40}$", + ), + ] = None, ) -> List[WorkflowExecutionOut]: """ Get all workflow executions.\n - Permission "workflow_execution:list" required, if 'user_id' is the same as the current user, - otherwise "workflow_execution:list_all" required. + Permission `workflow_execution:list` required, if 'user_id' is the same as the current user, + otherwise `workflow_execution:list_all` required. \f Parameters ---------- - user_id : str | None, default None + executor_id : str | None, default None Filter for workflow executions by a user. Query Parameter. execution_status : List[clowmdb.models.WorkflowExecution.WorkflowExecutionStatus] | None, default None Filter for status of workflow execution. Query Parameter. @@ -355,17 +343,17 @@ async def list_workflow_executions( List of filtered workflow executions. """ current_span = trace.get_current_span() - if user_id is not None: # pragma: no cover - current_span.set_attribute("user_id", user_id) + if executor_id is not None: # pragma: no cover + current_span.set_attribute("user_id", str(executor_id)) if execution_status is not None and len(execution_status) > 0: # pragma: no cover current_span.set_attribute("execution_status", [stat.name for stat in execution_status]) if workflow_version_id is not None: # pragma: no cover current_span.set_attribute("git_commit_hash", workflow_version_id) - rbac_operation = "list" if user_id is not None and user_id == current_user.uid else "list_all" + rbac_operation = "list" if executor_id is not None and executor_id == current_user.uid else "list_all" await authorization(rbac_operation) executions = await CRUDWorkflowExecution.list( - db, uid=user_id, workflow_version_id=workflow_version_id, status_list=execution_status + db, executor_id=executor_id, workflow_version_id=workflow_version_id, status_list=execution_status ) return [ WorkflowExecutionOut.from_db_model( @@ -375,7 +363,9 @@ async def list_workflow_executions( ] -@router.get("/{eid}", status_code=status.HTTP_200_OK, summary="Get a workflow execution") +@router.get( + "/{eid}", status_code=status.HTTP_200_OK, summary="Get a workflow execution", response_model_exclude_none=True +) @start_as_current_span_async("api_workflow_execution_get", tracer=tracer) async def get_workflow_execution( workflow_execution: CurrentWorkflowExecution, @@ -384,8 +374,8 @@ async def get_workflow_execution( ) -> WorkflowExecutionOut: """ Get a specific workflow execution.\n - Permission "workflow_execution:read" required if the current user started the workflow execution, - otherwise "workflow_execution:read_any" required. + Permission `workflow_execution:read` required if the current user started the workflow execution, + otherwise `workflow_execution:read_any` required. \f Parameters ---------- @@ -402,7 +392,7 @@ async def get_workflow_execution( Workflow execution with the given ID. """ trace.get_current_span().set_attribute("execution_id", str(workflow_execution.execution_id)) - rbac_operation = "read" if workflow_execution.user_id == current_user.uid else "read_any" + rbac_operation = "read" if workflow_execution.executor_id == current_user.uid else "read_any" await authorization(rbac_operation) return WorkflowExecutionOut.from_db_model( workflow_execution, @@ -420,8 +410,8 @@ async def get_workflow_execution_params( ) -> Dict[str, Any]: """ Get the parameters of a specific workflow execution.\n - Permission "workflow_execution:read" required if the current user started the workflow execution, - otherwise "workflow_execution:read_any" required. + Permission `workflow_execution:read` required if the current user started the workflow execution, + otherwise `workflow_execution:read_any` required. \f Parameters ---------- @@ -440,7 +430,7 @@ async def get_workflow_execution_params( Workflow execution with the given id. """ trace.get_current_span().set_attribute("execution_id", str(workflow_execution.execution_id)) - rbac_operation = "read" if workflow_execution.user_id == current_user.uid else "read_any" + rbac_operation = "read" if workflow_execution.executor_id == current_user.uid else "read_any" await authorization(rbac_operation) params_file_name = f"params-{workflow_execution.execution_id.hex}.json" with SpooledTemporaryFile(max_size=512000) as f: @@ -456,13 +446,12 @@ async def delete_workflow_execution( db: DBSession, current_user: CurrentUser, authorization: Authorization, - s3: S3Service, workflow_execution: CurrentWorkflowExecution, ) -> None: """ Delete a specific workflow execution.\n - Permission "workflow_execution:delete" required if the current user started the workflow execution, - otherwise "workflow_execution:delete_any" required. + Permission `workflow_execution:delete` required if the current user started the workflow execution, + otherwise `workflow_execution:delete_any` required. \f Parameters ---------- @@ -476,11 +465,9 @@ async def delete_workflow_execution( Current user who will be the owner of the newly created bucket. Dependency Injection. authorization : Callable[[str], Awaitable[Any]] Async function to ask the auth service for authorization. Dependency Injection. - s3 : boto3_type_annotations.s3.ServiceResource - S3 Service to perform operations on buckets in Ceph. Dependency Injection. """ trace.get_current_span().set_attribute("execution_id", str(workflow_execution.execution_id)) - rbac_operation = "delete" if workflow_execution.user_id == current_user.uid else "delete_any" + rbac_operation = "delete" if workflow_execution.executor_id == current_user.uid else "delete_any" await authorization(rbac_operation) if workflow_execution.status in [ WorkflowExecution.WorkflowExecutionStatus.PENDING, @@ -491,7 +478,7 @@ async def delete_workflow_execution( status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot delete workflow execution that is not finished." ) background_tasks.add_task( - s3.Bucket(name=settings.PARAMS_BUCKET).Object(key=f"params-{workflow_execution.execution_id.hex}.json").delete + delete_s3_obj, bucket_name=settings.PARAMS_BUCKET, key=f"params-{workflow_execution.execution_id.hex}.json" ) await CRUDWorkflowExecution.delete(db, workflow_execution.execution_id) @@ -504,12 +491,11 @@ async def cancel_workflow_execution( current_user: CurrentUser, authorization: Authorization, workflow_execution: CurrentWorkflowExecution, - slurm_client: SlurmClient = Depends(get_slurm_client), ) -> None: """ Cancel a running workflow execution.\n - Permission "workflow_execution:cancel" required if the current user started the workflow execution, - otherwise "workflow_execution:cancel_any" required. + Permission `workflow_execution:cancel` required if the current user started the workflow execution, + otherwise `workflow_execution:cancel_any` required. \f Parameters ---------- @@ -523,11 +509,9 @@ async def cancel_workflow_execution( Current user who will be the owner of the newly created bucket. Dependency Injection. authorization : Callable[[str], Awaitable[Any]] Async function to ask the auth service for authorization. Dependency Injection. - slurm_client : app.slurm.rest_client.SlurmClient - Slurm Rest Client to communicate with Slurm cluster. Dependency Injection. """ trace.get_current_span().set_attribute("execution_id", str(workflow_execution.execution_id)) - rbac_operation = "cancel" if workflow_execution.user_id == current_user.uid else "cancel_any" + rbac_operation = "cancel" if workflow_execution.executor_id == current_user.uid else "cancel_any" await authorization(rbac_operation) if workflow_execution.status not in [ WorkflowExecution.WorkflowExecutionStatus.PENDING, @@ -538,5 +522,5 @@ async def cancel_workflow_execution( status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot cancel workflow execution that is finished." ) if workflow_execution.slurm_job_id >= 0: - background_tasks.add_task(slurm_client.cancel_job, job_id=workflow_execution.slurm_job_id) - await CRUDWorkflowExecution.cancel(db, workflow_execution.execution_id) + background_tasks.add_task(cancel_slurm_job, job_id=workflow_execution.slurm_job_id) + await CRUDWorkflowExecution.set_error(db, workflow_execution.execution_id) diff --git a/app/api/endpoints/workflow_mode.py b/app/api/endpoints/workflow_mode.py index 8f9aff8a678cec02277e4abce6505f29617afc12..7911554530c14bf7a1d1b136fe3bfa70ecedec0f 100644 --- a/app/api/endpoints/workflow_mode.py +++ b/app/api/endpoints/workflow_mode.py @@ -22,14 +22,16 @@ tracer = trace.get_tracer_provider().get_tracer(__name__) async def get_workflow_mode( db: DBSession, authorization: Authorization, - mode_id: UUID = Path( - ..., - description="ID of a workflow mode", - ), + mode_id: Annotated[ + UUID, + Path( + description="ID of a workflow mode", + ), + ], ) -> WorkflowModeOut: """ Get a workflow mode\n - Permission 'workflow:read' required + Permission `workflow:read` required \f Parameters ---------- diff --git a/app/api/endpoints/workflow_version.py b/app/api/endpoints/workflow_version.py index ff6bd4ef8f11dbd7adb9933c5e17d346180b8697..fad9b2df5b6a69f6ef0cc7cac9cea0f0ab941590 100644 --- a/app/api/endpoints/workflow_version.py +++ b/app/api/endpoints/workflow_version.py @@ -1,12 +1,14 @@ from enum import Enum, unique -from typing import Annotated, Any, Awaitable, Callable, List, Optional +from typing import Annotated, Any, Awaitable, Callable, List, Union from uuid import UUID from clowmdb.models import WorkflowVersion from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Path, Query, UploadFile, status from fastapi.responses import StreamingResponse from opentelemetry import trace +from pydantic.json_schema import SkipJsonSchema +from app.api.background_tasks import delete_remote_icon from app.api.dependencies import ( AuthorizationDependency, CurrentUser, @@ -14,9 +16,8 @@ from app.api.dependencies import ( CurrentWorkflowVersion, DBSession, HTTPClient, - S3Service, ) -from app.api.utils import delete_remote_icon, upload_icon +from app.api.utils import upload_icon from app.core.config import settings from app.crud import CRUDWorkflowVersion from app.git_repository import build_repository @@ -53,21 +54,25 @@ class DocumentationEnum(Enum): return "nextflow_schema.json" -@router.get("", status_code=status.HTTP_200_OK, summary="Get all versions of a workflow") +@router.get( + "", status_code=status.HTTP_200_OK, summary="Get all versions of a workflow", response_model_exclude_none=True +) @start_as_current_span_async("api_workflow_version_list", tracer=tracer) async def list_workflow_version( current_user: CurrentUser, workflow: CurrentWorkflow, db: DBSession, authorization: Authorization, - version_status: Optional[List[WorkflowVersion.Status]] = Query( - None, - description=f"Which versions of the workflow to include in the response. Permission 'workflow:list_filter' required if you are not the developer of this workflow. Default {WorkflowVersion.Status.PUBLISHED.name} and {WorkflowVersion.Status.DEPRECATED.name}", # noqa: E501 - ), + version_status: Annotated[ + Union[List[WorkflowVersion.Status], SkipJsonSchema[None]], + Query( + description=f"Which versions of the workflow to include in the response. Permission `workflow:list_filter` required if you are not the developer of this workflow. Default `{WorkflowVersion.Status.PUBLISHED.name}` and `{WorkflowVersion.Status.DEPRECATED.name}`", # noqa: E501 + ), + ] = None, ) -> List[WorkflowVersionSchema]: """ List all versions of a Workflow.\n - Permission "workflow:list" required. + Permission `workflow:list` required. \f Parameters ---------- @@ -110,6 +115,7 @@ async def list_workflow_version( "/{git_commit_hash}", status_code=status.HTTP_200_OK, summary="Get a workflow version", + response_model_exclude_none=True, ) @start_as_current_span_async("api_workflow_version_get", tracer=tracer) async def get_workflow_version( @@ -117,17 +123,19 @@ async def get_workflow_version( db: DBSession, current_user: CurrentUser, authorization: Authorization, - git_commit_hash: str = Path( - ..., - description="Git commit git_commit_hash of specific version or 'latest'.", - pattern=r"^([0-9a-f]{40}|latest)$", - examples=["latest", "ba8bcd9294c2c96aedefa1763a84a18077c50c0f"], - ), + git_commit_hash: Annotated[ + str, + Path( + description="Git commit `git_commit_hash` of specific version or `latest`.", + pattern=r"^([0-9a-f]{40}|latest)$", + examples=["latest", "ba8bcd9294c2c96aedefa1763a84a18077c50c0f"], + ), + ], ) -> WorkflowVersionSchema: """ Get a specific version of a workflow.\n - Permission "workflow:read" required if the version is public or you are the developer of the workflow, - otherwise "workflow:read_any" + Permission `workflow:read` required if the version is public or you are the developer of the workflow, + otherwise `workflow:read_any` \f Parameters ---------- @@ -172,7 +180,12 @@ async def get_workflow_version( return WorkflowVersionSchema.from_db_version(version, load_modes=True) -@router.patch("/{git_commit_hash}/status", status_code=status.HTTP_200_OK, summary="Update status of workflow version") +@router.patch( + "/{git_commit_hash}/status", + status_code=status.HTTP_200_OK, + summary="Update status of workflow version", + response_model_exclude_none=True, +) @start_as_current_span_async("api_workflow_version_status_update", tracer=tracer) async def update_workflow_version_status( version_status: WorkflowVersionStatus, @@ -182,7 +195,7 @@ async def update_workflow_version_status( ) -> WorkflowVersionSchema: """ Update the status of a workflow version.\n - Permission "workflow:update_status" + Permission `workflow:update_status` \f Parameters ---------- @@ -203,7 +216,7 @@ async def update_workflow_version_status( trace.get_current_span().set_attributes( { "workflow_id": str(workflow_version.workflow_id), - "git_commit_hash": workflow_version.git_commit_hash, + "workflow_version_id": workflow_version.git_commit_hash, "version_status": version_status.status.name, } ) @@ -213,7 +226,12 @@ async def update_workflow_version_status( return WorkflowVersionSchema.from_db_version(workflow_version, load_modes=True) -@router.post("/{git_commit_hash}/deprecate", status_code=status.HTTP_200_OK, summary="Deprecate a workflow version") +@router.post( + "/{git_commit_hash}/deprecate", + status_code=status.HTTP_200_OK, + summary="Deprecate a workflow version", + response_model_exclude_none=True, +) @start_as_current_span_async("api_workflow_version_status_update", tracer=tracer) async def deprecate_workflow_version( workflow: CurrentWorkflow, @@ -224,8 +242,8 @@ async def deprecate_workflow_version( ) -> WorkflowVersionSchema: """ Deprecate a workflow version.\n - Permission "workflow:update" required if you are the developer of the workflow, - otherwise "workflow:read_status" + Permission `workflow:update` required if you are the developer of the workflow, + otherwise `workflow:read_status` \f Parameters ---------- @@ -246,7 +264,7 @@ async def deprecate_workflow_version( Version of the workflow with deprecated status """ trace.get_current_span().set_attributes( - {"workflow_id": str(workflow_version.workflow_id), "git_commit_hash": workflow_version.git_commit_hash} + {"workflow_id": str(workflow_version.workflow_id), "workflow_version_id": workflow_version.git_commit_hash} ) await authorization("update_status" if current_user.uid != workflow.developer_id else "update") await CRUDWorkflowVersion.update_status(db, workflow_version.git_commit_hash, WorkflowVersion.Status.DEPRECATED) @@ -266,15 +284,15 @@ async def download_workflow_documentation( workflow_version: CurrentWorkflowVersion, authorization: Authorization, client: HTTPClient, - document: DocumentationEnum = Query( - DocumentationEnum.USAGE, description="Specify which type of documentation the client wants to fetch" - ), - mode_id: Optional[UUID] = Query(default=None, description="Workflow Mode"), + document: Annotated[ + DocumentationEnum, Query(description="Specify which type of documentation the client wants to fetch") + ] = DocumentationEnum.USAGE, + mode_id: Annotated[Union[UUID, SkipJsonSchema[None]], Query(description="Workflow Mode")] = None, ) -> StreamingResponse: """ Get the documentation for a specific workflow version. Streams the response directly from the right git repository.\n - Permission "workflow:read" required. + Permission `workflow:read` required. \f Parameters ---------- @@ -300,7 +318,7 @@ async def download_workflow_documentation( current_span.set_attributes( { "workflow_id": str(workflow_version.workflow_id), - "git_commit_hash": workflow_version.git_commit_hash, + "workflow_version_id": workflow_version.git_commit_hash, "document": document.name, } ) @@ -339,15 +357,14 @@ async def upload_workflow_version_icon( workflow: CurrentWorkflow, background_tasks: BackgroundTasks, workflow_version: CurrentWorkflowVersion, - s3: S3Service, authorization: Authorization, current_user: CurrentUser, db: DBSession, - icon: UploadFile = File(..., description="Optional Icon for the Workflow."), + icon: Annotated[UploadFile, File(description="Icon for the Workflow.")], ) -> IconUpdateOut: """ Upload an icon for the workflow version and returns the new icon URL.\n - Permission "workflow:update" required. + Permission `workflow:update` required. \f Parameters ---------- @@ -359,8 +376,6 @@ async def upload_workflow_version_icon( Workflow version with given ID. Dependency Injection. authorization : Callable[[str], Awaitable[Any]] Async function to ask the auth service for authorization. Dependency Injection. - s3 : boto3_type_annotations.s3.ServiceResource - S3 Service to perform operations on buckets in Ceph. Dependency Injection. current_user : clowmdb.models.User Current user. Dependency Injection. db : sqlalchemy.ext.asyncio.AsyncSession. @@ -375,18 +390,22 @@ async def upload_workflow_version_icon( """ current_span = trace.get_current_span() current_span.set_attributes( - {"workflow_id": str(workflow_version.workflow_id), "git_commit_hash": workflow_version.git_commit_hash} + {"workflow_id": str(workflow_version.workflow_id), "workflow_version_id": workflow_version.git_commit_hash} ) + if icon.content_type is not None: # pragma: no cover + current_span.set_attribute("content_type", icon.content_type) + if icon.filename is not None: # pragma: no cover + current_span.set_attribute("filename", icon.filename) await authorization("update") if current_user.uid != workflow.developer_id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the developer can update his workflow") old_slug = workflow_version.icon_slug - icon_slug = upload_icon(s3=s3, background_tasks=background_tasks, icon=icon) + icon_slug = await upload_icon(background_tasks=background_tasks, icon=icon) current_span.set_attribute("icon_slug", icon_slug) await CRUDWorkflowVersion.update_icon(db, workflow_version.git_commit_hash, icon_slug) # Delete old icon if possible if old_slug is not None: - background_tasks.add_task(delete_remote_icon, s3=s3, db=db, icon_slug=old_slug) + background_tasks.add_task(delete_remote_icon, icon_slug=old_slug) return IconUpdateOut(icon_url=str(settings.OBJECT_GATEWAY_URI) + "/".join([settings.ICON_BUCKET, icon_slug])) @@ -402,12 +421,11 @@ async def delete_workflow_version_icon( background_tasks: BackgroundTasks, authorization: Authorization, current_user: CurrentUser, - s3: S3Service, db: DBSession, ) -> None: """ Delete the icon of the workflow version.\n - Permission "workflow:update" required. + Permission `workflow:update` required. \f Parameters ---------- @@ -421,18 +439,16 @@ async def delete_workflow_version_icon( Async function to ask the auth service for authorization. Dependency Injection. current_user : clowmdb.models.User Current user. Dependency Injection. - s3 : boto3_type_annotations.s3.ServiceResource - S3 Service to perform operations on buckets in Ceph. Dependency Injection. db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. Dependency Injection. """ current_span = trace.get_current_span() current_span.set_attributes( - {"workflow_id": str(workflow_version.workflow_id), "git_commit_hash": workflow_version.git_commit_hash} + {"workflow_id": str(workflow_version.workflow_id), "workflow_version_id": workflow_version.git_commit_hash} ) await authorization("update") if current_user.uid != workflow.developer_id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the developer can update his workflow") if workflow_version.icon_slug is not None: - background_tasks.add_task(delete_remote_icon, s3=s3, db=db, icon_slug=workflow_version.icon_slug) + background_tasks.add_task(delete_remote_icon, icon_slug=workflow_version.icon_slug) await CRUDWorkflowVersion.update_icon(db, workflow_version.git_commit_hash, icon_slug=None) diff --git a/app/api/utils.py b/app/api/utils.py index 48c56f3f890092600e885680d1e55f715c10235d..5d5642729901275c79574a417535f70e74674c0b 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -1,39 +1,24 @@ import asyncio import json import re -import shlex -from asyncio import sleep as async_sleep -from io import BytesIO -from os import environ from pathlib import Path -from tempfile import SpooledTemporaryFile -from typing import TYPE_CHECKING, Any, BinaryIO, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union from uuid import UUID, uuid4 -import botocore.client -import dotenv +from anyio import open_file from clowmdb.models import WorkflowExecution, WorkflowMode from fastapi import BackgroundTasks, HTTPException, UploadFile, status -from httpx import AsyncClient, HTTPError -from mako.template import Template +from httpx import AsyncClient from opentelemetry import trace from PIL import Image, UnidentifiedImageError from sqlalchemy.ext.asyncio import AsyncSession -from app.core.config import MonitorJobBackoffStrategy, settings -from app.crud import CRUDBucket, CRUDWorkflowExecution, CRUDWorkflowVersion +from app.api.background_tasks import process_and_upload_icon +from app.core.config import settings +from app.crud import CRUDBucket, CRUDWorkflowExecution from app.git_repository.abstract_repository import GitRepository from app.schemas.workflow_mode import WorkflowModeIn -from app.scm import SCM -from app.slurm import SlurmClient, SlurmJobSubmission -from app.utils.backoff_strategy import BackoffStrategy, ExponentialBackoff, LinearBackoff, NoBackoff -if TYPE_CHECKING: - from mypy_boto3_s3.service_resource import S3ServiceResource -else: - S3ServiceResource = object - -nextflow_command_template = Template(filename="mako_templates/nextflow_command.tmpl") # regex to find S3 files in parameters of workflow execution s3_file_regex = re.compile( r"s3://(?!(((2(5[0-5]|[0-4]\d)|[01]?\d{1,2})\.){3}(2(5[0-5]|[0-4]\d)|[01]?\d{1,2})$))[a-z\d][a-z\d.-]{1,61}[a-z\d][^\"]*" @@ -41,19 +26,13 @@ s3_file_regex = re.compile( tracer = trace.get_tracer_provider().get_tracer(__name__) -dotenv.load_dotenv() -execution_env: Dict[str, Union[str, int, bool]] = {key: val for key, val in environ.items() if key.startswith("NXF_")} -execution_env["SCM_DIR"] = str(Path(settings.SLURM_WORKING_DIRECTORY) / "scm") - -def upload_icon(s3: S3ServiceResource, background_tasks: BackgroundTasks, icon: UploadFile) -> str: +async def upload_icon(background_tasks: BackgroundTasks, icon: UploadFile) -> str: """ Upload an icon to the icon bucket. Parameters ---------- - s3 : boto3_type_annotations.s3.ServiceResource - S3 Service to perform operations on buckets in Ceph. background_tasks : fastapi.BackgroundTasks Entrypoint for new BackgroundTasks. icon : fastapi.UploadFile @@ -66,60 +45,23 @@ def upload_icon(s3: S3ServiceResource, background_tasks: BackgroundTasks, icon: """ try: Image.open(icon.file) + icon.file.seek(0) except UnidentifiedImageError: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="icon needs to be an image") icon_slug = f"{uuid4().hex}.png" - # Write the icon to the icon bucket in a background task - background_tasks.add_task(_process_and_upload_icon, s3=s3, icon_slug=icon_slug, icon_buffer=icon.file) + # Save the icon to a file to access it in a background task + icon_file = Path(f"/tmp/{icon_slug}") + async with await open_file(icon_file, "wb") as f: + size = 32000 + buffer = await icon.read(size) + while len(buffer) > 0: + await f.write(buffer) + buffer = await icon.read(size) + + background_tasks.add_task(process_and_upload_icon, icon_slug=icon_slug, icon_buffer_file=icon_file) return icon_slug -def _process_and_upload_icon(s3: S3ServiceResource, icon_slug: str, icon_buffer: BinaryIO) -> None: - """ - Process the icon and upload it to the S3 Icon Bucket - - Parameters - ---------- - s3 : boto3_type_annotations.s3.ServiceResource - S3 Service to perform operations on buckets in Ceph. - icon_slug : str - Slug of the icon - icon_buffer : typing.BinaryIO - Binary stream containing the image - """ - im = Image.open(icon_buffer) - im.thumbnail((64, 64)) # Crop to 64x64 image - thumbnail_buffer = BytesIO() - im.save(thumbnail_buffer, "PNG") # save in buffer as PNG image - thumbnail_buffer.seek(0) - with tracer.start_as_current_span("s3_upload_workflow_version_icon") as span: - span.set_attribute("icon", icon_slug) - # Upload to bucket - s3.Bucket(name=settings.ICON_BUCKET).Object(key=icon_slug).upload_fileobj( - Fileobj=thumbnail_buffer, ExtraArgs={"ContentType": "image/png"} - ) - - -async def delete_remote_icon(s3: S3ServiceResource, db: AsyncSession, icon_slug: str) -> None: - """ - Delete icon in S3 Bucket if there are no other workflow versions that depend on it - - Parameters - ---------- - s3 : boto3_type_annotations.s3.ServiceResource - S3 Service to perform operations on buckets in Ceph. - db : sqlalchemy.ext.asyncio.AsyncSession. - Async database session to perform query on. - icon_slug : str - Name of the icon file. - """ - # If there are no more Workflow versions that have this icon, delete it in the S3 ICON_BUCKET - if not await CRUDWorkflowVersion.icon_exists(db, icon_slug): - with tracer.start_as_current_span("s3_delete_workflow_version_icon") as span: - span.set_attribute("icon", icon_slug) - s3.Bucket(name=settings.ICON_BUCKET).Object(key=icon_slug).delete() - - async def check_repo( repo: GitRepository, client: AsyncClient, modes: Optional[Sequence[Union[WorkflowModeIn, WorkflowMode]]] = None ) -> None: @@ -150,162 +92,7 @@ async def check_repo( ) -async def start_workflow_execution( - s3: S3ServiceResource, - db: AsyncSession, - execution: WorkflowExecution, - parameters: Dict[str, Any], - git_repo: GitRepository, - slurm_client: SlurmClient, - scm_file_id: Optional[UUID] = None, - workflow_entrypoint: Optional[str] = None, -) -> None: - """ - Start a workflow on the Slurm cluster. - - Parameters - ---------- - s3 : boto3_type_annotations.s3.ServiceResource - S3 Service to perform operations on buckets in Ceph. - db : sqlalchemy.ext.asyncio.AsyncSession. - Async database session to perform query on. - execution : clowmdb.models.WorkflowExecution - Workflow execution to execute. - parameters : Dict[str, Any] - Parameters for the workflow. - git_repo : app.git_repository.abstract_repository.GitRepository - Git repository of the workflow version. - slurm_client : app.slurm.rest_client.SlurmClient - Slurm Rest Client to communicate with Slurm cluster. - scm_file_id : UUID | None - ID of the SCM file for private git repositories. - workflow_entrypoint : str | None - Entrypoint for the workflow by specifying the `-entry` parameter - """ - # Upload parameters to - params_file_name = f"params-{execution.execution_id.hex}.json" - with SpooledTemporaryFile(max_size=512000) as f: - f.write(json.dumps(parameters).encode("utf-8")) - f.seek(0) - with tracer.start_as_current_span("s3_upload_workflow_execution_parameters") as span: - span.set_attribute("workflow_execution_id", str(execution.execution_id)) - s3.Bucket(name=settings.PARAMS_BUCKET).Object(key=params_file_name).upload_fileobj(f) - for key in parameters.keys(): - if isinstance(parameters[key], str): - # Escape string parameters for bash shell - parameters[key] = shlex.quote(parameters[key]).replace("$", "\$") - - # Check if the there is an SCM file for the workflow - if scm_file_id is not None: - try: - with tracer.start_as_current_span("s3_check_workflow_scm_file"): - s3.Bucket(settings.PARAMS_BUCKET).Object(SCM.generate_filename(scm_file_id)).load() - except botocore.client.ClientError: - scm_file_id = None - - nextflow_script = nextflow_command_template.render( - repo=git_repo, - parameters=parameters, - execution_id=execution.execution_id, - settings=settings, - scm_file_id=scm_file_id, - debug_s3_path=execution.debug_path, - logs_s3_path=execution.logs_path, - provenance_s3_path=execution.provenance_path, - workflow_entrypoint=workflow_entrypoint, - ) - - # Setup env for the workflow execution - work_directory = str(Path(settings.SLURM_WORKING_DIRECTORY) / f"run-{execution.execution_id.hex}") - env = execution_env.copy() - env["TOWER_WORKSPACE_ID"] = execution.execution_id.hex[:16] - env["NXF_WORK"] = str(Path(work_directory) / "work") - env["NXF_ANSI_LOG"] = False - env["NXF_ASSETS"] = str( - Path(env.get("NXF_ASSETS", "$HOME/.nextflow/assets")) / f"{git_repo.name}_{git_repo.commit}" # type: ignore[arg-type] - ) - - try: - job_submission = SlurmJobSubmission( - script=nextflow_script.strip(), - job={ - "current_working_directory": settings.SLURM_WORKING_DIRECTORY, - "environment": env, - "name": execution.execution_id.hex, - "requeue": False, - "standard_output": str( - Path(settings.SLURM_WORKING_DIRECTORY) / f"slurm-{execution.execution_id.hex}.out" - ), - }, - ) - # Try to start the job on the slurm cluster - slurm_job_id = await slurm_client.submit_job(job_submission=job_submission) - await CRUDWorkflowExecution.update_slurm_job_id( - db, slurm_job_id=slurm_job_id, execution_id=execution.execution_id - ) - if not settings.SLURM_JOB_MONITORING == MonitorJobBackoffStrategy.NOMONITORING: # pragma: no cover - await _monitor_proper_job_execution( - db=db, slurm_client=slurm_client, execution_id=execution.execution_id, slurm_job_id=slurm_job_id - ) - except (HTTPError, KeyError): - # Mark job as aborted when there is an error - await CRUDWorkflowExecution.cancel( - db, execution_id=execution.execution_id, status=WorkflowExecution.WorkflowExecutionStatus.ERROR - ) - - -async def _monitor_proper_job_execution( - db: AsyncSession, slurm_client: SlurmClient, execution_id: UUID, slurm_job_id: int -) -> None: # pragma: no cover - """ - Check in an interval based on a backoff strategy if the slurm job is still running - the workflow execution in the database is not marked as finished. - - Parameters - ---------- - db : sqlalchemy.ext.asyncio.AsyncSession. - Async database session to perform query on. - slurm_client : app.slurm.rest_client.SlurmClient - Slurm Rest Client to communicate with Slurm cluster. - execution_id : uuid.UUID - ID of the workflow execution - slurm_job_id : int - ID of the slurm job to monitor - """ - previous_span_link = None - if settings.SLURM_JOB_MONITORING == MonitorJobBackoffStrategy.EXPONENTIAL: - # exponential to 50 minutes - sleep_generator: BackoffStrategy = ExponentialBackoff(initial_delay=30, max_value=300) # type: ignore[no-redef] - elif settings.SLURM_JOB_MONITORING == MonitorJobBackoffStrategy.LINEAR: - # 5 seconds increase to 5 minutes - sleep_generator: BackoffStrategy = LinearBackoff( # type: ignore[no-redef] - initial_delay=30, backoff=5, max_value=300 - ) - elif settings.SLURM_JOB_MONITORING == MonitorJobBackoffStrategy.CONSTANT: - # constant 30 seconds polling - sleep_generator: BackoffStrategy = NoBackoff(initial_delay=30, constant_value=30) # type: ignore[no-redef] - else: - return - for sleep_time in sleep_generator: - await async_sleep(sleep_time) - with tracer.start_span("monitor_job", links=previous_span_link) as span: - span.set_attributes({"execution_id": str(execution_id), "slurm_job_id": slurm_job_id}) - if await slurm_client.is_job_finished(slurm_job_id): - await db.close() # Reset connection - execution = await CRUDWorkflowExecution.get(db, execution_id=execution_id) - # Check if the execution is marked as finished in the database - if execution is not None: - span.set_attribute("workflow_execution_status", str(execution.status)) - if execution.end_time is None: - # Mark job as finished with an error - await CRUDWorkflowExecution.cancel( - db, execution_id=execution_id, status=WorkflowExecution.WorkflowExecutionStatus.ERROR - ) - sleep_generator.close() - previous_span_link = [trace.Link(span.get_span_context())] - - -async def check_active_workflow_execution_limit(db: AsyncSession, uid: str) -> None: +async def check_active_workflow_execution_limit(db: AsyncSession, uid: UUID) -> None: """ Check the number of active workflow executions of a usr and raise an HTTP exception if a new one would violate the limit of active workflow executions. @@ -318,9 +105,7 @@ async def check_active_workflow_execution_limit(db: AsyncSession, uid: str) -> N ID of a user. """ active_executions = await CRUDWorkflowExecution.list( - db, - uid=uid, - status_list=WorkflowExecution.WorkflowExecutionStatus.active_workflows(), + db, executor_id=uid, status_list=WorkflowExecution.WorkflowExecutionStatus.active_workflows() ) if settings != -1 and len(active_executions) + 1 > settings.ACTIVE_WORKFLOW_EXECUTION_LIMIT: raise HTTPException( @@ -330,7 +115,7 @@ async def check_active_workflow_execution_limit(db: AsyncSession, uid: str) -> N async def check_buckets_access( - db: AsyncSession, parameters: Dict[str, Any], uid: str, other_buckets: Optional[List[Optional[str]]] = None + db: AsyncSession, parameters: Dict[str, Any], uid: UUID, other_buckets: Optional[List[Optional[str]]] = None ) -> None: """ Check if the user has access to the buckets referenced in the workflow execution parameters. @@ -369,7 +154,7 @@ async def check_buckets_access( ) -async def _check_bucket_access(db: AsyncSession, uid: str, bucket_path: str) -> Optional[str]: +async def _check_bucket_access(db: AsyncSession, uid: UUID, bucket_path: str) -> Optional[str]: """ Check if the bucket exists and the user has READWRITE access to it. @@ -397,23 +182,3 @@ async def _check_bucket_access(db: AsyncSession, uid: str, bucket_path: str) -> error += f" and file/directory {file}" return error return None - - -def upload_scm_file(s3: S3ServiceResource, scm: SCM, scm_file_id: UUID) -> None: - """ - Upload the SCM file of a private workflow into the PARAMS_BUCKET with the workflow id as key - - Parameters - ---------- - s3 : boto3_type_annotations.s3.ServiceResource - S3 Service to perform operations on buckets in Ceph. - scm : app.scm.SCM - Python object representing the SCM file - scm_file_id : str - ID of the scm file for the name of the file in the bucket - """ - with BytesIO() as handle: - scm.serialize(handle) - handle.seek(0) - with tracer.start_as_current_span("s3_upload_workflow_credentials"): - s3.Bucket(settings.PARAMS_BUCKET).Object(SCM.generate_filename(scm_file_id)).upload_fileobj(handle) diff --git a/app/core/security.py b/app/core/security.py index b4240707c528a234df9ccffb05a3f7ba2948fc9c..70f7eb030aabad8c9b38ec5ae3efa4808a952ac6 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -70,7 +70,7 @@ async def request_authorization(request_params: AuthzRequest, client: AsyncClien response = await client.post( f"{settings.OPA_URI}v1/data{settings.OPA_POLICY_PATH}", json={"input": request_params.model_dump()} ) - parsed_response = AuthzResponse(**response.json()) + parsed_response = AuthzResponse.model_validate(response.json()) span.set_attribute("decision_id", str(parsed_response.decision_id)) if not parsed_response.result: # pragma: no cover raise HTTPException( diff --git a/app/crud/crud_bucket.py b/app/crud/crud_bucket.py index 0de0dadc33f84c8ba2425013be582b0d8d50fae7..9113057609ebc8eb41997ba0186b929f61ce463b 100644 --- a/app/crud/crud_bucket.py +++ b/app/crud/crud_bucket.py @@ -1,4 +1,5 @@ from typing import Optional +from uuid import UUID from clowmdb.models import Bucket, BucketPermission from opentelemetry import trace @@ -30,13 +31,15 @@ class CRUDBucket: Flag if the check was successful. """ stmt = select(Bucket).where(Bucket.name == bucket_name) - trace.get_current_span().set_attributes({"bucket_name": bucket_name, "sql_query": str(stmt)}) - bucket = await db.scalar(stmt) - return bucket is not None + with tracer.start_as_current_span( + "db_check_bucket_exists", attributes={"bucket_name": bucket_name, "sql_query": str(stmt)} + ): + bucket = await db.scalar(stmt) + return bucket is not None @staticmethod @start_as_current_span_async("db_check_bucket_access", tracer=tracer) - async def check_access(db: AsyncSession, bucket_name: str, uid: str, key: Optional[str] = None) -> bool: + async def check_access(db: AsyncSession, bucket_name: str, uid: UUID, key: Optional[str] = None) -> bool: """ Check if the given user has access to the bucket. @@ -46,7 +49,7 @@ class CRUDBucket: Async database session to perform query on. bucket_name : str Name of a bucket. - uid : str + uid : uuid.UUID UID of a user. key : str | None, default None Additional check if the user has access to the key in the bucket. @@ -56,10 +59,13 @@ class CRUDBucket: check : bool Flag if the check was successful. """ - current_span = trace.get_current_span() - stmt = select(Bucket).where(Bucket.name == bucket_name).where(Bucket.owner_id == uid) - current_span.set_attributes({"bucket_name": bucket_name, "sql_query": str(stmt)}) - bucket = await db.scalar(stmt) + trace.get_current_span().set_attributes({"bucket_name": bucket_name, "uid": str(uid), "key": str(key)}) + stmt = select(Bucket).where(Bucket.name == bucket_name).where(Bucket._owner_id == uid.bytes) + with tracer.start_as_current_span( + "db_check_bucket_access_get_bucket", + attributes={"bucket_name": bucket_name, "sql_query": str(stmt), "uid": str(uid)}, + ): + bucket = await db.scalar(stmt) # If the user is the owner of the bucket -> user has access if bucket is not None: return True @@ -67,7 +73,7 @@ class CRUDBucket: stmt = ( select(BucketPermission) .where(BucketPermission.bucket_name == bucket_name) # check bucket name - .where(BucketPermission.user_id == uid) # check grantee of permission + .where(BucketPermission._uid == uid.bytes) # check grantee of permission .where(BucketPermission.permissions == BucketPermission.Permission.READWRITE) # check READWRITE Permission .where( # check 'form' timestamp is no violated or_( @@ -82,9 +88,12 @@ class CRUDBucket: ) ) ) - current_span.set_attributes({"sql_query": str(stmt)}) - permission: Optional[BucketPermission] = await db.scalar(stmt) + with tracer.start_as_current_span( + "db_check_bucket_access_get_permission", + attributes={"bucket_name": bucket_name, "sql_query": str(stmt), "uid": str(uid)}, + ): + permission: Optional[BucketPermission] = await db.scalar(stmt) # If the user has no active READWRITE Permission for the bucket -> user has no access if permission is None: return False diff --git a/app/crud/crud_user.py b/app/crud/crud_user.py index 945af06c3071e9c92cb383e0a8fd3c0143b07017..89eb52735c270db8ae4de0fd3fc064b5a383b7ec 100644 --- a/app/crud/crud_user.py +++ b/app/crud/crud_user.py @@ -1,4 +1,5 @@ from typing import Optional +from uuid import UUID from clowmdb.models import User from opentelemetry import trace @@ -10,7 +11,7 @@ tracer = trace.get_tracer_provider().get_tracer(__name__) class CRUDUser: @staticmethod - async def get(db: AsyncSession, uid: str) -> Optional[User]: + async def get(db: AsyncSession, uid: UUID) -> Optional[User]: """ Get a user by its UID. @@ -18,7 +19,7 @@ class CRUDUser: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - uid : str + uid : uuid.UUID UID of a user. Returns @@ -26,7 +27,6 @@ class CRUDUser: user : clowmdb.models.User | None The user for the given UID if he exists, None otherwise """ - with tracer.start_as_current_span("db_get_user", attributes={"uid": uid}) as span: - stmt = select(User).where(User.uid == uid) - span.set_attribute("sql_query", str(stmt)) + stmt = select(User).where(User._uid == uid.bytes) + with tracer.start_as_current_span("db_get_user", attributes={"uid": str(uid), "sql_query": str(stmt)}): return await db.scalar(stmt) diff --git a/app/crud/crud_workflow.py b/app/crud/crud_workflow.py index 9b7a9f9339620adf0b2bd263fe357058fab53981..ddc842f5debcaea41b6ab3adc1e651dcd95e8ff2 100644 --- a/app/crud/crud_workflow.py +++ b/app/crud/crud_workflow.py @@ -1,7 +1,7 @@ from datetime import date, datetime from hashlib import sha256 from os import urandom -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from uuid import UUID from clowmdb.models import Workflow, WorkflowExecution, WorkflowVersion @@ -23,7 +23,7 @@ class CRUDWorkflow: async def list_workflows( db: AsyncSession, name_substring: Optional[str] = None, - developer_id: Optional[str] = None, + developer_id: Optional[UUID] = None, version_status: Optional[List[WorkflowVersion.Status]] = None, ) -> List[Workflow]: """ @@ -51,8 +51,8 @@ class CRUDWorkflow: span.set_attribute("name_substring", name_substring) stmt = stmt.where(Workflow.name.contains(name_substring)) if developer_id is not None: - span.set_attribute("uid", developer_id) - stmt = stmt.where(Workflow.developer_id == developer_id) + span.set_attribute("uid", str(developer_id)) + stmt = stmt.where(Workflow._developer_id == developer_id.bytes) if version_status is not None and len(version_status) > 0: span.set_attribute("status", [stat.name for stat in version_status]) stmt = stmt.options( @@ -64,7 +64,7 @@ class CRUDWorkflow: return [w for w in (await db.scalars(stmt)).unique().all() if len(w.versions) > 0] @staticmethod - async def delete(db: AsyncSession, workflow_id: Union[UUID, bytes]) -> None: + async def delete(db: AsyncSession, workflow_id: UUID) -> None: """ Delete a workflow. @@ -72,20 +72,18 @@ class CRUDWorkflow: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - workflow_id : bytes | uuid.UUID + workflow_id : uuid.UUID UID of a workflow """ - with tracer.start_as_current_span("db_delete_workflow") as span: - wid = workflow_id.bytes if isinstance(workflow_id, UUID) else workflow_id - stmt = delete(Workflow).where(Workflow._workflow_id == wid) - span.set_attributes({"workflow_id": str(workflow_id), "sql_query": str(stmt)}) + stmt = delete(Workflow).where(Workflow._workflow_id == workflow_id.bytes) + with tracer.start_as_current_span( + "db_delete_workflow", attributes={"workflow_id": str(workflow_id), "sql_query": str(stmt)} + ): await db.execute(stmt) await db.commit() @staticmethod - async def update_credentials( - db: AsyncSession, workflow_id: Union[UUID, bytes], token: Optional[str] = None - ) -> None: + async def update_credentials(db: AsyncSession, workflow_id: UUID, token: Optional[str] = None) -> None: """ Delete a workflow. @@ -93,22 +91,23 @@ class CRUDWorkflow: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - workflow_id : bytes | uuid.UUID + workflow_id : uuid.UUID UID of a workflow token : str | None Token to save in the database. If None, the token in the database gets deleted """ - with tracer.start_as_current_span("db_update_workflow_credentials") as span: - wid = workflow_id.bytes if isinstance(workflow_id, UUID) else workflow_id - stmt = update(Workflow).where(Workflow._workflow_id == wid).values(credentials_token=token) - span.set_attributes({"workflow_id": str(workflow_id), "sql_query": str(stmt), "delete": token is None}) + stmt = update(Workflow).where(Workflow._workflow_id == workflow_id.bytes).values(credentials_token=token) + with tracer.start_as_current_span( + "db_update_workflow_credentials", + attributes={"workflow_id": str(workflow_id), "sql_query": str(stmt), "delete": token is None}, + ): await db.execute(stmt) await db.commit() @staticmethod async def developer_statistics( db: AsyncSession, - developer_id: Optional[str] = None, + developer_id: Optional[UUID] = None, workflow_ids: Optional[List[UUID]] = None, start: Optional[date] = None, end: Optional[date] = None, @@ -120,7 +119,7 @@ class CRUDWorkflow: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - developer_id : str | None, default None + developer_id : uuid.UUID | None, default None Filter workflow by developer ID. workflow_ids : uuid.UUID | None, default None Filter workflows by ID. @@ -139,11 +138,11 @@ class CRUDWorkflow: select( cast(func.FROM_UNIXTIME(WorkflowExecution.start_time), Date).label("started_at"), WorkflowExecution._execution_id, - WorkflowExecution.user_id, + WorkflowExecution._executor_id, WorkflowExecution._workflow_mode_id, WorkflowVersion.git_commit_hash, Workflow._workflow_id, - Workflow.developer_id, + Workflow._developer_id, WorkflowExecution.status, ) .select_from(WorkflowExecution) @@ -152,8 +151,8 @@ class CRUDWorkflow: .where(WorkflowExecution.end_time != None) # noqa:E711 ) if developer_id: - span.set_attribute("developer_id", developer_id) - stmt = stmt.where(Workflow.developer_id == developer_id) + span.set_attribute("developer_id", str(developer_id)) + stmt = stmt.where(Workflow._developer_id == developer_id.bytes) if workflow_ids: span.set_attribute("workflow_ids", [str(wid) for wid in workflow_ids]) stmt = stmt.where(*[Workflow._workflow_id == wid.bytes for wid in workflow_ids]) @@ -165,12 +164,12 @@ class CRUDWorkflow: span.set_attribute("end_date", end.isoformat()) timestamp = round(datetime(year=end.year, month=end.month, day=end.day).timestamp()) stmt = stmt.where(WorkflowExecution.start_time < timestamp) - user_hashes: Dict[str, str] = {} + user_hashes: Dict[bytes, str] = {} - def hash_user_id(uid: str) -> str: + def hash_user_id(uid: bytes) -> str: if uid not in user_hashes.keys(): hash_obj = sha256(usedforsecurity=True) - hash_obj.update(bytes.fromhex(uid if len(uid) % 2 == 0 else uid + "0")) + hash_obj.update(uid) hash_obj.update(urandom(32)) user_hashes[uid] = hash_obj.hexdigest() return user_hashes[uid] @@ -180,19 +179,19 @@ class CRUDWorkflow: return [ AnonymizedWorkflowExecution( workflow_execution_id=row._execution_id, - pseudo_uid=hash_user_id(row.user_id), + pseudo_uid=hash_user_id(row._executor_id), workflow_mode_id=row._workflow_mode_id, started_at=row.started_at, workflow_id=row._workflow_id, - developer_id=row.developer_id, - git_commit_hash=row.git_commit_hash, + developer_id=row._developer_id, + workflow_version_id=row.git_commit_hash, status=row.status, ) for row in rows ] @staticmethod - async def statistics(db: AsyncSession, workflow_id: Union[bytes, UUID]) -> List[WorkflowStatistic]: + async def statistics(db: AsyncSession, workflow_id: UUID) -> List[WorkflowStatistic]: """ Calculate the number of workflows started per day for a specific workflow @@ -200,7 +199,7 @@ class CRUDWorkflow: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - workflow_id : bytes | uuid.UUID + workflow_id : uuid.UUID UID of a workflow. Returns @@ -208,21 +207,21 @@ class CRUDWorkflow: stats : List[app.schemas.Workflow.WorkflowStatistic] List of datapoints """ - with tracer.start_as_current_span("db_get_workflow_statistics") as span: - wid = workflow_id.bytes if isinstance(workflow_id, UUID) else workflow_id - stmt = ( - select(cast(func.FROM_UNIXTIME(WorkflowExecution.start_time), Date).label("day"), func.count()) - .select_from(WorkflowExecution) - .join(WorkflowVersion) - .where(WorkflowVersion._workflow_id == wid) - .group_by("day") - .order_by("day") - ) - span.set_attributes({"workflow_id": str(workflow_id), "sql_query": str(stmt)}) + stmt = ( + select(cast(func.FROM_UNIXTIME(WorkflowExecution.start_time), Date).label("day"), func.count()) + .select_from(WorkflowExecution) + .join(WorkflowVersion) + .where(WorkflowVersion._workflow_id == workflow_id.bytes) + .group_by("day") + .order_by("day") + ) + with tracer.start_as_current_span( + "db_get_workflow_statistics", attributes={"workflow_id": str(workflow_id), "sql_query": str(stmt)} + ): return [WorkflowStatistic(day=row.day, count=row.count) for row in await db.execute(stmt)] @staticmethod - async def get(db: AsyncSession, workflow_id: Union[UUID, bytes]) -> Optional[Workflow]: + async def get(db: AsyncSession, workflow_id: UUID) -> Optional[Workflow]: """ Get a workflow by its ID. @@ -230,7 +229,7 @@ class CRUDWorkflow: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - workflow_id : bytes | uuid.UUID + workflow_id : uuid.UUID UID of a workflow. Returns @@ -238,14 +237,14 @@ class CRUDWorkflow: user : clowmdb.models.Workflow | None The workflow with the given ID if it exists, None otherwise """ - with tracer.start_as_current_span("db_get_workflow") as span: - wid = workflow_id.bytes if isinstance(workflow_id, UUID) else workflow_id - stmt = ( - select(Workflow) - .where(Workflow._workflow_id == wid) - .options(joinedload(Workflow.versions).selectinload(WorkflowVersion.workflow_modes)) - ) - span.set_attributes({"workflow_id": str(workflow_id), "sql_query": str(stmt)}) + stmt = ( + select(Workflow) + .where(Workflow._workflow_id == workflow_id.bytes) + .options(joinedload(Workflow.versions).selectinload(WorkflowVersion.workflow_modes)) + ) + with tracer.start_as_current_span( + "db_get_workflow", attributes={"workflow_id": str(workflow_id), "sql_query": str(stmt)} + ): return await db.scalar(stmt) @staticmethod @@ -265,20 +264,21 @@ class CRUDWorkflow: user : clowmdb.models.Workflow | None The workflow with the given name if it exists, None otherwise """ - with tracer.start_as_current_span("db_get_workflow_by_name") as span: - stmt = ( - select(Workflow) - .where(Workflow.name == workflow_name) - .options(joinedload(Workflow.versions).selectinload(WorkflowVersion.workflow_modes)) - ) - span.set_attributes({"name": workflow_name, "sql_query": str(stmt)}) + stmt = ( + select(Workflow) + .where(Workflow.name == workflow_name) + .options(joinedload(Workflow.versions).selectinload(WorkflowVersion.workflow_modes)) + ) + with tracer.start_as_current_span( + "db_get_workflow_by_name", attributes={"name": workflow_name, "sql_query": str(stmt)} + ): return await db.scalar(stmt) @staticmethod async def create( db: AsyncSession, workflow: WorkflowIn, - developer: str, + developer_id: UUID, icon_slug: Optional[str] = None, ) -> Workflow: """ @@ -290,7 +290,7 @@ class CRUDWorkflow: Async database session to perform query on. workflow : app.schemas.workflow.WorkflowIn Parameters for creating the workflow - developer : str + developer_id : uuid.UUID UID of the developer icon_slug : str | None, default None Optional slug of the icon saved in the icon bucket @@ -305,7 +305,7 @@ class CRUDWorkflow: name=workflow.name, repository_url=workflow.repository_url, short_description=workflow.short_description, - developer_id=developer, + _developer_id=developer_id.bytes, credentials_token=workflow.token, ) db.add(workflow_db) @@ -319,9 +319,9 @@ class CRUDWorkflow: db, git_commit_hash=workflow.git_commit_hash, version=workflow.initial_version, - wid=workflow_db.workflow_id, + workflow_id=workflow_db.workflow_id, icon_slug=icon_slug, modes=[mode.mode_id for mode in modes_db], ) - span.set_attribute("workflow_id", workflow_db.workflow_id) + span.set_attribute("workflow_id", str(workflow_db.workflow_id)) return await CRUDWorkflow.get(db, workflow_db.workflow_id) diff --git a/app/crud/crud_workflow_execution.py b/app/crud/crud_workflow_execution.py index 010aacf8583160c25043824bc7f482c4ebed1164..b837b196e23913336ff875c5c8526df3f46bee32 100644 --- a/app/crud/crud_workflow_execution.py +++ b/app/crud/crud_workflow_execution.py @@ -17,7 +17,7 @@ class CRUDWorkflowExecution: async def create( db: AsyncSession, execution: Union[WorkflowExecutionIn, DevWorkflowExecutionIn], - owner_id: str, + executor_id: UUID, notes: Optional[str] = None, ) -> WorkflowExecution: """ @@ -29,7 +29,7 @@ class CRUDWorkflowExecution: Async database session to perform query on. execution : app.schemas.workflow_execution.WorkflowExecutionIn | DevWorkflowExecutionIn Workflow execution input parameters. - owner_id : str + executor_id : uuid.UUID User who started the workflow execution. notes : str | None, default None Notes to add to the workflow execution. Only usd if 'execution' has type 'DevWorkflowExecutionIn'. @@ -39,18 +39,24 @@ class CRUDWorkflowExecution: workflow_execution : clowmdb.models.WorkflowExecution The newly created workflow execution """ - with tracer.start_as_current_span("db_create_workflow_execution") as span: + with tracer.start_as_current_span( + "db_create_workflow_execution", attributes={"executor_id": str(executor_id)} + ) as span: if isinstance(execution, WorkflowExecutionIn): + span.set_attribute("workflow_version_id", execution.workflow_version_id) workflow_execution = WorkflowExecution( - user_id=owner_id, + _executor_id=executor_id.bytes, workflow_version_id=execution.workflow_version_id, notes=execution.notes, slurm_job_id=-1, - _workflow_mode_id=execution.mode.bytes if execution.mode is not None else None, + _workflow_mode_id=execution.mode_id.bytes if execution.mode_id is not None else None, ) else: + span.set_attributes( + {"git_commit_hash": execution.git_commit_hash, "repository_url": str(execution.repository_url)} + ) workflow_execution = WorkflowExecution( - user_id=owner_id, + _executor_id=executor_id.bytes, workflow_version_id=None, notes=notes, slurm_job_id=-1, @@ -59,26 +65,30 @@ class CRUDWorkflowExecution: await db.commit() await db.refresh(workflow_execution) span.set_attribute("workflow_execution_id", str(workflow_execution.execution_id)) - await db.execute( - update(WorkflowExecution) - .where(WorkflowExecution._execution_id == workflow_execution.execution_id.bytes) - .values( - logs_path=None - if execution.logs_s3_path is None - else execution.logs_s3_path + f"/run-{workflow_execution.execution_id.hex}", - debug_path=None - if execution.debug_s3_path is None - else execution.debug_s3_path + f"/run-{workflow_execution.execution_id.hex}", - provenance_path=None - if execution.provenance_s3_path is None - else execution.provenance_s3_path + f"/run-{workflow_execution.execution_id.hex}", + with tracer.start_as_current_span( + "db_create_workflow_execution_update_paths", + attributes={"execution_id": str(workflow_execution.execution_id)}, + ): + await db.execute( + update(WorkflowExecution) + .where(WorkflowExecution._execution_id == workflow_execution.execution_id.bytes) + .values( + logs_path=None + if execution.logs_s3_path is None + else execution.logs_s3_path + f"/run-{workflow_execution.execution_id.hex}", + debug_path=None + if execution.debug_s3_path is None + else execution.debug_s3_path + f"/run-{workflow_execution.execution_id.hex}", + provenance_path=None + if execution.provenance_s3_path is None + else execution.provenance_s3_path + f"/run-{workflow_execution.execution_id.hex}", + ) ) - ) - await db.commit() + await db.commit() return workflow_execution @staticmethod - async def get(db: AsyncSession, execution_id: Union[bytes, UUID]) -> Optional[WorkflowExecution]: + async def get(db: AsyncSession, execution_id: UUID) -> Optional[WorkflowExecution]: """ Get a workflow execution by its execution id from the database. @@ -86,7 +96,7 @@ class CRUDWorkflowExecution: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - execution_id : uuid.UUID | bytes + execution_id : uuid.UUID ID of the workflow execution Returns @@ -94,20 +104,20 @@ class CRUDWorkflowExecution: workflow_execution : clowmdb.models.WorkflowExecution The workflow execution with the given id if it exists, None otherwise """ - with tracer.start_as_current_span("db_get_workflow_execution") as span: - eid = execution_id.bytes if isinstance(execution_id, UUID) else execution_id - stmt = ( - select(WorkflowExecution) - .where(WorkflowExecution._execution_id == eid) - .options(joinedload(WorkflowExecution.workflow_version)) - ) - span.set_attributes({"workflow_execution_id": str(execution_id), "sql_query": str(stmt)}) + stmt = ( + select(WorkflowExecution) + .where(WorkflowExecution._execution_id == execution_id.bytes) + .options(joinedload(WorkflowExecution.workflow_version)) + ) + with tracer.start_as_current_span( + "db_get_workflow_execution", attributes={"workflow_execution_id": str(execution_id), "sql_query": str(stmt)} + ): return await db.scalar(stmt) @staticmethod async def list( db: AsyncSession, - uid: Optional[str] = None, + executor_id: Optional[UUID] = None, workflow_version_id: Optional[str] = None, status_list: Optional[List[WorkflowExecution.WorkflowExecutionStatus]] = None, ) -> Sequence[WorkflowExecution]: @@ -118,7 +128,7 @@ class CRUDWorkflowExecution: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - uid : str | None, default None + executor_id : uuid.UUID | None, default None Filter for the user who started the workflow execution. workflow_version_id : str | None, default None Filter for the workflow version @@ -132,9 +142,9 @@ class CRUDWorkflowExecution: """ with tracer.start_as_current_span("db_list_workflow_executions") as span: stmt = select(WorkflowExecution).options(joinedload(WorkflowExecution.workflow_version)) - if uid is not None: - span.set_attribute("uid", uid) - stmt = stmt.where(WorkflowExecution.user_id == uid) + if executor_id is not None: + span.set_attribute("executor_id", str(executor_id)) + stmt = stmt.where(WorkflowExecution._executor_id == executor_id.bytes) if workflow_version_id is not None: span.set_attribute("git_commit_hash", workflow_version_id) stmt = stmt.where(WorkflowExecution.workflow_version_id == workflow_version_id) @@ -142,11 +152,10 @@ class CRUDWorkflowExecution: span.set_attribute("status", [stat.name for stat in status_list]) stmt = stmt.where(or_(*[WorkflowExecution.status == status for status in status_list])) span.set_attribute("sql_query", str(stmt)) - executions = (await db.scalars(stmt)).all() - return executions + return (await db.scalars(stmt)).all() @staticmethod - async def delete(db: AsyncSession, execution_id: Union[bytes, UUID]) -> None: + async def delete(db: AsyncSession, execution_id: UUID) -> None: """ Delete a workflow execution from the database. @@ -154,20 +163,21 @@ class CRUDWorkflowExecution: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - execution_id : uuid.UUID | bytes + execution_id : uuid.UUID ID of the workflow execution """ - with tracer.start_as_current_span("db_delete_workflow_execution") as span: - eid = execution_id.bytes if isinstance(execution_id, UUID) else execution_id - stmt = delete(WorkflowExecution).where(WorkflowExecution._execution_id == eid) - span.set_attributes({"workflow_execution_id": str(execution_id), "sql_query": str(stmt)}) + stmt = delete(WorkflowExecution).where(WorkflowExecution._execution_id == execution_id.bytes) + with tracer.start_as_current_span( + "db_delete_workflow_execution", + attributes={"workflow_execution_id": str(execution_id), "sql_query": str(stmt)}, + ): await db.execute(stmt) await db.commit() @staticmethod - async def cancel( + async def set_error( db: AsyncSession, - execution_id: Union[bytes, UUID], + execution_id: UUID, status: WorkflowExecution.WorkflowExecutionStatus = WorkflowExecution.WorkflowExecutionStatus.CANCELED, ) -> None: """ @@ -177,26 +187,25 @@ class CRUDWorkflowExecution: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - execution_id : uuid.UUID | bytes + execution_id : uuid.UUID ID of the workflow execution status : clowmdb.models.WorkflowExecution.WorkflowExecutionStatus, default WorkflowExecutionStatus.CANCELED Error status the workflow execution should get """ - with tracer.start_as_current_span("db_cancel_workflow_execution") as span: - eid = execution_id.bytes if isinstance(execution_id, UUID) else execution_id - stmt = ( - update(WorkflowExecution) - .where(WorkflowExecution._execution_id == eid) - .values(status=status.name, end_time=func.UNIX_TIMESTAMP()) - ) - span.set_attributes( - {"workflow_execution_id": str(execution_id), "status": status.name, "sql_query": str(stmt)} - ) + stmt = ( + update(WorkflowExecution) + .where(WorkflowExecution._execution_id == execution_id.bytes) + .values(status=status.name, end_time=func.UNIX_TIMESTAMP()) + ) + with tracer.start_as_current_span( + "db_cancel_workflow_execution", + attributes={"workflow_execution_id": str(execution_id), "status": status.name, "sql_query": str(stmt)}, + ): await db.execute(stmt) await db.commit() @staticmethod - async def update_slurm_job_id(db: AsyncSession, execution_id: Union[bytes, UUID], slurm_job_id: int) -> None: + async def update_slurm_job_id(db: AsyncSession, execution_id: UUID, slurm_job_id: int) -> None: """ Update the status of a workflow execution to CANCELED in the database. @@ -204,20 +213,23 @@ class CRUDWorkflowExecution: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - execution_id : uuid.UUID | bytes + execution_id : uuid.UUID ID of the workflow execution slurm_job_id : int New slurm job ID """ - with tracer.start_as_current_span("db_update_workflow_execution_slurm_id") as span: - eid = execution_id.bytes if isinstance(execution_id, UUID) else execution_id - stmt = ( - update(WorkflowExecution) - .where(WorkflowExecution._execution_id == eid) - .values(slurm_job_id=slurm_job_id) - ) - span.set_attributes( - {"workflow_execution_id": str(execution_id), "slurm_job_id": slurm_job_id, "sql_query": str(stmt)} - ) + stmt = ( + update(WorkflowExecution) + .where(WorkflowExecution._execution_id == execution_id.bytes) + .values(slurm_job_id=slurm_job_id) + ) + with tracer.start_as_current_span( + "db_update_workflow_execution_slurm_id", + attributes={ + "workflow_execution_id": str(execution_id), + "slurm_job_id": slurm_job_id, + "sql_query": str(stmt), + }, + ): await db.execute(stmt) await db.commit() diff --git a/app/crud/crud_workflow_mode.py b/app/crud/crud_workflow_mode.py index 91e14b5889094a05282467410c22bfddc548775f..1d856046cec434f13b54d02759abae5cd04fde82 100644 --- a/app/crud/crud_workflow_mode.py +++ b/app/crud/crud_workflow_mode.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Union +from typing import Iterable, List, Optional from uuid import UUID from clowmdb.models import WorkflowMode, workflow_mode_association_table @@ -15,7 +15,7 @@ class CRUDWorkflowMode: @staticmethod async def list_modes( db: AsyncSession, - workflow_version: str, + workflow_version_id: str, ) -> List[WorkflowMode]: """ List all workflow modes of a specific workflow version. @@ -24,7 +24,7 @@ class CRUDWorkflowMode: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - workflow_version : str + workflow_version_id : str The version id for which the modes should be loaded. Returns @@ -32,19 +32,18 @@ class CRUDWorkflowMode: modes : List[clowmdb.models.WorkflowMode] List of workflow modes. """ - with tracer.start_as_current_span("db_list_workflow_modes") as span: - stmt = ( - select(WorkflowMode) - .join(workflow_mode_association_table) - .where(workflow_mode_association_table.columns.workflow_version_commit_hash == workflow_version) - ) - span.set_attributes({"git_commit_hash": workflow_version, "sql_query": str(stmt)}) + stmt = ( + select(WorkflowMode) + .join(workflow_mode_association_table) + .where(workflow_mode_association_table.columns.workflow_version_commit_hash == workflow_version_id) + ) + with tracer.start_as_current_span( + "db_list_workflow_modes", attributes={"workflow_version_id": workflow_version_id, "sql_query": str(stmt)} + ): return list((await db.scalars(stmt)).all()) @staticmethod - async def get( - db: AsyncSession, mode_id: Union[bytes, UUID], workflow_version: Optional[str] = None - ) -> Optional[WorkflowMode]: + async def get(db: AsyncSession, mode_id: UUID, workflow_version_id: Optional[str] = None) -> Optional[WorkflowMode]: """ Get a specific workflow mode. @@ -52,9 +51,9 @@ class CRUDWorkflowMode: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - mode_id : UUID | bytes + mode_id : UUID ID of a workflow mode. - workflow_version : str | None, default None + workflow_version_id : str | None, default None Optional workflow version the workflow mode has to be connected to. Returns @@ -62,14 +61,14 @@ class CRUDWorkflowMode: workflows : clowmdb.models.WorkflowMode | None Requested workflow mode if it exists, None otherwise """ - with tracer.start_as_current_span("db_get_workflow_mode") as span: - mid = mode_id.bytes if isinstance(mode_id, UUID) else mode_id - span.set_attribute("workflow_mode_id", str(mode_id)) - stmt = select(WorkflowMode).where(WorkflowMode._mode_id == mid) - if workflow_version is not None: - span.set_attribute("git_commit_hash", workflow_version) + with tracer.start_as_current_span( + "db_get_workflow_mode", attributes={"workflow_mode_id": str(mode_id)} + ) as span: + stmt = select(WorkflowMode).where(WorkflowMode._mode_id == mode_id.bytes) + if workflow_version_id is not None: + span.set_attribute("workflow_version_id", workflow_version_id) stmt = stmt.join(workflow_mode_association_table).where( - workflow_mode_association_table.columns.workflow_version_commit_hash == workflow_version + workflow_mode_association_table.columns.workflow_version_commit_hash == workflow_version_id ) span.set_attribute("sql_query", str(stmt)) return await db.scalar(stmt) @@ -113,8 +112,9 @@ class CRUDWorkflowMode: modes : List[uuid.UUID] ID of workflow modes to delete """ - with tracer.start_as_current_span("db_delete_workflow_mode") as span: - stmt = delete(WorkflowMode).where(WorkflowMode._mode_id.in_([uuid.bytes for uuid in modes])) - span.set_attributes({"workflow_mode_ids": [str(m) for m in modes], "sql_query": str(stmt)}) + stmt = delete(WorkflowMode).where(WorkflowMode._mode_id.in_([uuid.bytes for uuid in modes])) + with tracer.start_as_current_span( + "db_delete_workflow_mode", attributes={"workflow_mode_ids": [str(m) for m in modes], "sql_query": str(stmt)} + ): await db.execute(stmt) await db.commit() diff --git a/app/crud/crud_workflow_version.py b/app/crud/crud_workflow_version.py index 83c018c7768ef3281a9aee8ed1c9a75af4ade69c..ee713b552332f0dbc1de66ed4c60a15cdc355527 100644 --- a/app/crud/crud_workflow_version.py +++ b/app/crud/crud_workflow_version.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence from uuid import UUID from clowmdb.models import WorkflowVersion, workflow_mode_association_table @@ -14,8 +14,8 @@ class CRUDWorkflowVersion: @staticmethod async def get( db: AsyncSession, - git_commit_hash: str, - workflow_id: Optional[Union[bytes, UUID]] = None, + workflow_version_id: str, + workflow_id: Optional[UUID] = None, populate_workflow: bool = False, ) -> Optional[WorkflowVersion]: """ @@ -25,9 +25,9 @@ class CRUDWorkflowVersion: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - git_commit_hash : str + workflow_version_id : str Git commit git_commit_hash of the version. - workflow_id : UUID | bytes | None, default None + workflow_id : UUID | None, default None Specify the workflow the version has to belong to. populate_workflow: boolean, default False Flag if to populate the workflow attribute. @@ -37,28 +37,25 @@ class CRUDWorkflowVersion: user : clowmdb.models.WorkflowVersion | None The workflow version with the given git_commit_hash if it exists, None otherwise """ - with tracer.start_as_current_span("db_get_workflow_version") as span: - span.set_attribute("git_commit_hash", git_commit_hash) - + with tracer.start_as_current_span( + "db_get_workflow_version", + attributes={"workflow_version_id": workflow_version_id, "populate_workflow": populate_workflow}, + ) as span: stmt = ( select(WorkflowVersion) - .where(WorkflowVersion.git_commit_hash == git_commit_hash) + .where(WorkflowVersion.git_commit_hash == workflow_version_id) .options(selectinload(WorkflowVersion.workflow_modes)) ) if populate_workflow: - span.set_attribute("populate_workflow", True) stmt = stmt.options(joinedload(WorkflowVersion.workflow)) if workflow_id is not None: span.set_attribute("workflow_id", str(workflow_id)) - wid = workflow_id if isinstance(workflow_id, bytes) else workflow_id.bytes - stmt = stmt.where(WorkflowVersion._workflow_id == wid) + stmt = stmt.where(WorkflowVersion._workflow_id == workflow_id.bytes) span.set_attribute("sql_query", str(stmt)) return await db.scalar(stmt) @staticmethod - async def get_latest( - db: AsyncSession, wid: Union[bytes, UUID], published: bool = True - ) -> Optional[WorkflowVersion]: + async def get_latest(db: AsyncSession, workflow_id: UUID, published: bool = True) -> Optional[WorkflowVersion]: """ Get the latest version of a workflow. @@ -66,7 +63,7 @@ class CRUDWorkflowVersion: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - wid : bytes | uuid.UUID + workflow_id : uuid.UUID Id of a workflow published : bool, default = True Get the latest versions that is published, otherwise get latest version overall @@ -76,21 +73,17 @@ class CRUDWorkflowVersion: user : clowmdb.models.WorkflowVersion | None The latest workflow version of the given workflow if the workflow exists, None otherwise """ - with tracer.start_as_current_span("db_get_latest_workflow_version") as span: - span.set_attribute("workflow_id", str(wid)) + with tracer.start_as_current_span( + "db_get_latest_workflow_version", attributes={"workflow_id": str(workflow_id), "published": published} + ) as span: stmt = ( select(WorkflowVersion) - .where( - WorkflowVersion._workflow_id == wid.bytes - if isinstance(wid, UUID) - else wid # type: ignore[arg-type] - ) + .where(WorkflowVersion._workflow_id == workflow_id.bytes) .order_by(desc(WorkflowVersion.created_at)) .limit(1) .options(selectinload(WorkflowVersion.workflow_modes)) ) if published: - span.set_attribute("only_published", True) stmt = stmt.where( or_( *[ @@ -104,7 +97,7 @@ class CRUDWorkflowVersion: @staticmethod async def list( - db: AsyncSession, wid: Union[bytes, UUID], version_status: Optional[List[WorkflowVersion.Status]] = None + db: AsyncSession, workflow_id: UUID, version_status: Optional[List[WorkflowVersion.Status]] = None ) -> Sequence[WorkflowVersion]: """ List all versions of a workflow. @@ -113,8 +106,8 @@ class CRUDWorkflowVersion: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - wid : bytes | uuid.UUID - Git commit git_commit_hash of the version. + workflow_id : uuid.UUID + ID of a workflow. version_status : List[clowmdb.models.WorkflowVersion.Status] | None, default None Filter versions based on the status @@ -123,16 +116,13 @@ class CRUDWorkflowVersion: user : List[clowmdb.models.WorkflowVersion] All workflow version of the given workflow """ - with tracer.start_as_current_span("db_list_workflow_versions") as span: - span.set_attribute("workflow_id", str(wid)) + with tracer.start_as_current_span( + "db_list_workflow_versions", attributes={"workflow_id": str(workflow_id)} + ) as span: stmt = ( select(WorkflowVersion) .options(selectinload(WorkflowVersion.workflow_modes)) - .where( - WorkflowVersion._workflow_id == wid.bytes - if isinstance(wid, UUID) - else wid # type: ignore[arg-type] - ) + .where(WorkflowVersion._workflow_id == workflow_id.bytes) ) if version_status is not None and len(version_status) > 0: span.set_attribute("version_status", [stat.name for stat in version_status]) @@ -146,7 +136,7 @@ class CRUDWorkflowVersion: db: AsyncSession, git_commit_hash: str, version: str, - wid: Union[bytes, UUID], + workflow_id: UUID, icon_slug: Optional[str] = None, previous_version: Optional[str] = None, modes: Optional[List[UUID]] = None, @@ -162,7 +152,7 @@ class CRUDWorkflowVersion: Git commit git_commit_hash of the version. version : str New version in string format - wid : bytes | uuid.UUID + workflow_id : uuid.UUID ID of the corresponding workflow icon_slug : str | None, default None Slug of the icon @@ -176,33 +166,41 @@ class CRUDWorkflowVersion: workflow_version : clowmdb.models.WorkflowVersion Newly create WorkflowVersion """ - with tracer.start_as_current_span("db_create_workflow_version") as span: - span.set_attributes({"git_commit_version": git_commit_hash, "workflow_id": str(wid)}) + with tracer.start_as_current_span( + "db_create_workflow_version", + attributes={"git_commit_version": git_commit_hash, "workflow_id": str(workflow_id)}, + ) as span: + if previous_version is not None: # pragma: no cover + span.set_attribute("previous_version", previous_version) if modes is None: modes = [] workflow_version = WorkflowVersion( git_commit_hash=git_commit_hash, version=version, - _workflow_id=wid.bytes if isinstance(wid, UUID) else wid, + _workflow_id=workflow_id.bytes, icon_slug=icon_slug, previous_version_hash=previous_version, ) db.add(workflow_version) if len(modes) > 0: span.set_attribute("mode_ids", [str(m) for m in modes]) - await db.commit() - await db.execute( - insert(workflow_mode_association_table), - [ - {"workflow_version_commit_hash": git_commit_hash, "workflow_mode_id": mode_id.bytes} - for mode_id in modes - ], - ) + with tracer.start_as_current_span( + "db_create_workflow_version_connect_modes", + attributes={"workflow_version_id": git_commit_hash, "mode_ids": [str(m) for m in modes]}, + ): + await db.commit() + await db.execute( + insert(workflow_mode_association_table), + [ + {"workflow_version_commit_hash": git_commit_hash, "workflow_mode_id": mode_id.bytes} + for mode_id in modes + ], + ) await db.commit() return workflow_version @staticmethod - async def update_status(db: AsyncSession, git_commit_hash: str, status: WorkflowVersion.Status) -> None: + async def update_status(db: AsyncSession, workflow_version_id: str, status: WorkflowVersion.Status) -> None: """ Update the status of a workflow version. @@ -210,24 +208,25 @@ class CRUDWorkflowVersion: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - git_commit_hash : str + workflow_version_id : str Git commit git_commit_hash of the version. status : clowmdb.models.WorkflowVersion.Status New status of the workflow version """ - with tracer.start_as_current_span("db_update_workflow_version_status") as span: - span.set_attributes({"status": status.name, "git_commit_version": git_commit_hash}) - stmt = ( - update(WorkflowVersion) - .where(WorkflowVersion.git_commit_hash == git_commit_hash) - .values(status=status.name) - ) - span.set_attribute("sql_query", str(stmt)) + stmt = ( + update(WorkflowVersion) + .where(WorkflowVersion.git_commit_hash == workflow_version_id) + .values(status=status.name) + ) + with tracer.start_as_current_span( + "db_update_workflow_version_status", + attributes={"status": status.name, "workflow_version_id": workflow_version_id, "sql_query": str(stmt)}, + ): await db.execute(stmt) await db.commit() @staticmethod - async def update_icon(db: AsyncSession, git_commit_hash: str, icon_slug: Optional[str] = None) -> None: + async def update_icon(db: AsyncSession, workflow_version_id: str, icon_slug: Optional[str] = None) -> None: """ Update the icon slug for a workflow version. @@ -235,24 +234,24 @@ class CRUDWorkflowVersion: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. - git_commit_hash : str + workflow_version_id : str Git commit git_commit_hash of the version. icon_slug : str | None, default None The new icon slug """ - with tracer.start_as_current_span("db_update_workflow_version_icon") as span: - stmt = ( - update(WorkflowVersion) - .where(WorkflowVersion.git_commit_hash == git_commit_hash) - .values(icon_slug=icon_slug) - ) - span.set_attributes( - { - "git_commit_hash": git_commit_hash, - "icon_slug": icon_slug if icon_slug else "None", - "sql_query": str(stmt), - } - ) + stmt = ( + update(WorkflowVersion) + .where(WorkflowVersion.git_commit_hash == workflow_version_id) + .values(icon_slug=icon_slug) + ) + with tracer.start_as_current_span( + "db_update_workflow_version_icon", + attributes={ + "git_commit_hash": workflow_version_id, + "icon_slug": icon_slug if icon_slug else "None", + "sql_query": str(stmt), + }, + ): await db.execute(stmt) await db.commit() @@ -273,8 +272,9 @@ class CRUDWorkflowVersion: exists : bool Flag if a version exists that depends on the icon """ - with tracer.start_as_current_span("db_check_workflow_version_icon_exists") as span: - stmt = select(WorkflowVersion).where(WorkflowVersion.icon_slug == icon_slug).limit(1) - span.set_attributes({"icon_slug": icon_slug, "sql_query": str(stmt)}) + stmt = select(WorkflowVersion).where(WorkflowVersion.icon_slug == icon_slug).limit(1) + with tracer.start_as_current_span( + "db_check_workflow_version_icon_exists", attributes={"icon_slug": icon_slug, "sql_query": str(stmt)} + ): version_with_icon = await db.scalar(stmt) return version_with_icon is not None diff --git a/mako_templates/nextflow_command.tmpl b/app/mako_templates/nextflow_command.tmpl similarity index 100% rename from mako_templates/nextflow_command.tmpl rename to app/mako_templates/nextflow_command.tmpl diff --git a/app/schemas/security.py b/app/schemas/security.py index c66952a0eb0806436a134c0c7af96f75a478364c..1c2500eee500bc9a78e41ae55181d24fda702657 100644 --- a/app/schemas/security.py +++ b/app/schemas/security.py @@ -1,7 +1,7 @@ from datetime import datetime from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, FieldSerializationInfo, field_serializer class AuthzResponse(BaseModel): @@ -14,6 +14,10 @@ class AuthzResponse(BaseModel): ) result: bool = Field(..., description="Result of the Authz request") + @field_serializer("decision_id") + def serialize_decision_id(self, decision_id: UUID, _info: FieldSerializationInfo) -> str: + return str(decision_id) + class AuthzRequest(BaseModel): """Schema for a Request to OPA""" diff --git a/app/schemas/workflow.py b/app/schemas/workflow.py index 9ce0dfc063309e4a94fa9e395bd4cf080344e864..35ceb54d8ba640e5489a2d4e8558d5af18a0aba8 100644 --- a/app/schemas/workflow.py +++ b/app/schemas/workflow.py @@ -32,7 +32,7 @@ class _BaseWorkflow(BaseModel): ) @field_serializer("repository_url") - def serialize_dt(self, url: AnyHttpUrl, _info: FieldSerializationInfo) -> str: + def serialize_url(self, url: AnyHttpUrl, _info: FieldSerializationInfo) -> str: return str(url) @@ -66,11 +66,19 @@ class WorkflowIn(_BaseWorkflow): class WorkflowOut(_BaseWorkflow): workflow_id: UUID = Field(..., description="ID of the workflow", examples=["20128c04-e834-40a8-9878-68939ae46423"]) versions: List[WorkflowVersion] = Field(..., description="Versions of the workflow") - developer_id: str = Field( - ..., description="ID of developer of the workflow", examples=["28c5353b8bb34984a8bd4169ba94c606"] + developer_id: Optional[UUID] = Field( + None, description="ID of developer of the workflow", examples=["1d3387f3-95c0-4813-8767-2cad87faeebf"] ) private: bool = Field(default=False, description="Flag if the workflow is hosted in a private git repository") + @field_serializer("workflow_id") + def serialize_workflow_id(self, workflow_id: UUID, _info: FieldSerializationInfo) -> str: + return str(workflow_id) + + @field_serializer("developer_id", when_used="unless-none") + def serialize_developer_id(self, developer_id: UUID, _info: FieldSerializationInfo) -> str: + return str(developer_id) + @staticmethod def from_db_workflow( db_workflow: WorkflowDB, versions: Optional[Sequence[WorkflowVersionDB]] = None, load_modes: bool = True @@ -150,3 +158,7 @@ class WorkflowUpdate(BaseModel): delete_modes: List[UUID] = Field( [], description="Delete modes for the new workflow version.", examples=["2a23a083-b6b9-4681-9ec4-ff4ffbe85d3c"] ) + + @field_serializer("delete_modes") + def serialize_modes(self, delete_modes: List[UUID], _info: FieldSerializationInfo) -> List[str]: + return [str(m) for m in delete_modes] diff --git a/app/schemas/workflow_execution.py b/app/schemas/workflow_execution.py index 350bdc37029b3b273a75c750873ea1dae2f0c3a8..62b655af230b39b8089e60bca6b111fda6a5d231 100644 --- a/app/schemas/workflow_execution.py +++ b/app/schemas/workflow_execution.py @@ -23,12 +23,16 @@ class _BaseWorkflowExecution(BaseModel): max_length=2**16, examples=["Some workflow execution specific notes"], ) - mode: Optional[UUID] = Field( + mode_id: Optional[UUID] = Field( default=None, description="ID of the workflow mode this workflow execution runs in", examples=["2a23a083-b6b9-4681-9ec4-ff4ffbe85d3c"], ) + @field_serializer("mode_id", when_used="unless-none") + def serialize_mode_id(self, mode_id: UUID, _info: FieldSerializationInfo) -> str: + return str(mode_id) + class _WorkflowExecutionInParameters(BaseModel): parameters: Dict[str, Any] = Field(..., description="Parameters for this workflow") @@ -66,8 +70,8 @@ class WorkflowExecutionOut(_BaseWorkflowExecution): execution_id: UUID = Field( ..., description="ID of the workflow execution", examples=["591b6a6e-a1f0-420d-8a20-a7a60704f695"] ) - user_id: str = Field( - ..., description="UID of user who started the workflow", examples=["28c5353b8bb34984a8bd4169ba94c606"] + executor_id: UUID = Field( + ..., description="UID of user who started the workflow", examples=["1d3387f3-95c0-4813-8767-2cad87faeebf"] ) start_time: int = Field( ..., @@ -106,20 +110,32 @@ class WorkflowExecutionOut(_BaseWorkflowExecution): examples=["s3://example-bucket/debug/run-591b6a6ea1f0420d8a20a7a60704f695"], ) + @field_serializer("execution_id") + def serialize_execution_id(self, execution_id: UUID, _info: FieldSerializationInfo) -> str: + return str(execution_id) + + @field_serializer("workflow_id", when_used="unless-none") + def serialize_workflow_id(self, workflow_id: UUID, _info: FieldSerializationInfo) -> str: + return str(workflow_id) + + @field_serializer("executor_id") + def serialize_executor_id(self, executor_id: UUID, _info: FieldSerializationInfo) -> str: + return str(executor_id) + @staticmethod def from_db_model( workflow_execution: WorkflowExecution, workflow_id: Optional[UUID] = None ) -> "WorkflowExecutionOut": return WorkflowExecutionOut( execution_id=workflow_execution.execution_id, - user_id=workflow_execution.user_id, + executor_id=workflow_execution.executor_id, start_time=workflow_execution.start_time, end_time=workflow_execution.end_time, status=workflow_execution.status, workflow_version_id=workflow_execution.workflow_version_id, workflow_id=workflow_id, notes=workflow_execution.notes, - mode=workflow_execution.workflow_mode_id, + mode_id=workflow_execution.workflow_mode_id, debug_s3_path=workflow_execution.debug_path, logs_s3_path=workflow_execution.logs_path, provenance_s3_path=workflow_execution.provenance_path, @@ -151,7 +167,7 @@ class DevWorkflowExecutionIn(_WorkflowExecutionInParameters): ) @field_serializer("repository_url") - def serialize_dt(self, url: AnyHttpUrl, _info: FieldSerializationInfo) -> str: + def serialize_url(self, url: AnyHttpUrl, _info: FieldSerializationInfo) -> str: return str(url) @@ -169,14 +185,14 @@ class AnonymizedWorkflowExecution(BaseModel): description="ID of the workflow mode this workflow execution ran in", examples=["2a23a083-b6b9-4681-9ec4-ff4ffbe85d3c"], ) - git_commit_hash: str = Field( + workflow_version_id: str = Field( ..., description="Hash of the git commit", examples=["ba8bcd9294c2c96aedefa1763a84a18077c50c0f"] ) started_at: date = Field( ..., description="Day of the workflow execution", examples=[date(day=1, month=1, year=2023)] ) workflow_id: UUID = Field(..., description="ID of the workflow", examples=["20128c04-e834-40a8-9878-68939ae46423"]) - developer_id: str = Field( + developer_id: UUID = Field( ..., description="ID of developer of the workflow", examples=["28c5353b8bb34984a8bd4169ba94c606"] ) status: WorkflowExecution.WorkflowExecutionStatus = Field( @@ -184,3 +200,19 @@ class AnonymizedWorkflowExecution(BaseModel): description="End status of the workflow execution", examples=[WorkflowExecution.WorkflowExecutionStatus.SUCCESS], ) + + @field_serializer("workflow_execution_id") + def serialize_workflow_execution_id(self, workflow_execution_id: UUID, _info: FieldSerializationInfo) -> str: + return str(workflow_execution_id) + + @field_serializer("workflow_id", when_used="unless-none") + def serialize_workflow_id(self, workflow_id: UUID, _info: FieldSerializationInfo) -> str: + return str(workflow_id) + + @field_serializer("developer_id") + def serialize_developer_id(self, developer_id: UUID, _info: FieldSerializationInfo) -> str: + return str(developer_id) + + @field_serializer("workflow_mode_id", when_used="unless-none") + def serialize_mode_id(self, mode_id: Optional[UUID], _info: FieldSerializationInfo) -> str: + return str(mode_id) diff --git a/app/schemas/workflow_mode.py b/app/schemas/workflow_mode.py index 6cd395eba1dab93fa266dc9755cbd09307ff52e1..2762ee85efba78abd865b649bb94bc684f918e19 100644 --- a/app/schemas/workflow_mode.py +++ b/app/schemas/workflow_mode.py @@ -1,6 +1,6 @@ from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, FieldSerializationInfo, field_serializer class _BaseWorkflowMode(BaseModel): @@ -22,3 +22,7 @@ class WorkflowModeIn(_BaseWorkflowMode): class WorkflowModeOut(_BaseWorkflowMode): mode_id: UUID = Field(..., description="ID of the workflow mode", examples=["2a23a083-b6b9-4681-9ec4-ff4ffbe85d3c"]) + + @field_serializer("mode_id") + def serialize_mode_id(self, mode_id: UUID, _info: FieldSerializationInfo) -> str: + return str(mode_id) diff --git a/app/schemas/workflow_version.py b/app/schemas/workflow_version.py index 75b0b728453e0cd5696ba808215d7fd9d054dbf7..f3e1cf102d2f097be30431632603d453e37c232a 100644 --- a/app/schemas/workflow_version.py +++ b/app/schemas/workflow_version.py @@ -3,7 +3,7 @@ from typing import List, Optional from uuid import UUID from clowmdb.models import WorkflowVersion as WorkflowVersionDB -from pydantic import AnyHttpUrl, BaseModel, Field +from pydantic import AnyHttpUrl, BaseModel, Field, FieldSerializationInfo, field_serializer from app.core.config import settings @@ -25,7 +25,7 @@ class WorkflowVersion(WorkflowVersionStatus): min_length=5, max_length=10, ) - git_commit_hash: str = Field( + workflow_version_id: str = Field( ..., description="Hash of the git commit", examples=["ba8bcd9294c2c96aedefa1763a84a18077c50c0f"], @@ -49,6 +49,10 @@ class WorkflowVersion(WorkflowVersionStatus): examples=["2a23a083-b6b9-4681-9ec4-ff4ffbe85d3c"], ) + @field_serializer("workflow_id") + def serialize_workflow_id(self, workflow_id: UUID, _info: FieldSerializationInfo) -> str: + return str(workflow_id) + @staticmethod def from_db_version( db_version: WorkflowVersionDB, mode_ids: Optional[List[UUID]] = None, load_modes: bool = False @@ -83,7 +87,7 @@ class WorkflowVersion(WorkflowVersionStatus): return WorkflowVersion( workflow_id=db_version.workflow_id, version=db_version.version, - git_commit_hash=db_version.git_commit_hash, + workflow_version_id=db_version.git_commit_hash, icon_url=icon_url, created_at=db_version.created_at, status=db_version.status, diff --git a/app/tests/api/test_security.py b/app/tests/api/test_security.py index 3f6df23dc5f4bce2c76b44e6321e5b247987cc87..b13a5780ff2ba25687adee1e254edc56e2c0108e 100644 --- a/app/tests/api/test_security.py +++ b/app/tests/api/test_security.py @@ -3,6 +3,8 @@ from fastapi import status from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession +from app.tests.mocks.mock_opa_service import MockOpaService +from app.tests.utils.cleanup import CleanupList from app.tests.utils.user import UserWithAuthHeader @@ -54,7 +56,7 @@ class TestJWTProtectedRoutes: Random user for testing. """ response = await client.get( - self.protected_route, params={"user": random_user.user.uid}, headers=random_user.auth_headers + self.protected_route, params={"developer_id": str(random_user.user.uid)}, headers=random_user.auth_headers ) assert response.status_code == status.HTTP_200_OK @@ -81,13 +83,17 @@ class TestJWTProtectedRoutes: await db.commit() response = await client.get( - self.protected_route, params={"user": random_user.user.uid}, headers=random_user.auth_headers + self.protected_route, params={"developer_id": str(random_user.user.uid)}, headers=random_user.auth_headers ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio async def test_routed_with_insufficient_permissions( - self, client: AsyncClient, random_user: UserWithAuthHeader + self, + client: AsyncClient, + random_user: UserWithAuthHeader, + mock_opa_service: MockOpaService, + cleanup: CleanupList, ) -> None: """ Test with correct authorization header but with insufficient permissions. @@ -99,9 +105,15 @@ class TestJWTProtectedRoutes: random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. """ + mock_opa_service.send_error = True + + def repair_opa_service() -> None: + mock_opa_service.send_error = False + + cleanup.add_task(repair_opa_service) + response = await client.get( self.protected_route, - params={"raise_opa_error": True}, headers=random_user.auth_headers, ) assert response.status_code == status.HTTP_403_FORBIDDEN diff --git a/app/tests/api/test_workflow.py b/app/tests/api/test_workflow.py index b9ab802164cde6d9e1386a5d82cb7d96213a52ce..2a5794630e4cf3b6c4401c7aa3f6639d37cc0240 100644 --- a/app/tests/api/test_workflow.py +++ b/app/tests/api/test_workflow.py @@ -1,19 +1,23 @@ from datetime import date from io import BytesIO -from uuid import UUID, uuid4 +from uuid import uuid4 import pytest from clowmdb.models import Workflow, WorkflowExecution, WorkflowMode, WorkflowVersion from fastapi import status from httpx import AsyncClient -from sqlalchemy import delete, select +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings -from app.schemas.workflow import WorkflowIn, WorkflowOut, WorkflowUpdate +from app.schemas.workflow import WorkflowIn, WorkflowOut, WorkflowStatistic, WorkflowUpdate +from app.schemas.workflow_execution import AnonymizedWorkflowExecution from app.schemas.workflow_mode import WorkflowModeIn +from app.schemas.workflow_version import WorkflowVersion as WorkflowVersionOut from app.scm import SCM, SCMProvider +from app.tests.mocks import DefaultMockHTTPService from app.tests.mocks.mock_s3_resource import MockS3ServiceResource +from app.tests.utils.cleanup import CleanupList, delete_workflow, delete_workflow_mode from app.tests.utils.user import UserWithAuthHeader from app.tests.utils.utils import random_hex_string, random_lower_string @@ -25,10 +29,7 @@ class _TestWorkflowRoutes: class TestWorkflowRoutesCreate(_TestWorkflowRoutes): @pytest.mark.asyncio async def test_create_workflow_with_github( - self, - db: AsyncSession, - client: AsyncClient, - random_user: UserWithAuthHeader, + self, db: AsyncSession, client: AsyncClient, random_user: UserWithAuthHeader, cleanup: CleanupList ) -> None: """ Exhaustive Test for successfully creating a workflow. @@ -41,6 +42,8 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): HTTP Client to perform the request on. random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ git_commit_hash = random_hex_string() workflow = WorkflowIn( @@ -51,10 +54,11 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): ).model_dump() response = await client.post(self.base_path, json=workflow, headers=random_user.auth_headers) assert response.status_code == status.HTTP_201_CREATED - created_workflow = response.json() - assert not created_workflow["private"] + created_workflow = WorkflowOut.model_validate(response.json()) + cleanup.add_task(delete_workflow, db=db, workflow_id=created_workflow.workflow_id) + assert not created_workflow.private - stmt = select(Workflow).where(Workflow._workflow_id == UUID(hex=created_workflow["workflow_id"]).bytes) + stmt = select(Workflow).where(Workflow._workflow_id == created_workflow.workflow_id.bytes) db_workflow = await db.scalar(stmt) assert db_workflow is not None @@ -62,11 +66,6 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): db_version = await db.scalar(stmt) assert db_version is not None - await db.execute( - delete(Workflow).where(Workflow._workflow_id == UUID(hex=created_workflow["workflow_id"]).bytes) - ) - await db.commit() - @pytest.mark.asyncio async def test_create_workflow_with_gitlab( self, @@ -74,6 +73,7 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): db: AsyncSession, random_user: UserWithAuthHeader, mock_s3_service: MockS3ServiceResource, + cleanup: CleanupList, ) -> None: """ Test for successfully creating a workflow with a Gitlab repository. @@ -88,6 +88,8 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): Random user for testing. mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ workflow = WorkflowIn( git_commit_hash=random_hex_string(), @@ -97,10 +99,11 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): ).model_dump() response = await client.post(self.base_path, json=workflow, headers=random_user.auth_headers) assert response.status_code == status.HTTP_201_CREATED - created_workflow = response.json() - assert not created_workflow["private"] + created_workflow = WorkflowOut.model_validate(response.json()) + cleanup.add_task(delete_workflow, db=db, workflow_id=created_workflow.workflow_id) + assert not created_workflow.private - stmt = select(Workflow).where(Workflow._workflow_id == UUID(hex=created_workflow["workflow_id"]).bytes) + stmt = select(Workflow).where(Workflow._workflow_id == created_workflow.workflow_id.bytes) db_workflow = await db.scalar(stmt) assert db_workflow is not None @@ -108,6 +111,7 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): scm_file_name = SCM.generate_filename(db_workflow.workflow_id) assert scm_file_name in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() obj = mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(scm_file_name) + cleanup.add_task(obj.delete) with BytesIO() as f: obj.download_fileobj(f) f.seek(0) @@ -121,12 +125,6 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): assert provider.platform == "gitlab" assert provider.server == "https://gitlab.de" - await db.execute( - delete(Workflow).where(Workflow._workflow_id == UUID(hex=created_workflow["workflow_id"]).bytes) - ) - await db.commit() - obj.delete() - @pytest.mark.asyncio async def test_create_workflow_with_private_gitlab( self, @@ -134,6 +132,7 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): db: AsyncSession, random_user: UserWithAuthHeader, mock_s3_service: MockS3ServiceResource, + cleanup: CleanupList, ) -> None: """ Test for successfully creating a workflow with a private Gitlab repository. @@ -148,6 +147,8 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): Random user for testing. mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ token = random_lower_string(20) workflow = WorkflowIn( @@ -159,11 +160,12 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): ).model_dump() response = await client.post(self.base_path, json=workflow, headers=random_user.auth_headers) assert response.status_code == status.HTTP_201_CREATED - created_workflow = response.json() - assert created_workflow["private"] + created_workflow = WorkflowOut.model_validate(response.json()) + cleanup.add_task(delete_workflow, db=db, workflow_id=created_workflow.workflow_id) + assert created_workflow.private # Check if workflow is created in database - stmt = select(Workflow).where(Workflow._workflow_id == UUID(hex=created_workflow["workflow_id"]).bytes) + stmt = select(Workflow).where(Workflow._workflow_id == created_workflow.workflow_id.bytes) db_workflow = await db.scalar(stmt) assert db_workflow is not None # Check if token is saved @@ -173,6 +175,7 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): scm_file_name = SCM.generate_filename(db_workflow.workflow_id) assert scm_file_name in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() obj = mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(scm_file_name) + cleanup.add_task(obj.delete) with BytesIO() as f: obj.download_fileobj(f) f.seek(0) @@ -186,11 +189,6 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): assert provider.platform == "gitlab" assert provider.server == "https://gitlab.de" - # Cleanup after test - await db.execute(delete(Workflow).where(Workflow._workflow_id == db_workflow.workflow_id.bytes)) - await db.commit() - obj.delete() - @pytest.mark.asyncio async def test_create_workflow_with_private_github( self, @@ -198,6 +196,7 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): db: AsyncSession, random_user: UserWithAuthHeader, mock_s3_service: MockS3ServiceResource, + cleanup: CleanupList, ) -> None: """ Test for successfully creating a workflow with a private GitHub repository. @@ -212,6 +211,8 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): Random user for testing. mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ token = random_lower_string(20) workflow = WorkflowIn( @@ -223,11 +224,12 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): ).model_dump() response = await client.post(self.base_path, json=workflow, headers=random_user.auth_headers) assert response.status_code == status.HTTP_201_CREATED - created_workflow = response.json() - assert created_workflow["private"] + created_workflow = WorkflowOut.model_validate(response.json()) + cleanup.add_task(delete_workflow, db=db, workflow_id=created_workflow.workflow_id) + assert created_workflow.private # Check if workflow is created in database - stmt = select(Workflow).where(Workflow._workflow_id == UUID(hex=created_workflow["workflow_id"]).bytes) + stmt = select(Workflow).where(Workflow._workflow_id == created_workflow.workflow_id.bytes) db_workflow = await db.scalar(stmt) assert db_workflow is not None # Check if token is saved @@ -237,6 +239,7 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): scm_file_name = SCM.generate_filename(db_workflow.workflow_id) assert scm_file_name in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() obj = mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(scm_file_name) + cleanup.add_task(obj.delete) with BytesIO() as f: obj.download_fileobj(f) f.seek(0) @@ -249,17 +252,14 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): assert provider.name == "github" assert provider.user == "example-user" - # Cleanup after test - await db.execute(delete(Workflow).where(Workflow._workflow_id == db_workflow.workflow_id.bytes)) - await db.commit() - obj.delete() - @pytest.mark.asyncio async def test_create_workflow_with_error( self, db: AsyncSession, client: AsyncClient, random_user: UserWithAuthHeader, + mock_default_http_server: DefaultMockHTTPService, + cleanup: CleanupList, ) -> None: """ Test for creating a workflow where the file checks don't pass @@ -272,7 +272,13 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): HTTP Client to perform the request on. random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. + mock_default_http_server : app.tests.mocks.DefaultMockHTTPService + Mock http service for testing """ + mock_default_http_server.send_error = True + cleanup.add_task(mock_default_http_server.reset) workflow = WorkflowIn( git_commit_hash=random_hex_string(), name=random_lower_string(10), @@ -329,7 +335,7 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): Random workflow for testing. """ workflow = WorkflowIn( - git_commit_hash=random_workflow.versions[0].git_commit_hash, + git_commit_hash=random_workflow.versions[0].workflow_version_id, name=random_lower_string(10), short_description=random_lower_string(65), repository_url="https://github.de/example-user/example", @@ -369,6 +375,7 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): client: AsyncClient, random_user: UserWithAuthHeader, mock_s3_service: MockS3ServiceResource, + cleanup: CleanupList, ) -> None: """ Exhaustive Test for successfully creating a workflow with a workflow mode. @@ -383,6 +390,8 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): Random user for testing. mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ git_commit_hash = random_hex_string() workflow_mode = WorkflowModeIn( @@ -397,9 +406,11 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): ) response = await client.post(self.base_path, json=workflow.model_dump(), headers=random_user.auth_headers) assert response.status_code == status.HTTP_201_CREATED - created_workflow = response.json() - assert len(created_workflow["versions"][0]["modes"]) == 2 - mode_id = UUID(hex=created_workflow["versions"][0]["modes"][0]) + created_workflow = WorkflowOut.model_validate(response.json()) + cleanup.add_task(delete_workflow, db=db, workflow_id=created_workflow.workflow_id) + assert len(created_workflow.versions[0].modes) == 2 + mode_id = created_workflow.versions[0].modes[0] + cleanup.add_task(delete_workflow_mode, db=db, mode_id=mode_id) stmt = select(WorkflowMode).where(WorkflowMode._mode_id == mode_id.bytes) db_mode = await db.scalar(stmt) @@ -415,13 +426,6 @@ class TestWorkflowRoutesCreate(_TestWorkflowRoutes): is None ) - # Clean up after test - await db.execute( - delete(Workflow).where(Workflow._workflow_id == UUID(hex=created_workflow["workflow_id"]).bytes) - ) - await db.execute(delete(WorkflowMode).where(WorkflowMode._mode_id == mode_id.bytes)) - await db.commit() - class TestWorkflowRoutesList(_TestWorkflowRoutes): @pytest.mark.asyncio @@ -500,7 +504,7 @@ class TestWorkflowRoutesList(_TestWorkflowRoutes): response = await client.get( self.base_path, headers=random_second_user.auth_headers, - params={"developer_id": random_user.user.uid, "version_status": WorkflowVersion.Status.CREATED.name}, + params={"developer_id": str(random_user.user.uid), "version_status": WorkflowVersion.Status.CREATED.name}, ) assert response.status_code == status.HTTP_200_OK workflows = response.json() @@ -536,10 +540,10 @@ class TestWorkflowRoutesList(_TestWorkflowRoutes): assert response.status_code == status.HTTP_200_OK executions = response.json() assert len(executions) == 1 - execution = executions[0] - assert execution["workflow_id"] == str(random_workflow.workflow_id) - assert execution["workflow_execution_id"] == str(random_completed_workflow_execution.execution_id) - assert execution["git_commit_hash"] == random_completed_workflow_execution.workflow_version_id + execution = AnonymizedWorkflowExecution(**executions[0]) + assert execution.workflow_id == random_workflow.workflow_id + assert execution.workflow_execution_id == random_completed_workflow_execution.execution_id + assert execution.workflow_version_id == random_completed_workflow_execution.workflow_version_id class TestWorkflowRoutesGet(_TestWorkflowRoutes): @@ -565,10 +569,10 @@ class TestWorkflowRoutesGet(_TestWorkflowRoutes): headers=random_user.auth_headers, ) assert response.status_code == status.HTTP_200_OK - workflow = response.json() - assert workflow["workflow_id"] == str(random_workflow.workflow_id) - assert not workflow["private"] - assert len(workflow["versions"]) > 0 + workflow = WorkflowOut.model_validate(response.json()) + assert workflow.workflow_id == random_workflow.workflow_id + assert not workflow.private + assert len(workflow.versions) > 0 @pytest.mark.asyncio async def test_get_non_existing_workflow(self, client: AsyncClient, random_user: UserWithAuthHeader) -> None: @@ -617,8 +621,9 @@ class TestWorkflowRoutesGet(_TestWorkflowRoutes): assert response.status_code == status.HTTP_200_OK statistics = response.json() assert len(statistics) == 1 - assert statistics[0]["day"] == str(date.today()) - assert statistics[0]["count"] == 1 + stat = WorkflowStatistic.model_validate(statistics[0]) + assert stat.day == date.today() + assert stat.count == 1 class TestWorkflowRoutesDelete(_TestWorkflowRoutes): @@ -651,7 +656,7 @@ class TestWorkflowRoutesDelete(_TestWorkflowRoutes): assert response.status_code == status.HTTP_204_NO_CONTENT icon_slug = str(random_workflow.versions[0].icon_url).split("/")[-1] assert icon_slug not in mock_s3_service.Bucket(settings.ICON_BUCKET).objects.all_keys() - schema_file = random_workflow.versions[0].git_commit_hash + ".json" + schema_file = random_workflow.versions[0].workflow_version_id + ".json" assert schema_file not in mock_s3_service.Bucket(settings.WORKFLOW_BUCKET).objects.all_keys() @pytest.mark.asyncio @@ -683,7 +688,7 @@ class TestWorkflowRoutesDelete(_TestWorkflowRoutes): assert response.status_code == status.HTTP_204_NO_CONTENT icon_slug = str(random_private_workflow.versions[0].icon_url).split("/")[-1] assert icon_slug not in mock_s3_service.Bucket(settings.ICON_BUCKET).objects.all_keys() - schema_file = random_private_workflow.versions[0].git_commit_hash + ".json" + schema_file = random_private_workflow.versions[0].workflow_version_id + ".json" assert schema_file not in mock_s3_service.Bucket(settings.WORKFLOW_BUCKET).objects.all_keys() scm_file = SCM.generate_filename(random_private_workflow.workflow_id) assert scm_file not in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() @@ -723,7 +728,7 @@ class TestWorkflowRoutesDelete(_TestWorkflowRoutes): assert response.status_code == status.HTTP_204_NO_CONTENT icon_slug = str(random_workflow.versions[0].icon_url).split("/")[-1] assert icon_slug not in mock_s3_service.Bucket(settings.ICON_BUCKET).objects.all_keys() - schema_file = f"{random_workflow.versions[0].git_commit_hash}-{random_workflow_mode.mode_id.hex}.json" + schema_file = f"{random_workflow.versions[0].workflow_version_id}-{random_workflow_mode.mode_id.hex}.json" assert schema_file not in mock_s3_service.Bucket(settings.WORKFLOW_BUCKET).objects.all_keys() mode_db = await db.scalar( @@ -762,10 +767,10 @@ class TestWorkflowRoutesUpdate(_TestWorkflowRoutes): headers=random_user.auth_headers, ) assert response.status_code == status.HTTP_201_CREATED - created_version = response.json() - assert created_version["git_commit_hash"] == git_commit_hash - assert created_version["status"] == WorkflowVersion.Status.CREATED - assert created_version["icon_url"] == str(random_workflow.versions[0].icon_url) + created_version = WorkflowVersionOut.model_validate(response.json()) + assert created_version.workflow_version_id == git_commit_hash + assert created_version.status == WorkflowVersion.Status.CREATED + assert created_version.icon_url == random_workflow.versions[0].icon_url stmt = select(WorkflowVersion).where(WorkflowVersion.git_commit_hash == git_commit_hash) db_version = await db.scalar(stmt) @@ -781,6 +786,7 @@ class TestWorkflowRoutesUpdate(_TestWorkflowRoutes): random_workflow: WorkflowOut, random_workflow_mode: WorkflowMode, mock_s3_service: MockS3ServiceResource, + cleanup: CleanupList, ) -> None: """ Test for successfully updating a workflow and adding new modes. @@ -799,6 +805,8 @@ class TestWorkflowRoutesUpdate(_TestWorkflowRoutes): Random workflow mode for testing mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ git_commit_hash = random_hex_string() version_update = WorkflowUpdate( @@ -816,20 +824,24 @@ class TestWorkflowRoutesUpdate(_TestWorkflowRoutes): headers=random_user.auth_headers, ) assert response.status_code == status.HTTP_201_CREATED - created_version = response.json() - assert created_version["git_commit_hash"] == git_commit_hash - assert created_version["status"] == WorkflowVersion.Status.CREATED - assert created_version["icon_url"] == str(random_workflow.versions[0].icon_url) - assert created_version["modes"] is not None - assert len(created_version["modes"]) == 2 + created_version = WorkflowVersionOut.model_validate(response.json()) + assert created_version.workflow_version_id == git_commit_hash + assert created_version.status == WorkflowVersion.Status.CREATED + assert created_version.icon_url == random_workflow.versions[0].icon_url + assert created_version.modes is not None + assert len(created_version.modes) == 2 stmt = select(WorkflowVersion).where(WorkflowVersion.git_commit_hash == git_commit_hash) db_version = await db.scalar(stmt) assert db_version is not None assert db_version.status == WorkflowVersion.Status.CREATED - new_mode_id = next((UUID(m) for m in created_version["modes"] if m != str(random_workflow_mode.mode_id)), None) + new_mode_id = next((m for m in created_version.modes if m != random_workflow_mode.mode_id), None) assert new_mode_id is not None + cleanup.add_task(delete_workflow_mode, db=db, mode_id=new_mode_id) + cleanup.add_task( + mock_s3_service.Bucket(settings.WORKFLOW_BUCKET).Object(f"{git_commit_hash}-{new_mode_id.hex}.json").delete + ) assert ( f"{git_commit_hash}-{new_mode_id.hex}.json" in mock_s3_service.Bucket(settings.WORKFLOW_BUCKET).objects.all_keys() @@ -839,11 +851,6 @@ class TestWorkflowRoutesUpdate(_TestWorkflowRoutes): in mock_s3_service.Bucket(settings.WORKFLOW_BUCKET).objects.all_keys() ) - # Clean up after test - await db.execute(delete(WorkflowMode).where(WorkflowMode._mode_id == new_mode_id.bytes)) - await db.commit() - mock_s3_service.Bucket(settings.WORKFLOW_BUCKET).Object(f"{git_commit_hash}-{new_mode_id.hex}.json").delete() - @pytest.mark.asyncio async def test_update_workflow_with_delete_non_existing_modes( self, @@ -866,13 +873,12 @@ class TestWorkflowRoutesUpdate(_TestWorkflowRoutes): random_workflow : app.schemas.workflow.WorkflowOut Random workflow for testing. """ - git_commit_hash = random_hex_string() version_update = WorkflowUpdate( - git_commit_hash=git_commit_hash, version=random_lower_string(8), delete_modes=[str(uuid4())] - ).model_dump_json() + git_commit_hash=random_hex_string(), version=random_lower_string(8), delete_modes=[str(uuid4())] + ).model_dump() response = await client.post( "/".join([self.base_path, str(random_workflow.workflow_id), "update"]), - data=version_update, # type: ignore[arg-type] + json=version_update, headers=random_user.auth_headers, ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -909,20 +915,19 @@ class TestWorkflowRoutesUpdate(_TestWorkflowRoutes): version_update = WorkflowUpdate( git_commit_hash=git_commit_hash, version=random_lower_string(8), - delete_modes=[str(random_workflow_mode.mode_id)], - ).model_dump_json() + delete_modes=[random_workflow_mode.mode_id], + ) response = await client.post( "/".join([self.base_path, str(random_workflow.workflow_id), "update"]), - data=version_update, # type: ignore[arg-type] + json=version_update.model_dump(), headers=random_user.auth_headers, ) assert response.status_code == status.HTTP_201_CREATED - created_version = response.json() - assert created_version["git_commit_hash"] == git_commit_hash - assert created_version["status"] == WorkflowVersion.Status.CREATED - assert created_version["icon_url"] == str(random_workflow.versions[0].icon_url) - assert created_version["modes"] is not None - assert created_version["modes"] is None or len(created_version["modes"]) == 0 + created_version = WorkflowVersionOut.model_validate(response.json()) + assert created_version.workflow_version_id == git_commit_hash + assert created_version.status == WorkflowVersion.Status.CREATED + assert created_version.icon_url == random_workflow.versions[0].icon_url + assert created_version.modes is None or len(created_version.modes) == 0 stmt = select(WorkflowVersion).where(WorkflowVersion.git_commit_hash == git_commit_hash) db_version = await db.scalar(stmt) @@ -943,6 +948,7 @@ class TestWorkflowRoutesUpdate(_TestWorkflowRoutes): random_workflow: WorkflowOut, random_workflow_mode: WorkflowMode, mock_s3_service: MockS3ServiceResource, + cleanup: CleanupList, ) -> None: """ Test for successfully updating a workflow with adding a new mode and delete an old mode. @@ -961,6 +967,8 @@ class TestWorkflowRoutesUpdate(_TestWorkflowRoutes): Random workflow mode for testing mock_s3_service : app.tests.mocks.mock_s3_resource.MockS3ServiceResource Mock S3 Service to manipulate objects. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ git_commit_hash = random_hex_string() version_update = WorkflowUpdate( @@ -972,26 +980,30 @@ class TestWorkflowRoutesUpdate(_TestWorkflowRoutes): name=random_lower_string(10), entrypoint=random_lower_string(16), schema_path=random_lower_string() ) ], - ).model_dump_json() + ).model_dump() response = await client.post( "/".join([self.base_path, str(random_workflow.workflow_id), "update"]), - data=version_update, # type: ignore[arg-type] + json=version_update, headers=random_user.auth_headers, ) assert response.status_code == status.HTTP_201_CREATED - created_version = response.json() - assert created_version["git_commit_hash"] == git_commit_hash - assert created_version["status"] == WorkflowVersion.Status.CREATED - assert created_version["icon_url"] == str(random_workflow.versions[0].icon_url) - assert len(created_version["modes"]) == 1 - assert created_version["modes"][0] != str(random_workflow_mode.mode_id) + created_version = WorkflowVersionOut.model_validate(response.json()) + assert created_version.workflow_version_id == git_commit_hash + assert created_version.status == WorkflowVersion.Status.CREATED + assert created_version.icon_url == random_workflow.versions[0].icon_url + assert len(created_version.modes) == 1 + assert created_version.modes[0] != random_workflow_mode.mode_id stmt = select(WorkflowVersion).where(WorkflowVersion.git_commit_hash == git_commit_hash) db_version = await db.scalar(stmt) assert db_version is not None assert db_version.status == WorkflowVersion.Status.CREATED - mode_id = UUID(created_version["modes"][0]) + mode_id = created_version.modes[0] + cleanup.add_task(delete_workflow_mode, db=db, mode_id=mode_id) + cleanup.add_task( + mock_s3_service.Bucket(settings.WORKFLOW_BUCKET).Object(f"{git_commit_hash}-{mode_id.hex}.json").delete + ) assert ( f"{git_commit_hash}-{mode_id.hex}.json" in mock_s3_service.Bucket(settings.WORKFLOW_BUCKET).objects.all_keys() @@ -1001,14 +1013,15 @@ class TestWorkflowRoutesUpdate(_TestWorkflowRoutes): not in mock_s3_service.Bucket(settings.WORKFLOW_BUCKET).objects.all_keys() ) - # Clean up after test - await db.execute(delete(WorkflowMode).where(WorkflowMode._mode_id == mode_id.bytes)) - await db.commit() - mock_s3_service.Bucket(settings.WORKFLOW_BUCKET).Object(f"{git_commit_hash}-{mode_id.hex}.json").delete() - @pytest.mark.asyncio async def test_update_workflow_with_error( - self, db: AsyncSession, client: AsyncClient, random_user: UserWithAuthHeader, random_workflow: WorkflowOut + self, + db: AsyncSession, + client: AsyncClient, + random_user: UserWithAuthHeader, + random_workflow: WorkflowOut, + mock_default_http_server: DefaultMockHTTPService, + cleanup: CleanupList, ) -> None: """ Test for updating a workflow where the file checks don't pass @@ -1023,7 +1036,13 @@ class TestWorkflowRoutesUpdate(_TestWorkflowRoutes): Random user for testing. random_workflow : app.schemas.workflow.WorkflowOut Random workflow for testing. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. + mock_default_http_server : app.tests.mocks.DefaultMockHTTPService + Mock http service for testing """ + mock_default_http_server.send_error = True + cleanup.add_task(mock_default_http_server.reset) version_update = WorkflowUpdate( git_commit_hash=random_hex_string(), version=random_lower_string(8), @@ -1060,7 +1079,7 @@ class TestWorkflowRoutesUpdate(_TestWorkflowRoutes): Random workflow for testing. """ version_update = WorkflowUpdate( - git_commit_hash=random_workflow.versions[0].git_commit_hash, + git_commit_hash=random_workflow.versions[0].workflow_version_id, version=random_lower_string(8), ).model_dump() response = await client.post( diff --git a/app/tests/api/test_workflow_credentials.py b/app/tests/api/test_workflow_credentials.py index b61a47132819eca8ce7115345b2b9a3855d81ae7..cffeaf5e80691f49586615ef5ee69fcb50be0887 100644 --- a/app/tests/api/test_workflow_credentials.py +++ b/app/tests/api/test_workflow_credentials.py @@ -9,9 +9,10 @@ from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings -from app.schemas.workflow import WorkflowCredentialsIn, WorkflowOut +from app.schemas.workflow import WorkflowCredentialsIn, WorkflowCredentialsOut, WorkflowOut from app.scm import SCM from app.tests.mocks.mock_s3_resource import MockS3ServiceResource +from app.tests.utils.cleanup import CleanupList from app.tests.utils.user import UserWithAuthHeader from app.tests.utils.utils import random_lower_string @@ -29,6 +30,7 @@ class TestWorkflowCredentialsRoutesUpdate(_TestWorkflowCredentialRoutes): random_user: UserWithAuthHeader, random_workflow: WorkflowOut, mock_s3_service: MockS3ServiceResource, + cleanup: CleanupList, ) -> None: """ Test for updating the credentials on a workflow formerly hosted in a public git repository. @@ -45,6 +47,8 @@ class TestWorkflowCredentialsRoutesUpdate(_TestWorkflowCredentialRoutes): Mock S3 Service to manipulate objects. random_workflow : app.schemas.workflow.WorkflowOut Random workflow for testing. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ credentials = WorkflowCredentialsIn(token=random_lower_string(15)) response = await client.put( @@ -61,6 +65,7 @@ class TestWorkflowCredentialsRoutesUpdate(_TestWorkflowCredentialRoutes): assert db_workflow.credentials_token == credentials.token scm_file = mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(SCM.generate_filename(db_workflow.workflow_id)) + cleanup.add_task(scm_file.delete) with BytesIO() as f: scm_file.download_fileobj(f) f.seek(0) @@ -69,9 +74,6 @@ class TestWorkflowCredentialsRoutesUpdate(_TestWorkflowCredentialRoutes): assert len(scm.providers) == 1 assert scm.providers[0].password == credentials.token - # Clean up after test - scm_file.delete() - @pytest.mark.asyncio async def test_update_workflow_credentials_on_private_workflow( self, @@ -367,7 +369,8 @@ class TestWorkflowCredentialsRoutesGet(_TestWorkflowCredentialRoutes): headers=random_user.auth_headers, ) assert response.status_code == status.HTTP_200_OK - assert response.json()["token"] is None + cred = WorkflowCredentialsOut.model_validate(response.json()) + assert cred.token is None @pytest.mark.asyncio async def test_get_workflow_credentials_of_private_workflow( @@ -400,8 +403,8 @@ class TestWorkflowCredentialsRoutesGet(_TestWorkflowCredentialRoutes): stmt = select(Workflow).where(Workflow._workflow_id == random_private_workflow.workflow_id.bytes) db_workflow = await db.scalar(stmt) assert db_workflow is not None - - assert response.json()["token"] == db_workflow.credentials_token + cred = WorkflowCredentialsOut.model_validate(response.json()) + assert cred.token == db_workflow.credentials_token @pytest.mark.asyncio async def test_get_workflow_credentials_as_foreign_user( diff --git a/app/tests/api/test_workflow_execution.py b/app/tests/api/test_workflow_execution.py index 8c4ab5199949b067fd7d101d8ef0be76a4f2c91e..4a4046ef8ddea451e06c56caa754b57b0ec0611d 100644 --- a/app/tests/api/test_workflow_execution.py +++ b/app/tests/api/test_workflow_execution.py @@ -1,5 +1,5 @@ from io import BytesIO -from uuid import UUID, uuid4 +from uuid import uuid4 import pytest from clowmdb.models import Bucket, Workflow, WorkflowExecution, WorkflowMode, WorkflowVersion @@ -11,12 +11,13 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.git_repository import build_repository from app.schemas.workflow import WorkflowOut -from app.schemas.workflow_execution import DevWorkflowExecutionIn, WorkflowExecutionIn +from app.schemas.workflow_execution import DevWorkflowExecutionIn, WorkflowExecutionIn, WorkflowExecutionOut from app.schemas.workflow_mode import WorkflowModeIn from app.scm import SCM, SCMProvider from app.tests.mocks.mock_s3_resource import MockS3ServiceResource from app.tests.mocks.mock_slurm_cluster import MockSlurmCluster from app.tests.utils.bucket import add_permission_for_bucket +from app.tests.utils.cleanup import CleanupList from app.tests.utils.user import UserWithAuthHeader from app.tests.utils.utils import random_hex_string, random_lower_string @@ -55,25 +56,25 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Mock Slurm cluster to inspect submitted jobs. """ execution_in = WorkflowExecutionIn(workflow_version_id=random_workflow_version.git_commit_hash, parameters={}) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_201_CREATED - execution_response = response.json() - assert execution_response["workflow_version_id"] == execution_in.workflow_version_id - assert execution_response["user_id"] == random_user.user.uid - assert execution_response["status"] == WorkflowExecution.WorkflowExecutionStatus.PENDING - assert execution_response["workflow_version_id"] == random_workflow_version.git_commit_hash - - execution_id = UUID(execution_response["execution_id"]) + execution_response = WorkflowExecutionOut.model_validate(response.json()) + assert execution_response.workflow_version_id == execution_in.workflow_version_id + assert execution_response.executor_id == random_user.user.uid + assert execution_response.status == WorkflowExecution.WorkflowExecutionStatus.PENDING + assert execution_response.workflow_version_id == random_workflow_version.git_commit_hash assert ( - f"params-{UUID(hex=execution_response['execution_id']).hex }.json" + f"params-{execution_response.execution_id.hex }.json" in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() ) - job = mock_slurm_cluster.get_job_by_name(execution_id) + job = mock_slurm_cluster.get_job_by_name(execution_response.execution_id) assert job is not None - assert job["job"]["name"] == execution_id.hex + assert job["job"]["name"] == execution_response.execution_id.hex assert job["job"]["current_working_directory"] == settings.SLURM_WORKING_DIRECTORY - assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_response.execution_id.hex[:16] assert "NXF_SCM_FILE" not in job["job"]["environment"].keys() nextflow_script = job["script"] @@ -116,23 +117,24 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): """ execution_in = WorkflowExecutionIn(workflow_version_id=random_workflow_version.git_commit_hash, parameters={}) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_201_CREATED - execution_response = response.json() - assert execution_response["workflow_version_id"] == execution_in.workflow_version_id - assert execution_response["user_id"] == random_user.user.uid - assert execution_response["status"] == WorkflowExecution.WorkflowExecutionStatus.PENDING - assert execution_response["workflow_version_id"] == random_workflow_version.git_commit_hash + execution_response = WorkflowExecutionOut.model_validate(response.json()) + assert execution_response.workflow_version_id == execution_in.workflow_version_id + assert execution_response.executor_id == random_user.user.uid + assert execution_response.status == WorkflowExecution.WorkflowExecutionStatus.PENDING + assert execution_response.workflow_version_id == random_workflow_version.git_commit_hash assert ( - f"params-{UUID(hex=execution_response['execution_id']).hex }.json" + f"params-{execution_response.execution_id.hex}.json" in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() ) - execution_id = UUID(execution_response["execution_id"]) - job = mock_slurm_cluster.get_job_by_name(execution_id) + job = mock_slurm_cluster.get_job_by_name(execution_response.execution_id) assert job is not None - assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_response.execution_id.hex[:16] nextflow_script = job["script"] assert "-hub github" in nextflow_script @@ -150,6 +152,7 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): random_workflow_version: WorkflowVersion, mock_s3_service: MockS3ServiceResource, mock_slurm_cluster: MockSlurmCluster, + cleanup: CleanupList, ) -> None: """ Test for starting a workflow execution from a public GitLab repository. @@ -170,8 +173,13 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Mock S3 Service to manipulate objects. mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster Mock Slurm cluster to inspect submitted jobs. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ - + scm_obj = mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object( + SCM.generate_filename(random_workflow.workflow_id) + ) + cleanup.add_task(scm_obj.delete) stmt = ( update(Workflow) .where(Workflow._workflow_id == random_workflow_version.workflow_id.bytes) @@ -189,29 +197,27 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): with BytesIO() as f: SCM([scm_provider]).serialize(f) f.seek(0) - scm_obj = mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object( - SCM.generate_filename(random_workflow.workflow_id) - ) scm_obj.upload_fileobj(f) execution_in = WorkflowExecutionIn(workflow_version_id=random_workflow_version.git_commit_hash, parameters={}) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_201_CREATED - execution_response = response.json() - assert execution_response["workflow_version_id"] == execution_in.workflow_version_id - assert execution_response["user_id"] == random_user.user.uid - assert execution_response["status"] == WorkflowExecution.WorkflowExecutionStatus.PENDING - assert execution_response["workflow_version_id"] == random_workflow_version.git_commit_hash + execution_response = WorkflowExecutionOut.model_validate(response.json()) + assert execution_response.workflow_version_id == execution_in.workflow_version_id + assert execution_response.executor_id == random_user.user.uid + assert execution_response.status == WorkflowExecution.WorkflowExecutionStatus.PENDING + assert execution_response.workflow_version_id == random_workflow_version.git_commit_hash assert ( - f"params-{UUID(hex=execution_response['execution_id']).hex }.json" + f"params-{execution_response.execution_id.hex }.json" in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() ) - execution_id = UUID(execution_response["execution_id"]) - job = mock_slurm_cluster.get_job_by_name(execution_id) + job = mock_slurm_cluster.get_job_by_name(execution_response.execution_id) assert job is not None - assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_response.execution_id.hex[:16] nextflow_script = job["script"] assert f"-hub {SCMProvider.generate_name(random_workflow.workflow_id)}" in nextflow_script @@ -219,8 +225,6 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): assert "export NXF_SCM_FILE" in nextflow_script assert f"-revision {random_workflow_version.git_commit_hash}" in nextflow_script - scm_obj.delete() - @pytest.mark.asyncio async def test_start_too_many_workflow_executions( self, @@ -246,7 +250,7 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): active_execution_counter = 0 while active_execution_counter < settings.ACTIVE_WORKFLOW_EXECUTION_LIMIT: execution = WorkflowExecution( - user_id=random_user.user.uid, + _executor_id=random_user.user.uid.bytes, workflow_version_id=random_workflow_version.git_commit_hash, slurm_job_id=1, ) @@ -254,7 +258,9 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): active_execution_counter += 1 await db.commit() execution_in = WorkflowExecutionIn(workflow_version_id=random_workflow_version.git_commit_hash, parameters={}) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.asyncio @@ -275,7 +281,9 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): workflow_version_id=random_hex_string(), parameters={}, ) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio @@ -310,7 +318,9 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): workflow_version_id=random_workflow_version.git_commit_hash, parameters={}, ) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.asyncio @@ -336,7 +346,9 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): workflow_version_id=random_workflow_version.git_commit_hash, parameters={"dir": "s3://" + random_lower_string()}, ) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -365,7 +377,9 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): workflow_version_id=random_workflow_version.git_commit_hash, parameters={"dir": f"s3://{random_bucket.name}"}, ) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_201_CREATED @pytest.mark.asyncio @@ -399,7 +413,7 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): parameters={"dir": f"s3://{random_bucket.name}"}, ) response = await client.post( - self.base_path, headers=random_second_user.auth_headers, json=execution_in.model_dump() + self.base_path, headers=random_second_user.auth_headers, json=execution_in.model_dump(exclude_none=True) ) assert response.status_code == status.HTTP_201_CREATED @@ -430,7 +444,7 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): parameters={"dir": f"s3://{random_bucket.name}"}, ) response = await client.post( - self.base_path, headers=random_second_user.auth_headers, json=execution_in.model_dump() + self.base_path, headers=random_second_user.auth_headers, json=execution_in.model_dump(exclude_none=True) ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -464,20 +478,25 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): parameters={}, logs_s3_path="s3://" + random_bucket.name, ) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_201_CREATED - assert response.json().get("logs_s3_path", None) is not None + execution_response = WorkflowExecutionOut.model_validate(response.json()) + assert execution_response.logs_s3_path is not None - execution_id = UUID(response.json()["execution_id"]) - job = mock_slurm_cluster.get_job_by_name(execution_id) + job = mock_slurm_cluster.get_job_by_name(execution_response.execution_id) assert job is not None - assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_response.execution_id.hex[:16] nextflow_script = job["script"] assert f"-with-report s3://{random_bucket.name}/run" in nextflow_script assert f"-with-timeline s3://{random_bucket.name}/run" in nextflow_script assert f"-revision {random_workflow_version.git_commit_hash}" in nextflow_script - assert f"cp $NEXTFLOW_LOG s3://{random_bucket.name}/run-{execution_id.hex}/nextflow.log" in nextflow_script + assert ( + f"cp $NEXTFLOW_LOG s3://{random_bucket.name}/run-{execution_response.execution_id.hex}/nextflow.log" + in nextflow_script + ) @pytest.mark.asyncio async def test_start_workflow_execution_with_good_debug_s3_path( @@ -509,18 +528,20 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): parameters={}, debug_s3_path="s3://" + random_bucket.name, ) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_201_CREATED - assert response.json().get("debug_s3_path", None) is not None + execution_response = WorkflowExecutionOut.model_validate(response.json()) + assert execution_response.debug_s3_path is not None - execution_id = UUID(response.json()["execution_id"]) - job = mock_slurm_cluster.get_job_by_name(execution_id) + job = mock_slurm_cluster.get_job_by_name(execution_response.execution_id) assert job is not None - assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_response.execution_id.hex[:16] nextflow_script = job["script"] assert ( - f"cp --include '*/.command*' $NXF_WORK/ s3://{random_bucket.name}/run-{execution_id.hex}/" + f"cp --include '*/.command*' $NXF_WORK/ s3://{random_bucket.name}/run-{execution_response.execution_id.hex}/" in nextflow_script ) @@ -554,17 +575,20 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): parameters={}, provenance_s3_path="s3://" + random_bucket.name, ) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_201_CREATED - assert response.json().get("provenance_s3_path", None) is not None + execution_response = WorkflowExecutionOut.model_validate(response.json()) + assert execution_response.provenance_s3_path is not None - execution_id = UUID(response.json()["execution_id"]) - job = mock_slurm_cluster.get_job_by_name(execution_id) + job = mock_slurm_cluster.get_job_by_name(execution_response.execution_id) assert job is not None - assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_response.execution_id.hex[:16] nextflow_script = job["script"] assert ( - f"cp --include 'nf-prov_*' $NXF_WORK/ s3://{random_bucket.name}/run-{execution_id.hex}/" in nextflow_script + f"cp --include 'nf-prov_*' $NXF_WORK/ s3://{random_bucket.name}/run-{execution_response.execution_id.hex}/" + in nextflow_script ) @pytest.mark.asyncio @@ -591,7 +615,9 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): parameters={}, logs_s3_path="s3://" + random_lower_string(), ) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -618,7 +644,9 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): parameters={}, debug_s3_path="s3://" + random_lower_string(), ) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -645,7 +673,9 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): parameters={}, provenance_s3_path="s3://" + random_lower_string(), ) - response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump()) + response = await client.post( + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) + ) assert response.status_code == status.HTTP_400_BAD_REQUEST @pytest.mark.asyncio @@ -674,19 +704,18 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): execution_in = WorkflowExecutionIn( workflow_version_id=random_workflow_version.git_commit_hash, parameters={}, - mode=random_workflow_mode.mode_id, - ).model_dump_json() + mode_id=random_workflow_mode.mode_id, + ) response = await client.post( - self.base_path, headers=random_user.auth_headers, data=execution_in # type: ignore[arg-type] + self.base_path, headers=random_user.auth_headers, json=execution_in.model_dump(exclude_none=True) ) assert response.status_code == status.HTTP_201_CREATED - response_body = response.json() - assert response_body["mode"] == str(random_workflow_mode.mode_id) + execution_response = WorkflowExecutionOut.model_validate(response.json()) + assert execution_response.mode_id == random_workflow_mode.mode_id - execution_id = UUID(response.json()["execution_id"]) - job = mock_slurm_cluster.get_job_by_name(execution_id) + job = mock_slurm_cluster.get_job_by_name(execution_response.execution_id) assert job is not None - assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_response.execution_id.hex[:16] nextflow_script = job["script"] assert f"-entry {random_workflow_mode.entrypoint}" in nextflow_script @@ -715,11 +744,9 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Mock Slurm cluster to inspect submitted jobs. """ execution_in = WorkflowExecutionIn( - workflow_version_id=random_workflow_version.git_commit_hash, parameters={}, mode=uuid4() - ).model_dump_json() - response = await client.post( - self.base_path, headers=random_user.auth_headers, data=execution_in # type: ignore[arg-type] - ) + workflow_version_id=random_workflow_version.git_commit_hash, parameters={}, mode_id=uuid4() + ).model_dump() + response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio @@ -748,7 +775,7 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): execution_in = WorkflowExecutionIn( workflow_version_id=random_workflow_version.git_commit_hash, parameters={}, - ).model_dump() + ).model_dump(exclude_none=True) response = await client.post(self.base_path, headers=random_user.auth_headers, json=execution_in) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -760,6 +787,8 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): random_user: UserWithAuthHeader, random_workflow_version: WorkflowVersion, random_workflow: WorkflowOut, + mock_slurm_cluster: MockSlurmCluster, + cleanup: CleanupList, ) -> None: """ Test for starting a workflow execution from a public GitHub repository. @@ -776,16 +805,25 @@ class TestWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Random workflow for testing. random_workflow_version : clowmdb.models.WorkflowVersion Random workflow version for testing. + mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster + Mock Slurm cluster to inspect submitted jobs. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ execution_in = WorkflowExecutionIn(workflow_version_id=random_workflow_version.git_commit_hash, parameters={}) + mock_slurm_cluster.send_error = True + + def repair_slurm_cluster() -> None: + mock_slurm_cluster.send_error = False + + cleanup.add_task(repair_slurm_cluster) response = await client.post( self.base_path, headers=random_user.auth_headers, - json=execution_in.model_dump(), - params={"raise_slurm_error": True}, + json=execution_in.model_dump(exclude_none=True), ) assert response.status_code == status.HTTP_201_CREATED - execution_id = UUID(response.json()["execution_id"]) + execution_id = WorkflowExecutionOut.model_validate(response.json()).execution_id stmt = select(WorkflowExecution).where(WorkflowExecution._execution_id == execution_id.bytes) execution_db = await db.scalar(stmt) @@ -829,19 +867,18 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): f"{self.base_path}/arbitrary", headers=random_user.auth_headers, json=execution_in.model_dump() ) assert response.status_code == status.HTTP_201_CREATED - execution_response = response.json() - assert execution_response["user_id"] == random_user.user.uid - assert execution_response["status"] == WorkflowExecution.WorkflowExecutionStatus.PENDING + execution_response = WorkflowExecutionOut.model_validate(response.json()) + assert execution_response.executor_id == random_user.user.uid + assert execution_response.status == WorkflowExecution.WorkflowExecutionStatus.PENDING assert ( - f"params-{UUID(hex=execution_response['execution_id']).hex }.json" + f"params-{execution_response.execution_id.hex }.json" in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() ) - execution_id = UUID(execution_response["execution_id"]) - job = mock_slurm_cluster.get_job_by_name(execution_id) + job = mock_slurm_cluster.get_job_by_name(execution_response.execution_id) assert job is not None - assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_response.execution_id.hex[:16] assert "NXF_SCM_FILE" not in job["job"]["environment"].keys() nextflow_script = job["script"] @@ -880,19 +917,18 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): f"{self.base_path}/arbitrary", headers=random_user.auth_headers, json=execution_in.model_dump() ) assert response.status_code == status.HTTP_201_CREATED - execution_response = response.json() - assert execution_response["user_id"] == random_user.user.uid - assert execution_response["status"] == WorkflowExecution.WorkflowExecutionStatus.PENDING + execution_response = WorkflowExecutionOut.model_validate(response.json()) + assert execution_response.executor_id == random_user.user.uid + assert execution_response.status == WorkflowExecution.WorkflowExecutionStatus.PENDING assert ( - f"params-{UUID(hex=execution_response['execution_id']).hex }.json" + f"params-{execution_response.execution_id.hex}.json" in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() ) - execution_id = UUID(execution_response["execution_id"]) - job = mock_slurm_cluster.get_job_by_name(execution_id) + job = mock_slurm_cluster.get_job_by_name(execution_response.execution_id) assert job is not None - assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_response.execution_id.hex[:16] assert "NXF_SCM_FILE" not in job["job"]["environment"].keys() nextflow_script = job["script"] @@ -909,6 +945,7 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): random_user: UserWithAuthHeader, mock_s3_service: MockS3ServiceResource, mock_slurm_cluster: MockSlurmCluster, + cleanup: CleanupList, ) -> None: """ Test for starting a workflow execution with an arbitrary Gitlab repository. @@ -923,6 +960,8 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Mock S3 Service to manipulate objects. mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster Mock Slurm cluster to inspect submitted jobs. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ execution_in = DevWorkflowExecutionIn( git_commit_hash=random_hex_string(), repository_url="https://gitlab.com/example-user/example", parameters={} @@ -931,18 +970,20 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): f"{self.base_path}/arbitrary", headers=random_user.auth_headers, json=execution_in.model_dump() ) assert response.status_code == status.HTTP_201_CREATED - execution_response = response.json() - assert execution_response["user_id"] == random_user.user.uid - assert execution_response["status"] == WorkflowExecution.WorkflowExecutionStatus.PENDING + execution_response = WorkflowExecutionOut.model_validate(response.json()) + scm_obj = mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object( + SCM.generate_filename(execution_response.execution_id) + ) + cleanup.add_task(scm_obj.delete) - execution_id = UUID(execution_response["execution_id"]) + assert execution_response.executor_id == random_user.user.uid + assert execution_response.status == WorkflowExecution.WorkflowExecutionStatus.PENDING # Check if params file is created - params_file_name = f"params-{execution_id.hex}.json" + params_file_name = f"params-{execution_response.execution_id.hex}.json" assert params_file_name in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() + cleanup.add_task(mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(params_file_name).delete) - scm_file_name = SCM.generate_filename(execution_id) - scm_obj = mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(scm_file_name) with BytesIO() as f: scm_obj.download_fileobj(f) f.seek(0) @@ -952,25 +993,21 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): assert len(scm.providers) == 1 provider = scm.providers[0] assert provider.password is None - assert provider.name == SCMProvider.generate_name(execution_id) + assert provider.name == SCMProvider.generate_name(execution_response.execution_id) assert provider.platform == "gitlab" assert provider.server == "https://gitlab.com" - job = mock_slurm_cluster.get_job_by_name(execution_id) + job = mock_slurm_cluster.get_job_by_name(execution_response.execution_id) assert job is not None - assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_response.execution_id.hex[:16] nextflow_script = job["script"] - assert f"-hub {SCMProvider.generate_name(execution_id)}" in nextflow_script + assert f"-hub {SCMProvider.generate_name(execution_response.execution_id)}" in nextflow_script assert "-entry" not in nextflow_script assert f"-revision {execution_in.git_commit_hash}" in nextflow_script assert f"run {execution_in.repository_url}" in nextflow_script assert "export NXF_SCM_FILE" in nextflow_script - # Clean up after test - mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(params_file_name).delete() - scm_obj.delete() - @pytest.mark.asyncio async def test_start_dev_workflow_execution_from_private_gitlab( self, @@ -978,6 +1015,7 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): random_user: UserWithAuthHeader, mock_s3_service: MockS3ServiceResource, mock_slurm_cluster: MockSlurmCluster, + cleanup: CleanupList, ) -> None: """ Test for starting a workflow execution with an arbitrary private Gitlab repository. @@ -992,6 +1030,8 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Mock S3 Service to manipulate objects. mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster Mock Slurm cluster to inspect submitted jobs. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ token = random_lower_string(15) execution_in = DevWorkflowExecutionIn( @@ -1004,20 +1044,21 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): f"{self.base_path}/arbitrary", headers=random_user.auth_headers, json=execution_in.model_dump() ) assert response.status_code == status.HTTP_201_CREATED - execution_response = response.json() - assert execution_response["user_id"] == random_user.user.uid - assert execution_response["status"] == WorkflowExecution.WorkflowExecutionStatus.PENDING + execution_response = WorkflowExecutionOut.model_validate(response.json()) - execution_id = UUID(hex=execution_response["execution_id"]) + assert execution_response.executor_id == random_user.user.uid + assert execution_response.status == WorkflowExecution.WorkflowExecutionStatus.PENDING # Check if params file is created - params_file_name = f"params-{execution_id.hex}.json" + params_file_name = f"params-{execution_response.execution_id.hex}.json" assert params_file_name in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() + cleanup.add_task(mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(params_file_name).delete) # Check if SCM file is created - scm_file_name = SCM.generate_filename(execution_id) + scm_file_name = SCM.generate_filename(execution_response.execution_id) assert scm_file_name in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() scm_obj = mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(scm_file_name) + cleanup.add_task(scm_obj.delete) with BytesIO() as f: scm_obj.download_fileobj(f) f.seek(0) @@ -1027,25 +1068,21 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): assert len(scm.providers) == 1 provider = scm.providers[0] assert provider.password == token - assert provider.name == SCMProvider.generate_name(execution_id) + assert provider.name == SCMProvider.generate_name(execution_response.execution_id) assert provider.platform == "gitlab" assert provider.server == "https://gitlab.com" - job = mock_slurm_cluster.get_job_by_name(execution_id) + job = mock_slurm_cluster.get_job_by_name(execution_response.execution_id) assert job is not None - assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_response.execution_id.hex[:16] nextflow_script = job["script"] - assert f"-hub {SCMProvider.generate_name(execution_id)}" in nextflow_script + assert f"-hub {SCMProvider.generate_name(execution_response.execution_id)}" in nextflow_script assert "-entry" not in nextflow_script assert f"-revision {execution_in.git_commit_hash}" in nextflow_script assert "export NXF_SCM_FILE" in nextflow_script assert f"run {execution_in.repository_url}" in nextflow_script - # Clean up after test - mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(params_file_name).delete() - scm_obj.delete() - @pytest.mark.asyncio async def test_start_dev_workflow_execution_from_private_github( self, @@ -1053,6 +1090,7 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): random_user: UserWithAuthHeader, mock_s3_service: MockS3ServiceResource, mock_slurm_cluster: MockSlurmCluster, + cleanup: CleanupList, ) -> None: """ Test for starting a workflow execution with an arbitrary private GitHub repository. @@ -1067,6 +1105,8 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): Mock S3 Service to manipulate objects. mock_slurm_cluster : app.tests.mocks.mock_slurm_cluster.MockSlurmCluster Mock Slurm cluster to inspect submitted jobs. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ token = random_lower_string(15) execution_in = DevWorkflowExecutionIn( @@ -1079,20 +1119,20 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): f"{self.base_path}/arbitrary", headers=random_user.auth_headers, json=execution_in.model_dump() ) assert response.status_code == status.HTTP_201_CREATED - execution_response = response.json() - assert execution_response["user_id"] == random_user.user.uid - assert execution_response["status"] == WorkflowExecution.WorkflowExecutionStatus.PENDING - - execution_id = UUID(hex=execution_response["execution_id"]) + execution_response = WorkflowExecutionOut.model_validate(response.json()) + assert execution_response.executor_id == random_user.user.uid + assert execution_response.status == WorkflowExecution.WorkflowExecutionStatus.PENDING # Check if params file is created - params_file_name = f"params-{execution_id.hex}.json" + params_file_name = f"params-{execution_response.execution_id.hex}.json" assert params_file_name in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() + cleanup.add_task(mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(params_file_name).delete) # Check if SCM file is created - scm_file_name = SCM.generate_filename(execution_id) + scm_file_name = SCM.generate_filename(execution_response.execution_id) assert scm_file_name in mock_s3_service.Bucket(settings.PARAMS_BUCKET).objects.all_keys() scm_obj = mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(scm_file_name) + cleanup.add_task(scm_obj.delete) with BytesIO() as f: scm_obj.download_fileobj(f) f.seek(0) @@ -1105,9 +1145,9 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): assert provider.name == "github" assert provider.user == "example-user" - job = mock_slurm_cluster.get_job_by_name(execution_id) + job = mock_slurm_cluster.get_job_by_name(execution_response.execution_id) assert job is not None - assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_id.hex[:16] + assert job["job"]["environment"]["TOWER_WORKSPACE_ID"] == execution_response.execution_id.hex[:16] nextflow_script = job["script"] assert "-hub github" in nextflow_script @@ -1116,10 +1156,6 @@ class TestDevWorkflowExecutionRoutesCreate(_TestWorkflowExecutionRoutes): assert "export NXF_SCM_FILE" in nextflow_script assert f"run {execution_in.repository_url}" in nextflow_script - # Clean up after test - mock_s3_service.Bucket(settings.PARAMS_BUCKET).Object(params_file_name).delete() - scm_obj.delete() - @pytest.mark.asyncio async def test_start_dev_workflow_execution_with_unknown_repository( self, diff --git a/app/tests/api/test_workflow_mode.py b/app/tests/api/test_workflow_mode.py index be506d55e25106de8d9cbe1086b212d9508e502d..c1dd6568fcae3ffde77b79322a43f7f709fc4eca 100644 --- a/app/tests/api/test_workflow_mode.py +++ b/app/tests/api/test_workflow_mode.py @@ -5,6 +5,7 @@ from clowmdb.models import WorkflowMode from fastapi import status from httpx import AsyncClient +from app.schemas.workflow_mode import WorkflowModeOut from app.tests.utils.user import UserWithAuthHeader @@ -31,9 +32,9 @@ class TestWorkflowModeRoutesGet: f"{self._base_path}/{str(random_workflow_mode.mode_id)}", headers=random_user.auth_headers ) assert response.status_code == status.HTTP_200_OK - mode = response.json() + mode = WorkflowModeOut.model_validate(response.json()) - assert mode["mode_id"] == str(random_workflow_mode.mode_id) - assert mode["name"] == random_workflow_mode.name - assert mode["entrypoint"] == random_workflow_mode.entrypoint - assert mode["schema_path"] == random_workflow_mode.schema_path + assert mode.mode_id == random_workflow_mode.mode_id + assert mode.name == random_workflow_mode.name + assert mode.entrypoint == random_workflow_mode.entrypoint + assert mode.schema_path == random_workflow_mode.schema_path diff --git a/app/tests/api/test_workflow_version.py b/app/tests/api/test_workflow_version.py index b16cff454902f8dd9a5fa76c87f790fe141100c4..45a806d40e89cec7340ba4728bc138574da769a8 100644 --- a/app/tests/api/test_workflow_version.py +++ b/app/tests/api/test_workflow_version.py @@ -12,8 +12,10 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.endpoints.workflow_version import DocumentationEnum from app.core.config import settings from app.schemas.workflow import WorkflowOut +from app.schemas.workflow_version import WorkflowVersion as WorkflowVersionOut from app.schemas.workflow_version import WorkflowVersionStatus from app.tests.mocks.mock_s3_resource import MockS3ServiceResource +from app.tests.utils.cleanup import CleanupList from app.tests.utils.user import UserWithAuthHeader from app.tests.utils.utils import random_hex_string @@ -45,15 +47,15 @@ class TestWorkflowVersionRoutesGet(_TestWorkflowVersionRoutes): self.base_path, str(random_workflow.workflow_id), "versions", - random_workflow.versions[0].git_commit_hash, + random_workflow.versions[0].workflow_version_id, ] ), headers=random_user.auth_headers, ) assert response.status_code == status.HTTP_200_OK - version = response.json() - assert version["workflow_id"] == str(random_workflow.workflow_id) - assert version["git_commit_hash"] == random_workflow.versions[0].git_commit_hash + version = WorkflowVersionOut.model_validate(response.json()) + assert version.workflow_id == random_workflow.workflow_id + assert version.workflow_version_id == random_workflow.versions[0].workflow_version_id @pytest.mark.asyncio async def test_get_non_existing_version( @@ -103,9 +105,9 @@ class TestWorkflowVersionRoutesList(_TestWorkflowVersionRoutes): assert response.status_code == status.HTTP_200_OK versions = response.json() assert len(versions) == 1 - version = versions[0] - assert version["workflow_id"] == str(random_workflow.workflow_id) - assert version["git_commit_hash"] == random_workflow.versions[0].git_commit_hash + version = WorkflowVersionOut.model_validate(versions[0]) + assert version.workflow_id == random_workflow.workflow_id + assert version.workflow_version_id == random_workflow.versions[0].workflow_version_id class TestWorkflowVersionRoutesUpdate(_TestWorkflowVersionRoutes): @@ -131,7 +133,7 @@ class TestWorkflowVersionRoutesUpdate(_TestWorkflowVersionRoutes): self.base_path, str(random_workflow.workflow_id), "versions", - random_workflow.versions[0].git_commit_hash, + random_workflow.versions[0].workflow_version_id, "status", ] ), @@ -140,9 +142,9 @@ class TestWorkflowVersionRoutesUpdate(_TestWorkflowVersionRoutes): ) assert response.status_code == status.HTTP_200_OK - version = response.json() - assert version["git_commit_hash"] == random_workflow.versions[0].git_commit_hash - assert version["status"] == WorkflowVersion.Status.PUBLISHED + version = WorkflowVersionOut.model_validate(response.json()) + assert version.workflow_version_id == random_workflow.versions[0].workflow_version_id + assert version.status == WorkflowVersion.Status.PUBLISHED @pytest.mark.asyncio async def test_deprecate_workflow_version( @@ -166,7 +168,7 @@ class TestWorkflowVersionRoutesUpdate(_TestWorkflowVersionRoutes): self.base_path, str(random_workflow.workflow_id), "versions", - random_workflow.versions[0].git_commit_hash, + random_workflow.versions[0].workflow_version_id, "deprecate", ] ), @@ -174,9 +176,9 @@ class TestWorkflowVersionRoutesUpdate(_TestWorkflowVersionRoutes): ) assert response.status_code == status.HTTP_200_OK - version = response.json() - assert version["git_commit_hash"] == random_workflow.versions[0].git_commit_hash - assert version["status"] == WorkflowVersion.Status.DEPRECATED + version = WorkflowVersionOut.model_validate(response.json()) + assert version.workflow_version_id == random_workflow.versions[0].workflow_version_id + assert version.status == WorkflowVersion.Status.DEPRECATED @pytest.mark.asyncio async def test_update_non_existing_workflow_version_status( @@ -228,11 +230,13 @@ class TestWorkflowVersionRoutesUpdate(_TestWorkflowVersionRoutes): class TestWorkflowVersionRoutesGetDocumentation(_TestWorkflowVersionRoutes): @pytest.mark.asyncio + @pytest.mark.parametrize("document", [d for d in DocumentationEnum]) async def test_download_workflow_version_documentation( self, client: AsyncClient, random_user: UserWithAuthHeader, random_workflow_version: WorkflowVersion, + document: DocumentationEnum, ) -> None: """ Test downloading all the different documentation file for a workflow version. @@ -245,29 +249,32 @@ class TestWorkflowVersionRoutesGetDocumentation(_TestWorkflowVersionRoutes): Random user for testing. random_workflow_version : clowmdb.models.WorkflowVersion Random workflow version for testing. + document : app.api.endpoints.workflow_version.DocumentationEnum + All possible documents as pytest parameter. """ - for document in DocumentationEnum: - response = await client.get( - "/".join( - [ - self.base_path, - str(random_workflow_version.workflow_id), - "versions", - random_workflow_version.git_commit_hash, - "documentation", - ] - ), - headers=random_user.auth_headers, - params={"document": document.value}, - ) - assert response.status_code == status.HTTP_200_OK + response = await client.get( + "/".join( + [ + self.base_path, + str(random_workflow_version.workflow_id), + "versions", + random_workflow_version.git_commit_hash, + "documentation", + ] + ), + headers=random_user.auth_headers, + params={"document": document.value}, + ) + assert response.status_code == status.HTTP_200_OK @pytest.mark.asyncio + @pytest.mark.parametrize("document", [d for d in DocumentationEnum]) async def test_download_workflow_version_documentation_with_non_existing_mode( self, client: AsyncClient, random_user: UserWithAuthHeader, random_workflow_version: WorkflowVersion, + document: DocumentationEnum, ) -> None: """ Test downloading all the different documentation file for a workflow version with a non-existing workflow mode. @@ -280,30 +287,33 @@ class TestWorkflowVersionRoutesGetDocumentation(_TestWorkflowVersionRoutes): Random user for testing. random_workflow_version : clowmdb.models.WorkflowVersion Random workflow version for testing. + document : app.api.endpoints.workflow_version.DocumentationEnum + All possible documents as pytest parameter. """ - for document in DocumentationEnum: - response = await client.get( - "/".join( - [ - self.base_path, - str(random_workflow_version.workflow_id), - "versions", - random_workflow_version.git_commit_hash, - "documentation", - ] - ), - headers=random_user.auth_headers, - params={"document": document.value, "mode_id": str(uuid4())}, - ) - assert response.status_code == status.HTTP_404_NOT_FOUND + response = await client.get( + "/".join( + [ + self.base_path, + str(random_workflow_version.workflow_id), + "versions", + random_workflow_version.git_commit_hash, + "documentation", + ] + ), + headers=random_user.auth_headers, + params={"document": document.value, "mode_id": str(uuid4())}, + ) + assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio + @pytest.mark.parametrize("document", [d for d in DocumentationEnum]) async def test_download_workflow_version_documentation_with_existing_mode( self, client: AsyncClient, random_user: UserWithAuthHeader, random_workflow_version: WorkflowVersion, random_workflow_mode: WorkflowMode, + document: DocumentationEnum, ) -> None: """ Test downloading all the different documentation file for a workflow version with a workflow mode. @@ -318,22 +328,23 @@ class TestWorkflowVersionRoutesGetDocumentation(_TestWorkflowVersionRoutes): Random workflow version for testing. random_workflow_mode : clowmdb.models.WorkflowMode Random workflow mode for testing. + document : app.api.endpoints.workflow_version.DocumentationEnum + All possible documents as pytest parameter. """ - for document in DocumentationEnum: - response = await client.get( - "/".join( - [ - self.base_path, - str(random_workflow_version.workflow_id), - "versions", - random_workflow_version.git_commit_hash, - "documentation", - ] - ), - headers=random_user.auth_headers, - params={"document": document.value, "mode_id": str(random_workflow_mode.mode_id)}, - ) - assert response.status_code == status.HTTP_200_OK + response = await client.get( + "/".join( + [ + self.base_path, + str(random_workflow_version.workflow_id), + "versions", + random_workflow_version.git_commit_hash, + "documentation", + ] + ), + headers=random_user.auth_headers, + params={"document": document.value, "mode_id": str(random_workflow_mode.mode_id)}, + ) + assert response.status_code == status.HTTP_200_OK class TestWorkflowVersionIconRoutes(_TestWorkflowVersionRoutes): @@ -345,6 +356,7 @@ class TestWorkflowVersionIconRoutes(_TestWorkflowVersionRoutes): random_workflow_version: WorkflowVersion, mock_s3_service: MockS3ServiceResource, db: AsyncSession, + cleanup: CleanupList, ) -> None: """ Test for uploading a new icon for a workflow version @@ -361,6 +373,8 @@ class TestWorkflowVersionIconRoutes(_TestWorkflowVersionRoutes): Mock S3 Service to manipulate objects. db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ img_buffer = BytesIO() Image.linear_gradient(mode="L").save(img_buffer, "PNG") @@ -382,6 +396,7 @@ class TestWorkflowVersionIconRoutes(_TestWorkflowVersionRoutes): assert response.status_code == status.HTTP_201_CREATED icon_url = response.json()["icon_url"] icon_slug = icon_url.split("/")[-1] + cleanup.add_task(mock_s3_service.Bucket(settings.ICON_BUCKET).Object(icon_slug).delete) assert icon_slug in mock_s3_service.Bucket(settings.ICON_BUCKET).objects.all_keys() db_version = await db.scalar( select(WorkflowVersion).where(WorkflowVersion.git_commit_hash == random_workflow_version.git_commit_hash) @@ -389,9 +404,6 @@ class TestWorkflowVersionIconRoutes(_TestWorkflowVersionRoutes): assert db_version is not None assert db_version.icon_slug == icon_slug - # Clean up - mock_s3_service.Bucket(settings.ICON_BUCKET).Object(icon_slug).delete() - @pytest.mark.asyncio async def test_upload_new_icon_as_text( self, diff --git a/app/tests/conftest.py b/app/tests/conftest.py index b184dfb39c03e2e359bd7a0d5b28b93822d5ab28..aa6141dc007839769badf6ae3a2dee9faec5d884 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -1,9 +1,10 @@ import asyncio import time +from contextlib import asynccontextmanager from functools import partial from io import BytesIO from secrets import token_urlsafe -from typing import AsyncIterator, Dict, Iterator +from typing import AsyncIterator, Iterator import httpx import pytest @@ -17,10 +18,10 @@ from clowmdb.models import ( WorkflowVersion, workflow_mode_association_table, ) -from pytrie import SortedStringTrie as Trie from sqlalchemy import insert, select, update from sqlalchemy.ext.asyncio import AsyncSession +from app.api import dependencies from app.api.dependencies import get_db, get_decode_jwt_function, get_httpx_client, get_s3_resource from app.core.config import settings from app.git_repository import build_repository @@ -32,6 +33,7 @@ from app.tests.mocks.mock_opa_service import MockOpaService from app.tests.mocks.mock_s3_resource import MockS3ServiceResource from app.tests.mocks.mock_slurm_cluster import MockSlurmCluster from app.tests.utils.bucket import create_random_bucket +from app.tests.utils.cleanup import CleanupList from app.tests.utils.user import UserWithAuthHeader, create_random_user, decode_mock_token, get_authorization_headers from app.tests.utils.utils import random_hex_string, random_lower_string @@ -77,54 +79,81 @@ def mock_opa_service() -> Iterator[MockOpaService]: mock_opa.reset() +@pytest.fixture(scope="session") +def mock_default_http_server() -> Iterator[DefaultMockHTTPService]: + mock_server = DefaultMockHTTPService() + yield mock_server + mock_server.reset() + + +@pytest_asyncio.fixture(scope="session") +async def mock_client( + mock_opa_service: MockOpaService, + mock_slurm_cluster: MockSlurmCluster, + mock_default_http_server: DefaultMockHTTPService, +) -> AsyncIterator[httpx.AsyncClient]: + def mock_request_handler(request: httpx.Request) -> httpx.Response: + url = str(request.url) + handler: MockHTTPService + if url.startswith(str(settings.SLURM_ENDPOINT)): + handler = mock_slurm_cluster + elif url.startswith(str(settings.OPA_URI)): + handler = mock_opa_service + else: + handler = mock_default_http_server + return handler.handle_request(request=request) + + async with httpx.AsyncClient(transport=httpx.MockTransport(mock_request_handler)) as http_client: + yield http_client + + +@pytest.fixture(autouse=True) +def monkeypatch_background_connections( + monkeypatch: pytest.MonkeyPatch, + db: AsyncSession, + mock_s3_service: MockS3ServiceResource, + mock_client: httpx.AsyncClient, +) -> None: + """ + Patch the functions to get resources in background tasks with mock resources. + """ + + @asynccontextmanager + async def get_http_client() -> AsyncIterator[httpx.AsyncClient]: + yield mock_client + + async def get_patch_db() -> AsyncIterator[AsyncSession]: + yield db + + monkeypatch.setattr(dependencies, "get_db", get_patch_db) + monkeypatch.setattr(dependencies, "get_s3_resource", lambda: mock_s3_service) + monkeypatch.setattr(dependencies, "get_background_http_client", get_http_client) + + +@pytest.fixture(autouse=True) +def monkeypatch_env(monkeypatch: pytest.MonkeyPatch) -> None: + """ + Set the appropriate ENV variables for all tests + """ + monkeypatch.setenv("OTLP_GRPC_ENDPOINT", "") + monkeypatch.setenv("DEV_SYSTEM", "True") + monkeypatch.setenv("ACTIVE_WORKFLOW_EXECUTION_LIMIT", "3") + monkeypatch.setenv("SLURM_JOB_MONITORING", "NOMONITORING") + + @pytest_asyncio.fixture(scope="module") async def client( mock_s3_service: MockS3ServiceResource, db: AsyncSession, mock_opa_service: MockOpaService, mock_slurm_cluster: MockSlurmCluster, + mock_client: httpx.AsyncClient, ) -> AsyncIterator[httpx.AsyncClient]: """ Fixture for creating a TestClient and perform HTTP Request on it. Overrides several dependencies. """ - endpoints: Dict[str, MockHTTPService] = { - str(settings.SLURM_ENDPOINT): mock_slurm_cluster, - str(settings.OPA_URI): mock_opa_service, - } - # data structure to easily find the appropriate mock request based on the URL - t = Trie(**endpoints) - - async def get_mock_httpx_client( - raise_opa_error: bool = False, raise_slurm_error: bool = False, raise_error: bool = False - ) -> AsyncIterator[httpx.AsyncClient]: - """ - FastAPI Dependency to get an async httpx client with mock transport. - - Parameters - ---------- - raise_opa_error : bool - Flag to raise an error when querying the OPA service. Query parameter. - raise_slurm_error : bool - Flag to raise an error when querying the Slurm service. Query parameter. - raise_error : bool - Flag to raise an error. Query parameter. - - Returns - ------- - client : AsyncIterator[httpx.AsyncClient] - Http client with mock transports. - """ - errors = locals() # catch all error flags in a dict - - def mock_request_handler(request: httpx.Request) -> httpx.Response: - url = str(request.url) - return t.longest_prefix_value(url, DefaultMockHTTPService()).handle_request(request, **errors) - - async with httpx.AsyncClient(transport=httpx.MockTransport(mock_request_handler)) as http_client: - yield http_client - - app.dependency_overrides[get_httpx_client] = get_mock_httpx_client + app.dependency_overrides[get_httpx_client] = lambda: mock_client app.dependency_overrides[get_s3_resource] = lambda: mock_s3_service app.dependency_overrides[get_decode_jwt_function] = lambda: partial(decode_mock_token, secret=jwt_secret) app.dependency_overrides[get_db] = lambda: db @@ -150,9 +179,9 @@ async def random_user(db: AsyncSession, mock_opa_service: MockOpaService) -> Asy Create a random user and deletes him afterward. """ user = await create_random_user(db) - mock_opa_service.add_user(user.uid, privileged=True) + mock_opa_service.add_user(user.lifescience_id, privileged=True) yield UserWithAuthHeader(user=user, auth_headers=get_authorization_headers(uid=user.uid, secret=jwt_secret)) - mock_opa_service.delete_user(user.uid) + mock_opa_service.delete_user(user.lifescience_id) await db.delete(user) await db.commit() @@ -163,9 +192,9 @@ async def random_second_user(db: AsyncSession, mock_opa_service: MockOpaService) Create a random second user and deletes him afterward. """ user = await create_random_user(db) - mock_opa_service.add_user(user.uid) + mock_opa_service.add_user(user.lifescience_id) yield UserWithAuthHeader(user=user, auth_headers=get_authorization_headers(uid=user.uid, secret=jwt_secret)) - mock_opa_service.delete_user(user.uid) + mock_opa_service.delete_user(user.lifescience_id) await db.delete(user) await db.commit() @@ -176,9 +205,9 @@ async def random_third_user(db: AsyncSession, mock_opa_service: MockOpaService) Create a random third user and deletes him afterward. """ user = await create_random_user(db) - mock_opa_service.add_user(user.uid) + mock_opa_service.add_user(user.lifescience_id) yield UserWithAuthHeader(user=user, auth_headers=get_authorization_headers(uid=user.uid, secret=jwt_secret)) - mock_opa_service.delete_user(user.uid) + mock_opa_service.delete_user(user.lifescience_id) await db.delete(user) await db.commit() @@ -209,7 +238,7 @@ async def random_workflow( name=random_lower_string(10), repository_url="https://github.de/example-user/example", short_description=random_lower_string(65), - developer_id=random_user.user.uid, + _developer_id=random_user.user.uid.bytes, ) db.add(workflow_db) await db.commit() @@ -274,7 +303,9 @@ async def random_workflow_version(db: AsyncSession, random_workflow: WorkflowOut """ Create a random workflow version. Will be deleted, when the workflow is deleted. """ - stmt = select(WorkflowVersion).where(WorkflowVersion.git_commit_hash == random_workflow.versions[0].git_commit_hash) + stmt = select(WorkflowVersion).where( + WorkflowVersion.git_commit_hash == random_workflow.versions[0].workflow_version_id + ) return await db.scalar(stmt) @@ -290,7 +321,7 @@ async def random_running_workflow_execution( Create a random running workflow execution. Will be deleted, when the user is deleted. """ execution = WorkflowExecution( - user_id=random_user.user.uid, + _executor_id=random_user.user.uid.bytes, workflow_version_id=random_workflow_version.git_commit_hash, slurm_job_id=-1, ) @@ -357,3 +388,13 @@ async def random_workflow_mode( parameter_schema.delete() await db.delete(mode) await db.commit() + + +@pytest_asyncio.fixture(scope="function") +async def cleanup(db: AsyncSession) -> AsyncIterator[CleanupList]: + """ + Yields a Cleanup object where (async) functions can be registered which get executed after a (failed) test + """ + cleanup_list = CleanupList() + yield cleanup_list + await cleanup_list.empty_queue() diff --git a/app/tests/crud/test_bucket.py b/app/tests/crud/test_bucket.py index b1417f97ea601cda2411defbb30017f6733dc0fa..31f86514bc25ec3edcc39cf48b3038abaf66a5bf 100644 --- a/app/tests/crud/test_bucket.py +++ b/app/tests/crud/test_bucket.py @@ -398,7 +398,7 @@ class TestBucketCRUDCheckAccess: assert not check -class TestBucketCRUDCheckBuckeExist: +class TestBucketCRUDCheckBucketExist: @pytest.mark.asyncio async def test_check_bucket_exist( self, diff --git a/app/tests/crud/test_user.py b/app/tests/crud/test_user.py index 7db6677c3e6f660439694f0b03e9b3d02b7f185a..d6c0a3b6c888c6a3b86755fc4df7fd940f0aed54 100644 --- a/app/tests/crud/test_user.py +++ b/app/tests/crud/test_user.py @@ -1,9 +1,10 @@ +from uuid import uuid4 + import pytest from sqlalchemy.ext.asyncio import AsyncSession from app.crud import CRUDUser from app.tests.utils.user import UserWithAuthHeader -from app.tests.utils.utils import random_hex_string class TestUserCRUD: @@ -33,5 +34,5 @@ class TestUserCRUD: db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. """ - user = await CRUDUser.get(db, random_hex_string()) + user = await CRUDUser.get(db, uuid4()) assert user is None diff --git a/app/tests/crud/test_workflow.py b/app/tests/crud/test_workflow.py index 637c2d8002b12564b319c24ab4fac0a91061832c..948e78de51a8bdd276dfe7a00fe00077982ea9b8 100644 --- a/app/tests/crud/test_workflow.py +++ b/app/tests/crud/test_workflow.py @@ -4,13 +4,14 @@ from uuid import uuid4 import pytest from clowmdb.models import Workflow, WorkflowExecution, WorkflowVersion -from sqlalchemy import delete, select +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from app.crud import CRUDWorkflow from app.schemas.workflow import WorkflowIn, WorkflowOut from app.schemas.workflow_mode import WorkflowModeIn +from app.tests.utils.cleanup import CleanupList, delete_workflow from app.tests.utils.user import UserWithAuthHeader from app.tests.utils.utils import random_hex_string, random_lower_string @@ -180,8 +181,8 @@ class TestWorkflowCRUDGet: assert statistics[0].started_at == date.today() assert statistics[0].workflow_id == random_workflow.workflow_id assert statistics[0].workflow_execution_id == random_completed_workflow_execution.execution_id - assert statistics[0].git_commit_hash == random_completed_workflow_execution.workflow_version_id - assert statistics[0].pseudo_uid != random_completed_workflow_execution.user_id + assert statistics[0].workflow_version_id == random_completed_workflow_execution.workflow_version_id + assert statistics[0].pseudo_uid != random_completed_workflow_execution.executor_id.hex assert statistics[0].developer_id == random_workflow.developer_id @pytest.mark.asyncio @@ -205,8 +206,8 @@ class TestWorkflowCRUDGet: assert statistics[0].started_at == date.today() assert statistics[0].workflow_id == random_workflow.workflow_id assert statistics[0].workflow_execution_id == random_completed_workflow_execution.execution_id - assert statistics[0].git_commit_hash == random_completed_workflow_execution.workflow_version_id - assert statistics[0].pseudo_uid != random_completed_workflow_execution.user_id + assert statistics[0].workflow_version_id == random_completed_workflow_execution.workflow_version_id + assert statistics[0].pseudo_uid != random_completed_workflow_execution.executor_id.hex assert statistics[0].developer_id == random_workflow.developer_id @pytest.mark.asyncio @@ -225,7 +226,7 @@ class TestWorkflowCRUDGet: random_completed_workflow_execution : clowmdb.models.WorkflowExecution Random workflow execution for testing. """ - statistics = await CRUDWorkflow.developer_statistics(db, developer_id=random_lower_string(40)) + statistics = await CRUDWorkflow.developer_statistics(db, developer_id=uuid4()) assert len(statistics) == 0 @pytest.mark.asyncio @@ -249,8 +250,8 @@ class TestWorkflowCRUDGet: assert statistics[0].started_at == date.today() assert statistics[0].workflow_id == random_workflow.workflow_id assert statistics[0].workflow_execution_id == random_completed_workflow_execution.execution_id - assert statistics[0].git_commit_hash == random_completed_workflow_execution.workflow_version_id - assert statistics[0].pseudo_uid != random_completed_workflow_execution.user_id + assert statistics[0].workflow_version_id == random_completed_workflow_execution.workflow_version_id + assert statistics[0].pseudo_uid != random_completed_workflow_execution.executor_id.hex assert statistics[0].developer_id == random_workflow.developer_id @pytest.mark.asyncio @@ -293,8 +294,8 @@ class TestWorkflowCRUDGet: assert statistics[0].started_at == date.today() assert statistics[0].workflow_id == random_workflow.workflow_id assert statistics[0].workflow_execution_id == random_completed_workflow_execution.execution_id - assert statistics[0].git_commit_hash == random_completed_workflow_execution.workflow_version_id - assert statistics[0].pseudo_uid != random_completed_workflow_execution.user_id + assert statistics[0].workflow_version_id == random_completed_workflow_execution.workflow_version_id + assert statistics[0].pseudo_uid != random_completed_workflow_execution.executor_id.hex assert statistics[0].developer_id == random_workflow.developer_id @pytest.mark.asyncio @@ -337,8 +338,8 @@ class TestWorkflowCRUDGet: assert statistics[0].started_at == date.today() assert statistics[0].workflow_id == random_workflow.workflow_id assert statistics[0].workflow_execution_id == random_completed_workflow_execution.execution_id - assert statistics[0].git_commit_hash == random_completed_workflow_execution.workflow_version_id - assert statistics[0].pseudo_uid != random_completed_workflow_execution.user_id + assert statistics[0].workflow_version_id == random_completed_workflow_execution.workflow_version_id + assert statistics[0].pseudo_uid != random_completed_workflow_execution.executor_id.hex assert statistics[0].developer_id == random_workflow.developer_id @pytest.mark.asyncio @@ -363,7 +364,9 @@ class TestWorkflowCRUDGet: class TestWorkflowCRUDCreate: @pytest.mark.asyncio - async def test_create_workflow(self, db: AsyncSession, random_user: UserWithAuthHeader) -> None: + async def test_create_workflow( + self, db: AsyncSession, random_user: UserWithAuthHeader, cleanup: CleanupList + ) -> None: """ Test for creating a workflow in CRUD Repository. @@ -373,6 +376,8 @@ class TestWorkflowCRUDCreate: Async database session to perform query on. random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ workflow_in = WorkflowIn( git_commit_hash=random_hex_string(), @@ -380,19 +385,19 @@ class TestWorkflowCRUDCreate: short_description=random_lower_string(65), repository_url="https://github.com/example/example", ) - workflow = await CRUDWorkflow.create(db, workflow=workflow_in, developer=random_user.user.uid) + workflow = await CRUDWorkflow.create(db, workflow=workflow_in, developer_id=random_user.user.uid) assert workflow is not None + cleanup.add_task(delete_workflow, db=db, workflow_id=workflow.workflow_id) stmt = select(Workflow).where(Workflow._workflow_id == workflow.workflow_id.bytes) created_workflow = await db.scalar(stmt) assert created_workflow is not None assert created_workflow == workflow - await db.execute(delete(Workflow).where(Workflow._workflow_id == workflow.workflow_id.bytes)) - await db.commit() - @pytest.mark.asyncio - async def test_create_workflow_with_mode(self, db: AsyncSession, random_user: UserWithAuthHeader) -> None: + async def test_create_workflow_with_mode( + self, db: AsyncSession, random_user: UserWithAuthHeader, cleanup: CleanupList + ) -> None: """ Test for creating a workflow with a mode in CRUD Repository. @@ -402,6 +407,8 @@ class TestWorkflowCRUDCreate: Async database session to perform query on. random_user : app.tests.utils.user.UserWithAuthHeader Random user for testing. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ workflow_in = WorkflowIn( git_commit_hash=random_hex_string(), @@ -414,8 +421,9 @@ class TestWorkflowCRUDCreate: ) ], ) - workflow = await CRUDWorkflow.create(db, workflow=workflow_in, developer=random_user.user.uid) + workflow = await CRUDWorkflow.create(db, workflow=workflow_in, developer_id=random_user.user.uid) assert workflow is not None + cleanup.add_task(delete_workflow, db=db, workflow_id=workflow.workflow_id) stmt = ( select(Workflow) @@ -429,9 +437,6 @@ class TestWorkflowCRUDCreate: assert len(created_workflow.versions) == 1 assert len(created_workflow.versions[0].workflow_modes) == 1 - await db.execute(delete(Workflow).where(Workflow._workflow_id == workflow.workflow_id.bytes)) - await db.commit() - @pytest.mark.asyncio async def test_create_workflow_credentials(self, db: AsyncSession, random_workflow: WorkflowOut) -> None: """ diff --git a/app/tests/crud/test_workflow_execution.py b/app/tests/crud/test_workflow_execution.py index 815eaf8dea2671e9f335c48a141dd074850976f3..c6ad283eb77fdff03c345fb10546084228dce401 100644 --- a/app/tests/crud/test_workflow_execution.py +++ b/app/tests/crud/test_workflow_execution.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.crud import CRUDWorkflowExecution from app.schemas.workflow_execution import DevWorkflowExecutionIn, WorkflowExecutionIn from app.tests.utils.user import UserWithAuthHeader -from app.tests.utils.utils import random_hex_string, random_lower_string +from app.tests.utils.utils import random_hex_string class TestWorkflowExecutionCRUDCreate: @@ -34,7 +34,7 @@ class TestWorkflowExecutionCRUDCreate: ) workflow_execution_db = await CRUDWorkflowExecution.create(db, workflow_execution, random_user.user.uid) assert workflow_execution_db - assert workflow_execution_db.user_id == random_user.user.uid + assert workflow_execution_db.executor_id == random_user.user.uid assert workflow_execution_db.workflow_version_id == random_workflow_version.git_commit_hash assert workflow_execution_db.status == WorkflowExecution.WorkflowExecutionStatus.PENDING @@ -43,7 +43,7 @@ class TestWorkflowExecutionCRUDCreate: ) assert workflow_execution_db - assert workflow_execution_db.user_id == random_user.user.uid + assert workflow_execution_db.executor_id == random_user.user.uid assert workflow_execution_db.workflow_version_id == random_workflow_version.git_commit_hash assert workflow_execution_db.status == WorkflowExecution.WorkflowExecutionStatus.PENDING @@ -66,14 +66,14 @@ class TestWorkflowExecutionCRUDCreate: ) workflow_execution_db = await CRUDWorkflowExecution.create(db, workflow_execution, random_user.user.uid) assert workflow_execution_db - assert workflow_execution_db.user_id == random_user.user.uid + assert workflow_execution_db.executor_id == random_user.user.uid assert workflow_execution_db.status == WorkflowExecution.WorkflowExecutionStatus.PENDING workflow_execution_db = await db.scalar( select(WorkflowExecution).where(WorkflowExecution._execution_id == workflow_execution_db.execution_id.bytes) ) assert workflow_execution_db - assert workflow_execution_db.user_id == random_user.user.uid + assert workflow_execution_db.executor_id == random_user.user.uid assert workflow_execution_db.status == WorkflowExecution.WorkflowExecutionStatus.PENDING @@ -143,10 +143,13 @@ class TestWorkflowExecutionCRUDList: random_running_workflow_execution : clowmdb.models.WorkflowExecution Random workflow execution for testing. """ - executions = await CRUDWorkflowExecution.list(db, uid=random_running_workflow_execution.user_id) + executions = await CRUDWorkflowExecution.list(db, executor_id=random_running_workflow_execution.executor_id) assert len(executions) > 0 assert sum(1 for execution in executions if execution == random_running_workflow_execution) == 1 - assert sum(1 for execution in executions if execution.user_id == random_running_workflow_execution.user_id) >= 1 + assert ( + sum(1 for execution in executions if execution.executor_id == random_running_workflow_execution.executor_id) + >= 1 + ) @pytest.mark.asyncio async def test_get_list_workflow_executions_of_non_existing_user( @@ -162,7 +165,7 @@ class TestWorkflowExecutionCRUDList: random_running_workflow_execution : clowmdb.models.WorkflowExecution Random workflow execution for testing. """ - executions = await CRUDWorkflowExecution.list(db, uid=random_lower_string()) + executions = await CRUDWorkflowExecution.list(db, executor_id=uuid4()) assert len(executions) == 0 assert sum(1 for execution in executions if execution == random_running_workflow_execution) == 0 @@ -247,7 +250,7 @@ class TestWorkflowExecutionCRUDLUpdate: random_running_workflow_execution : clowmdb.models.WorkflowExecution Random workflow execution for testing. """ - await CRUDWorkflowExecution.cancel(db, random_running_workflow_execution.execution_id) + await CRUDWorkflowExecution.set_error(db, random_running_workflow_execution.execution_id) stmt = select(WorkflowExecution).where( WorkflowExecution._execution_id == random_running_workflow_execution.execution_id.bytes diff --git a/app/tests/crud/test_workflow_mode.py b/app/tests/crud/test_workflow_mode.py index b9db55a44abb0238f153b744c9ba13298ca1a0ad..2f1f1f3b832c324c87ae21193d1994be9e506371 100644 --- a/app/tests/crud/test_workflow_mode.py +++ b/app/tests/crud/test_workflow_mode.py @@ -1,16 +1,17 @@ import pytest from clowmdb.models import WorkflowMode, WorkflowVersion -from sqlalchemy import delete, select +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.crud import CRUDWorkflowMode from app.schemas.workflow_mode import WorkflowModeIn +from app.tests.utils.cleanup import CleanupList, delete_workflow_mode from app.tests.utils.utils import random_hex_string, random_lower_string class TestWorkflowModeCRUDCreate: @pytest.mark.asyncio - async def test_create_workflow_mode(self, db: AsyncSession) -> None: + async def test_create_workflow_mode(self, db: AsyncSession, cleanup: CleanupList) -> None: """ Test for creating a single workflow mode in CRUD Repository. @@ -18,6 +19,8 @@ class TestWorkflowModeCRUDCreate: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ workflow_mode_in = WorkflowModeIn( name=random_lower_string(16), @@ -25,15 +28,13 @@ class TestWorkflowModeCRUDCreate: entrypoint=random_lower_string(10), ) modes = await CRUDWorkflowMode.create(db, modes=[workflow_mode_in]) + for m in modes: + cleanup.add_task(delete_workflow_mode, db=db, mode_id=m.mode_id) assert len(modes) == 1 assert modes[0].name == workflow_mode_in.name - # Clean up after test - await db.execute(delete(WorkflowMode).where(WorkflowMode._mode_id == modes[0].mode_id.bytes)) - await db.commit() - @pytest.mark.asyncio - async def test_create_multiple_workflow_mode(self, db: AsyncSession) -> None: + async def test_create_multiple_workflow_mode(self, db: AsyncSession, cleanup: CleanupList) -> None: """ Test for creating multiple workflow modes in CRUD Repository. @@ -41,6 +42,8 @@ class TestWorkflowModeCRUDCreate: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ n = 10 modes_in = [ @@ -52,6 +55,8 @@ class TestWorkflowModeCRUDCreate: for _ in range(n) ] modes = await CRUDWorkflowMode.create(db, modes=modes_in) + for m in modes: + cleanup.add_task(delete_workflow_mode, db=db, mode_id=m.mode_id) assert len(modes) == n for mode in modes: @@ -60,10 +65,6 @@ class TestWorkflowModeCRUDCreate: assert db_mode.mode_id == mode.mode_id assert db_mode.name == mode.name - # Clean up after test - await db.execute(delete(WorkflowMode).where(WorkflowMode._mode_id == mode.mode_id.bytes)) - await db.commit() - class TestWorkflowModeCRUDGet: @pytest.mark.asyncio @@ -100,7 +101,7 @@ class TestWorkflowModeCRUDGet: Random workflow version for testing. """ mode = await CRUDWorkflowMode.get( - db, mode_id=random_workflow_mode.mode_id, workflow_version=random_workflow_version.git_commit_hash + db, mode_id=random_workflow_mode.mode_id, workflow_version_id=random_workflow_version.git_commit_hash ) assert mode is not None assert mode.name == random_workflow_mode.name @@ -121,7 +122,7 @@ class TestWorkflowModeCRUDGet: Random workflow mode for testing. """ mode = await CRUDWorkflowMode.get( - db, mode_id=random_workflow_mode.mode_id, workflow_version=random_hex_string() + db, mode_id=random_workflow_mode.mode_id, workflow_version_id=random_hex_string() ) assert mode is None @@ -143,7 +144,7 @@ class TestWorkflowModeCRUDList: random_workflow_version : clowmdb.models.WorkflowMode Random workflow version for testing. """ - modes = await CRUDWorkflowMode.list_modes(db, workflow_version=random_workflow_version.git_commit_hash) + modes = await CRUDWorkflowMode.list_modes(db, workflow_version_id=random_workflow_version.git_commit_hash) assert len(modes) == 1 mode = modes[0] assert mode.name == random_workflow_mode.name @@ -165,13 +166,13 @@ class TestWorkflowModeCRUDList: random_workflow_version : clowmdb.models.WorkflowMode Random workflow version for testing. """ - modes = await CRUDWorkflowMode.list_modes(db, workflow_version=random_hex_string()) + modes = await CRUDWorkflowMode.list_modes(db, workflow_version_id=random_hex_string()) assert len(modes) == 0 class TestWorkflowModeCRUDDelete: @pytest.mark.asyncio - async def test_delete_multiple_workflow_modes(self, db: AsyncSession) -> None: + async def test_delete_multiple_workflow_modes(self, db: AsyncSession, cleanup: CleanupList) -> None: """ Test for deleting multiple workflow modes. @@ -179,6 +180,8 @@ class TestWorkflowModeCRUDDelete: ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. + cleanup : app.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. """ # Test setup. Create workflow modes n = 10 @@ -189,6 +192,7 @@ class TestWorkflowModeCRUDDelete: ) db.add(mode_db) await db.commit() + cleanup.add_task(delete_workflow_mode, db=db, mode_id=mode_db.mode_id) mode_ids.append(mode_db.mode_id) # Actual test diff --git a/app/tests/crud/test_workflow_version.py b/app/tests/crud/test_workflow_version.py index 2dfddbee25323d282186c0115866658036eadd86..27e8e1bd21dced04aca99c699490fbea966f25af 100644 --- a/app/tests/crud/test_workflow_version.py +++ b/app/tests/crud/test_workflow_version.py @@ -181,8 +181,8 @@ class TestWorkflowVersionCRUDCreate: db, git_commit_hash=random_hex_string(), version="v2.0.0", - wid=random_workflow.workflow_id, - previous_version=random_workflow.versions[-1].git_commit_hash, + workflow_id=random_workflow.workflow_id, + previous_version=random_workflow.versions[-1].workflow_version_id, ) assert workflow_version is not None @@ -216,8 +216,8 @@ class TestWorkflowVersionCRUDCreate: db, git_commit_hash=random_hex_string(), version="v2.0.0", - wid=random_workflow.workflow_id, - previous_version=random_workflow.versions[-1].git_commit_hash, + workflow_id=random_workflow.workflow_id, + previous_version=random_workflow.versions[-1].workflow_version_id, modes=[random_workflow_mode.mode_id], ) assert workflow_version is not None @@ -250,7 +250,7 @@ class TestWorkflowVersionCRUDUpdate: Random workflow version for testing. """ await CRUDWorkflowVersion.update_status( - db, git_commit_hash=random_workflow_version.git_commit_hash, status=WorkflowVersion.Status.PUBLISHED + db, workflow_version_id=random_workflow_version.git_commit_hash, status=WorkflowVersion.Status.PUBLISHED ) stmt = select(WorkflowVersion).where(WorkflowVersion.git_commit_hash == random_workflow_version.git_commit_hash) @@ -274,7 +274,7 @@ class TestWorkflowVersionCRUDUpdate: """ new_slug = random_hex_string() await CRUDWorkflowVersion.update_icon( - db, git_commit_hash=random_workflow_version.git_commit_hash, icon_slug=new_slug + db, workflow_version_id=random_workflow_version.git_commit_hash, icon_slug=new_slug ) stmt = select(WorkflowVersion).where(WorkflowVersion.git_commit_hash == random_workflow_version.git_commit_hash) @@ -298,7 +298,7 @@ class TestWorkflowVersionCRUDUpdate: Random workflow version for testing. """ await CRUDWorkflowVersion.update_icon( - db, git_commit_hash=random_workflow_version.git_commit_hash, icon_slug=None + db, workflow_version_id=random_workflow_version.git_commit_hash, icon_slug=None ) stmt = select(WorkflowVersion).where(WorkflowVersion.git_commit_hash == random_workflow_version.git_commit_hash) diff --git a/app/tests/mocks/__init__.py b/app/tests/mocks/__init__.py index 56f370d8a08ba9a3bb5696e22b43c5f1caf4d182..6cf3e24940bdb39ae263ef2854b68ea6a207b94f 100644 --- a/app/tests/mocks/__init__.py +++ b/app/tests/mocks/__init__.py @@ -5,16 +5,21 @@ from httpx import Request, Response class MockHTTPService(ABC): + def __init__(self) -> None: + self.send_error = False + @abstractmethod - def handle_request(self, request: Request, **kwargs: bool) -> Response: + def handle_request(self, request: Request) -> Response: ... + def reset(self) -> None: + self.send_error = False + class DefaultMockHTTPService(MockHTTPService): - def handle_request(self, request: Request, **kwargs: bool) -> Response: - raise_error = kwargs.get("raise_error", False) + def handle_request(self, request: Request) -> Response: return Response( - status_code=status.HTTP_404_NOT_FOUND if raise_error else status.HTTP_200_OK, + status_code=status.HTTP_404_NOT_FOUND if self.send_error else status.HTTP_200_OK, json={ # When checking if a file exists in a git repository, the GitHub API expects this in a response "download_url": "https://example.com" diff --git a/app/tests/mocks/mock_opa_service.py b/app/tests/mocks/mock_opa_service.py index d0b41c6750e59931c4ac67ae185cb5fc9d29a773..699811fa469cc7daa023ae0d72dada3d3ed9211d 100644 --- a/app/tests/mocks/mock_opa_service.py +++ b/app/tests/mocks/mock_opa_service.py @@ -1,6 +1,6 @@ import json -from typing import Any, Dict -from uuid import UUID, uuid4 +from typing import Dict +from uuid import uuid4 from fastapi import status from httpx import Request, Response @@ -9,15 +9,6 @@ from app.schemas.security import AuthzRequest, AuthzResponse from app.tests.mocks import MockHTTPService -# Custom JSON encoder to encode python UUID type -class UUIDEncoder(json.JSONEncoder): - def default(self, obj: Any) -> str: - if isinstance(obj, UUID): - # if the obj is uuid, we simply return the string representation - return str(obj) - return json.JSONEncoder.default(self, obj) - - class MockOpaService(MockHTTPService): """ Class to mock the Open Policy Agent service. @@ -25,6 +16,7 @@ class MockOpaService(MockHTTPService): """ def __init__(self) -> None: + super().__init__() self._users: Dict[str, bool] = {} def add_user(self, uid: str, privileged: bool = False) -> None: @@ -56,6 +48,7 @@ class MockOpaService(MockHTTPService): """ Reset the mock service to its initial state. """ + super().reset() self._users = {} def handle_request(self, request: Request, **kwargs: bool) -> Response: @@ -72,16 +65,14 @@ class MockOpaService(MockHTTPService): response : httpx.Response Appropriate response to the received request. """ - raise_error = kwargs.get("raise_opa_error", False) - authz_request = AuthzRequest(**json.loads(request.read().decode("utf-8"))["input"]) - if raise_error or authz_request.uid not in self._users: + authz_request = AuthzRequest.model_validate(json.loads(request.read().decode("utf-8"))["input"]) + if self.send_error or authz_request.uid not in self._users: result = False else: result = not MockOpaService.request_admin_permission(authz_request) or self._users[authz_request.uid] return Response( status_code=status.HTTP_200_OK, - text=json.dumps(AuthzResponse(result=result, decision_id=uuid4()).model_dump(), cls=UUIDEncoder), - headers={"Content-Type": "application/json"}, + json=AuthzResponse(result=result, decision_id=uuid4()).model_dump(), ) @staticmethod diff --git a/app/tests/mocks/mock_slurm_cluster.py b/app/tests/mocks/mock_slurm_cluster.py index 4e042ca621d5a000d39f680b9044a020f140a45e..b6dfe6fe4c5a28de9d783bf126872ab11d68a3cd 100644 --- a/app/tests/mocks/mock_slurm_cluster.py +++ b/app/tests/mocks/mock_slurm_cluster.py @@ -21,6 +21,7 @@ class MockSlurmCluster(MockHTTPService): ) def __init__(self, version: str = "v0.0.38") -> None: + super().__init__() self._request_bodies: List[SlurmRequestBody] = [] self._job_states: List[bool] = [] self.base_path = f"slurm/{version}" @@ -40,7 +41,7 @@ class MockSlurmCluster(MockHTTPService): response : httpx.Response Appropriate response to the request """ - if kwargs.get("raise_slurm_error", False): + if self.send_error: return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) # Authorize request error_response = MockSlurmCluster.authorize_request(request.headers) @@ -99,6 +100,7 @@ class MockSlurmCluster(MockHTTPService): """ Resets the mock service to its initial state. """ + super().reset() self._request_bodies = [] def add_workflow_execution(self, job: SlurmRequestBody) -> int: diff --git a/app/tests/utils/bucket.py b/app/tests/utils/bucket.py index 2f04f31e47ce111fc00e9ee78b9df304a94e25ac..dc1cb9022e7bb3adde44fcfa2def5010e549a097 100644 --- a/app/tests/utils/bucket.py +++ b/app/tests/utils/bucket.py @@ -1,5 +1,6 @@ from datetime import datetime from typing import Optional +from uuid import UUID import pytest from clowmdb.models import Bucket, BucketPermission, User @@ -28,7 +29,7 @@ async def create_random_bucket(db: AsyncSession, user: User) -> Bucket: bucket = Bucket( name=random_lower_string(), description=random_lower_string(length=127), - owner_id=user.uid, + _owner_id=user.uid.bytes, ) db.add(bucket) await db.commit() @@ -39,7 +40,7 @@ async def create_random_bucket(db: AsyncSession, user: User) -> Bucket: async def add_permission_for_bucket( db: AsyncSession, bucket_name: str, - uid: str, + uid: UUID, from_: Optional[datetime] = None, to: Optional[datetime] = None, permission: BucketPermission.Permission = BucketPermission.Permission.READWRITE, @@ -66,7 +67,7 @@ async def add_permission_for_bucket( The file prefix for the permission. """ perm = BucketPermission( - user_id=uid, + _uid=uid.bytes, bucket_name=bucket_name, from_=round(from_.timestamp()) if from_ is not None else None, to=round(to.timestamp()) if to is not None else None, diff --git a/app/tests/utils/cleanup.py b/app/tests/utils/cleanup.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed6701ff36dc295db202f61b0437d1bfd738cfc --- /dev/null +++ b/app/tests/utils/cleanup.py @@ -0,0 +1,81 @@ +from inspect import iscoroutinefunction +from typing import Any, Awaitable, Callable, Generic, List, ParamSpec, TypeVar +from uuid import UUID + +from clowmdb.models import Workflow, WorkflowMode +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession + +P = ParamSpec("P") +T = TypeVar("T") + + +class Job(Generic[P, T]): + def __init__(self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: + self.func = func + self.args = args + self.kwargs = kwargs + + @property + def is_async(self) -> bool: + return iscoroutinefunction(self.func) + + def __call__(self) -> T: + return self.func(*self.args, **self.kwargs) + + +class AsyncJob(Job): + def __init__(self, func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> None: + super().__init__(func, *args, **kwargs) + assert iscoroutinefunction(self.func) + + async def __call__(self) -> T: + return await super().__call__() + + +class CleanupList: + """ + Helper object to hold a queue of functions that can be executed later + """ + + def __init__(self) -> None: + self.queue: List[Job] = [] + + def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: + """ + Add a (async) function to the queue. + + Parameters + ---------- + func : Callable[P, Any] + Function to register. + args : P.args + Arguments to the function. + kwargs : P.kwargs + Keyword arguments to the function. + """ + if iscoroutinefunction(func): + self.queue.append(AsyncJob(func, *args, **kwargs)) + else: + self.queue.append(Job(func, *args, **kwargs)) + + async def empty_queue(self) -> None: + """ + Empty the queue by executing the registered functions. + """ + while len(self.queue) > 0: + func = self.queue.pop() + if func.is_async: + await func() + else: + func() + + +async def delete_workflow(db: AsyncSession, workflow_id: UUID) -> None: + await db.execute(delete(Workflow).where(Workflow._workflow_id == workflow_id.bytes)) + await db.commit() + + +async def delete_workflow_mode(db: AsyncSession, mode_id: UUID) -> None: + await db.execute(delete(WorkflowMode).where(WorkflowMode._mode_id == mode_id.bytes)) + await db.commit() diff --git a/app/tests/utils/user.py b/app/tests/utils/user.py index ff514b4ca098533bce88bf17f3e6988a32d83343..341f295bfc0768095de61613cfbc1c3dc2da4dd7 100644 --- a/app/tests/utils/user.py +++ b/app/tests/utils/user.py @@ -1,6 +1,7 @@ from dataclasses import dataclass -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from typing import Dict +from uuid import UUID import pytest from authlib.jose import JsonWebToken @@ -18,7 +19,7 @@ class UserWithAuthHeader: user: User -def get_authorization_headers(uid: str, secret: str = "SuperSecret") -> Dict[str, str]: +def get_authorization_headers(uid: UUID, secret: str = "SuperSecret") -> Dict[str, str]: """ Create a valid JWT and return the correct headers for subsequent requests. @@ -33,7 +34,7 @@ def get_authorization_headers(uid: str, secret: str = "SuperSecret") -> Dict[str headers : Dict[str,str] HTTP Headers to authorize each request. """ - to_encode = {"sub": uid, "exp": datetime.utcnow() + timedelta(hours=1)} + to_encode = {"sub": str(uid), "exp": datetime.now(UTC) + timedelta(hours=1)} encoded_jwt = _jwt.encode(header={"alg": "HS256"}, payload=to_encode, key=secret) headers = {"Authorization": f"Bearer {encoded_jwt.decode('utf-8')}"} @@ -83,10 +84,7 @@ async def create_random_user(db: AsyncSession) -> User: user : clowmdb.models.User Newly created user. """ - user = User( - uid=random_hex_string(), - display_name=random_lower_string(), - ) + user = User(display_name=random_lower_string(), lifescience_id=random_hex_string()) db.add(user) await db.commit() return user diff --git a/gunicorn_conf.py b/gunicorn_conf.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd141dfc55f98de00b07daffea9a898677e9df4 --- /dev/null +++ b/gunicorn_conf.py @@ -0,0 +1,67 @@ +import json +import multiprocessing +import os + +workers_per_core_str = os.getenv("WORKERS_PER_CORE", "1") +max_workers_str = os.getenv("MAX_WORKERS") +use_max_workers = None +if max_workers_str: + use_max_workers = int(max_workers_str) +web_concurrency_str = os.getenv("WEB_CONCURRENCY", None) + +host = os.getenv("HOST", "0.0.0.0") +port = os.getenv("PORT", "80") +bind_env = os.getenv("BIND", None) +use_loglevel = os.getenv("LOG_LEVEL", "info") +if bind_env: + use_bind = bind_env +else: + use_bind = f"{host}:{port}" + +cores = multiprocessing.cpu_count() +workers_per_core = float(workers_per_core_str) +default_web_concurrency = workers_per_core * cores +if web_concurrency_str: + web_concurrency = int(web_concurrency_str) + assert web_concurrency > 0 +else: + web_concurrency = max(int(default_web_concurrency), 2) + if use_max_workers: + web_concurrency = min(web_concurrency, use_max_workers) +accesslog_var = os.getenv("ACCESS_LOG", "-") +use_accesslog = accesslog_var or None +errorlog_var = os.getenv("ERROR_LOG", "-") +use_errorlog = errorlog_var or None +graceful_timeout_str = os.getenv("GRACEFUL_TIMEOUT", "120") +timeout_str = os.getenv("TIMEOUT", "120") +keepalive_str = os.getenv("KEEP_ALIVE", "5") + +# Gunicorn config variables +loglevel = use_loglevel +workers = web_concurrency +bind = use_bind +errorlog = use_errorlog +worker_tmp_dir = "/dev/shm" +accesslog = use_accesslog +graceful_timeout = int(graceful_timeout_str) +timeout = int(timeout_str) +keepalive = int(keepalive_str) + + +# For debugging and testing +log_data = { + "loglevel": loglevel, + "workers": workers, + "bind": bind, + "graceful_timeout": graceful_timeout, + "timeout": timeout, + "keepalive": keepalive, + "errorlog": errorlog, + "accesslog": accesslog, + # Additional, non-gunicorn variables + "workers_per_core": workers_per_core, + "use_max_workers": use_max_workers, + "host": host, + "port": port, +} +print(json.dumps(log_data)) diff --git a/requirements-dev.txt b/requirements-dev.txt index 3943c55d6ce8a67b69f665e6fb5630dce42d1c23..de814b3dad75665f8c74a7815e137a0d70d6f6ce 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,15 +2,15 @@ pytest>=7.4.0,<7.5.0 pytest-asyncio>=0.21.0,<0.22.0 pytest-cov>=4.1.0,<4.2.0 -coverage[toml]>=7.3.0,<7.4.0 +coverage[toml]>=7.4.0,<7.5.0 # Linters ruff>=0.1.0,<0.2.0 -black>=23.11.0,<23.12.0 -isort>=5.12.0,<5.13.0 -mypy>=1.7.0,<1.8.0 +black>=23.12.0,<24.1.0 +isort>=5.13.0,<5.14.0 +mypy>=1.8.0,<1.9.0 # stubs for mypy -boto3-stubs-lite[s3]>=1.33.0,<1.34.0 +boto3-stubs-lite[s3]>=1.34.0,<1.35.0 types-requests # Miscellaneous -pre-commit>=3.5.0,<3.6.0 -PyTrie>=0.4.0,<0.5.0 +pre-commit>=3.6.0,<3.7.0 +uvicorn>=0.27.0,<0.28.0 diff --git a/requirements.txt b/requirements.txt index 7206a4a79c61097bc8a3a235167c3d313a5ea4fc..dff49a4bd18965635170c304f1a0bbd9ec2fbafa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,31 +1,29 @@ --extra-index-url https://gitlab.ub.uni-bielefeld.de/api/v4/projects/5493/packages/pypi/simple -clowmdb>=2.3.0,<2.4.0 +clowmdb>=3.0.0,<3.1.0 # Webserver packages -anyio>=3.7.0,<4.0.0 -fastapi>=0.104.0,<0.105.0 +fastapi>=0.109.0,<0.110.0 pydantic>=2.5.0,<2.6.0 pydantic-settings>=2.1.0,<2.2.0 -uvicorn>=0.24.0,<0.25.0 python-multipart # Database packages PyMySQL>=1.1.0,<1.2.0 SQLAlchemy>=2.0.0,<2.1.0 aiomysql>=0.2.0,<0.3.0 # Security packages -authlib>=1.2.0,<1.3.0 +authlib>=1.3.0,<1.4.0 # Ceph and S3 packages -boto3>=1.33.0,<1.34.0 +boto3>=1.34.0,<1.35.0 # Miscellaneous tenacity>=8.2.0,<8.3.0 -httpx>=0.25.0,<0.26.0 +httpx>=0.26.0,<0.27.0 itsdangerous jsonschema>=4.0.0,<5.0.0 # template engine mako>=1.3.0,<1.4.0 python-dotenv # Image processing -Pillow>=10.1.0,<10.2.0 +Pillow>=10.2.0,<10.3.0 # Compression with br algorithm brotli-asgi>=1.4.0,<1.5.0 diff --git a/start_service.sh b/start_service.sh deleted file mode 100755 index 2fa8c34e1fec829ac3c9dc245c8ab9836287dd65..0000000000000000000000000000000000000000 --- a/start_service.sh +++ /dev/null @@ -1,6 +0,0 @@ -#! /usr/bin/env bash - -./scripts/prestart.sh - -# Start webserver -uvicorn app.main:app --host 0.0.0.0 --port 8000 --no-server-header diff --git a/start_service_gunicorn.sh b/start_service_gunicorn.sh new file mode 100755 index 0000000000000000000000000000000000000000..1b65c1164b89af0c3f9da44d8135959f614d0c9d --- /dev/null +++ b/start_service_gunicorn.sh @@ -0,0 +1,7 @@ +#! /usr/bin/env sh +set -e + +./prestart.sh + +# Start Gunicorn +exec gunicorn -k uvicorn.workers.UvicornWorker -c /app/gunicorn_conf.py app.main:app diff --git a/start_service_uvicorn.sh b/start_service_uvicorn.sh new file mode 100755 index 0000000000000000000000000000000000000000..392596854a58728ac008b79d7f5d5178d1c90c2b --- /dev/null +++ b/start_service_uvicorn.sh @@ -0,0 +1,7 @@ +#! /usr/bin/env bash +set -e + +./prestart.sh + +# Start webserver +uvicorn app.main:app --host 0.0.0.0 --port "$PORT" --no-server-header