diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a6a07950c9b673cc93736742574512507c0346ef..4a84d0e784db24d1d7aab76813e3b4dc26ff55dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,7 @@ repos: - id: check-merge-conflict - id: check-ast - repo: https://github.com/psf/black - rev: 23.12.1 + rev: 24.1.0 hooks: - id: black files: app diff --git a/app/api/endpoints/workflow.py b/app/api/endpoints/workflow.py index a64c96f8ead482a2486ab3cf6ee044cc9af2034c..497655c8082b430bc6a545da8ba31d5d6e35ca15 100644 --- a/app/api/endpoints/workflow.py +++ b/app/api/endpoints/workflow.py @@ -98,9 +98,11 @@ async def list_workflows( db, name_substring=name_substring, developer_id=developer_id, - version_status=[WorkflowVersion.Status.PUBLISHED, WorkflowVersion.Status.DEPRECATED] - if version_status is None - else version_status, + version_status=( + [WorkflowVersion.Status.PUBLISHED, WorkflowVersion.Status.DEPRECATED] + if version_status is None + else version_status + ), ) return [WorkflowOut.from_db_workflow(workflow, versions=workflow.versions) for workflow in workflows] diff --git a/app/api/endpoints/workflow_version.py b/app/api/endpoints/workflow_version.py index fad9b2df5b6a69f6ef0cc7cac9cea0f0ab941590..68844f7db38aa67699b4bc686c42a98a29cf148d 100644 --- a/app/api/endpoints/workflow_version.py +++ b/app/api/endpoints/workflow_version.py @@ -104,9 +104,11 @@ async def list_workflow_version( versions = await CRUDWorkflowVersion.list( db, workflow.workflow_id, - version_status=version_status - if version_status is not None - else [WorkflowVersion.Status.PUBLISHED, WorkflowVersion.Status.DEPRECATED], + version_status=( + version_status + if version_status is not None + else [WorkflowVersion.Status.PUBLISHED, WorkflowVersion.Status.DEPRECATED] + ), ) return [WorkflowVersionSchema.from_db_version(v, load_modes=True) for v in versions] diff --git a/app/api/middleware/etagmiddleware.py b/app/api/middleware/etagmiddleware.py index 24c99becf0b0703b956176bdd2bff32a39f43bc3..44cfd515f85e4a377f2afa151acb60c28403756f 100644 --- a/app/api/middleware/etagmiddleware.py +++ b/app/api/middleware/etagmiddleware.py @@ -1,8 +1,9 @@ from hashlib import md5 -from typing import Mapping, Optional +from typing import Awaitable, Callable, Mapping, Optional -from fastapi import status +from fastapi import Request, Response, status from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware class HashJSONResponse(JSONResponse): @@ -13,73 +14,16 @@ class HashJSONResponse(JSONResponse): self.headers["ETag"] = md5(self.body).hexdigest() -# from typing import Awaitable, Callable, -# from fastapi import Request, Response -# from starlette.middleware.base import BaseHTTPMiddleware -# from starlette.datastructures import Headers, MutableHeaders -# from starlette.types import ASGIApp, Message, Receive, Scope, Send -# -# class ETagMiddleware(BaseHTTPMiddleware): -# async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: -# print("Before call next", request.method, request.url) -# response = await call_next(request) -# print("after call next", request.method, request.url) -# if request.method == "GET": -# print("GET response", request.method, request.url) -# # Client can ask if the cached data is stale or not -# # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/If-None-Match -# # This saves network bandwidth but the database is still queried -# if ( -# response.headers.get("ETag") is not None -# and request.headers.get("If-None-Match") == response.headers["ETag"] -# ): -# print("Not modified response", request.method, request.url) -# return Response(status_code=status.HTTP_304_NOT_MODIFIED) -# print("Normal response", request.method, request.url) -# return response -# -# -# class ETagMiddleware: -# def __init__(self, app: ASGIApp) -> None: -# self.app = app -# -# async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: -# if scope["type"] == "http": -# headers = Headers(scope=scope) -# etag_hash = headers.get("If-None-Match", "") -# if len(etag_hash) > 0: -# responder = ETagResponser(self.app, etag_hash) -# await responder(scope, receive, send) -# return -# return await self.app(scope, receive, send) -# -# -# class ETagResponser: -# def __init__(self, app: ASGIApp, etag: str = ""): -# self.app = app -# self.initial_message: Message = {} -# self.etag = etag -# self.send: Send = unattached_send -# self.response_fresh = False -# -# async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: -# self.send = send -# await self.app(scope, receive, self.send_response) -# -# async def send_response(self, message: Message) -> None: -# message_type = message["type"] -# if message_type == "http.response.start": -# self.initial_message = message -# headers = Headers(raw=self.initial_message["headers"]) -# self.response_fresh = headers.get("ETag", "") == self.etag -# elif message_type == "http.response.body" and self.response_fresh: -# headers = MutableHeaders(raw=self.initial_message["headers"]) -# del headers["Content-Length"] -# message["status"] = "304" -# message["body"] = b"" -# await self.send(message) -# await self.send(message) -# -# -# async def unattached_send(message: Message) -> None: -# raise RuntimeError("send awaitable not set") +class ETagMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + response = await call_next(request) + if request.method == "GET": + # Client can ask if the cached data is stale or not + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/If-None-Match + # This saves network bandwidth but the database is still queried + if ( + response.headers.get("ETag") is not None + and request.headers.get("If-None-Match") == response.headers["ETag"] + ): + return Response(status_code=status.HTTP_304_NOT_MODIFIED) + return response diff --git a/app/crud/crud_workflow_execution.py b/app/crud/crud_workflow_execution.py index b837b196e23913336ff875c5c8526df3f46bee32..9b2da271b27ea7269bb34c28651a45c6c0f2421b 100644 --- a/app/crud/crud_workflow_execution.py +++ b/app/crud/crud_workflow_execution.py @@ -73,15 +73,21 @@ class CRUDWorkflowExecution: update(WorkflowExecution) .where(WorkflowExecution._execution_id == workflow_execution.execution_id.bytes) .values( - logs_path=None - if execution.logs_s3_path is None - else execution.logs_s3_path + f"/run-{workflow_execution.execution_id.hex}", - debug_path=None - if execution.debug_s3_path is None - else execution.debug_s3_path + f"/run-{workflow_execution.execution_id.hex}", - provenance_path=None - if execution.provenance_s3_path is None - else execution.provenance_s3_path + f"/run-{workflow_execution.execution_id.hex}", + logs_path=( + None + if execution.logs_s3_path is None + else execution.logs_s3_path + f"/run-{workflow_execution.execution_id.hex}" + ), + debug_path=( + None + if execution.debug_s3_path is None + else execution.debug_s3_path + f"/run-{workflow_execution.execution_id.hex}" + ), + provenance_path=( + None + if execution.provenance_s3_path is None + else execution.provenance_s3_path + f"/run-{workflow_execution.execution_id.hex}" + ), ) ) await db.commit() diff --git a/app/git_repository/abstract_repository.py b/app/git_repository/abstract_repository.py index 2ac6a395eede851c3390e59e785400f9c6f332aa..e2768a0b483bf0c54c45f60d3e5301eaadf91c6e 100644 --- a/app/git_repository/abstract_repository.py +++ b/app/git_repository/abstract_repository.py @@ -25,8 +25,7 @@ class GitRepository(ABC): @property @abstractmethod - def provider(self) -> str: - ... + def provider(self) -> str: ... @property def token(self) -> Optional[str]: @@ -88,17 +87,14 @@ class GitRepository(ABC): @cached_property @abstractmethod - def request_auth(self) -> Optional[Auth]: - ... + def request_auth(self) -> Optional[Auth]: ... @cached_property @abstractmethod - def request_headers(self) -> Dict[str, str]: - ... + def request_headers(self) -> Dict[str, str]: ... @abstractmethod - def __repr__(self) -> str: - ... + def __repr__(self) -> str: ... def __str__(self) -> str: return repr(self) @@ -144,8 +140,9 @@ class GitRepository(ABC): exist : List[bool] Flags if the files exist. """ - with tracer.start_as_current_span("git_check_files_exists") as span: - span.set_attributes({"repository": self.url, "files": files}) + with tracer.start_as_current_span( + "git_check_files_exists", attributes={"repository": self.url, "files": files} + ) as span: tasks = [asyncio.ensure_future(self.check_file_exists(file, client=client)) for file in files] result = await asyncio.gather(*tasks) if raise_error: @@ -171,8 +168,9 @@ class GitRepository(ABC): client : httpx.AsyncClient Async HTTP Client with an open connection. """ - with tracer.start_as_current_span("git_copy_file_to_bucket") as span: - span.set_attributes({"repository": self.url, "file": filepath}) + with tracer.start_as_current_span( + "git_copy_file_to_bucket", attributes={"repository": self.url, "file": filepath} + ): with SpooledTemporaryFile(max_size=512000) as f: # temporary file with 500kB data spooled in memory await self.download_file(filepath, client=client, file_handle=f) f.seek(0) @@ -195,8 +193,9 @@ class GitRepository(ABC): byte_iterator : AsyncIterator[bytes] Async iterator over the bytes of the file """ - with tracer.start_as_current_span("git_stream_file_content") as span: - span.set_attributes({"repository": self.url, "file": filepath}) + with tracer.start_as_current_span( + "git_stream_file_content", attributes={"repository": self.url, "file": filepath} + ): async with client.stream( method="GET", url=str(await self.download_file_url(filepath, client)), @@ -219,7 +218,6 @@ class GitRepository(ABC): file_handle : IOBase Write the file into this stream in binary mode. """ - with tracer.start_as_current_span("git_download_file") as span: - span.set_attributes({"repository": self.url, "file": filepath}) + with tracer.start_as_current_span("git_download_file", attributes={"repository": self.url, "file": filepath}): async for chunk in self.download_file_stream(filepath, client): file_handle.write(chunk) diff --git a/app/git_repository/github.py b/app/git_repository/github.py index 29497e33a89e0f8b9fd99591c7c487908cc7dc2d..31c385038ae0fd3d7c58c1075848502c6297486c 100644 --- a/app/git_repository/github.py +++ b/app/git_repository/github.py @@ -55,8 +55,9 @@ class GitHubRepository(GitRepository): path="/".join([self.account, self.repository, self.commit, filepath]), ) # If the repo is private, request a download URL with a token from the GitHub API - with tracer.start_as_current_span("github_get_download_link") as span: - span.set_attributes({"repository": self.url, "file": filepath}) + with tracer.start_as_current_span( + "github_get_download_link", attributes={"repository": self.url, "file": filepath} + ): response = await client.get( str(self.check_file_url(filepath)), auth=USE_CLIENT_DEFAULT if self.request_auth is None else self.request_auth, diff --git a/app/main.py b/app/main.py index cbbccba116633e17fc456dab7b7efdd396caa40e..7b14b6b5becb60558d75840414b32d291057ffa1 100644 --- a/app/main.py +++ b/app/main.py @@ -18,7 +18,7 @@ from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.trace import Status, StatusCode from app.api.api import api_router -from app.api.middleware.etagmiddleware import HashJSONResponse +from app.api.middleware.etagmiddleware import ETagMiddleware, HashJSONResponse from app.api.miscellaneous_endpoints import miscellaneous_router from app.core.config import settings @@ -49,7 +49,7 @@ app = FastAPI( "email": "dgoebel@techfak.uni-bielefeld.de", }, generate_unique_id_function=custom_generate_unique_id, - # license_info={"name": "MIT", "url": "https://mit-license.org/"}, + license_info={"name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0"}, root_path=settings.API_PREFIX, openapi_url=None, # create it manually to enable caching on client side default_response_class=HashJSONResponse, # Add ETag header based on MD5 hash of content @@ -87,7 +87,7 @@ FastAPIInstrumentor.instrument_app( # Enable caching based on ETag -# app.add_middleware(ETagMiddleware) +app.add_middleware(ETagMiddleware) # Enable br compression for large responses, fallback gzip app.add_middleware(BrotliMiddleware) diff --git a/app/schemas/workflow_version.py b/app/schemas/workflow_version.py index f3e1cf102d2f097be30431632603d453e37c232a..b27435e08a3f20e627447335fdd17871c50dc912 100644 --- a/app/schemas/workflow_version.py +++ b/app/schemas/workflow_version.py @@ -91,9 +91,11 @@ class WorkflowVersion(WorkflowVersionStatus): icon_url=icon_url, created_at=db_version.created_at, status=db_version.status, - modes=[mode.mode_id for mode in db_version.workflow_modes] - if load_modes and len(db_version.workflow_modes) > 0 - else mode_ids, + modes=( + [mode.mode_id for mode in db_version.workflow_modes] + if load_modes and len(db_version.workflow_modes) > 0 + else mode_ids + ), ) diff --git a/app/tests/mocks/__init__.py b/app/tests/mocks/__init__.py index 6cf3e24940bdb39ae263ef2854b68ea6a207b94f..000f2df1a911ef43a2f45fe4fa5927f021497b2a 100644 --- a/app/tests/mocks/__init__.py +++ b/app/tests/mocks/__init__.py @@ -9,8 +9,7 @@ class MockHTTPService(ABC): self.send_error = False @abstractmethod - def handle_request(self, request: Request) -> Response: - ... + def handle_request(self, request: Request) -> Response: ... def reset(self) -> None: self.send_error = False diff --git a/requirements-dev.txt b/requirements-dev.txt index de814b3dad75665f8c74a7815e137a0d70d6f6ce..df0cce16cee4d17e8a0a16e608f9fdd5f41b808a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,7 +5,7 @@ pytest-cov>=4.1.0,<4.2.0 coverage[toml]>=7.4.0,<7.5.0 # Linters ruff>=0.1.0,<0.2.0 -black>=23.12.0,<24.1.0 +black>=24.1.0,<24.2.0 isort>=5.13.0,<5.14.0 mypy>=1.8.0,<1.9.0 # stubs for mypy