Skip to content
Snippets Groups Projects
abstract_repository.py 7.12 KiB
import asyncio
from abc import ABC, abstractmethod
from functools import cached_property
from io import IOBase
from tempfile import SpooledTemporaryFile
from typing import TYPE_CHECKING, AsyncIterator, Dict, List, Optional

from fastapi import HTTPException, status
from httpx import USE_CLIENT_DEFAULT, AsyncClient, Auth
from opentelemetry import trace
from pydantic import AnyHttpUrl

tracer = trace.get_tracer_provider().get_tracer(__name__)

if TYPE_CHECKING:
    from mypy_boto3_s3.service_resource import Object
else:
    Object = object


class GitRepository(ABC):
    """
    Abstract class for Git Repositories
    """

    @property
    @abstractmethod
    def provider(self) -> str:
        ...

    @property
    def token(self) -> Optional[str]:
        return self._token

    def __init__(self, url: str, git_commit_hash: str, token: Optional[str] = None):
        """
        Initialize Git repository object.

        Parameters
        ----------
        url : str
            URL of the git repository
        git_commit_hash : str
            Pin down git commit git_commit_hash
        token : str | None
            Token to access a private git repository
        """
        self.url = url
        self.name = (url[:-1] if url.endswith("/") else url).split("/")[-1]
        self.commit = git_commit_hash
        self._token = token

    @abstractmethod
    async def download_file_url(self, filepath: str, client: AsyncClient) -> AnyHttpUrl:
        """
        Construct an URL where to download a file from

        Parameters
        ----------
        filepath : str
            Path of a file
        client: httpx.AsyncClient
            HTTP client for requesting a download link, like GitHub API

        Returns
        -------
        url : str
            URL where to download the specified file from.
        """
        ...

    @abstractmethod
    def check_file_url(self, filepath: str) -> AnyHttpUrl:
        """
        Construct an URL where to access meta data of the file

        Parameters
        ----------
        filepath : str
            Path of a file

        Returns
        -------
        url : str
            URL where to download the specified file from.
        """
        ...

    @cached_property
    @abstractmethod
    def request_auth(self) -> Optional[Auth]:
        ...

    @cached_property
    @abstractmethod
    def request_headers(self) -> Dict[str, str]:
        ...

    @abstractmethod
    def __repr__(self) -> str:
        ...

    def __str__(self) -> str:
        return repr(self)

    async def check_file_exists(self, filepath: str, client: AsyncClient) -> bool:
        """
        Check if a file exists in the Git Repository

        Parameters
        ----------
        filepath : str
            Path to the file
        client : httpx.AsyncClient
            Async HTTP Client with an open connection

        Returns
        -------
        exist : bool
            Flag if the file exists.
        """
        response = await client.head(
            str(self.check_file_url(filepath)),
            auth=USE_CLIENT_DEFAULT if self.request_auth is None else self.request_auth,
            follow_redirects=True,
            headers=self.request_headers,
        )
        return response.status_code == status.HTTP_200_OK

    async def check_files_exist(self, files: List[str], client: AsyncClient, raise_error: bool = True) -> List[bool]:
        """
        Check if multiple files exists in the Git Repository

        Parameters
        ----------
        files : List[str]
            Paths to the file to check
        client : httpx.AsyncClient
            Async HTTP Client with an open connection
        raise_error : bool, default True
            Raise an HTTPException if any of the files doesn't exist.
        Returns
        -------
        exist : List[bool]
            Flags if the files exist.
        """
        with tracer.start_as_current_span("git_check_files_exists") as span:
            span.set_attribute("repository", self.url)
            tasks = [asyncio.ensure_future(self.check_file_exists(file, client=client)) for file in files]
            result = await asyncio.gather(*tasks)
            if raise_error:
                missing_files = [f for f, exist in zip(files, result) if not exist]
                if len(missing_files) > 0:
                    raise HTTPException(
                        status_code=status.HTTP_400_BAD_REQUEST,
                        detail=f"The files {', '.join(missing_files)} are missing in the repo {str(self)}",
                    )
            return result

    async def copy_file_to_bucket(self, filepath: str, obj: Object, client: AsyncClient) -> None:
        """
        Copy a file from a git repository to a bucket

        Parameters
        ----------
        filepath : str
            Path of the file to copy.
        obj : mypy_boto3_s3.service_resource import Object
            S3 object to upload file to.
        client : httpx.AsyncClient
            Async HTTP Client with an open connection.
        """
        with tracer.start_as_current_span("git_copy_file_to_bucket") as span:
            span.set_attributes({"repository": self.url, "file": filepath})
            with SpooledTemporaryFile(max_size=512000) as f:  # temporary file with 500kB data spooled in memory
                await self.download_file(filepath, client=client, file_handle=f)
                f.seek(0)
                obj.upload_fileobj(f)

    async def download_file_stream(self, filepath: str, client: AsyncClient) -> AsyncIterator[bytes]:
        """
        Iterate over the stream of bytes of the downloaded file

        Parameters
        ----------
        filepath : str
            Path of the file to copy.
        client : httpx.AsyncClient
            Async HTTP Client with an open connection.

        Returns
        -------
        byte_iterator : AsyncIterator[bytes]
            Async iterator over the bytes of the file
        """
        with tracer.start_as_current_span("git_stream_file_content") as span:
            span.set_attributes({"repository": self.url, "file": filepath})
            async with client.stream(
                method="GET",
                url=str(await self.download_file_url(filepath, client)),
                auth=USE_CLIENT_DEFAULT if self.request_auth is None else self.request_auth,
                follow_redirects=True,
            ) as r:
                async for chunk in r.aiter_bytes():
                    yield chunk

    async def download_file(self, filepath: str, client: AsyncClient, file_handle: IOBase) -> None:
        """
        Download a file from the git repository into a file-like object.

        Parameters
        ----------
        filepath : str
            Path of the file to copy.
        client : httpx.AsyncClient
            Async HTTP Client with an open connection.
        file_handle : IOBase
            Write the file into this stream in binary mode.
        """
        with tracer.start_as_current_span("git_download_file") as span:
            span.set_attributes({"repository": self.url, "file": filepath})
            async for chunk in self.download_file_stream(filepath, client):
                file_handle.write(chunk)