Skip to content
Snippets Groups Projects
main.py 4.71 KiB
from contextlib import asynccontextmanager
from typing import AsyncGenerator

from brotli_asgi import BrotliMiddleware
from fastapi import FastAPI, Request, Response
from fastapi.exception_handlers import http_exception_handler, request_validation_exception_handler
from fastapi.exceptions import RequestValidationError, StarletteHTTPException
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.responses import HTMLResponse
from fastapi.routing import APIRoute
from httpx import AsyncClient
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import Status, StatusCode

from app.api.api import api_router
from app.api.middleware.etagmiddleware import HashJSONResponse
from app.api.miscellaneous_endpoints import miscellaneous_router
from app.core.config import settings
from app.s3.s3_resource import boto_session

description = """
This is the resource service from the CloWM Service.
"""


def custom_generate_unique_id(route: APIRoute) -> str:
    return f"{route.tags[-1]}-{route.name}"


@asynccontextmanager
async def lifespan(fastapi_app: FastAPI) -> AsyncGenerator[None, None]:  # pragma: no cover
    # Create a http client once instead for every request and attach it to the app
    async with AsyncClient() as client, boto_session.resource(
        service_name="s3",
        endpoint_url=str(settings.OBJECT_GATEWAY_URI)[:-1],
        verify=str(settings.OBJECT_GATEWAY_URI).startswith("https"),
    ) as s3_resource:
        fastapi_app.requests_client = client  # type: ignore[attr-defined]
        fastapi_app.s3_resource = s3_resource  # type: ignore[attr-defined]
        yield


app = FastAPI(
    title="CloWM Resource Service",
    version="1.0.0",
    description=description,
    contact={
        "name": "Daniel Goebel",
        "url": "https://ekvv.uni-bielefeld.de/pers_publ/publ/PersonDetail.jsp?personId=223066601",
        "email": "dgoebel@techfak.uni-bielefeld.de",
    },
    generate_unique_id_function=custom_generate_unique_id,
    # license_info={"name": "MIT", "url": "https://mit-license.org/"},
    root_path=settings.API_PREFIX,
    openapi_url=None,  # create it manually to enable caching on client side
    default_response_class=HashJSONResponse,  # Add ETag header based on MD5 hash of content
    lifespan=lifespan,
)
if settings.API_PREFIX:  # pragma: no cover
    app.servers.insert(0, {"url": app.root_path})

if settings.OTLP_GRPC_ENDPOINT is not None and len(settings.OTLP_GRPC_ENDPOINT) > 0:  # pragma: no cover
    resource = Resource(attributes={SERVICE_NAME: "clowm-resource-service"})
    provider = TracerProvider(resource=resource)
    provider.add_span_processor(
        BatchSpanProcessor(OTLPSpanExporter(endpoint=settings.OTLP_GRPC_ENDPOINT, insecure=True))
    )
    trace.set_tracer_provider(provider)

    @app.exception_handler(StarletteHTTPException)
    async def trace_http_exception_handler(request: Request, exc: StarletteHTTPException) -> Response:
        current_span = trace.get_current_span()
        current_span.set_status(Status(StatusCode.ERROR))
        current_span.record_exception(exc)
        return await http_exception_handler(request, exc)

    @app.exception_handler(RequestValidationError)
    async def trace_validation_exception_handler(request: Request, exc: RequestValidationError) -> Response:
        current_span = trace.get_current_span()
        current_span.set_status(Status(StatusCode.ERROR))
        current_span.record_exception(exc)
        return await request_validation_exception_handler(request, exc)


FastAPIInstrumentor.instrument_app(
    app, excluded_urls="health,docs,openapi.json", tracer_provider=trace.get_tracer_provider()
)


# Enable caching based on ETag
# app.add_middleware(ETagMiddleware)
# Enable br compression for large responses, fallback gzip
app.add_middleware(BrotliMiddleware)

# Include all routes
app.include_router(api_router)
app.include_router(miscellaneous_router)


# manually add Swagger UI route
async def swagger_ui_html(req: Request) -> HTMLResponse:
    return get_swagger_ui_html(
        openapi_url=app.root_path + "/openapi.json",
        title=app.title + " - Swagger UI",
        swagger_favicon_url="/favicon.ico",
    )


# Create Custom route for OpenAPI schema to enable caching on the clients side
async def openapi(req: Request) -> Response:
    return HashJSONResponse(app.openapi())


app.add_route("/docs", swagger_ui_html, include_in_schema=False)
app.add_route("/openapi.json", openapi, include_in_schema=False)