Skip to content
Snippets Groups Projects
Commit 586df9f8 authored by Daniel Göbel's avatar Daniel Göbel
Browse files

Merge branch 'feature/57-add-monitoring-of-traces' into 'development'

Resolve "Add monitoring of traces based on OpenTelemetry"

Closes #57

See merge request !54
parents fb107dcc e7cc9859
No related branches found
No related tags found
2 merge requests!69Delete dev branch,!54Resolve "Add monitoring of traces based on OpenTelemetry"
Pipeline #38400 passed
Showing
with 558 additions and 271 deletions
......@@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: end-of-file-fixer
- id: check-added-large-files
......@@ -21,7 +21,7 @@ repos:
files: app
args: [--check]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.291'
rev: 'v0.0.292'
hooks:
- id: ruff
- repo: https://github.com/PyCQA/isort
......
FROM python:3.11-slim
EXPOSE 8000
ENV OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS="set-cookie"
# dumb-init forwards the kill signal to the python process
RUN apt-get update && apt-get -y install dumb-init curl
......
FROM tiangolo/uvicorn-gunicorn-fastapi:python3.11-slim
EXPOSE 8000
ENV PORT=8000
ENV OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS="set-cookie"
RUN pip install --no-cache-dir httpx[cli]
......
......@@ -42,6 +42,7 @@ This is the Workflow service of the CloWM service.
| `DEV_SYSTEM` | `False` | `<"True"&#x7c;"False">` | Activates an endpoint that allows execution of an workflow from an arbitrary Git Repository.<br>HAS TO BE `False` in PRODUCTION! |
| `OPA_POLICY_PATH` | `/clowm/authz/allow` | URL path | Path to the OPA Policy for Authorization |
| `SLURM_JOB_STATUS_CHECK_INTERVAL` | 30 | integer (seconds) | Interval for checking the slurm jobs status after starting a workflow execution |
| `OTLP_GRPC_ENDPOINT` | unset | <hostname / IP> | OTLP compatible endpoint to send traces via gRPC, e.g. Jaeger |
### Nextflow Variables
......
......@@ -8,12 +8,14 @@ from fastapi import Depends, HTTPException, Path, status
from fastapi.security import HTTPBearer
from fastapi.security.http import HTTPAuthorizationCredentials
from httpx import AsyncClient
from opentelemetry import trace
from sqlalchemy.ext.asyncio import AsyncSession
from app.ceph.rgw import s3_resource
from app.core.config import settings
from app.core.security import decode_token, request_authorization
from app.crud import CRUDUser, CRUDWorkflow, CRUDWorkflowExecution, CRUDWorkflowVersion
from app.otlp import start_as_current_span_async
from app.schemas.security import JWT, AuthzRequest, AuthzResponse
from app.slurm.slurm_rest_client import SlurmClient
......@@ -29,6 +31,9 @@ def get_s3_resource() -> S3ServiceResource: # pragma: no cover
return s3_resource
tracer = trace.get_tracer_provider().get_tracer(__name__)
S3Service = Annotated[S3ServiceResource, Depends(get_s3_resource)]
......@@ -87,6 +92,7 @@ def get_decode_jwt_function() -> Callable[[str], Dict[str, str]]: # pragma: no
return decode_token
@start_as_current_span_async("decode_jwt", tracer=tracer)
async def decode_bearer_token(
token: HTTPAuthorizationCredentials = Depends(bearer_token),
decode: Callable[[str], Dict[str, str]] = Depends(get_decode_jwt_function),
......@@ -157,8 +163,10 @@ class AuthorizationDependency:
"""
async def authorization_wrapper(operation: str) -> AuthzResponse:
params = AuthzRequest(operation=operation, resource=self.resource, uid=token.sub)
return await request_authorization(request_params=params, client=client)
with tracer.start_as_current_span("authorization") as span:
span.set_attributes({"resource": self.resource, "operation": operation})
params = AuthzRequest(operation=operation, resource=self.resource, uid=token.sub)
return await request_authorization(request_params=params, client=client)
return authorization_wrapper
......
......@@ -3,6 +3,7 @@ from uuid import UUID
from clowmdb.models import Workflow, WorkflowMode, WorkflowVersion
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Response, status
from opentelemetry import trace
from app.api.dependencies import AuthorizationDependency, CurrentUser, CurrentWorkflow, DBSession, HTTPClient, S3Service
from app.api.utils import check_repo, upload_scm_file
......@@ -10,6 +11,7 @@ from app.core.config import settings
from app.crud import CRUDWorkflow, CRUDWorkflowVersion
from app.crud.crud_workflow_mode import CRUDWorkflowMode
from app.git_repository import GitHubRepository, build_repository
from app.otlp import start_as_current_span_async
from app.schemas.workflow import WorkflowIn, WorkflowOut, WorkflowStatistic, WorkflowUpdate
from app.schemas.workflow_version import WorkflowVersion as WorkflowVersionSchema
from app.scm import SCM, Provider
......@@ -19,8 +21,11 @@ workflow_authorization = AuthorizationDependency(resource="workflow")
Authorization = Annotated[Callable[[str], Awaitable[Any]], Depends(workflow_authorization)]
tracer = trace.get_tracer_provider().get_tracer(__name__)
@router.get("", status_code=status.HTTP_200_OK, summary="List workflows")
@start_as_current_span_async("api_workflow_list", tracer=tracer)
async def list_workflows(
db: DBSession,
authorization: Authorization,
......@@ -65,6 +70,13 @@ async def list_workflows(
workflows : List[app.schemas.workflow.WorkflowOut]
Workflows in the system
"""
current_span = trace.get_current_span()
if developer_id is not None:
current_span.set_attribute("developer_id", developer_id)
if name_substring is not None:
current_span.set_attribute("name_substring", name_substring)
if version_status is not None and len(version_status) > 0:
current_span.set_attribute("version_status", [stat.name for stat in version_status])
rbac_operation = "list"
if developer_id is not None and current_user.uid != developer_id:
rbac_operation = "list_filter"
......@@ -84,6 +96,7 @@ async def list_workflows(
@router.post("", status_code=status.HTTP_201_CREATED, summary="Create a new workflow")
@start_as_current_span_async("api_workflow_create", tracer=tracer)
async def create_workflow(
background_tasks: BackgroundTasks,
db: DBSession,
......@@ -172,10 +185,12 @@ async def create_workflow(
obj=s3.Bucket(name=settings.WORKFLOW_BUCKET).Object(key=f"{workflow.git_commit_hash}.json"),
client=client,
)
trace.get_current_span().set_attribute("workflow_id", str(workflow_db.workflow_id))
return WorkflowOut.from_db_workflow(await CRUDWorkflow.get(db, workflow_db.workflow_id))
@router.get("/{wid}", status_code=status.HTTP_200_OK, summary="Get a workflow")
@start_as_current_span_async("api_workflow_get", tracer=tracer)
async def get_workflow(
workflow: CurrentWorkflow,
db: DBSession,
......@@ -208,6 +223,10 @@ async def get_workflow(
workflow : app.schemas.workflow.WorkflowOut
Workflow with existing ID
"""
current_span = trace.get_current_span()
current_span.set_attribute("workflow_id", str(workflow.workflow_id))
if version_status is not None and len(version_status) > 0:
current_span.set_attribute("version_status", [stat.name for stat in version_status])
rbac_operation = "read_any" if workflow.developer_id != current_user.uid and version_status is not None else "read"
await authorization(rbac_operation)
version_stat = (
......@@ -220,6 +239,7 @@ async def get_workflow(
@router.get("/{wid}/statistics", status_code=status.HTTP_200_OK, summary="Get statistics for a workflow")
@start_as_current_span_async("api_workflow_get_statistics", tracer=tracer)
async def get_workflow_statistics(
workflow: CurrentWorkflow, db: DBSession, authorization: Authorization, response: Response
) -> List[WorkflowStatistic]:
......@@ -241,6 +261,7 @@ async def get_workflow_statistics(
-------
statistics : List[app.schema.Workflow.WorkflowStatistic]
"""
trace.get_current_span().set_attribute("workflow_id", str(workflow.workflow_id))
await authorization("read")
# Instruct client to cache response for 1 hour
response.headers["Cache-Control"] = "max-age=3600"
......@@ -248,6 +269,7 @@ async def get_workflow_statistics(
@router.delete("/{wid}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete a workflow")
@start_as_current_span_async("api_workflow_delete", tracer=tracer)
async def delete_workflow(
background_tasks: BackgroundTasks,
workflow: CurrentWorkflow,
......@@ -275,6 +297,7 @@ async def delete_workflow(
current_user : clowmdb.models.User
Current user. Dependency Injection.
"""
trace.get_current_span().set_attribute("workflow_id", str(workflow.workflow_id))
rbac_operation = "delete" if workflow.developer_id == current_user.uid else "delete_any"
await authorization(rbac_operation)
versions = await CRUDWorkflowVersion.list(db, workflow.workflow_id)
......@@ -308,6 +331,7 @@ async def delete_workflow(
@router.post("/{wid}/update", status_code=status.HTTP_201_CREATED, summary="Update a workflow")
@start_as_current_span_async("api_workflow_update", tracer=tracer)
async def update_workflow(
background_tasks: BackgroundTasks,
workflow: CurrentWorkflow,
......@@ -345,6 +369,7 @@ async def update_workflow(
version : app.schemas.workflow_version.WorkflowVersion
The new workflow version
"""
trace.get_current_span().set_attribute("workflow_id", str(workflow.workflow_id))
await authorization("update")
if current_user.uid != workflow.developer_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the developer can update his workflow")
......
from typing import Annotated, Any, Awaitable, Callable
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
from opentelemetry import trace
from app.api.dependencies import AuthorizationDependency, CurrentUser, CurrentWorkflow, DBSession, HTTPClient, S3Service
from app.api.utils import check_repo, upload_scm_file
......@@ -8,6 +9,7 @@ from app.core.config import settings
from app.crud.crud_workflow import CRUDWorkflow
from app.crud.crud_workflow_version import CRUDWorkflowVersion
from app.git_repository import GitHubRepository, build_repository
from app.otlp import start_as_current_span_async
from app.schemas.workflow import WorkflowCredentialsIn, WorkflowCredentialsOut
from app.scm import SCM, Provider
......@@ -16,8 +18,11 @@ workflow_authorization = AuthorizationDependency(resource="workflow")
Authorization = Annotated[Callable[[str], Awaitable[Any]], Depends(workflow_authorization)]
tracer = trace.get_tracer_provider().get_tracer(__name__)
@router.get("", status_code=status.HTTP_200_OK, summary="Get the credentials of a workflow")
@start_as_current_span_async("api_workflow_credentials_get", tracer=tracer)
async def get_workflow_credentials(
workflow: CurrentWorkflow, current_user: CurrentUser, authorization: Authorization
) -> WorkflowCredentialsOut:
......@@ -39,6 +44,7 @@ async def get_workflow_credentials(
workflow : app.schemas.workflow.WorkflowOut
Workflow with existing ID
"""
trace.get_current_span().set_attribute("workflow_id", str(workflow.workflow_id))
await authorization("update")
if current_user.uid != workflow.developer_id:
raise HTTPException(
......@@ -48,6 +54,7 @@ async def get_workflow_credentials(
@router.put("", status_code=status.HTTP_200_OK, summary="Update the credentials of a workflow")
@start_as_current_span_async("api_workflow_credentials_update", tracer=tracer)
async def update_workflow_credentials(
credentials: WorkflowCredentialsIn,
workflow: CurrentWorkflow,
......@@ -86,6 +93,7 @@ async def update_workflow_credentials(
workflow : app.schemas.workflow.WorkflowOut
Workflow with existing ID
"""
trace.get_current_span().set_attribute("workflow_id", str(workflow.workflow_id))
await authorization("update")
if current_user.uid != workflow.developer_id:
raise HTTPException(
......@@ -109,6 +117,7 @@ async def update_workflow_credentials(
@router.delete("", status_code=status.HTTP_204_NO_CONTENT, summary="Delete the credentials of a workflow")
@start_as_current_span_async("api_workflow_credentials_delete", tracer=tracer)
async def delete_workflow_credentials(
background_tasks: BackgroundTasks,
workflow: CurrentWorkflow,
......@@ -141,11 +150,13 @@ async def delete_workflow_credentials(
workflow : app.schemas.workflow.WorkflowOut
Workflow with existing ID
"""
trace.get_current_span().set_attribute("workflow_id", str(workflow.workflow_id))
rbac_operation = "delete" if workflow.developer_id == current_user.uid else "delete_any"
await authorization(rbac_operation)
repo = build_repository(workflow.repository_url, workflow.versions[0].git_commit_hash)
if isinstance(repo, GitHubRepository):
s3.Bucket(settings.PARAMS_BUCKET).Object(f"{workflow.workflow_id.hex}.scm").delete()
with tracer.start_as_current_span("s3_delete_workflow_execution_parameters"):
s3.Bucket(settings.PARAMS_BUCKET).Object(f"{workflow.workflow_id.hex}.scm").delete()
else:
scm_provider = Provider.from_repo(repo=repo, name=f"repo{workflow.workflow_id.hex}")
background_tasks.add_task(
......
......@@ -5,6 +5,7 @@ from typing import Annotated, Any, Awaitable, Callable, Dict, List, Optional
import jsonschema
from clowmdb.models import WorkflowExecution, WorkflowVersion
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, status
from opentelemetry import trace
from app.api.dependencies import (
AuthorizationDependency,
......@@ -25,6 +26,7 @@ from app.api.utils import (
from app.core.config import settings
from app.crud import CRUDWorkflowExecution, CRUDWorkflowVersion
from app.git_repository import GitHubRepository, build_repository
from app.otlp import start_as_current_span_async
from app.schemas.workflow_execution import DevWorkflowExecutionIn, WorkflowExecutionIn, WorkflowExecutionOut
from app.scm import SCM, Provider
from app.slurm.slurm_rest_client import SlurmClient
......@@ -35,8 +37,11 @@ workflow_authorization = AuthorizationDependency(resource="workflow_execution")
Authorization = Annotated[Callable[[str], Awaitable[Any]], Depends(workflow_authorization)]
CurrentWorkflowExecution = Annotated[WorkflowExecution, Depends(get_current_workflow_execution)]
tracer = trace.get_tracer_provider().get_tracer(__name__)
@router.post("", status_code=status.HTTP_201_CREATED, summary="Start a new workflow execution")
@start_as_current_span_async("api_workflow_execution_start", tracer=tracer)
async def start_workflow(
background_tasks: BackgroundTasks,
workflow_execution_in: WorkflowExecutionIn,
......@@ -73,6 +78,8 @@ async def start_workflow(
execution : clowmdb.models.WorkflowExecution
Created workflow execution from the database
"""
current_span = trace.get_current_span()
current_span.set_attribute("git_commit_hash", workflow_execution_in.workflow_version_id)
# Check if Workflow version exists
workflow_version = await CRUDWorkflowVersion.get(
db, workflow_execution_in.workflow_version_id, populate_workflow=True
......@@ -82,6 +89,7 @@ async def start_workflow(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Workflow version with git commit hash {workflow_execution_in.workflow_version_id} not found",
)
current_span.set_attribute("workflow_id", str(workflow_version.workflow_id))
# Check authorization
rbac_operation = "start" if workflow_version.status == WorkflowVersion.Status.PUBLISHED else "start_unpublished"
await authorization(rbac_operation)
......@@ -95,6 +103,7 @@ async def start_workflow(
# If a workflow mode is specified, check that the mode is associated with the workflow version
workflow_mode = None
if workflow_execution_in.mode is not None:
current_span.set_attribute("workflow_mode_id", str(workflow_execution_in.mode))
workflow_mode = next(
(mode for mode in workflow_version.workflow_modes if mode.mode_id == workflow_execution_in.mode), None
)
......@@ -120,7 +129,8 @@ async def start_workflow(
else f"{workflow_execution_in.workflow_version_id}-{workflow_execution_in.mode.hex}.json"
)
with SpooledTemporaryFile(max_size=512000) as f:
s3.Bucket(settings.WORKFLOW_BUCKET).Object(schema_name).download_fileobj(f)
with tracer.start_as_current_span("s3_download_workflow_parameter_schema"):
s3.Bucket(settings.WORKFLOW_BUCKET).Object(schema_name).download_fileobj(f)
f.seek(0)
nextflow_schema = json.load(f)
try:
......@@ -156,6 +166,7 @@ async def start_workflow(
workflow_entrypoint=workflow_mode.entrypoint if workflow_mode is not None else None,
)
current_span.set_attribute("execution_id", str(execution.execution_id))
return WorkflowExecutionOut.from_db_model(execution, workflow_id=workflow_version.workflow_id)
......@@ -165,6 +176,7 @@ async def start_workflow(
summary="Start a workflow execution with arbitrary git repository",
include_in_schema=settings.DEV_SYSTEM,
)
@start_as_current_span_async("api_workflow_execution_start_arbitrary", tracer=tracer)
async def start_arbitrary_workflow(
background_tasks: BackgroundTasks,
workflow_execution_in: DevWorkflowExecutionIn,
......@@ -207,6 +219,18 @@ async def start_arbitrary_workflow(
"""
if not settings.DEV_SYSTEM: # pragma: no cover
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not available")
current_span = trace.get_current_span()
current_span.set_attribute("repository_url", str(workflow_execution_in.repository_url))
if workflow_execution_in.token is not None:
current_span.set_attribute("private_repository", True)
if workflow_execution_in.mode is not None:
current_span.set_attributes(
{
"workflow_entrypoint": workflow_execution_in.mode.entrypoint,
"workflow_schema_path": workflow_execution_in.mode.schema_path,
}
)
await authorization("create")
await check_active_workflow_execution_limit(db, current_user.uid)
......@@ -276,11 +300,12 @@ async def start_arbitrary_workflow(
scm_file_id=execution.execution_id.hex,
workflow_entrypoint=workflow_execution_in.mode.entrypoint if workflow_execution_in.mode is not None else None,
)
current_span.set_attribute("execution_id", str(execution.execution_id))
return WorkflowExecutionOut.from_db_model(execution)
@router.get("", status_code=status.HTTP_200_OK, summary="Get all workflow executions")
@start_as_current_span_async("api_workflow_execution_list", tracer=tracer)
async def list_workflow_executions(
db: DBSession,
current_user: CurrentUser,
......@@ -325,6 +350,14 @@ async def list_workflow_executions(
executions : List[clowmdb.models.WorkflowExecution]
List of filtered workflow executions.
"""
current_span = trace.get_current_span()
if user_id is not None:
current_span.set_attribute("user_id", user_id)
if execution_status is not None and len(execution_status) > 0:
current_span.set_attribute("execution_status", [stat.name for stat in execution_status])
if workflow_version_id is not None:
current_span.set_attribute("git_commit_hash", workflow_version_id)
rbac_operation = "list" if user_id is not None and user_id == current_user.uid else "list_all"
await authorization(rbac_operation)
executions = await CRUDWorkflowExecution.list(
......@@ -339,6 +372,7 @@ async def list_workflow_executions(
@router.get("/{eid}", status_code=status.HTTP_200_OK, summary="Get a workflow execution")
@start_as_current_span_async("api_workflow_execution_get", tracer=tracer)
async def get_workflow_execution(
workflow_execution: CurrentWorkflowExecution,
current_user: CurrentUser,
......@@ -363,6 +397,7 @@ async def get_workflow_execution(
execution : clowmdb.models.WorkflowExecution
Workflow execution with the given id.
"""
trace.get_current_span().set_attribute("execution_id", str(workflow_execution.execution_id))
rbac_operation = "read" if workflow_execution.user_id == current_user.uid else "read_any"
await authorization(rbac_operation)
return WorkflowExecutionOut.from_db_model(
......@@ -372,6 +407,7 @@ async def get_workflow_execution(
@router.get("/{eid}/params", status_code=status.HTTP_200_OK, summary="Get the parameters of a workflow execution")
@start_as_current_span_async("api_workflow_execution_params_get", tracer=tracer)
async def get_workflow_execution_params(
workflow_execution: CurrentWorkflowExecution,
current_user: CurrentUser,
......@@ -399,6 +435,7 @@ async def get_workflow_execution_params(
execution : clowmdb.models.WorkflowExecution
Workflow execution with the given id.
"""
trace.get_current_span().set_attribute("execution_id", str(workflow_execution.execution_id))
rbac_operation = "read" if workflow_execution.user_id == current_user.uid else "read_any"
await authorization(rbac_operation)
params_file_name = f"params-{workflow_execution.execution_id.hex}.json"
......@@ -409,6 +446,7 @@ async def get_workflow_execution_params(
@router.delete("/{eid}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete a workflow execution")
@start_as_current_span_async("api_workflow_execution_delete", tracer=tracer)
async def delete_workflow_execution(
background_tasks: BackgroundTasks,
db: DBSession,
......@@ -437,6 +475,7 @@ async def delete_workflow_execution(
s3 : boto3_type_annotations.s3.ServiceResource
S3 Service to perform operations on buckets in Ceph. Dependency Injection.
"""
trace.get_current_span().set_attribute("execution_id", str(workflow_execution.execution_id))
rbac_operation = "delete" if workflow_execution.user_id == current_user.uid else "delete_any"
await authorization(rbac_operation)
if workflow_execution.status in [
......@@ -454,6 +493,7 @@ async def delete_workflow_execution(
@router.post("/{eid}/cancel", status_code=status.HTTP_204_NO_CONTENT, summary="Cancel a workflow execution")
@start_as_current_span_async("api_workflow_execution_cancel", tracer=tracer)
async def cancel_workflow_execution(
background_tasks: BackgroundTasks,
db: DBSession,
......@@ -482,6 +522,7 @@ async def cancel_workflow_execution(
slurm_client : app.slurm.slurm_rest_client.SlurmClient
Slurm Rest Client to communicate with Slurm cluster. Dependency Injection.
"""
trace.get_current_span().set_attribute("execution_id", str(workflow_execution.execution_id))
rbac_operation = "cancel" if workflow_execution.user_id == current_user.uid else "cancel_any"
await authorization(rbac_operation)
if workflow_execution.status not in [
......
......@@ -2,9 +2,11 @@ from typing import Annotated, Any, Awaitable, Callable
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Path, status
from opentelemetry import trace
from app.api.dependencies import AuthorizationDependency, DBSession
from app.crud.crud_workflow_mode import CRUDWorkflowMode
from app.otlp import start_as_current_span_async
from app.schemas.workflow_mode import WorkflowModeOut
router = APIRouter(prefix="/workflow_modes", tags=["Workflow Mode"])
......@@ -12,8 +14,11 @@ workflow_authorization = AuthorizationDependency(resource="workflow")
Authorization = Annotated[Callable[[str], Awaitable[Any]], Depends(workflow_authorization)]
tracer = trace.get_tracer_provider().get_tracer(__name__)
@router.get("/{mode_id}", status_code=status.HTTP_200_OK, summary="List workflows")
@router.get("/{mode_id}", status_code=status.HTTP_200_OK, summary="Get workflow mode")
@start_as_current_span_async("api_workflow_mode_get", tracer=tracer)
async def get_workflow_mode(
db: DBSession,
authorization: Authorization,
......@@ -39,6 +44,7 @@ async def get_workflow_mode(
-------
mode : app.schemas.workflow_mode.WorkflowModeOut
"""
trace.get_current_span().set_attribute("workflow_mode_id", str(mode_id))
await authorization("read")
mode = await CRUDWorkflowMode.get(db=db, mode_id=mode_id)
if mode is None:
......
......@@ -5,6 +5,7 @@ from uuid import UUID
from clowmdb.models import WorkflowVersion
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Path, Query, UploadFile, status
from fastapi.responses import StreamingResponse
from opentelemetry import trace
from app.api.dependencies import (
AuthorizationDependency,
......@@ -19,6 +20,7 @@ from app.api.utils import delete_remote_icon, upload_icon
from app.core.config import settings
from app.crud import CRUDWorkflowVersion
from app.git_repository import build_repository
from app.otlp import start_as_current_span_async
from app.schemas.workflow_version import IconUpdateOut
from app.schemas.workflow_version import WorkflowVersion as WorkflowVersionSchema
from app.schemas.workflow_version import WorkflowVersionStatus
......@@ -28,6 +30,8 @@ workflow_authorization = AuthorizationDependency(resource="workflow")
Authorization = Annotated[Callable[[str], Awaitable[Any]], Depends(workflow_authorization)]
tracer = trace.get_tracer_provider().get_tracer(__name__)
@unique
class DocumentationEnum(Enum):
......@@ -50,6 +54,7 @@ class DocumentationEnum(Enum):
@router.get("", status_code=status.HTTP_200_OK, summary="Get all versions of a workflow")
@start_as_current_span_async("api_workflow_version_list", tracer=tracer)
async def list_workflow_version(
current_user: CurrentUser,
workflow: CurrentWorkflow,
......@@ -82,6 +87,10 @@ async def list_workflow_version(
versions : [app.schemas.workflow_version.WorkflowVersion]
All versions of the workflow
"""
current_span = trace.get_current_span()
current_span.set_attribute("workflow_id", str(workflow.workflow_id))
if version_status is not None and len(version_status) > 0:
current_span.set_attribute("version_status", [stat.name for stat in version_status])
rbac_operation = (
"list_filter" if workflow.developer_id != current_user.uid and version_status is not None else "list"
)
......@@ -102,6 +111,7 @@ async def list_workflow_version(
status_code=status.HTTP_200_OK,
summary="Get a workflow version",
)
@start_as_current_span_async("api_workflow_version_get", tracer=tracer)
async def get_workflow_version(
workflow: CurrentWorkflow,
db: DBSession,
......@@ -137,6 +147,9 @@ async def get_workflow_version(
version : app.schemas.workflow_version.WorkflowVersion
The specified WorkflowVersion
"""
trace.get_current_span().set_attributes(
{"workflow_id": str(workflow.workflow_id), "git_commit_hash": git_commit_hash}
)
rbac_operation = "read"
version = (
await CRUDWorkflowVersion.get_latest(db, workflow.workflow_id)
......@@ -160,6 +173,7 @@ async def get_workflow_version(
@router.patch("/{git_commit_hash}/status", status_code=status.HTTP_200_OK, summary="Update status of workflow version")
@start_as_current_span_async("api_workflow_version_status_update", tracer=tracer)
async def update_workflow_version_status(
version_status: WorkflowVersionStatus,
workflow_version: CurrentWorkflowVersion,
......@@ -186,6 +200,13 @@ async def update_workflow_version_status(
version : clowmdb.models.WorkflowVersion
Version of the workflow with updated status
"""
trace.get_current_span().set_attributes(
{
"workflow_id": str(workflow_version.workflow_id),
"git_commit_hash": workflow_version.git_commit_hash,
"version_status": version_status.status.name,
}
)
await authorization("update_status")
await CRUDWorkflowVersion.update_status(db, workflow_version.git_commit_hash, version_status.status)
workflow_version.status = version_status.status
......@@ -193,6 +214,7 @@ async def update_workflow_version_status(
@router.post("/{git_commit_hash}/deprecate", status_code=status.HTTP_200_OK, summary="Deprecate a workflow version")
@start_as_current_span_async("api_workflow_version_status_update", tracer=tracer)
async def deprecate_workflow_version(
workflow: CurrentWorkflow,
workflow_version: CurrentWorkflowVersion,
......@@ -223,6 +245,9 @@ async def deprecate_workflow_version(
version : clowmdb.models.WorkflowVersion
Version of the workflow with deprecated status
"""
trace.get_current_span().set_attributes(
{"workflow_id": str(workflow_version.workflow_id), "git_commit_hash": workflow_version.git_commit_hash}
)
await authorization("update_status" if current_user.uid != workflow.developer_id else "update")
await CRUDWorkflowVersion.update_status(db, workflow_version.git_commit_hash, WorkflowVersion.Status.DEPRECATED)
workflow_version.status = WorkflowVersion.Status.DEPRECATED
......@@ -235,14 +260,14 @@ async def deprecate_workflow_version(
summary="Fetch documentation for a workflow version",
response_class=StreamingResponse,
)
@start_as_current_span_async("api_workflow_version_get_documentation", tracer=tracer)
async def download_workflow_documentation(
workflow: CurrentWorkflow,
workflow_version: CurrentWorkflowVersion,
authorization: Authorization,
client: HTTPClient,
db: DBSession,
document: DocumentationEnum = Query(
DocumentationEnum.USAGE, description="Specific which type of documentation the client wants to fetch"
DocumentationEnum.USAGE, description="Specify which type of documentation the client wants to fetch"
),
mode_id: Optional[UUID] = Query(default=None, description="Workflow Mode"),
) -> StreamingResponse:
......@@ -263,8 +288,6 @@ async def download_workflow_documentation(
HTTP Client with an open connection. Dependency Injection.
document : DocumentationEnum, default DocumentationEnum.USAGE
Which type of documentation the client wants to fetch
db : sqlalchemy.ext.asyncio.AsyncSession.
Async database session to perform query on. Dependency Injection.
mode_id : UUID | None
Select the workflow mode of the workflow version
......@@ -273,6 +296,16 @@ async def download_workflow_documentation(
response : StreamingResponse
Streams the requested document from the git repository directly to the client
"""
current_span = trace.get_current_span()
current_span.set_attributes(
{
"workflow_id": str(workflow_version.workflow_id),
"git_commit_hash": workflow_version.git_commit_hash,
"document": document.name,
}
)
if mode_id is not None:
current_span.set_attribute("workflow_mode_id", str(mode_id))
await authorization("read")
repo = build_repository(
workflow.repository_url,
......@@ -301,6 +334,7 @@ async def download_workflow_documentation(
status_code=status.HTTP_201_CREATED,
summary="Upload icon for workflow version",
)
@start_as_current_span_async("api_workflow_version_upload_icon", tracer=tracer)
async def upload_workflow_version_icon(
workflow: CurrentWorkflow,
background_tasks: BackgroundTasks,
......@@ -339,11 +373,16 @@ async def upload_workflow_version_icon(
icon_url : str
URL where the icon can be downloaded
"""
current_span = trace.get_current_span()
current_span.set_attributes(
{"workflow_id": str(workflow_version.workflow_id), "git_commit_hash": workflow_version.git_commit_hash}
)
await authorization("update")
if current_user.uid != workflow.developer_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the developer can update his workflow")
old_slug = workflow_version.icon_slug
icon_slug = upload_icon(s3=s3, background_tasks=background_tasks, icon=icon)
current_span.set_attribute("icon_slug", icon_slug)
await CRUDWorkflowVersion.update_icon(db, workflow_version.git_commit_hash, icon_slug)
# Delete old icon if possible
if old_slug is not None:
......@@ -356,6 +395,7 @@ async def upload_workflow_version_icon(
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete icon of workflow version",
)
@start_as_current_span_async("api_workflow_version_delete_icon", tracer=tracer)
async def delete_workflow_version_icon(
workflow: CurrentWorkflow,
workflow_version: CurrentWorkflowVersion,
......@@ -386,6 +426,10 @@ async def delete_workflow_version_icon(
db : sqlalchemy.ext.asyncio.AsyncSession.
Async database session to perform query on. Dependency Injection.
"""
current_span = trace.get_current_span()
current_span.set_attributes(
{"workflow_id": str(workflow_version.workflow_id), "git_commit_hash": workflow_version.git_commit_hash}
)
await authorization("update")
if current_user.uid != workflow.developer_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the developer can update his workflow")
......
......@@ -12,6 +12,7 @@ from clowmdb.models import WorkflowExecution, WorkflowMode
from fastapi import BackgroundTasks, HTTPException, UploadFile, status
from httpx import AsyncClient, ConnectError, ConnectTimeout
from mako.template import Template
from opentelemetry import trace
from PIL import Image, UnidentifiedImageError
from sqlalchemy.ext.asyncio import AsyncSession
......@@ -33,6 +34,8 @@ s3_file_regex = re.compile(
r"s3://(?!(((2(5[0-5]|[0-4]\d)|[01]?\d{1,2})\.){3}(2(5[0-5]|[0-4]\d)|[01]?\d{1,2})$))[a-z\d][a-z\d.-]{1,61}[a-z\d][^\"]*"
)
tracer = trace.get_tracer_provider().get_tracer(__name__)
def upload_icon(s3: S3ServiceResource, background_tasks: BackgroundTasks, icon: UploadFile) -> str:
"""
......@@ -80,10 +83,12 @@ def _process_and_upload_icon(s3: S3ServiceResource, icon_slug: str, icon_buffer:
thumbnail_buffer = BytesIO()
im.save(thumbnail_buffer, "PNG") # save in buffer as PNG image
thumbnail_buffer.seek(0)
# Upload to bucket
s3.Bucket(name=settings.ICON_BUCKET).Object(key=icon_slug).upload_fileobj(
Fileobj=thumbnail_buffer, ExtraArgs={"ContentType": "image/png"}
)
with tracer.start_as_current_span("s3_upload_workflow_version_icon") as span:
span.set_attribute("icon", icon_slug)
# Upload to bucket
s3.Bucket(name=settings.ICON_BUCKET).Object(key=icon_slug).upload_fileobj(
Fileobj=thumbnail_buffer, ExtraArgs={"ContentType": "image/png"}
)
async def delete_remote_icon(s3: S3ServiceResource, db: AsyncSession, icon_slug: str) -> None:
......@@ -101,7 +106,9 @@ async def delete_remote_icon(s3: S3ServiceResource, db: AsyncSession, icon_slug:
"""
# If there are no more Workflow versions that have this icon, delete it in the S3 ICON_BUCKET
if not await CRUDWorkflowVersion.icon_exists(db, icon_slug):
s3.Bucket(name=settings.ICON_BUCKET).Object(key=icon_slug).delete()
with tracer.start_as_current_span("s3_delete_workflow_version_icon") as span:
span.set_attribute("icon", icon_slug)
s3.Bucket(name=settings.ICON_BUCKET).Object(key=icon_slug).delete()
async def check_repo(
......@@ -174,7 +181,9 @@ async def start_workflow_execution(
with SpooledTemporaryFile(max_size=512000) as f:
f.write(json.dumps(parameters).encode("utf-8"))
f.seek(0)
s3.Bucket(name=settings.PARAMS_BUCKET).Object(key=params_file_name).upload_fileobj(f)
with tracer.start_as_current_span("s3_upload_workflow_execution_parameters") as span:
span.set_attribute("workflow_execution_id", str(execution.execution_id))
s3.Bucket(name=settings.PARAMS_BUCKET).Object(key=params_file_name).upload_fileobj(f)
for key in parameters.keys():
if isinstance(parameters[key], str):
# Escape string parameters for bash shell
......@@ -185,7 +194,8 @@ async def start_workflow_execution(
if scm_file_id is not None:
scm_file_name = f"{scm_file_id}.scm"
try:
s3.Bucket(settings.PARAMS_BUCKET).Object(scm_file_name).load()
with tracer.start_as_current_span("s3_check_workflow_execution_parameters"):
s3.Bucket(settings.PARAMS_BUCKET).Object(scm_file_name).load()
scm_file_id = f"repo{scm_file_id}"
except botocore.client.ClientError:
scm_file_id = None
......@@ -240,17 +250,21 @@ async def _monitor_proper_job_execution(
slurm_job_id : int
ID of the slurm job to monitor
"""
previous_span_link = None
while True:
await async_sleep(settings.SLURM_JOB_STATUS_CHECK_INTERVAL)
if await slurm_client.is_job_finished(slurm_job_id):
execution = await CRUDWorkflowExecution.get(db, execution_id=execution_id)
# Check if the execution is marked as finished in the database
if execution is not None and execution.end_time is None:
# Mark job as finished with an error
await CRUDWorkflowExecution.cancel(
db, execution_id=execution_id, status=WorkflowExecution.WorkflowExecutionStatus.ERROR
)
break
with tracer.start_span("monitor_job", links=previous_span_link) as span:
span.set_attributes({"execution_id": str(execution_id), "slurm_job_id": slurm_job_id})
if await slurm_client.is_job_finished(slurm_job_id):
execution = await CRUDWorkflowExecution.get(db, execution_id=execution_id)
# Check if the execution is marked as finished in the database
if execution is not None and execution.end_time is None:
# Mark job as finished with an error
await CRUDWorkflowExecution.cancel(
db, execution_id=execution_id, status=WorkflowExecution.WorkflowExecutionStatus.ERROR
)
break
previous_span_link = [trace.Link(span.get_span_context())]
async def check_active_workflow_execution_limit(db: AsyncSession, uid: str) -> None:
......@@ -361,4 +375,5 @@ def upload_scm_file(s3: S3ServiceResource, scm: SCM, scm_file_id: str) -> None:
with BytesIO() as handle:
scm.serialize(handle)
handle.seek(0)
s3.Bucket(settings.PARAMS_BUCKET).Object(f"{scm_file_id}.scm").upload_fileobj(handle)
with tracer.start_as_current_span("s3_upload_workflow_credentials"):
s3.Bucket(settings.PARAMS_BUCKET).Object(f"{scm_file_id}.scm").upload_fileobj(handle)
......@@ -112,6 +112,9 @@ class Settings(BaseSettings):
description="Interval for checking the slurm jobs status after starting a workflow execution in seconds",
)
DEV_SYSTEM: bool = Field(False, description="Open a endpoint where to execute arbitrary workflows.")
OTLP_GRPC_ENDPOINT: Optional[str] = Field(
None, description="OTLP compatible endpoint to send traces via gRPC, e.g. Jaeger"
)
model_config = SettingsConfigDict(case_sensitive=True, env_file=".env", secrets_dir="/run/secrets", extra="ignore")
......
from typing import Optional
from clowmdb.models import Bucket, BucketPermission
from opentelemetry import trace
from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.otlp import start_as_current_span_async
tracer = trace.get_tracer_provider().get_tracer(__name__)
class CRUDBucket:
@staticmethod
@start_as_current_span_async("db_check_bucket_exists", tracer=tracer)
async def check_bucket_exist(db: AsyncSession, bucket_name: str) -> bool:
"""
Check if the given bucket exists.
......@@ -24,10 +30,12 @@ class CRUDBucket:
Flag if the check was successful.
"""
stmt = select(Bucket).where(Bucket.name == bucket_name)
trace.get_current_span().set_attributes({"bucket_name": bucket_name, "sql_query": str(stmt)})
bucket = await db.scalar(stmt)
return bucket is not None
@staticmethod
@start_as_current_span_async("db_check_bucket_access", tracer=tracer)
async def check_access(db: AsyncSession, bucket_name: str, uid: str, key: Optional[str] = None) -> bool:
"""
Check if the given user has access to the bucket.
......@@ -48,7 +56,9 @@ class CRUDBucket:
check : bool
Flag if the check was successful.
"""
current_span = trace.get_current_span()
stmt = select(Bucket).where(Bucket.name == bucket_name).where(Bucket.owner_id == uid)
current_span.set_attributes({"bucket_name": bucket_name, "sql_query": str(stmt)})
bucket = await db.scalar(stmt)
# If the user is the owner of the bucket -> user has access
if bucket is not None:
......@@ -72,6 +82,7 @@ class CRUDBucket:
)
)
)
current_span.set_attributes({"sql_query": str(stmt)})
permission: Optional[BucketPermission] = await db.scalar(stmt)
# If the user has no active READWRITE Permission for the bucket -> user has no access
......
from typing import Optional
from clowmdb.models import User
from opentelemetry import trace
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
tracer = trace.get_tracer_provider().get_tracer(__name__)
class CRUDUser:
@staticmethod
......@@ -23,5 +26,7 @@ class CRUDUser:
user : clowmdb.models.User | None
The user for the given UID if he exists, None otherwise
"""
stmt = select(User).where(User.uid == uid)
return await db.scalar(stmt)
with tracer.start_as_current_span("db_get_user") as span:
stmt = select(User).where(User.uid == uid)
span.set_attribute("sql_query", str(stmt))
return await db.scalar(stmt)
......@@ -2,6 +2,7 @@ from typing import List, Optional, Union
from uuid import UUID
from clowmdb.models import Workflow, WorkflowExecution, WorkflowVersion
from opentelemetry import trace
from sqlalchemy import Date, cast, delete, func, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
......@@ -10,6 +11,8 @@ from app.crud.crud_workflow_mode import CRUDWorkflowMode
from app.crud.crud_workflow_version import CRUDWorkflowVersion
from app.schemas.workflow import WorkflowIn, WorkflowStatistic
tracer = trace.get_tracer_provider().get_tracer(__name__)
class CRUDWorkflow:
@staticmethod
......@@ -38,18 +41,23 @@ class CRUDWorkflow:
workflows : List[clowmdb.models.Workflow]
List of workflows.
"""
stmt = select(Workflow).options(joinedload(Workflow.versions).selectinload(WorkflowVersion.workflow_modes))
if name_substring is not None:
stmt = stmt.where(Workflow.name.contains(name_substring))
if developer_id is not None:
stmt = stmt.where(Workflow.developer_id == developer_id)
if version_status is not None:
stmt = stmt.options(
joinedload(
Workflow.versions.and_(or_(*[WorkflowVersion.status == status for status in version_status]))
with tracer.start_as_current_span("db_list_workflows") as span:
stmt = select(Workflow).options(joinedload(Workflow.versions).selectinload(WorkflowVersion.workflow_modes))
if name_substring is not None:
span.set_attribute("name_substring", name_substring)
stmt = stmt.where(Workflow.name.contains(name_substring))
if developer_id is not None:
span.set_attribute("uid", developer_id)
stmt = stmt.where(Workflow.developer_id == developer_id)
if version_status is not None and len(version_status) > 0:
span.set_attribute("status", [stat.name for stat in version_status])
stmt = stmt.options(
joinedload(
Workflow.versions.and_(or_(*[WorkflowVersion.status == status for status in version_status]))
)
)
)
return [w for w in (await db.scalars(stmt)).unique().all() if len(w.versions) > 0]
span.set_attribute("sql_query", str(stmt))
return [w for w in (await db.scalars(stmt)).unique().all() if len(w.versions) > 0]
@staticmethod
async def delete(db: AsyncSession, workflow_id: Union[UUID, bytes]) -> None:
......@@ -63,10 +71,12 @@ class CRUDWorkflow:
workflow_id : bytes | uuid.UUID
UID of a workflow
"""
wid = workflow_id.bytes if isinstance(workflow_id, UUID) else workflow_id
stmt = delete(Workflow).where(Workflow._workflow_id == wid)
await db.execute(stmt)
await db.commit()
with tracer.start_as_current_span("db_delete_workflow") as span:
wid = workflow_id.bytes if isinstance(workflow_id, UUID) else workflow_id
stmt = delete(Workflow).where(Workflow._workflow_id == wid)
span.set_attributes({"workflow_id": str(workflow_id), "sql_query": str(stmt)})
await db.execute(stmt)
await db.commit()
@staticmethod
async def update_credentials(
......@@ -84,10 +94,12 @@ class CRUDWorkflow:
token : str | None
Token to save in the database. If None, the token in the database gets deleted
"""
wid = workflow_id.bytes if isinstance(workflow_id, UUID) else workflow_id
stmt = update(Workflow).where(Workflow._workflow_id == wid).values(credentials_token=token)
await db.execute(stmt)
await db.commit()
with tracer.start_as_current_span("db_update_workflow_credentials") as span:
wid = workflow_id.bytes if isinstance(workflow_id, UUID) else workflow_id
stmt = update(Workflow).where(Workflow._workflow_id == wid).values(credentials_token=token)
span.set_attributes({"workflow_id": str(workflow_id), "sql_query": str(stmt), "delete": token is None})
await db.execute(stmt)
await db.commit()
@staticmethod
async def statistics(db: AsyncSession, workflow_id: Union[bytes, UUID]) -> List[WorkflowStatistic]:
......@@ -106,16 +118,18 @@ class CRUDWorkflow:
stat : List[app.schemas.Workflow.WorkflowStatistic]
List of datapoints
"""
wid = workflow_id.bytes if isinstance(workflow_id, UUID) else workflow_id
stmt = (
select(cast(func.FROM_UNIXTIME(WorkflowExecution.start_time), Date).label("day"), func.count())
.select_from(WorkflowExecution)
.join(WorkflowVersion)
.where(WorkflowVersion._workflow_id == wid)
.group_by("day")
.order_by("day")
)
return [WorkflowStatistic(day=row.day, count=row.count) for row in await db.execute(stmt)]
with tracer.start_as_current_span("db_get_workflow_statistics") as span:
wid = workflow_id.bytes if isinstance(workflow_id, UUID) else workflow_id
stmt = (
select(cast(func.FROM_UNIXTIME(WorkflowExecution.start_time), Date).label("day"), func.count())
.select_from(WorkflowExecution)
.join(WorkflowVersion)
.where(WorkflowVersion._workflow_id == wid)
.group_by("day")
.order_by("day")
)
span.set_attributes({"workflow_id": str(workflow_id), "sql_query": str(stmt)})
return [WorkflowStatistic(day=row.day, count=row.count) for row in await db.execute(stmt)]
@staticmethod
async def get(db: AsyncSession, workflow_id: Union[UUID, bytes]) -> Optional[Workflow]:
......@@ -134,13 +148,15 @@ class CRUDWorkflow:
user : clowmdb.models.Workflow | None
The workflow with the given ID if it exists, None otherwise
"""
wid = workflow_id.bytes if isinstance(workflow_id, UUID) else workflow_id
stmt = (
select(Workflow)
.where(Workflow._workflow_id == wid)
.options(joinedload(Workflow.versions).selectinload(WorkflowVersion.workflow_modes))
)
return await db.scalar(stmt)
with tracer.start_as_current_span("db_get_workflow") as span:
wid = workflow_id.bytes if isinstance(workflow_id, UUID) else workflow_id
stmt = (
select(Workflow)
.where(Workflow._workflow_id == wid)
.options(joinedload(Workflow.versions).selectinload(WorkflowVersion.workflow_modes))
)
span.set_attributes({"workflow_id": str(workflow_id), "sql_query": str(stmt)})
return await db.scalar(stmt)
@staticmethod
async def get_by_name(db: AsyncSession, workflow_name: str) -> Optional[Workflow]:
......@@ -159,12 +175,14 @@ class CRUDWorkflow:
user : clowmdb.models.Workflow | None
The workflow with the given name if it exists, None otherwise
"""
stmt = (
select(Workflow)
.where(Workflow.name == workflow_name)
.options(joinedload(Workflow.versions).selectinload(WorkflowVersion.workflow_modes))
)
return await db.scalar(stmt)
with tracer.start_as_current_span("db_get_workflow_by_name") as span:
stmt = (
select(Workflow)
.where(Workflow.name == workflow_name)
.options(joinedload(Workflow.versions).selectinload(WorkflowVersion.workflow_modes))
)
span.set_attributes({"name": workflow_name, "sql_query": str(stmt)})
return await db.scalar(stmt)
@staticmethod
async def create(
......@@ -192,26 +210,28 @@ class CRUDWorkflow:
workflow : clowmdb.models.Workflow
The newly created workflow
"""
workflow_db = Workflow(
name=workflow.name,
repository_url=workflow.repository_url,
short_description=workflow.short_description,
developer_id=developer,
credentials_token=workflow.token,
)
db.add(workflow_db)
await db.commit()
# If there are workflow modes, create them first
modes_db = []
if len(workflow.modes) > 0:
modes_db = await CRUDWorkflowMode.create(db, workflow.modes)
await CRUDWorkflowVersion.create(
db,
git_commit_hash=workflow.git_commit_hash,
version=workflow.initial_version,
wid=workflow_db.workflow_id,
icon_slug=icon_slug,
modes=[mode.mode_id for mode in modes_db],
)
return await CRUDWorkflow.get(db, workflow_db.workflow_id)
with tracer.start_as_current_span("db_create_workflow") as span:
workflow_db = Workflow(
name=workflow.name,
repository_url=workflow.repository_url,
short_description=workflow.short_description,
developer_id=developer,
credentials_token=workflow.token,
)
db.add(workflow_db)
await db.commit()
# If there are workflow modes, create them first
modes_db = []
if len(workflow.modes) > 0:
modes_db = await CRUDWorkflowMode.create(db, workflow.modes)
await CRUDWorkflowVersion.create(
db,
git_commit_hash=workflow.git_commit_hash,
version=workflow.initial_version,
wid=workflow_db.workflow_id,
icon_slug=icon_slug,
modes=[mode.mode_id for mode in modes_db],
)
span.set_attribute("workflow_id", workflow_db.workflow_id)
return await CRUDWorkflow.get(db, workflow_db.workflow_id)
......@@ -2,12 +2,15 @@ from typing import List, Optional, Sequence, Union
from uuid import UUID
from clowmdb.models import WorkflowExecution
from opentelemetry import trace
from sqlalchemy import delete, func, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.schemas.workflow_execution import DevWorkflowExecutionIn, WorkflowExecutionIn
tracer = trace.get_tracer_provider().get_tracer(__name__)
class CRUDWorkflowExecution:
@staticmethod
......@@ -36,22 +39,24 @@ class CRUDWorkflowExecution:
workflow_execution : clowmdb.models.WorkflowExecution
The newly created workflow execution
"""
if isinstance(execution, WorkflowExecutionIn):
workflow_execution = WorkflowExecution(
user_id=owner_id,
workflow_version_id=execution.workflow_version_id,
notes=execution.notes,
slurm_job_id=-1,
_workflow_mode_id=execution.mode.bytes if execution.mode is not None else None,
)
else:
workflow_execution = WorkflowExecution(
user_id=owner_id, workflow_version_id=None, notes=notes, slurm_job_id=-1
)
db.add(workflow_execution)
await db.commit()
await db.refresh(workflow_execution)
return workflow_execution
with tracer.start_as_current_span("db_create_workflow_execution") as span:
if isinstance(execution, WorkflowExecutionIn):
workflow_execution = WorkflowExecution(
user_id=owner_id,
workflow_version_id=execution.workflow_version_id,
notes=execution.notes,
slurm_job_id=-1,
_workflow_mode_id=execution.mode.bytes if execution.mode is not None else None,
)
else:
workflow_execution = WorkflowExecution(
user_id=owner_id, workflow_version_id=None, notes=notes, slurm_job_id=-1
)
db.add(workflow_execution)
await db.commit()
await db.refresh(workflow_execution)
span.set_attribute("workflow_execution_id", str(workflow_execution.execution_id))
return workflow_execution
@staticmethod
async def get(db: AsyncSession, execution_id: Union[bytes, UUID]) -> Optional[WorkflowExecution]:
......@@ -70,14 +75,15 @@ class CRUDWorkflowExecution:
workflow_execution : clowmdb.models.WorkflowExecution
The workflow execution with the given id if it exists, None otherwise
"""
eid = execution_id.bytes if isinstance(execution_id, UUID) else execution_id
stmt = (
select(WorkflowExecution)
.where(WorkflowExecution._execution_id == eid)
.options(joinedload(WorkflowExecution.workflow_version))
)
execution = await db.scalar(stmt)
return execution
with tracer.start_as_current_span("db_get_workflow_execution") as span:
eid = execution_id.bytes if isinstance(execution_id, UUID) else execution_id
stmt = (
select(WorkflowExecution)
.where(WorkflowExecution._execution_id == eid)
.options(joinedload(WorkflowExecution.workflow_version))
)
span.set_attributes({"workflow_execution_id": str(execution_id), "sql_query": str(stmt)})
return await db.scalar(stmt)
@staticmethod
async def list(
......@@ -105,15 +111,20 @@ class CRUDWorkflowExecution:
workflow_executions : List[clowmdb.models.WorkflowExecution]
List of all workflow executions with applied filters.
"""
stmt = select(WorkflowExecution).options(joinedload(WorkflowExecution.workflow_version))
if uid is not None:
stmt = stmt.where(WorkflowExecution.user_id == uid)
if workflow_version_id is not None:
stmt = stmt.where(WorkflowExecution.workflow_version_id == workflow_version_id)
if status_list is not None:
stmt = stmt.where(or_(*[WorkflowExecution.status == status for status in status_list]))
executions = (await db.scalars(stmt)).all()
return executions
with tracer.start_as_current_span("db_list_workflow_executions") as span:
stmt = select(WorkflowExecution).options(joinedload(WorkflowExecution.workflow_version))
if uid is not None:
span.set_attribute("uid", uid)
stmt = stmt.where(WorkflowExecution.user_id == uid)
if workflow_version_id is not None:
span.set_attribute("git_commit_hash", workflow_version_id)
stmt = stmt.where(WorkflowExecution.workflow_version_id == workflow_version_id)
if status_list is not None:
span.set_attribute("status", [stat.name for stat in status_list])
stmt = stmt.where(or_(*[WorkflowExecution.status == status for status in status_list]))
span.set_attribute("sql_query", str(stmt))
executions = (await db.scalars(stmt)).all()
return executions
@staticmethod
async def delete(db: AsyncSession, execution_id: Union[bytes, UUID]) -> None:
......@@ -127,10 +138,12 @@ class CRUDWorkflowExecution:
execution_id : uuid.UUID | bytes
ID of the workflow execution
"""
eid = execution_id.bytes if isinstance(execution_id, UUID) else execution_id
stmt = delete(WorkflowExecution).where(WorkflowExecution._execution_id == eid)
await db.execute(stmt)
await db.commit()
with tracer.start_as_current_span("db_delete_workflow_execution") as span:
eid = execution_id.bytes if isinstance(execution_id, UUID) else execution_id
stmt = delete(WorkflowExecution).where(WorkflowExecution._execution_id == eid)
span.set_attributes({"workflow_execution_id": str(execution_id), "sql_query": str(stmt)})
await db.execute(stmt)
await db.commit()
@staticmethod
async def cancel(
......@@ -150,14 +163,18 @@ class CRUDWorkflowExecution:
status : clowmdb.models.WorkflowExecution.WorkflowExecutionStatus, default WorkflowExecutionStatus.CANCELED
Error status the workflow execution should get
"""
eid = execution_id.bytes if isinstance(execution_id, UUID) else execution_id
stmt = (
update(WorkflowExecution)
.where(WorkflowExecution._execution_id == eid)
.values(status=status.name, end_time=func.UNIX_TIMESTAMP())
)
await db.execute(stmt)
await db.commit()
with tracer.start_as_current_span("db_cancel_workflow_execution") as span:
eid = execution_id.bytes if isinstance(execution_id, UUID) else execution_id
stmt = (
update(WorkflowExecution)
.where(WorkflowExecution._execution_id == eid)
.values(status=status.name, end_time=func.UNIX_TIMESTAMP())
)
span.set_attributes(
{"workflow_execution_id": str(execution_id), "status": status.name, "sql_query": str(stmt)}
)
await db.execute(stmt)
await db.commit()
@staticmethod
async def update_slurm_job_id(db: AsyncSession, execution_id: Union[bytes, UUID], slurm_job_id: int) -> None:
......@@ -173,7 +190,15 @@ class CRUDWorkflowExecution:
slurm_job_id : int
New slurm job ID
"""
eid = execution_id.bytes if isinstance(execution_id, UUID) else execution_id
stmt = update(WorkflowExecution).where(WorkflowExecution._execution_id == eid).values(slurm_job_id=slurm_job_id)
await db.execute(stmt)
await db.commit()
with tracer.start_as_current_span("db_update_workflow_execution_slurm_id") as span:
eid = execution_id.bytes if isinstance(execution_id, UUID) else execution_id
stmt = (
update(WorkflowExecution)
.where(WorkflowExecution._execution_id == eid)
.values(slurm_job_id=slurm_job_id)
)
span.set_attributes(
{"workflow_execution_id": str(execution_id), "slurm_job_id": slurm_job_id, "sql_query": str(stmt)}
)
await db.execute(stmt)
await db.commit()
......@@ -2,11 +2,14 @@ from typing import Iterable, List, Optional, Union
from uuid import UUID
from clowmdb.models import WorkflowMode, workflow_mode_association_table
from opentelemetry import trace
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.schemas.workflow_mode import WorkflowModeIn
tracer = trace.get_tracer_provider().get_tracer(__name__)
class CRUDWorkflowMode:
@staticmethod
......@@ -29,12 +32,14 @@ class CRUDWorkflowMode:
modes : List[clowmdb.models.WorkflowMode]
List of workflow modes.
"""
stmt = (
select(WorkflowMode)
.join(workflow_mode_association_table)
.where(workflow_mode_association_table.columns.workflow_version_commit_hash == workflow_version)
)
return list((await db.scalars(stmt)).all())
with tracer.start_as_current_span("db_list_workflow_modes") as span:
stmt = (
select(WorkflowMode)
.join(workflow_mode_association_table)
.where(workflow_mode_association_table.columns.workflow_version_commit_hash == workflow_version)
)
span.set_attributes({"git_commit_hash": workflow_version, "sql_query": str(stmt)})
return list((await db.scalars(stmt)).all())
@staticmethod
async def get(
......@@ -57,13 +62,17 @@ class CRUDWorkflowMode:
workflows : clowmdb.models.WorkflowMode | None
Requested workflow mode if it exists, None otherwise
"""
mid = mode_id.bytes if isinstance(mode_id, UUID) else mode_id
stmt = select(WorkflowMode).where(WorkflowMode._mode_id == mid)
if workflow_version is not None:
stmt = stmt.join(workflow_mode_association_table).where(
workflow_mode_association_table.columns.workflow_version_commit_hash == workflow_version
)
return await db.scalar(stmt)
with tracer.start_as_current_span("db_get_workflow_mode") as span:
mid = mode_id.bytes if isinstance(mode_id, UUID) else mode_id
span.set_attribute("workflow_mode_id", str(mode_id))
stmt = select(WorkflowMode).where(WorkflowMode._mode_id == mid)
if workflow_version is not None:
span.set_attribute("git_commit_hash", workflow_version)
stmt = stmt.join(workflow_mode_association_table).where(
workflow_mode_association_table.columns.workflow_version_commit_hash == workflow_version
)
span.set_attribute("sql_query", str(stmt))
return await db.scalar(stmt)
@staticmethod
async def create(db: AsyncSession, modes: List[WorkflowModeIn]) -> List[WorkflowMode]:
......@@ -82,13 +91,15 @@ class CRUDWorkflowMode:
modes : List[clowmdb.models.WorkflowMode]
Newly created workflow modes
"""
modes_db = []
for mode in modes:
mode_db = WorkflowMode(name=mode.name, entrypoint=mode.entrypoint, schema_path=mode.schema_path)
db.add(mode_db)
modes_db.append(mode_db)
await db.commit()
return modes_db
with tracer.start_as_current_span("db_create_workflow_mode") as span:
modes_db = []
for mode in modes:
mode_db = WorkflowMode(name=mode.name, entrypoint=mode.entrypoint, schema_path=mode.schema_path)
db.add(mode_db)
modes_db.append(mode_db)
await db.commit()
span.set_attribute("workflow_mode_ids", [str(m.mode_id) for m in modes_db])
return modes_db
@staticmethod
async def delete(db: AsyncSession, modes: Iterable[UUID]) -> None:
......@@ -102,6 +113,8 @@ class CRUDWorkflowMode:
modes : List[uuid.UUID]
ID of workflow modes to delete
"""
stmt = delete(WorkflowMode).where(WorkflowMode._mode_id.in_([uuid.bytes for uuid in modes]))
await db.execute(stmt)
await db.commit()
with tracer.start_as_current_span("db_delete_workflow_mode") as span:
stmt = delete(WorkflowMode).where(WorkflowMode._mode_id.in_([uuid.bytes for uuid in modes]))
span.set_attributes({"workflow_mode_ids": [str(m) for m in modes], "sql_query": str(stmt)})
await db.execute(stmt)
await db.commit()
......@@ -2,10 +2,13 @@ from typing import List, Optional, Sequence, Union
from uuid import UUID
from clowmdb.models import WorkflowVersion, workflow_mode_association_table
from opentelemetry import trace
from sqlalchemy import desc, insert, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, selectinload
tracer = trace.get_tracer_provider().get_tracer(__name__)
class CRUDWorkflowVersion:
@staticmethod
......@@ -34,20 +37,28 @@ class CRUDWorkflowVersion:
user : clowmdb.models.WorkflowVersion | None
The workflow version with the given git_commit_hash if it exists, None otherwise
"""
stmt = (
select(WorkflowVersion)
.where(WorkflowVersion.git_commit_hash == git_commit_hash)
.options(selectinload(WorkflowVersion.workflow_modes))
)
if populate_workflow:
stmt = stmt.options(joinedload(WorkflowVersion.workflow))
if workflow_id is not None:
wid = workflow_id if isinstance(workflow_id, bytes) else workflow_id.bytes
stmt = stmt.where(WorkflowVersion._workflow_id == wid)
return await db.scalar(stmt)
with tracer.start_as_current_span("db_get_workflow_version") as span:
span.set_attribute("git_commit_hash", git_commit_hash)
stmt = (
select(WorkflowVersion)
.where(WorkflowVersion.git_commit_hash == git_commit_hash)
.options(selectinload(WorkflowVersion.workflow_modes))
)
if populate_workflow:
span.set_attribute("populate_workflow", True)
stmt = stmt.options(joinedload(WorkflowVersion.workflow))
if workflow_id is not None:
span.set_attribute("workflow_id", str(workflow_id))
wid = workflow_id if isinstance(workflow_id, bytes) else workflow_id.bytes
stmt = stmt.where(WorkflowVersion._workflow_id == wid)
span.set_attribute("sql_query", str(stmt))
return await db.scalar(stmt)
@staticmethod
async def get_latest(db: AsyncSession, wid: bytes | UUID, published: bool = True) -> Optional[WorkflowVersion]:
async def get_latest(
db: AsyncSession, wid: Union[bytes, UUID], published: bool = True
) -> Optional[WorkflowVersion]:
"""
Get the latest version of a workflow.
......@@ -65,25 +76,31 @@ class CRUDWorkflowVersion:
user : clowmdb.models.WorkflowVersion | None
The latest workflow version of the given workflow if the workflow exists, None otherwise
"""
stmt = (
select(WorkflowVersion)
.where(
WorkflowVersion._workflow_id == wid.bytes if isinstance(wid, UUID) else wid # type: ignore[arg-type]
)
.order_by(desc(WorkflowVersion.created_at))
.limit(1)
.options(selectinload(WorkflowVersion.workflow_modes))
)
if published:
stmt = stmt.where(
or_(
*[
WorkflowVersion.status == status
for status in [WorkflowVersion.Status.PUBLISHED, WorkflowVersion.Status.DEPRECATED]
]
with tracer.start_as_current_span("db_get_latest_workflow_version") as span:
span.set_attribute("workflow_id", str(wid))
stmt = (
select(WorkflowVersion)
.where(
WorkflowVersion._workflow_id == wid.bytes
if isinstance(wid, UUID)
else wid # type: ignore[arg-type]
)
.order_by(desc(WorkflowVersion.created_at))
.limit(1)
.options(selectinload(WorkflowVersion.workflow_modes))
)
return await db.scalar(stmt)
if published:
span.set_attribute("only_published", True)
stmt = stmt.where(
or_(
*[
WorkflowVersion.status == status
for status in [WorkflowVersion.Status.PUBLISHED, WorkflowVersion.Status.DEPRECATED]
]
)
)
span.set_attribute("sql_query", str(stmt))
return await db.scalar(stmt)
@staticmethod
async def list(
......@@ -106,17 +123,23 @@ class CRUDWorkflowVersion:
user : List[clowmdb.models.WorkflowVersion]
All workflow version of the given workflow
"""
stmt = (
select(WorkflowVersion)
.options(selectinload(WorkflowVersion.workflow_modes))
.where(
WorkflowVersion._workflow_id == wid.bytes if isinstance(wid, UUID) else wid # type: ignore[arg-type]
with tracer.start_as_current_span("db_list_workflow_versions") as span:
span.set_attribute("workflow_id", str(wid))
stmt = (
select(WorkflowVersion)
.options(selectinload(WorkflowVersion.workflow_modes))
.where(
WorkflowVersion._workflow_id == wid.bytes
if isinstance(wid, UUID)
else wid # type: ignore[arg-type]
)
)
)
if version_status is not None:
stmt = stmt.where(or_(*[WorkflowVersion.status == status for status in version_status]))
stmt = stmt.order_by(WorkflowVersion.created_at)
return (await db.scalars(stmt)).unique().all()
if version_status is not None and len(version_status) > 0:
span.set_attribute("version_status", [stat.name for stat in version_status])
stmt = stmt.where(or_(*[WorkflowVersion.status == status for status in version_status]))
stmt = stmt.order_by(WorkflowVersion.created_at)
span.set_attribute("sql_query", str(stmt))
return (await db.scalars(stmt)).unique().all()
@staticmethod
async def create(
......@@ -153,27 +176,30 @@ class CRUDWorkflowVersion:
workflow_version : clowmdb.models.WorkflowVersion
Newly create WorkflowVersion
"""
if modes is None:
modes = []
workflow_version = WorkflowVersion(
git_commit_hash=git_commit_hash,
version=version,
_workflow_id=wid.bytes if isinstance(wid, UUID) else wid,
icon_slug=icon_slug,
previous_version_hash=previous_version,
)
db.add(workflow_version)
if len(modes) > 0:
await db.commit()
await db.execute(
insert(workflow_mode_association_table),
[
{"workflow_version_commit_hash": git_commit_hash, "workflow_mode_id": mode_id.bytes}
for mode_id in modes
],
with tracer.start_as_current_span("db_create_workflow_version") as span:
span.set_attributes({"git_commit_version": git_commit_hash, "workflow_id": str(wid)})
if modes is None:
modes = []
workflow_version = WorkflowVersion(
git_commit_hash=git_commit_hash,
version=version,
_workflow_id=wid.bytes if isinstance(wid, UUID) else wid,
icon_slug=icon_slug,
previous_version_hash=previous_version,
)
await db.commit()
return workflow_version
db.add(workflow_version)
if len(modes) > 0:
span.set_attribute("mode_ids", [str(m) for m in modes])
await db.commit()
await db.execute(
insert(workflow_mode_association_table),
[
{"workflow_version_commit_hash": git_commit_hash, "workflow_mode_id": mode_id.bytes}
for mode_id in modes
],
)
await db.commit()
return workflow_version
@staticmethod
async def update_status(db: AsyncSession, git_commit_hash: str, status: WorkflowVersion.Status) -> None:
......@@ -189,11 +215,16 @@ class CRUDWorkflowVersion:
status : clowmdb.models.WorkflowVersion.Status
New status of the workflow version
"""
stmt = (
update(WorkflowVersion).where(WorkflowVersion.git_commit_hash == git_commit_hash).values(status=status.name)
)
await db.execute(stmt)
await db.commit()
with tracer.start_as_current_span("db_update_workflow_version_status") as span:
span.set_attributes({"status": status.name, "git_commit_version": git_commit_hash})
stmt = (
update(WorkflowVersion)
.where(WorkflowVersion.git_commit_hash == git_commit_hash)
.values(status=status.name)
)
span.set_attribute("sql_query", str(stmt))
await db.execute(stmt)
await db.commit()
@staticmethod
async def update_icon(db: AsyncSession, git_commit_hash: str, icon_slug: Optional[str] = None) -> None:
......@@ -209,13 +240,21 @@ class CRUDWorkflowVersion:
icon_slug : str | None, default None
The new icon slug
"""
stmt = (
update(WorkflowVersion)
.where(WorkflowVersion.git_commit_hash == git_commit_hash)
.values(icon_slug=icon_slug)
)
await db.execute(stmt)
await db.commit()
with tracer.start_as_current_span("db_update_workflow_version_icon") as span:
stmt = (
update(WorkflowVersion)
.where(WorkflowVersion.git_commit_hash == git_commit_hash)
.values(icon_slug=icon_slug)
)
span.set_attributes(
{
"git_commit_hash": git_commit_hash,
"icon_slug": icon_slug if icon_slug else "None",
"sql_query": str(stmt),
}
)
await db.execute(stmt)
await db.commit()
@staticmethod
async def icon_exists(db: AsyncSession, icon_slug: str) -> bool:
......@@ -234,6 +273,8 @@ class CRUDWorkflowVersion:
exists : bool
Flag if a version exists that depends on the icon
"""
stmt = select(WorkflowVersion).where(WorkflowVersion.icon_slug == icon_slug).limit(1)
version_with_icon = await db.scalar(stmt)
return version_with_icon is not None
with tracer.start_as_current_span("db_check_workflow_version_icon_exists") as span:
stmt = select(WorkflowVersion).where(WorkflowVersion.icon_slug == icon_slug).limit(1)
span.set_attributes({"icon_slug": icon_slug, "sql_query": str(stmt)})
version_with_icon = await db.scalar(stmt)
return version_with_icon is not None
......@@ -7,8 +7,11 @@ from typing import TYPE_CHECKING, AsyncIterator, Dict, List, Optional
from fastapi import HTTPException, status
from httpx import USE_CLIENT_DEFAULT, AsyncClient, Auth
from opentelemetry import trace
from pydantic import AnyHttpUrl
tracer = trace.get_tracer_provider().get_tracer(__name__)
if TYPE_CHECKING:
from mypy_boto3_s3.service_resource import Object
else:
......@@ -150,16 +153,18 @@ class GitRepository(ABC):
exist : List[bool]
Flags if the files exist.
"""
tasks = [asyncio.ensure_future(self.check_file_exists(file, client=client)) for file in files]
result = await asyncio.gather(*tasks)
if raise_error:
missing_files = [f for f, exist in zip(files, result) if not exist]
if len(missing_files) > 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The files {', '.join(missing_files)} are missing in the repo {str(self)}",
)
return result
with tracer.start_as_current_span("git_check_files_exists") as span:
span.set_attribute("repository", self.repo_url)
tasks = [asyncio.ensure_future(self.check_file_exists(file, client=client)) for file in files]
result = await asyncio.gather(*tasks)
if raise_error:
missing_files = [f for f, exist in zip(files, result) if not exist]
if len(missing_files) > 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The files {', '.join(missing_files)} are missing in the repo {str(self)}",
)
return result
async def copy_file_to_bucket(self, filepath: str, obj: Object, client: AsyncClient) -> None:
"""
......@@ -174,10 +179,12 @@ class GitRepository(ABC):
client : httpx.AsyncClient
Async HTTP Client with an open connection.
"""
with SpooledTemporaryFile(max_size=512000) as f: # temporary file with 500kB data spooled in memory
await self.download_file(filepath, client=client, file_handle=f)
f.seek(0)
obj.upload_fileobj(f)
with tracer.start_as_current_span("git_copy_file_to_bucket") as span:
span.set_attributes({"repository": self.repo_url, "file": filepath})
with SpooledTemporaryFile(max_size=512000) as f: # temporary file with 500kB data spooled in memory
await self.download_file(filepath, client=client, file_handle=f)
f.seek(0)
obj.upload_fileobj(f)
async def download_file_stream(self, filepath: str, client: AsyncClient) -> AsyncIterator[bytes]:
"""
......@@ -195,14 +202,16 @@ class GitRepository(ABC):
byte_iterator : AsyncIterator[bytes]
Async iterator over the bytes of the file
"""
async with client.stream(
method="GET",
url=str(await self.download_file_url(filepath, client)),
auth=USE_CLIENT_DEFAULT if self.request_auth is None else self.request_auth,
follow_redirects=True,
) as r:
async for chunk in r.aiter_bytes():
yield chunk
with tracer.start_as_current_span("git_stream_file_content") as span:
span.set_attributes({"repository": self.repo_url, "file": filepath})
async with client.stream(
method="GET",
url=str(await self.download_file_url(filepath, client)),
auth=USE_CLIENT_DEFAULT if self.request_auth is None else self.request_auth,
follow_redirects=True,
) as r:
async for chunk in r.aiter_bytes():
yield chunk
async def download_file(self, filepath: str, client: AsyncClient, file_handle: IOBase) -> None:
"""
......@@ -217,5 +226,7 @@ class GitRepository(ABC):
file_handle : IOBase
Write the file into this stream in binary mode.
"""
async for chunk in self.download_file_stream(filepath, client):
file_handle.write(chunk)
with tracer.start_as_current_span("git_download_file") as span:
span.set_attributes({"repository": self.repo_url, "file": filepath})
async for chunk in self.download_file_stream(filepath, client):
file_handle.write(chunk)
......@@ -4,10 +4,13 @@ from urllib.parse import quote, urlparse
from fastapi import status
from httpx import USE_CLIENT_DEFAULT, AsyncClient, BasicAuth
from opentelemetry import trace
from pydantic import AnyHttpUrl
from .abstract_repository import GitRepository
tracer = trace.get_tracer_provider().get_tracer(__name__)
class GitHubRepository(GitRepository):
"""
......@@ -66,11 +69,13 @@ class GitHubRepository(GitRepository):
path="/".join([self.account, self.repository, self.commit, filepath]),
)
# If the repo is private, request a download URL with a token from the GitHub API
response = await client.get(
str(self.check_file_url(filepath)),
auth=USE_CLIENT_DEFAULT if self.request_auth is None else self.request_auth,
headers=self.request_headers,
)
with tracer.start_as_current_span("github_get_download_link") as span:
span.set_attributes({"repository": self.repo_url, "file": filepath})
response = await client.get(
str(self.check_file_url(filepath)),
auth=USE_CLIENT_DEFAULT if self.request_auth is None else self.request_auth,
headers=self.request_headers,
)
assert response.status_code == status.HTTP_200_OK
return AnyHttpUrl(response.json()["download_url"])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment