From ee37b0baa359f316ad83e94038c0640ab0df6657 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20G=C3=B6bel?= <dgoebel@techfak.uni-bielefeld.de> Date: Fri, 24 May 2024 11:53:57 +0200 Subject: [PATCH] Resolve "Invite a user to clowm via email" --- .pre-commit-config.yaml | 2 +- clowm/api/endpoints/login.py | 140 +++++--- clowm/api/endpoints/s3key.py | 2 +- clowm/api/endpoints/users.py | 50 ++- clowm/api/utils.py | 36 +- clowm/crud/crud_user.py | 98 ++++-- clowm/db/types.py | 12 + clowm/models/user.py | 6 +- clowm/smtp/send_email.py | 17 +- clowm/smtp/smtp.py | 2 +- .../smtp/templates/html/invitation.html.tmpl | 6 + .../templates/html/registration.html.tmpl | 5 - .../smtp/templates/plain/invitation.txt.tmpl | 9 + .../templates/plain/registration.txt.tmpl | 5 - clowm/tests/api/test_login.py | 323 +++++++++++++----- clowm/tests/api/test_users.py | 84 ++++- clowm/tests/crud/test_resource_version.py | 240 +++++++++++-- clowm/tests/crud/test_user.py | 106 ++++-- ...eb4bf69e3_add_invitation_token_for_user.py | 34 ++ 19 files changed, 929 insertions(+), 248 deletions(-) create mode 100644 clowm/smtp/templates/html/invitation.html.tmpl delete mode 100644 clowm/smtp/templates/html/registration.html.tmpl create mode 100644 clowm/smtp/templates/plain/invitation.txt.tmpl delete mode 100644 clowm/smtp/templates/plain/registration.txt.tmpl create mode 100644 migrations/versions/c0ceb4bf69e3_add_invitation_token_for_user.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fedd082..042d912 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,7 @@ repos: - id: check-merge-conflict - id: check-ast - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.4.4' + rev: 'v0.4.5' hooks: - id: ruff args: [ "--fix" ] diff --git a/clowm/api/endpoints/login.py b/clowm/api/endpoints/login.py index 78a52d3..3531b36 100644 --- a/clowm/api/endpoints/login.py +++ b/clowm/api/endpoints/login.py @@ -1,6 +1,8 @@ +import time import urllib import urllib.parse from typing import Annotated +from uuid import UUID from fastapi import APIRouter, BackgroundTasks, Query, Request, status from fastapi.responses import RedirectResponse @@ -14,33 +16,46 @@ from clowm.core.oidc import LoginException, OIDCClient from clowm.crud import CRUDUser from clowm.models import User from clowm.otlp import start_as_current_span_async -from clowm.smtp.send_email import send_first_login_email -from ..background.initialize_users import initialize_user from ..dependencies import DBSession, OIDCClientDep, RGWService +from ..utils import create_rgw_user router = APIRouter(prefix="/auth", tags=["Auth"]) tracer = trace.get_tracer_provider().get_tracer(__name__) NEXT_PATH_KEY = "NEXT" +INVITATION_UID_KEY = "INVITATION_UID" -def build_url(base_url: str, *path: str) -> AnyHttpUrl: - # Returns a list in the structure of urlparse.ParseResult - url_parts = list(urllib.parse.urlparse(base_url)) - url_parts[2] = "/".join(path) - return AnyHttpUrl(urllib.parse.urlunparse(url_parts)) +def oidc_redirect_uri(provider: OIDCClient.OIDCProvider) -> AnyHttpUrl: + return AnyHttpUrl.build( + scheme=settings.ui_uri.scheme, + host=settings.ui_uri.host, # type: ignore[arg-type] + path="/".join([settings.ui_uri.path, settings.api_prefix, router.prefix[1:], "callback", provider.name]).strip( # type: ignore[list-item] + "/" + ), + port=settings.ui_uri.port, + ) @router.get( "/login", status_code=status.HTTP_302_FOUND, - summary="Redirect to LifeScience OIDC Login", + summary="Kickstart the login flow", ) @start_as_current_span_async("api_login_route", tracer=tracer) async def login( request: Request, oidc_client: OIDCClientDep, + db: DBSession, + invitation_token: Annotated[ + str | SkipJsonSchema[None], + Query( + min_length=43, + max_length=43, + description="Unique token to validate an invitation", + ), + ] = None, provider: Annotated[ OIDCClient.OIDCProvider, Query(description="The OIDC provider to use for login") ] = OIDCClient.OIDCProvider.lifescience, @@ -49,7 +64,7 @@ async def login( Query( alias="next", max_length=128, - description="Will be appended to redirect response in the callback route as URL query parameter `next_path`", + description="Will be appended to redirect response in the callback route as URL query parameter `next`", ), ] = None, ) -> RedirectResponse: @@ -64,9 +79,13 @@ async def login( The wrapper around the oidc client. Dependency Injection. request : fastapi.requests.Request Raw request object. + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. Dependency Injection. next_ : str | None - Query parameter that gets stored in the session cookie. - Will be appended to RedirectResponse in the callback route as URL query parameter 'next_path' + Query parameter that gets stored in the session cookie. Query Parameter. + Will be appended to RedirectResponse in the callback route as URL query parameter 'next' + invitation_token : str | None + Token from the invitation email to connect the created account with an identity provider. Query Parameter. Returns ------- @@ -75,10 +94,25 @@ async def login( """ # Clear session to prevent an overflow request.session.clear() + current_span = trace.get_current_span() + current_span.set_attribute("provider", provider.name) if next_: + current_span.set_attribute("next", next_) request.session[NEXT_PATH_KEY] = next_ - redirect_uri = build_url(str(settings.ui_uri), settings.api_prefix, router.prefix[1:], "callback", provider.name) - return await oidc_client.authorize_redirect(request, redirect_uri=redirect_uri, provider=provider) + if invitation_token is not None: + user = await CRUDUser.get_by_invitation_token(invitation_token, db=db) + if user is None: + raise LoginException(error_source="invalid invitation link") + current_span.set_attribute("invitation_uid", str(user.uid)) + if ( + round(time.time()) - (0 if user.invitation_token_created_at is None else user.invitation_token_created_at) + > 84200 + ): + raise LoginException(error_source="expired invitation link") + request.session[INVITATION_UID_KEY] = str(user.uid) + return await oidc_client.authorize_redirect( + request, redirect_uri=oidc_redirect_uri(provider=provider), provider=provider + ) @router.get( @@ -141,54 +175,49 @@ async def login_callback( path : str Redirect path after successful login. """ - redirect_path = "/" + redirect_path = str(settings.ui_uri) current_span = trace.get_current_span() + current_span.set_attribute("provider", provider.name) next_path: str | None = request.session.get(NEXT_PATH_KEY, None) # get return path from session cookie if next_path is not None: current_span.set_attribute("next", next_path) redirect_path += f"?next={urllib.parse.quote_plus(next_path)}" try: user_info = await oidc_client.verify_fetch_userinfo(request=request, provider=provider) - lifescience_id = user_info.sub if isinstance(user_info.sub, str) else user_info.sub[0] - current_span.set_attribute("lifescience_id", lifescience_id) - lifescience_id = lifescience_id.split("@")[0] - - user = await CRUDUser.get_by_lifescience_id(lifescience_id=lifescience_id, db=db) + lifescience_id = (user_info.sub if isinstance(user_info.sub, str) else user_info.sub[0]).split("@")[0] + current_span.set_attributes({"lifescience_id": lifescience_id, "raw_lifescience_id": user_info.sub}) - # if we want to block foreign users and the user is None or has no role, reject this login attempt - if settings.block_foreign_users and (user is None or len(user.roles) == 0): - raise LoginException(error_source="Access denied") - # if user does not exist in system - if user is None: - # try to get user by smtp to get registered but not initialized user - user = await CRUDUser.get_by_email(db=db, email=user_info.email) + invitation_uid: str | None = request.session.get(INVITATION_UID_KEY, None) + # if the user was invited + if invitation_uid is not None: + current_span.set_attribute("invitation_uid", invitation_uid) + user = await CRUDUser.get(uid=UUID(invitation_uid), db=db) + if user is None: + raise LoginException(error_source="unknown user") + elif await CRUDUser.get_by_lifescience_id(lifescience_id=lifescience_id, db=db) is not None: + raise LoginException(error_source="lifescience account already connected to other account") + # update the invited user and initialize him + await CRUDUser.update_invited_user( + user.uid, lifescience_id=lifescience_id, display_name=user_info.name, email=user_info.email, db=db + ) + create_rgw_user(user=user, background_tasks=background_tasks, rgw=rgw) + else: + user = await CRUDUser.get_by_lifescience_id(lifescience_id=lifescience_id, db=db) + # if we want to block foreign users and the user is None or has no role, reject this login attempt + if settings.block_foreign_users: + if user is None: + raise LoginException(error_source="Access denied") + elif len(user.roles) == 0: + raise LoginException(error_source="Access denied") + # if user does not exist in system, create a new user if user is None: - # if user is not registered by admin, create a new user user = await CRUDUser.create( User(lifescience_id=lifescience_id, display_name=user_info.name, email=user_info.email), db=db, ) - else: # if smtp is connected to a user, connect the user to this lifescience account - # if user is already connected to a lifescience account, reject this login attempt - if user.lifescience_id is not None: - raise LoginException("Email already connected to other account") - await CRUDUser.update_registered_user( - uid=user.uid, db=db, lifescience_id=lifescience_id, display_name=user_info.name - ) - with tracer.start_as_current_span( - "rgw_create_user", attributes={"uid": str(user.uid), "display_name": user.display_name} - ): - rgw.create_user( - uid=str(user.uid), - max_buckets=-1, - display_name=user.display_name, - ) - - background_tasks.add_task(initialize_user, user=user) - background_tasks.add_task(send_first_login_email, user=user) - elif user.email != user_info.email: - await CRUDUser.update_email(user.uid, user_info.email, db=db) - + create_rgw_user(user=user, background_tasks=background_tasks, rgw=rgw) + elif user.email != user_info.email: + await CRUDUser.update_email(user.uid, user_info.email, db=db) jwt = create_access_token(str(user.uid)) response.set_cookie( key="bearer", @@ -229,5 +258,18 @@ async def login_callback( }, ) async def logout(response: RedirectResponse) -> str: - response.set_cookie(key="bearer", secure=True, max_age=0, domain=settings.ui_uri.host) + """ + Logout the user from the system by deleting the bearer cookie. + + Parameters + ---------- + response : fastapi.responses.RedirectResponse + Response which will delete the JWT cookie. + + Returns + ------- + path : str + Redirect path after successful logout. + """ + response.delete_cookie(key="bearer", secure=True, domain=settings.ui_uri.host) return str(settings.ui_uri) diff --git a/clowm/api/endpoints/s3key.py b/clowm/api/endpoints/s3key.py index b4106de..8907323 100644 --- a/clowm/api/endpoints/s3key.py +++ b/clowm/api/endpoints/s3key.py @@ -61,7 +61,7 @@ async def get_user_keys( """ trace.get_current_span().set_attribute("uid", str(user.uid)) if current_user.uid != user.uid: - raise HTTPException(status.HTTP_403_FORBIDDEN, detail="Action forbidden.") + raise HTTPException(status.HTTP_403_FORBIDDEN, detail="Action forbidden") authorization(RBACOperation.LIST) return get_s3_keys(rgw, user.uid) diff --git a/clowm/api/endpoints/users.py b/clowm/api/endpoints/users.py index 704133b..ec1fd17 100644 --- a/clowm/api/endpoints/users.py +++ b/clowm/api/endpoints/users.py @@ -1,6 +1,6 @@ from typing import Callable -from fastapi import APIRouter, BackgroundTasks, Depends, Query, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, status from opentelemetry import trace from pydantic import TypeAdapter from pydantic.json_schema import SkipJsonSchema @@ -11,7 +11,7 @@ from clowm.crud import CRUDUser from clowm.models import Role, User from clowm.otlp import start_as_current_span_async from clowm.schemas.user import UserIn, UserOut, UserOutExtended, UserRoles -from clowm.smtp.send_email import send_registration_email +from clowm.smtp.send_email import send_invitation_email from ..dependencies import AuthorizationDependency, CurrentUser, DBSession, PathUser @@ -23,11 +23,12 @@ tracer = trace.get_tracer_provider().get_tracer(__name__) @router.post("", response_model=UserOutExtended, status_code=status.HTTP_201_CREATED) +@start_as_current_span_async("api_invite_user", tracer=tracer) async def create_user( db: DBSession, user_in: UserIn, background_tasks: BackgroundTasks, authorization: Authorization ) -> UserOutExtended: """ - Create a new user in the system and notify him. The smtp MUST be the same as the one saved by the OIDC provider. + Create a new user in the system and notify him. Permission `user:create` required. \f @@ -47,15 +48,20 @@ async def create_user( user : clowm.schemas.user.UserOutExtended The newly created user. """ + current_span = trace.get_current_span() + current_span.set_attribute("user_in", user_in.model_dump_json()) authorization(RBACOperation.CREATE) user = await CRUDUser.create( User(display_name=user_in.display_name, email=user_in.email, lifescience_id=None), roles=user_in.roles, db=db ) - background_tasks.add_task(send_registration_email, user=user) + current_span.set_attribute("uid", str(user.uid)) + token = await CRUDUser.create_invitation_token(user.uid, db=db) + background_tasks.add_task(send_invitation_email, user=user, token=token) return UserOutExtended.from_db_user(user) @router.get("/search", response_model=list[UserOut]) +@start_as_current_span_async("api_search_users", tracer=tracer) async def search_users( db: DBSession, authorization: Authorization, @@ -113,6 +119,7 @@ async def get_logged_in_user(current_user: CurrentUser, authorization: Authoriza current_user : clowm.schemas.user.UserOutExtended User associated to used JWT. """ + trace.get_current_span().set_attribute("uid", str(current_user.uid)) authorization(RBACOperation.READ) return UserOutExtended.from_db_user(current_user) @@ -220,6 +227,41 @@ async def update_roles( user : clowm.schemas.user.UserOutExtended User with the updated roles. """ + trace.get_current_span().set_attributes({"uid": str(user.uid), "roles": [role.name for role in roles_body.roles]}) authorization(RBACOperation.UPDATE) await CRUDUser.update_roles(user, roles_body.roles, db=db) return UserOutExtended.from_db_user(user) + + +@router.patch("/{uid}/invitation", response_model=UserOutExtended) +@start_as_current_span_async("api_resend_invitation", tracer=tracer) +async def resend_invitation( + db: DBSession, user: PathUser, authorization: Authorization, background_tasks: BackgroundTasks +) -> UserOutExtended: + """ + Resend the invitation link for an user that has an open invitation. + + Permission `user:create` required. + \f + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. Dependency Injection. + user : clowm.models.User + The user associated with the UID in the path. Dependency Injection. + authorization : Callable[[clowm.core.rbac.RBACOperation], None] + Function to call determines if the current user is authorized for this request. Dependency Injection. + background_tasks : fastapi.BackgroundTasks + Entrypoint for new BackgroundTasks. Provided by FastAPI. + + Returns + ------- + user : clowm.schemas.user.UserOutExtended + User with the updated roles. + """ + authorization(RBACOperation.CREATE) + if user.invitation_token is None: + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=f"user {user.uid} has not open invitation") + token = await CRUDUser.create_invitation_token(user.uid, db=db) + background_tasks.add_task(send_invitation_email, user=user, token=token) + return UserOutExtended.from_db_user(user) diff --git a/clowm/api/utils.py b/clowm/api/utils.py index 9fcc13c..76c5ae2 100644 --- a/clowm/api/utils.py +++ b/clowm/api/utils.py @@ -9,16 +9,20 @@ from fastapi import BackgroundTasks, HTTPException, UploadFile, status from httpx import AsyncClient from opentelemetry import trace from PIL import Image, UnidentifiedImageError +from rgwadmin import RGWAdmin from sqlalchemy.ext.asyncio import AsyncSession -from clowm.api.background.resources import update_used_resources -from clowm.api.background.s3 import process_and_upload_icon from clowm.core.config import settings from clowm.crud import CRUDResourceVersion, CRUDWorkflowExecution from clowm.git_repository.abstract_repository import GitRepository -from clowm.models import ResourceVersion, WorkflowExecution, WorkflowMode +from clowm.models import ResourceVersion, User, WorkflowExecution, WorkflowMode from clowm.schemas.resource_version import ResourceVersionOut from clowm.schemas.workflow_mode import WorkflowModeIn +from clowm.smtp.send_email import send_first_login_email + +from .background.initialize_users import initialize_user +from .background.resources import update_used_resources +from .background.s3 import process_and_upload_icon if TYPE_CHECKING: from types_aiobotocore_s3.service_resource import ObjectSummary, S3ServiceResource @@ -98,6 +102,32 @@ async def check_repo( ) +def create_rgw_user(user: User, background_tasks: BackgroundTasks, rgw: RGWAdmin) -> None: + """ + Create the user in RGW and initializes him in the background. + + Parameters + ---------- + user : clowm.models.User + User that should be created. + background_tasks : fastapi.BackgroundTasks + Entrypoint for new BackgroundTasks + rgw : rgwadmin.RGWAdmin + RGW admin interface to manage Ceph's object store. + """ + with tracer.start_as_current_span( + "rgw_create_user", attributes={"uid": str(user.uid), "display_name": user.display_name} + ): + rgw.create_user( + uid=str(user.uid), + max_buckets=-1, + display_name=user.display_name, + ) + + background_tasks.add_task(initialize_user, user=user) + background_tasks.add_task(send_first_login_email, user=user) + + async def check_active_workflow_execution_limit(db: AsyncSession, uid: UUID) -> None: """ Check the number of active workflow executions of a usr and raise an HTTP exception if a new one would violate the diff --git a/clowm/crud/crud_user.py b/clowm/crud/crud_user.py index e935a19..ac18a0c 100644 --- a/clowm/crud/crud_user.py +++ b/clowm/crud/crud_user.py @@ -1,3 +1,5 @@ +import secrets +import time from typing import Sequence from uuid import UUID @@ -50,6 +52,57 @@ class CRUDUser: await db.refresh(user, attribute_names=["roles"]) return user + @staticmethod + async def create_invitation_token(uid: UUID, *, db: AsyncSession) -> str: + """ + Create an invitation token and save it for the user. + + Parameters + ---------- + uid : uuid.UUID + UID of a user. + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + + Returns + ------- + token : str + The saved token for the user + """ + token = secrets.token_urlsafe(32) + stmt = ( + update(User) + .where(User.uid == uid) + .values(invitation_token=token, invitation_token_created_at=round(time.time())) + ) + with tracer.start_as_current_span( + "db_create_invitation_token", attributes={"sql_query": str(stmt), "uid": str(uid)} + ): + await db.execute(stmt) + await db.commit() + return token + + @staticmethod + async def get_by_invitation_token(invitation_token: str, *, db: AsyncSession) -> User | None: + """ + Get a user by an invitation token. + + Parameters + ---------- + invitation_token : str + The token to search for. + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + + Returns + ------- + user : clowm.models.User | None + The user belonging to the invitation token. + """ + stmt = select(User).where(User.invitation_token == invitation_token) + with tracer.start_as_current_span("db_get_by_invitation_token", attributes={"sql_query": str(stmt)}): + return await db.scalar(stmt) + @staticmethod async def update_roles(user: User, roles: list[Role.RoleEnum], *, db: AsyncSession) -> None: """ @@ -102,9 +155,11 @@ class CRUDUser: await db.commit() @staticmethod - async def update_registered_user(uid: UUID, lifescience_id: str, display_name: str, *, db: AsyncSession) -> None: + async def update_invited_user( + uid: UUID, lifescience_id: str, display_name: str, email: str | None, *, db: AsyncSession + ) -> None: """ - Update the information of a registered user. + Update the information of an invited user. Parameters ---------- @@ -114,10 +169,22 @@ class CRUDUser: The display name of the user. lifescience_id : str The lifescience id of a user. + email : str | None + The email of the user. db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. """ - stmt = update(User).where(User.uid == uid).values(display_name=display_name, lifescience_id=lifescience_id) + stmt = ( + update(User) + .where(User.uid == uid) + .values( + display_name=display_name, + lifescience_id=lifescience_id, + email=email, + invitation_token=None, + invitation_token_created_at=None, + ) + ) with tracer.start_as_current_span( "db_update_registered_user", attributes={ @@ -193,31 +260,6 @@ class CRUDUser: ): return await db.scalar(stmt) - @staticmethod - async def get_by_email(email: str | None, db: AsyncSession) -> User | None: - """ - Get a user by his smtp. - - Parameters - ---------- - email : str | None, default None - The smtp of a user. - db : sqlalchemy.ext.asyncio.AsyncSession. - Async database session to perform query on. - - Returns - ------- - user : clowm.models.User | None - The user for the given smtp if he exists, None otherwise - """ - stmt = select(User).where(User.email == email) - with tracer.start_as_current_span( - "db_get_user_email", attributes={"sql_query": str(stmt), "smtp": "None" if email is None else email} - ): - if email is None: - return None - return await db.scalar(stmt) - @staticmethod async def list_users( name_substring: str | None = None, roles: list[Role.RoleEnum] | None = None, *, db: AsyncSession diff --git a/clowm/db/types.py b/clowm/db/types.py index e30f9d3..e429ffd 100644 --- a/clowm/db/types.py +++ b/clowm/db/types.py @@ -1,3 +1,4 @@ +import base64 from uuid import UUID import sqlalchemy.types as types @@ -19,3 +20,14 @@ class SQLUUID(types.TypeDecorator): def process_result_value(self, value: bytes | None, dialect: Dialect) -> UUID | None: # type: ignore[override] return None if value is None else UUID(bytes=value) + + +class Token(types.TypeDecorator): + impl = types.BINARY(32) + cache_ok = True + + def process_bind_param(self, value: str | None, dialect: Dialect) -> bytes | None: # type: ignore[override] + return base64.urlsafe_b64decode(value + "==") if value is not None else None + + def process_result_value(self, value: bytes | None, dialect: Dialect) -> str | None: # type: ignore[override] + return None if value is None else base64.urlsafe_b64encode(value).rstrip(b"=").decode("ascii") diff --git a/clowm/models/user.py b/clowm/models/user.py index 79125dd..9f7e4d3 100644 --- a/clowm/models/user.py +++ b/clowm/models/user.py @@ -1,11 +1,11 @@ from typing import TYPE_CHECKING, Any from uuid import UUID -from sqlalchemy import Boolean, String +from sqlalchemy import Boolean, Integer, String from sqlalchemy.orm import Mapped, mapped_column, relationship from clowm.db.base_class import Base, uuid7 -from clowm.db.types import SQLUUID +from clowm.db.types import SQLUUID, Token from .role import Role, RoleIdMapping, UserRoleMapping @@ -51,6 +51,8 @@ class User(Base): roles: Mapped[list["UserRoleMapping"]] = relationship( back_populates="user", cascade="all, delete", passive_deletes=True ) + invitation_token: Mapped[str | None] = mapped_column(Token, nullable=True) + invitation_token_created_at: Mapped[int | None] = mapped_column(Integer, nullable=True) def has_role(self, role: Role.RoleEnum) -> bool: mapping = RoleIdMapping() diff --git a/clowm/smtp/send_email.py b/clowm/smtp/send_email.py index 7622c43..c1a3b04 100644 --- a/clowm/smtp/send_email.py +++ b/clowm/smtp/send_email.py @@ -1,4 +1,7 @@ +from pydantic import AnyHttpUrl + from clowm.api.background.dependencies import get_background_db +from clowm.core.config import settings from clowm.crud import CRUDUser from clowm.models import Resource, ResourceVersion, Role, User, Workflow, WorkflowVersion from clowm.schemas.resource import ResourceOut @@ -20,7 +23,7 @@ async def send_first_login_email(user: User) -> None: send_email(recipients=user, subject="CloWM first login", html_msg=html, plain_msg=plain) -async def send_registration_email(user: User) -> None: +async def send_invitation_email(user: User, token: str) -> None: """ Email a user when an admin register him. @@ -28,9 +31,17 @@ async def send_registration_email(user: User) -> None: ---------- user : clowm.models.User The user who was registered by an admin. + token : str """ - html, plain = EmailTemplates.REGISTRATION.render(user=user) - send_email(recipients=user, subject="CloWM registration", html_msg=html, plain_msg=plain) + invitation_link = AnyHttpUrl.build( + scheme=settings.ui_uri.scheme, + host=settings.ui_uri.host, # type: ignore[arg-type] + port=settings.ui_uri.port, + path=settings.ui_uri.path.strip("/") if settings.ui_uri.path is not None else None, + query=f"invitation_token={token}", + ) + html, plain = EmailTemplates.INVITATION.render(user=user, invitation_link=invitation_link) + send_email(recipients=user, subject="Invitation to CloWM", html_msg=html, plain_msg=plain) async def send_review_request_email(workflow: Workflow, version: WorkflowVersion) -> None: diff --git a/clowm/smtp/smtp.py b/clowm/smtp/smtp.py index 26426e4..e1b5e5a 100644 --- a/clowm/smtp/smtp.py +++ b/clowm/smtp/smtp.py @@ -101,7 +101,7 @@ def smtp_connection(smtp_settings: EmailSettings) -> Generator[smtplib.SMTP, Non @unique class EmailTemplates(StrEnum): - REGISTRATION = "registration" + INVITATION = "invitation" FIRST_LOGIN = "first_login" REVIEW_REQUEST = "review_request" REVIEW_RESPONSE = "review_response" diff --git a/clowm/smtp/templates/html/invitation.html.tmpl b/clowm/smtp/templates/html/invitation.html.tmpl new file mode 100644 index 0000000..00c9a45 --- /dev/null +++ b/clowm/smtp/templates/html/invitation.html.tmpl @@ -0,0 +1,6 @@ +<%inherit file="base.html.tmpl"/> + +<p>Hello ${user.display_name}</p> +<p>the administrator of CloWM created an account for you. Click on the link below and connect your account with one of the available identity providers.</p> +<p><a href=${invitation_link}>${invitation_link}</a></p> +<p>This link will expire in 24 hours.</p> diff --git a/clowm/smtp/templates/html/registration.html.tmpl b/clowm/smtp/templates/html/registration.html.tmpl deleted file mode 100644 index a3aa206..0000000 --- a/clowm/smtp/templates/html/registration.html.tmpl +++ /dev/null @@ -1,5 +0,0 @@ -<%inherit file="base.html.tmpl"/> - -<p>Hello {user.display_name}</p> -<p>the administrator of CloWM registered you under this email. Visit the <a href="${settings.ui_uri}">Website<a> and login to finish your registration -for your account in CloWM.</p> diff --git a/clowm/smtp/templates/plain/invitation.txt.tmpl b/clowm/smtp/templates/plain/invitation.txt.tmpl new file mode 100644 index 0000000..8454a3d --- /dev/null +++ b/clowm/smtp/templates/plain/invitation.txt.tmpl @@ -0,0 +1,9 @@ +<%inherit file="base.txt.tmpl"/> + +Hello ${user.display_name} +the administrator of CloWM created an account for you. Click on the link below and connect your account with one of \ +the available identity providers. + +${invitation_link} + +This link will expire in 24 hours. diff --git a/clowm/smtp/templates/plain/registration.txt.tmpl b/clowm/smtp/templates/plain/registration.txt.tmpl deleted file mode 100644 index e8bb90a..0000000 --- a/clowm/smtp/templates/plain/registration.txt.tmpl +++ /dev/null @@ -1,5 +0,0 @@ -<%inherit file="base.txt.tmpl"/> - -Hello {user.display_name} -the administrator of CloWM registered you under this email. Visit the Website (${settings.ui_uri}) and login to finish your registration -for your account in CloWM. diff --git a/clowm/tests/api/test_login.py b/clowm/tests/api/test_login.py index 4c444dc..ed9a145 100644 --- a/clowm/tests/api/test_login.py +++ b/clowm/tests/api/test_login.py @@ -1,14 +1,20 @@ +import json +import time import urllib.parse +from base64 import b64decode +from secrets import token_urlsafe from typing import TYPE_CHECKING, Any from uuid import UUID import pytest from fastapi import status from httpx import AsyncClient -from sqlalchemy import delete, select +from sqlalchemy import delete, select, update from sqlalchemy.ext.asyncio import AsyncSession +from clowm.api.endpoints.login import INVITATION_UID_KEY, NEXT_PATH_KEY from clowm.core.auth import decode_token +from clowm.core.config import settings from clowm.core.oidc import LoginException, OIDCClient, UserInfo from clowm.models import Bucket, User from clowm.tests.mocks import MockRGWAdmin @@ -20,13 +26,15 @@ else: S3ServiceResource = object -class TestLoginRoute: +class _TestLoginRoute: auth_path = "/auth" + +class TestLoginRouteRedirects(_TestLoginRoute): @pytest.mark.asyncio async def test_login_redirect(self, client: AsyncClient) -> None: """ - Test for the query parameter on the login redirect route. + Test for the query parameter and session cookie on the login redirect route. Parameters ---------- @@ -39,7 +47,110 @@ class TestLoginRoute: follow_redirects=False, ) assert r.status_code == status.HTTP_302_FOUND + assert "login_error=" not in r.headers["location"] + session_cookie = r.cookies.get("session", None) + assert session_cookie is not None + decoded_session_cookie = json.loads(b64decode(session_cookie)) + assert decoded_session_cookie[NEXT_PATH_KEY] == "/dashboard" + + @pytest.mark.asyncio + async def test_successful_invitation_redirect( + self, client: AsyncClient, db: AsyncSession, random_user: UserWithAuthHeader + ) -> None: + """ + Test for the query parameter and session cookie on the login redirect route with an invitation token. + + Parameters + ---------- + client : httpx.AsyncClient + HTTP Client to perform the request on. + random_user : clowm.tests.utils.UserWithAuthHeader + Random user for testing. + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + """ + token = token_urlsafe(32) + await db.execute( + update(User) + .where(User.uid == random_user.user.uid) + .values(invitation_token=token, invitation_token_created_at=round(time.time())) + ) + await db.commit() + r = await client.get( + self.auth_path + "/login", + params={"invitation_token": token, "provider": "lifescience"}, + follow_redirects=False, + ) + assert r.status_code == status.HTTP_302_FOUND + assert "login_error=" not in r.headers["location"] + session_cookie = r.cookies.get("session", None) + assert session_cookie is not None + decoded_session_cookie = json.loads(b64decode(session_cookie)) + assert decoded_session_cookie[INVITATION_UID_KEY] == str(random_user.user.uid) + + @pytest.mark.asyncio + async def test_invitation_redirect_with_wrong_token( + self, client: AsyncClient, db: AsyncSession, random_user: UserWithAuthHeader + ) -> None: + """ + Test login route with an unknown invitation token. + + Parameters + ---------- + client : httpx.AsyncClient + HTTP Client to perform the request on. + random_user : clowm.tests.utils.UserWithAuthHeader + Random user for testing. + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + """ + await db.execute( + update(User) + .where(User.uid == random_user.user.uid) + .values(invitation_token=token_urlsafe(32), invitation_token_created_at=round(time.time())) + ) + await db.commit() + r = await client.get( + self.auth_path + "/login", + params={"invitation_token": token_urlsafe(32), "provider": "lifescience"}, + follow_redirects=False, + ) + assert r.status_code == status.HTTP_302_FOUND + assert "login_error=" in r.headers["location"] + @pytest.mark.asyncio + async def test_invitation_redirect_with_expired_token( + self, client: AsyncClient, db: AsyncSession, random_user: UserWithAuthHeader + ) -> None: + """ + Test login route with an expired invitation token. + + Parameters + ---------- + client : httpx.AsyncClient + HTTP Client to perform the request on. + random_user : clowm.tests.utils.UserWithAuthHeader + Random user for testing. + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + """ + token = token_urlsafe(32) + await db.execute( + update(User) + .where(User.uid == random_user.user.uid) + .values(invitation_token=token, invitation_token_created_at=0) + ) + await db.commit() + r = await client.get( + self.auth_path + "/login", + params={"invitation_token": token, "provider": "lifescience"}, + follow_redirects=False, + ) + assert r.status_code == status.HTTP_302_FOUND + assert "login_error=" in r.headers.get("location", "") + + +class TestLoginRouteCallback(_TestLoginRoute): @pytest.mark.asyncio async def test_successful_login_with_existing_user( self, client: AsyncClient, random_user: UserWithAuthHeader, monkeypatch: pytest.MonkeyPatch @@ -68,17 +179,10 @@ class TestLoginRoute: follow_redirects=False, ) assert response.status_code == status.HTTP_302_FOUND - assert "set-cookie" in response.headers.keys() - cookie_header = response.headers.get("set-cookie") - right_header = None - for t in cookie_header.split(";"): - if t.startswith("bearer"): - right_header = t - break - assert right_header - claim = decode_token(right_header.split("=")[1]) + jwt = response.cookies.get("bearer", None) + assert jwt is not None + claim = decode_token(jwt) assert claim["sub"] == str(random_user.user.uid) - assert response.headers["location"].startswith(f"/?next={urllib.parse.quote_plus('/dashboard')}") @pytest.mark.asyncio async def test_successful_login_with_existing_user_and_different_email( @@ -112,17 +216,11 @@ class TestLoginRoute: follow_redirects=False, ) assert r.status_code == status.HTTP_302_FOUND - assert "set-cookie" in r.headers.keys() - cookie_header = r.headers["set-cookie"] - right_header = None - for t in cookie_header.split(";"): - if t.startswith("bearer"): - right_header = t - break - assert right_header - claim = decode_token(right_header.split("=")[1]) + jwt = r.cookies.get("bearer", None) + assert jwt is not None + claim = decode_token(jwt) assert claim["sub"] == str(random_user.user.uid) - assert r.headers["location"] == "/" + assert r.headers["location"] == str(settings.ui_uri) db_user = await db.scalar(select(User).where(User.uid == random_user.user.uid)) assert db_user @@ -148,26 +246,89 @@ class TestLoginRoute: follow_redirects=False, ) assert r.status_code == status.HTTP_302_FOUND - if "set-cookie" in r.headers.keys(): - assert find_cookie(searched_cookie_name="bearer", cookie_header=r.headers["set-cookie"]) is None + assert r.cookies.get("bearer", None) is None assert "login_error=" in r.headers["location"] @pytest.mark.asyncio - async def test_login_with_registered_user_and_occupied_email_address( + async def test_login_with_invited_unknown_user( self, client: AsyncClient, + mock_rgw_admin: MockRGWAdmin, db: AsyncSession, cleanup: CleanupList, - random_user: UserWithAuthHeader, monkeypatch: pytest.MonkeyPatch, ) -> None: """ - Test for login callback route with an registered user but its smtp is already connected to another user. + Test login callback route with an unknown invited user. Parameters ---------- client : httpx.AsyncClient HTTP Client to perform the request on. + mock_rgw_admin : clowm.tests.mocks.mock_rgw_admin.MockRGWAdmin + Mock RGW admin for Ceph. + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + cleanup : clowm.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. + """ + # set up user that is registered by admin but has not logged in yet + lifescience_id = random_lower_string() + user = User( + display_name=random_lower_string(), + email=f"{random_lower_string(10)}@example.com", + invitation_token=token_urlsafe(32), + invitation_token_created_at=round(time.time()), + ) + db.add(user) + await db.commit() + assert not user.initialized + + async def mock_userinfo(*args: Any, **kwargs: Any) -> UserInfo: + return UserInfo( + sub=f"{lifescience_id}@lifescience.org", + name=user.display_name, + email=user.email, + ) + + monkeypatch.setattr(OIDCClient, "verify_fetch_userinfo", mock_userinfo) + + pre_request = await client.get( + self.auth_path + "/login", + params={"invitation_token": user.invitation_token, "provider": "lifescience", "next": "/dashboard"}, + follow_redirects=False, + ) + assert pre_request.status_code == status.HTTP_302_FOUND + + await db.execute(delete(User).where(User.uid == user.uid)) + await db.commit() + + r = await client.get( + self.auth_path + "/callback/lifescience", follow_redirects=False, cookies=pre_request.cookies + ) + # Check response and valid/right jwt token + assert r.status_code == status.HTTP_302_FOUND + assert "login_error" in r.headers["location"] + + @pytest.mark.asyncio + async def test_login_with_invited_user_and_taken_lifescience_account( + self, + client: AsyncClient, + mock_rgw_admin: MockRGWAdmin, + db: AsyncSession, + cleanup: CleanupList, + monkeypatch: pytest.MonkeyPatch, + random_user: UserWithAuthHeader, + ) -> None: + """ + Test login callback route with an invited user that tires to use an already connected lifescience account. + + Parameters + ---------- + client : httpx.AsyncClient + HTTP Client to perform the request on. + mock_rgw_admin : clowm.tests.mocks.mock_rgw_admin.MockRGWAdmin + Mock RGW admin for Ceph. db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. cleanup : clowm.tests.utils.utils.CleanupList @@ -176,7 +337,12 @@ class TestLoginRoute: Random user for testing. """ # set up user that is registered by admin but has not logged in yet - user = User(display_name=random_lower_string(), email=random_user.user.email) + user = User( + display_name=random_lower_string(), + email=f"{random_lower_string(10)}@example.com", + invitation_token=token_urlsafe(32), + invitation_token_created_at=round(time.time()), + ) db.add(user) await db.commit() @@ -185,27 +351,33 @@ class TestLoginRoute: await db.commit() cleanup.add_task(delete_user) + assert not user.initialized async def mock_userinfo(*args: Any, **kwargs: Any) -> UserInfo: return UserInfo( - sub=f"{random_lower_string()}@lifescience.org", + sub=f"{random_user.user.lifescience_id}@lifescience.org", name=user.display_name, email=user.email, ) monkeypatch.setattr(OIDCClient, "verify_fetch_userinfo", mock_userinfo) - r = await client.get( - self.auth_path + "/callback/lifescience", + pre_request = await client.get( + self.auth_path + "/login", + params={"invitation_token": user.invitation_token, "provider": "lifescience", "next": "/dashboard"}, follow_redirects=False, ) + assert pre_request.status_code == status.HTTP_302_FOUND + + r = await client.get( + self.auth_path + "/callback/lifescience", follow_redirects=False, cookies=pre_request.cookies + ) + # Check response and valid/right jwt token assert r.status_code == status.HTTP_302_FOUND - if "set-cookie" in r.headers.keys(): - assert find_cookie(searched_cookie_name="bearer", cookie_header=r.headers["set-cookie"]) is None - assert "login_error=" in r.headers["location"] + assert "login_error" in r.headers["location"] @pytest.mark.asyncio - async def test_successful_login_with_registered_user( + async def test_successful_login_with_invited_user( self, client: AsyncClient, mock_rgw_admin: MockRGWAdmin, @@ -215,7 +387,7 @@ class TestLoginRoute: monkeypatch: pytest.MonkeyPatch, ) -> None: """ - Test for login callback route with a registered user. + Test successful login callback route with an invited user. Parameters ---------- @@ -232,7 +404,12 @@ class TestLoginRoute: """ # set up user that is registered by admin but has not logged in yet lifescience_id = random_lower_string() - user = User(display_name=random_lower_string(), email=f"{random_lower_string(10)}@example.com") + user = User( + display_name=random_lower_string(), + email=f"{random_lower_string(10)}@example.com", + invitation_token=token_urlsafe(32), + invitation_token_created_at=round(time.time()), + ) db.add(user) await db.commit() assert not user.initialized @@ -253,27 +430,38 @@ class TestLoginRoute: monkeypatch.setattr(OIDCClient, "verify_fetch_userinfo", mock_userinfo) - r = await client.get( - self.auth_path + "/callback/lifescience", + pre_request = await client.get( + self.auth_path + "/login", + params={"invitation_token": user.invitation_token, "provider": "lifescience", "next": "/dashboard"}, follow_redirects=False, ) + assert pre_request.status_code == status.HTTP_302_FOUND + + r = await client.get( + self.auth_path + "/callback/lifescience", follow_redirects=False, cookies=pre_request.cookies + ) # Check response and valid/right jwt token assert r.status_code == status.HTTP_302_FOUND assert "login_error" not in r.headers["location"] cleanup.add_task(mock_rgw_admin.delete_user, str(user.uid)) - assert "set-cookie" in r.headers.keys() - cookie = find_cookie(searched_cookie_name="bearer", cookie_header=r.headers["set-cookie"]) - assert cookie is not None - jwt = decode_token(cookie.split("=")[1]) - assert jwt.get("sub", None) is not None - assert UUID(jwt["sub"]) == user.uid + + assert f"next={urllib.parse.quote_plus('/dashboard')}" in r.headers["location"] + jwt = r.cookies.get("bearer", None) + assert jwt is not None + claim = decode_token(jwt) + assert claim.get("sub", None) is not None + assert UUID(claim["sub"]) == user.uid # Check that user is created in RGW - assert mock_rgw_admin.get_user(jwt["sub"])["keys"][0]["user"] is not None + assert mock_rgw_admin.get_user(claim["sub"])["keys"][0]["user"] is not None # Check that user is created in DB - await db.refresh(user, attribute_names=["initialized", "lifescience_id"]) + await db.refresh( + user, attribute_names=["initialized", "lifescience_id", "invitation_token", "invitation_token_created_at"] + ) assert user.lifescience_id == lifescience_id assert user.initialized + assert user.invitation_token is None + assert user.invitation_token_created_at is None # Check that upload and download bucket are created db_buckets = (await db.execute(select(Bucket).where(Bucket.owner_id == user.uid))).scalars().all() @@ -338,12 +526,10 @@ class TestLoginRoute: # Check response and valid/right jwt token assert r.status_code == status.HTTP_302_FOUND assert "login_error" not in r.headers["location"] - assert "set-cookie" in r.headers.keys() - cookie = find_cookie(searched_cookie_name="bearer", cookie_header=r.headers["set-cookie"]) - assert cookie is not None - jwt = decode_token(cookie.split("=")[1]) - assert jwt.get("sub", None) is not None - uid = UUID(jwt["sub"]) + jwt = r.cookies.get("bearer", None) + assert jwt is not None + claim = decode_token(jwt) + uid = UUID(claim["sub"]) async def cleanup_db() -> None: await db.execute(delete(Bucket).where(Bucket.owner_id == uid)) @@ -354,7 +540,7 @@ class TestLoginRoute: cleanup.add_task(mock_rgw_admin.delete_user, str(uid)) # Check that user is created in RGW - assert mock_rgw_admin.get_user(jwt["sub"])["keys"][0]["user"] is not None + assert mock_rgw_admin.get_user(str(uid))["keys"][0]["user"] is not None # Check that user is created in DB await db.reset() @@ -400,29 +586,4 @@ class TestLoginRoute: ) # Check response and valid/right jwt token assert r.status_code == status.HTTP_302_FOUND - assert "set-cookie" in r.headers.keys() - cookie = find_cookie(searched_cookie_name="bearer", cookie_header=r.headers["set-cookie"]) - assert cookie is not None - assert len(cookie.split("=")[1].strip('"')) == 0 - - -def find_cookie(searched_cookie_name: str, cookie_header: str) -> str | None: - """ - Find a specific cookie in the set-cookie header of a HTTP response - - Parameters - ---------- - searched_cookie_name : str - Name of the cookie to be searched - cookie_header : str - Cookie string from HTTP header - - Returns - ------- - cookie : str | None - Returns the cookie if it is present, None otherwise - """ - for cookie in cookie_header.split(";"): - if cookie.startswith(searched_cookie_name): - return cookie - return None + assert r.cookies.get("bearer", None) is None diff --git a/clowm/tests/api/test_users.py b/clowm/tests/api/test_users.py index 2c401b2..54164f2 100644 --- a/clowm/tests/api/test_users.py +++ b/clowm/tests/api/test_users.py @@ -1,11 +1,13 @@ import random +import time import uuid +from secrets import token_urlsafe import pytest from fastapi import status from httpx import AsyncClient from pydantic import TypeAdapter -from sqlalchemy import delete +from sqlalchemy import delete, select, update from sqlalchemy.ext.asyncio import AsyncSession from clowm.models import Role, User @@ -200,7 +202,7 @@ class TestUserRoutesUpdate(_TestUserRoutes): self, client: AsyncClient, random_user: UserWithAuthHeader, random_second_user: UserWithAuthHeader ) -> None: """ - Test for updating a users role. + Test for updating a user's role. Parameters ---------- @@ -226,6 +228,74 @@ class TestUserRoutesUpdate(_TestUserRoutes): assert user.roles[0] == Role.RoleEnum.REVIEWER.value or user.roles[1] == Role.RoleEnum.REVIEWER.value assert user.roles[0] == Role.RoleEnum.DEVELOPER.value or user.roles[1] == Role.RoleEnum.DEVELOPER.value + @pytest.mark.asyncio + async def test_resend_invitation_for_normal_user( + self, client: AsyncClient, random_user: UserWithAuthHeader, random_second_user: UserWithAuthHeader + ) -> None: + """ + Test for resending an invitation email with no open invitation. + + Parameters + ---------- + client : httpx.AsyncClient + HTTP Client to perform the request on. + random_user : clowm.test.utils.user.UserWithAuthHeader + Random user for testing. + random_second_user : clowm.test.utils.user.UserWithAuthHeader + Random user that gets is role updated for testing. + """ + + response = await client.patch( + f"{self.base_path}/{str(random_second_user.user.uid)}/invitation", + headers=random_user.auth_headers, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + @pytest.mark.asyncio + async def test_successful_resend_invitation( + self, + client: AsyncClient, + db: AsyncSession, + random_user: UserWithAuthHeader, + random_second_user: UserWithAuthHeader, + ) -> None: + """ + Test for resending an invitation email with an open invitation. + + Parameters + ---------- + client : httpx.AsyncClient + HTTP Client to perform the request on. + random_user : clowm.test.utils.user.UserWithAuthHeader + Random user for testing. + random_second_user : clowm.test.utils.user.UserWithAuthHeader + Random user that gets is role updated for testing. + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + """ + token = token_urlsafe(32) + token_timestamp = round(time.time()) - 100 + await db.execute( + update(User) + .where(User.uid == random_second_user.user.uid) + .values(invitation_token=token, invitation_token_created_at=token_timestamp) + ) + await db.commit() + + response = await client.patch( + f"{self.base_path}/{str(random_second_user.user.uid)}/invitation", + headers=random_user.auth_headers, + ) + assert response.status_code == status.HTTP_200_OK + response_user = UserOutExtended.model_validate_json(response.content) + + user: User | None = await db.scalar(select(User).where(User.uid == response_user.uid)) + assert user is not None + assert user.invitation_token is not None + assert user.invitation_token != token + assert user.invitation_token_created_at is not None + assert user.invitation_token_created_at != token_timestamp + class TestUserRoutesCreate(_TestUserRoutes): @pytest.mark.asyncio @@ -270,6 +340,11 @@ class TestUserRoutesCreate(_TestUserRoutes): assert user.roles[0] == user_in.roles[0] assert user.display_name == user_in.display_name + db_user = await db.scalar(select(User).where(User.uid == user.uid)) + assert db_user is not None + assert db_user.invitation_token is not None + assert db_user.invitation_token_created_at is not None + @pytest.mark.asyncio async def test_create_user_without_roles( self, client: AsyncClient, random_user: UserWithAuthHeader, db: AsyncSession, cleanup: CleanupList @@ -310,3 +385,8 @@ class TestUserRoutesCreate(_TestUserRoutes): assert len(user.roles) == 0 assert user.lifescience_id is None assert user.display_name == user_in.display_name + + db_user = await db.scalar(select(User).where(User.uid == user.uid)) + assert db_user is not None + assert db_user.invitation_token is not None + assert db_user.invitation_token_created_at is not None diff --git a/clowm/tests/crud/test_resource_version.py b/clowm/tests/crud/test_resource_version.py index 63ff188..4e179b0 100644 --- a/clowm/tests/crud/test_resource_version.py +++ b/clowm/tests/crud/test_resource_version.py @@ -1,36 +1,85 @@ from uuid import uuid4 import pytest -from sqlalchemy import select +from sqlalchemy import delete, select +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from clowm.crud import CRUDResourceVersion -from clowm.models import ResourceVersion +from clowm.models import Resource, ResourceVersion +from clowm.tests.utils import CleanupList, random_lower_string class TestResourceVersionCRUDGet: @pytest.mark.asyncio - async def test_get_resource_version( - self, - db: AsyncSession, - random_resource_version: ResourceVersion, + async def test_get_resource_version(self, db: AsyncSession, random_resource_version: ResourceVersion) -> None: + """ + Test for getting an existing resource version from the database + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_resource_version : clowm.models.ResourceVersion + Random resource for testing. + """ + resource = await CRUDResourceVersion.get(db, resource_version_id=random_resource_version.resource_version_id) + assert resource is not None + assert resource == random_resource_version + + @pytest.mark.asyncio + async def test_get_resource_version_with_resource_id( + self, db: AsyncSession, random_resource_version: ResourceVersion ) -> None: """ - Test for getting a resource version based on the resource and resource version id. + Test for getting an existing resource version from the database and filter is with the correct resource id Parameters ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. random_resource_version : clowm.models.ResourceVersion - Random resource version for testing. + Random resource for testing. """ - version = await CRUDResourceVersion.get( - resource_id=random_resource_version.resource_id, + resource = await CRUDResourceVersion.get( + db, resource_version_id=random_resource_version.resource_version_id, - db=db, + resource_id=random_resource_version.resource_id, ) - assert version == random_resource_version + assert resource is not None + assert resource == random_resource_version + + @pytest.mark.asyncio + async def test_get_resource_version_with_wrong_resource_id( + self, db: AsyncSession, random_resource_version: ResourceVersion + ) -> None: + """ + Test for getting an existing resource version from the database and filter it with a wrong resource id + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_resource_version : clowm.models.ResourceVersion + Random resource for testing. + """ + resource = await CRUDResourceVersion.get( + db, resource_version_id=random_resource_version.resource_id, resource_id=uuid4() + ) + assert resource is None + + @pytest.mark.asyncio + async def test_get_non_existing_resource_version(self, db: AsyncSession) -> None: + """ + Test for getting a non-existing resource version from the database + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + """ + resource = await CRUDResourceVersion.get(db, resource_version_id=uuid4()) + assert resource is None @pytest.mark.asyncio async def test_get_resource_version_latest( @@ -54,42 +103,122 @@ class TestResourceVersionCRUDGet: ) assert version == random_resource_version + +class TestResourceVersionCRUDUpdate: @pytest.mark.asyncio - async def test_get_resource_version_with_wrong_id( + async def test_update_resource_version( self, db: AsyncSession, random_resource_version: ResourceVersion, + resource_state: ResourceVersion.ResourceVersionStatus, ) -> None: """ - Test for getting a non-existing resource version. + Test for updating the resource version status from the CRUD Repository. Parameters ---------- db : sqlalchemy.ext.asyncio.AsyncSession. Async database session to perform query on. random_resource_version : clowm.models.ResourceVersion - Random resource version for testing. + Random resource for testing. + """ + if resource_state == ResourceVersion.ResourceVersionStatus.LATEST: + pytest.skip("Separate test for that") + await CRUDResourceVersion.update_status( + db, + resource_version_id=random_resource_version.resource_version_id, + status=resource_state, + slurm_job_id=10, + ) + + updated_resource_version = await db.scalar( + select(ResourceVersion).where( + ResourceVersion.resource_version_id == random_resource_version.resource_version_id + ) + ) + assert updated_resource_version is not None + assert updated_resource_version == random_resource_version + + assert updated_resource_version.status == resource_state + + @pytest.mark.asyncio + async def test_update_resource_version_to_latest_without_resource_id( + self, db: AsyncSession, random_resource_version: ResourceVersion + ) -> None: """ - assert ( - await CRUDResourceVersion.get( - resource_id=uuid4(), + Test for updating the resource version status from the CRUD Repository. + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_resource_version : clowm.models.ResourceVersion + Random resource for testing. + """ + with pytest.raises(ValueError): + await CRUDResourceVersion.update_status( + db, resource_version_id=random_resource_version.resource_version_id, - db=db, + status=ResourceVersion.ResourceVersionStatus.LATEST, ) - is None + + @pytest.mark.asyncio + async def test_update_resource_version_to_latest_with_resource_id( + self, db: AsyncSession, random_resource_version: ResourceVersion + ) -> None: + """ + Test for updating the resource version status from the CRUD Repository. + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_resource_version : clowm.models.ResourceVersion + Random resource for testing. + """ + await CRUDResourceVersion.update_status( + db, + resource_version_id=random_resource_version.resource_version_id, + status=ResourceVersion.ResourceVersionStatus.LATEST, + resource_id=random_resource_version.resource_id, + ) + updated_resource_version = await db.scalars( + select(ResourceVersion) + .where(ResourceVersion.resource_id == random_resource_version.resource_id) + .where(ResourceVersion.status == ResourceVersion.ResourceVersionStatus.LATEST.name) ) + assert sum(1 for _ in updated_resource_version) == 1 - assert ( - await CRUDResourceVersion.get( - resource_id=random_resource_version.resource_id, - resource_version_id=uuid4(), - db=db, + @pytest.mark.asyncio + async def test_update_non_existing_resource_version( + self, db: AsyncSession, random_resource_version: ResourceVersion + ) -> None: + """ + Test for updating a non-existing resource version status from the CRUD Repository. + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_resource_version : clowm.models.ResourceVersion + Random resource for testing. + """ + await CRUDResourceVersion.update_status( + db, + resource_version_id=uuid4(), + status=ResourceVersion.ResourceVersionStatus.S3_DELETED, + ) + + resource_version = await db.scalar( + select(ResourceVersion).where( + ResourceVersion.resource_version_id == random_resource_version.resource_version_id ) - is None ) + assert resource_version is not None + assert resource_version == random_resource_version + assert resource_version.status == random_resource_version.status -class TestResourceVersionCRUDUpdate: @pytest.mark.asyncio async def test_update_used_resource_version( self, @@ -119,3 +248,60 @@ class TestResourceVersionCRUDUpdate: assert db_version == random_resource_version assert db_version.times_used == 1 assert db_version.last_used_timestamp is not None + + +class TestResourceVersionCRUDCreate: + @pytest.mark.asyncio + async def test_create_resource_version( + self, db: AsyncSession, random_resource: Resource, cleanup: CleanupList + ) -> None: + """ + Test for creating a new resource version with the CRUD Repository. + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_resource : clowm.models.Resource + Random resource for testing. + cleanup : clowm.tests.utils.utils.CleanupList + Cleanup object where (async) functions can be registered which get executed after a (failed) test. + """ + release = random_lower_string(8) + + resource_version = await CRUDResourceVersion.create( + db, resource_id=random_resource.resource_id, release=release + ) + + assert resource_version is not None + + async def delete_resource_version() -> None: + await db.execute( + delete(ResourceVersion).where( + ResourceVersion.resource_version_id == resource_version.resource_version_id + ) + ) + await db.commit() + + cleanup.add_task(delete_resource_version) + + created_resource_version = await db.scalar( + select(ResourceVersion).where(ResourceVersion.resource_version_id == resource_version.resource_version_id) + ) + assert created_resource_version is not None + assert created_resource_version == resource_version + + assert resource_version.status == ResourceVersion.ResourceVersionStatus.RESOURCE_REQUESTED + + @pytest.mark.asyncio + async def test_create_resource_version_with_wrong_resource_id(self, db: AsyncSession) -> None: + """ + Test for creating a new resource version with a wrong resource id the CRUD Repository. + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + """ + with pytest.raises(IntegrityError): + await CRUDResourceVersion.create(db, resource_id=uuid4(), release=random_lower_string(8)) diff --git a/clowm/tests/crud/test_user.py b/clowm/tests/crud/test_user.py index e29b24c..79d0dce 100644 --- a/clowm/tests/crud/test_user.py +++ b/clowm/tests/crud/test_user.py @@ -1,8 +1,10 @@ import random +import secrets +import time import uuid import pytest -from sqlalchemy import delete, select +from sqlalchemy import delete, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -143,9 +145,9 @@ class TestUserCRUDUpdate: ) @pytest.mark.asyncio - async def test_update_registered_user(self, db: AsyncSession, random_second_user: UserWithAuthHeader) -> None: + async def test_update_invited_user(self, db: AsyncSession, random_second_user: UserWithAuthHeader) -> None: """ - Test for marking a user initialized in the User CRUD Repository. + Test for updating an invited user the User CRUD Repository. Parameters ---------- @@ -157,23 +159,27 @@ class TestUserCRUDUpdate: new_name = random_lower_string() new_lifescience_id = random_lower_string() - await CRUDUser.update_registered_user( - random_second_user.user.uid, display_name=new_name, lifescience_id=new_lifescience_id, db=db + await CRUDUser.update_invited_user( + uid=random_second_user.user.uid, + display_name=new_name, + lifescience_id=new_lifescience_id, + email=random_second_user.user.email, + db=db, ) db_user = await db.scalar(select(User).where(User.uid == random_second_user.user.uid)) - - assert db_user + await db.refresh(db_user, attribute_names=["invitation_token", "invitation_token_created_at"]) + assert db_user is not None assert db_user.uid == random_second_user.user.uid assert db_user.display_name == new_name assert db_user.lifescience_id == new_lifescience_id + assert db_user.invitation_token is None + assert db_user.invitation_token_created_at is None - -class TestUserCRUDGet: @pytest.mark.asyncio - async def test_get_user_by_id(self, db: AsyncSession, random_user: UserWithAuthHeader) -> None: + async def test_create_invitation_token(self, db: AsyncSession, random_user: UserWithAuthHeader) -> None: """ - Test for getting a user by id from the User CRUD Repository. + Test creating an invitation token for a user. Parameters ---------- @@ -182,15 +188,21 @@ class TestUserCRUDGet: random_user : clowm.tests.utils.UserWithAuthHeader Random user for testing. """ - user = await CRUDUser.get(random_user.user.uid, db=db) - assert user - assert random_user.user.uid == user.uid - assert random_user.user.display_name == user.display_name + token = await CRUDUser.create_invitation_token(random_user.user.uid, db=db) + + db_user = await db.scalar(select(User).where(User.uid == random_user.user.uid)) + # await db.refresh(db_user, attribute_names=["invitation_token", "invitation_token_created_at"]) + assert db_user is not None + assert db_user.invitation_token == token + assert db_user.invitation_token_created_at is not None + + +class TestUserCRUDGet: @pytest.mark.asyncio - async def test_get_user_by_lifescience_id(self, db: AsyncSession, random_user: UserWithAuthHeader) -> None: + async def test_get_user_by_id(self, db: AsyncSession, random_user: UserWithAuthHeader) -> None: """ - Test for getting a user by lifescience id from the User CRUD Repository. + Test for getting a user by id from the User CRUD Repository. Parameters ---------- @@ -199,16 +211,15 @@ class TestUserCRUDGet: random_user : clowm.tests.utils.UserWithAuthHeader Random user for testing. """ - assert random_user.user.lifescience_id is not None - user = await CRUDUser.get_by_lifescience_id(random_user.user.lifescience_id, db=db) + user = await CRUDUser.get(random_user.user.uid, db=db) assert user assert random_user.user.uid == user.uid assert random_user.user.display_name == user.display_name @pytest.mark.asyncio - async def test_get_user_by_email(self, db: AsyncSession, random_user: UserWithAuthHeader) -> None: + async def test_get_user_by_lifescience_id(self, db: AsyncSession, random_user: UserWithAuthHeader) -> None: """ - Test for getting a user by smtp from the User CRUD Repository. + Test for getting a user by lifescience id from the User CRUD Repository. Parameters ---------- @@ -217,24 +228,11 @@ class TestUserCRUDGet: random_user : clowm.tests.utils.UserWithAuthHeader Random user for testing. """ - user = await CRUDUser.get_by_email(random_user.user.email, db=db) + assert random_user.user.lifescience_id is not None + user = await CRUDUser.get_by_lifescience_id(random_user.user.lifescience_id, db=db) assert user assert random_user.user.uid == user.uid assert random_user.user.display_name == user.display_name - assert user.email == random_user.user.email - - @pytest.mark.asyncio - async def test_get_none_by_email(self, db: AsyncSession) -> None: - """ - Test for getting none by a none smtp. - - Parameters - ---------- - db : sqlalchemy.ext.asyncio.AsyncSession. - Async database session to perform query on. - """ - user = await CRUDUser.get_by_email(None, db=db) - assert user is None @pytest.mark.asyncio async def test_get_unknown_user_by_id( @@ -288,6 +286,42 @@ class TestUserCRUDGet: users = await CRUDUser.list_users(name_substring=2 * random_user.user.display_name, db=db) assert sum(1 for u in users if u.uid == random_user.user.uid) == 0 + @pytest.mark.asyncio + async def test_get_by_invitation_token(self, db: AsyncSession, random_user: UserWithAuthHeader) -> None: + """ + Test for getting a user by an invitation token from the database. + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + random_user : clowm.tests.utils.UserWithAuthHeader + Random user for testing. + """ + token = secrets.token_urlsafe(32) + await db.execute( + update(User) + .where(User.uid == random_user.user.uid) + .values(invitation_token=token, invitation_token_created_at=round(time.time())) + ) + await db.commit() + user = await CRUDUser.get_by_invitation_token(token, db=db) + assert user is not None + assert user.uid == random_user.user.uid + + @pytest.mark.asyncio + async def test_get_by_non_existing_invitation_token(self, db: AsyncSession) -> None: + """ + Test for getting a user by a non-existing invitation token from the database. + + Parameters + ---------- + db : sqlalchemy.ext.asyncio.AsyncSession. + Async database session to perform query on. + """ + user = await CRUDUser.get_by_invitation_token(secrets.token_urlsafe(32), db=db) + assert user is None + class TestUserCRUDList: @pytest.mark.asyncio diff --git a/migrations/versions/c0ceb4bf69e3_add_invitation_token_for_user.py b/migrations/versions/c0ceb4bf69e3_add_invitation_token_for_user.py new file mode 100644 index 0000000..1645fd8 --- /dev/null +++ b/migrations/versions/c0ceb4bf69e3_add_invitation_token_for_user.py @@ -0,0 +1,34 @@ +"""Add invitation token for user + +Revision ID: c0ceb4bf69e3 +Revises: 3a302d32b0ad +Create Date: 2024-05-23 09:25:13.257546 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +import clowm.db.types + +# revision identifiers, used by Alembic. +revision: str = "c0ceb4bf69e3" +down_revision: Union[str, None] = "3a302d32b0ad" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("user", sa.Column("invitation_token", clowm.db.types.Token(length=32), nullable=True)) + op.add_column("user", sa.Column("invitation_token_created_at", sa.Integer(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("user", "invitation_token_created_at") + op.drop_column("user", "invitation_token") + # ### end Alembic commands ### -- GitLab