From 4fb5f2f2dcb159144bf1b33c99d5ad24a001d3cc Mon Sep 17 00:00:00 2001
From: Patrick Jentsch <p.jentsch@uni-bielefeld.de>
Date: Mon, 6 Mar 2023 15:02:46 +0100
Subject: [PATCH] Add content negotiation related route decorators

---
 app/decorators.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 55 insertions(+), 1 deletion(-)

diff --git a/app/decorators.py b/app/decorators.py
index 47e6d749..5fa3d82b 100644
--- a/app/decorators.py
+++ b/app/decorators.py
@@ -1,7 +1,9 @@
-from flask import abort, current_app
+from flask import abort, current_app, request
 from flask_login import current_user
 from functools import wraps
 from threading import Thread
+from typing import List, Union
+from werkzeug.exceptions import NotAcceptable
 from app.models import Permission
 
 
@@ -61,3 +63,55 @@ def background(f):
         thread.start()
         return thread
     return wrapped
+
+
+def consumes(mime_type: str, *_mime_types: str):
+    def decorator(f):
+        @wraps(f)
+        def decorated_function(*args, **kwargs):
+            provided = request.mimetype
+            consumeables = {mime_type, *_mime_types}
+            if provided not in consumeables:
+                raise NotAcceptable()
+            return f(*args, **kwargs)
+        return decorated_function
+    return decorator
+
+
+def produces(mime_type: str, *_mime_types: str):
+    def decorator(f):
+        @wraps(f)
+        def decorated_function(*args, **kwargs):
+            accepted = {*request.accept_mimetypes.values()}
+            produceables = {mime_type, *_mime_types}
+            if len(produceables & accepted) == 0:
+                raise NotAcceptable()
+            return f(*args, **kwargs)
+        return decorated_function
+    return decorator
+
+
+def content_negotiation(
+    produces: Union[str, List[str]],
+    consumes: Union[str, List[str]]
+):
+    def decorator(f):
+        @wraps(f)
+        def decorated_function(*args, **kwargs):
+            provided = request.mimetype
+            if isinstance(consumes, str):
+                consumeables = {consumes}
+            else:
+                consumeables = {*consumes}
+            accepted = {*request.accept_mimetypes.values()}
+            if isinstance(produces, str):
+                produceables = {produces}
+            else:
+                produceables = {*produces}
+            if len(produceables & accepted) == 0:
+                raise NotAcceptable()
+            if provided not in consumeables:
+                raise NotAcceptable()
+            return f(*args, **kwargs)
+        return decorated_function
+    return decorator
-- 
GitLab