From 1de9fce85414aa2545456dbf81fa0119dc34f97b Mon Sep 17 00:00:00 2001
From: Patrick Jentsch <p.jentsch@uni-bielefeld.de>
Date: Mon, 18 Jul 2022 17:10:09 +0200
Subject: [PATCH] Switch from Authlib to pyJWT

---
 app/auth/routes.py |   8 ++--
 app/models.py      | 114 ++++++++++++++++++++++++++++-----------------
 requirements.txt   |   2 +-
 3 files changed, 75 insertions(+), 49 deletions(-)

diff --git a/app/auth/routes.py b/app/auth/routes.py
index 045d0baf..6897b088 100644
--- a/app/auth/routes.py
+++ b/app/auth/routes.py
@@ -91,7 +91,7 @@ def register():
             db.session.rollback()
             flash('Internal Server Error', category='error')
             abort(500)
-        token = user.generate_confirmation_token()
+        token = user.generate_confirm_user_token()
         msg = create_message(
             user.email,
             'Confirm Your Account',
@@ -115,7 +115,7 @@ def register():
 def confirm(token):
     if current_user.confirmed:
         return redirect(url_for('main.dashboard'))
-    if current_user.confirm(token):
+    if current_user.confirm_user(token):
         db.session.commit()
         flash('You have confirmed your account')
         return redirect(url_for('main.dashboard'))
@@ -139,7 +139,7 @@ def unconfirmed():
 @bp.route('/confirm')
 @login_required
 def resend_confirmation():
-    token = current_user.generate_confirmation_token()
+    token = current_user.generate_confirm_user_token()
     msg = create_message(
         current_user.email,
         'Confirm Your Account',
@@ -160,7 +160,7 @@ def reset_password_request():
     if form.validate_on_submit():
         user = User.query.filter_by(email=form.email.data.lower()).first()
         if user is not None:
-            token = user.generate_reset_token()
+            token = user.generate_password_reset_token()
             msg = create_message(
                 user.email,
                 'Reset Your Password',
diff --git a/app/models.py b/app/models.py
index 2eff4b21..7e2728bd 100644
--- a/app/models.py
+++ b/app/models.py
@@ -1,7 +1,6 @@
-from app import db, hashids, login, mail, socketio
+from app import db, login, mail, socketio
 from app.converters.vrt import normalize_vrt_file
 from app.email import create_message
-from authlib.jose import jwt, JoseError
 from datetime import datetime, timedelta
 from enum import Enum, IntEnum
 from flask import current_app, url_for
@@ -12,10 +11,10 @@ from tqdm import tqdm
 from werkzeug.security import generate_password_hash, check_password_hash
 import base64
 import json
+import jwt
 import os
 import requests
 import shutil
-import time
 import xml.etree.ElementTree as ET
 import yaml
 
@@ -213,6 +212,29 @@ class Role(HashidMixin, db.Model):
         db.session.commit()
 
 
+class Token(db.Model):
+    __tablename__ = 'tokens'
+    # Primary key
+    id = db.Column(db.Integer, primary_key=True)
+    # Foreign keys
+    user_id = db.Column(db.Integer, db.ForeignKey('users.id'))
+    # Fields
+    access_token = db.Column(db.String(64), nullable=False, index=True)
+    access_expiration = db.Column(db.DateTime, nullable=False)
+    refresh_token = db.Column(db.String(64), nullable=False, index=True)
+    refresh_expiration = db.Column(db.DateTime, nullable=False)
+
+    # def generate(self):
+    #     header = {'alg': 'HS256', 'exp': int(time.time()) + expiration}
+    #     payload = {'confirm': self.hashid}
+    #     return jwt.encode(header, payload, current_app.config['SECRET_KEY'])
+    #     self.access_token = secrets.token_urlsafe()
+    #     self.access_expiration = datetime.utcnow() + \
+    #         timedelta(minutes=current_app.config['ACCESS_TOKEN_MINUTES'])
+    #     self.refresh_token = secrets.token_urlsafe()
+    #     self.refresh_expiration = datetime.utcnow() + \
+    #         timedelta(days=current_app.config['REFRESH_TOKEN_DAYS'])
+
 class User(HashidMixin, UserMixin, db.Model):
     __tablename__ = 'users'
     # Primary key
@@ -292,17 +314,20 @@ class User(HashidMixin, UserMixin, db.Model):
     def can(self, permission):
         return self.role.has_permission(permission)
 
-    def confirm(self, token):
-        # s = TimedJSONWebSignatureSerializer(current_app.config['SECRET_KEY'])
-        # try:
-        #     data = s.loads(token.encode('utf-8'))
-        # except BadSignature:
-        #     return False
+    def confirm_user(self, token):
         try:
-            data = jwt.decode(token, current_app.config['SECRET_KEY'])
-        except JoseError:
+            payload = jwt.decode(
+                token,
+                current_app.config['SECRET_KEY'],
+                algorithms=['HS256'],
+                issuer=current_app.config['SERVER_NAME'],
+                options={'require': ['exp', 'iat', 'iss', 'purpose', 'sub']}
+            )
+        except jwt.PyJWTError:
             return False
-        if data.get('confirm') != self.hashid:
+        if payload.get('purpose') != 'confirm_user':
+            return False
+        if payload.get('sub') != self.id:
             return False
         self.confirmed = True
         db.session.add(self)
@@ -312,22 +337,27 @@ class User(HashidMixin, UserMixin, db.Model):
         shutil.rmtree(self.path, ignore_errors=True)
         db.session.delete(self)
 
-    def generate_confirmation_token(self, expiration=3600):
-        # s = TimedJSONWebSignatureSerializer(
-        #     current_app.config['SECRET_KEY'], expiration)
-        # return s.dumps({'confirm': self.hashid}).decode('utf-8')
-        header = {'alg': 'HS256', 'exp': int(time.time()) + expiration}
-        payload = {'confirm': self.hashid}
-        return jwt.encode(header, payload, current_app.config['SECRET_KEY'])
-
-
-    def generate_reset_token(self, expiration=3600):
-        # s = TimedJSONWebSignatureSerializer(
-        #     current_app.config['SECRET_KEY'], expiration)
-        # return s.dumps({'reset': self.hashid}).decode('utf-8')
-        header = {'alg': 'HS256', 'exp': int(time.time()) + expiration}
-        payload = {'reset': self.hashid}
-        return jwt.encode(header, payload, current_app.config['SECRET_KEY'])
+    def generate_confirm_user_token(self, expiration=3600):
+        utc_now = datetime.utcnow()
+        payload = {
+            'exp': utc_now + timedelta(seconds=expiration),
+            'iat': utc_now,
+            'iss': current_app.config['SERVER_NAME'],
+            'purpose': 'confirm_user',
+            'sub': self.id
+        }
+        return jwt.encode(payload, current_app.config['SECRET_KEY'], algorithm='HS256')
+
+    def generate_password_reset_token(self, expiration=3600):
+        utc_now = datetime.utcnow()
+        payload = {
+            'exp': utc_now + timedelta(seconds=expiration),
+            'iat': utc_now,
+            'iss': current_app.config['SERVER_NAME'],
+            'purpose': 'reset_password',
+            'sub': self.id
+        }
+        return jwt.encode(payload, current_app.config['SECRET_KEY'], algorithm='HS256')
 
     def get_token(self, expires_in=3600):
         now = datetime.utcnow()
@@ -410,25 +440,21 @@ class User(HashidMixin, UserMixin, db.Model):
 
     @staticmethod
     def reset_password(token, new_password):
-        # s = TimedJSONWebSignatureSerializer(current_app.config['SECRET_KEY'])
-        # try:
-        #     data = s.loads(token.encode('utf-8'))
-        # except BadSignature:
-        #     return False
-        # user = User.query.get(data.get('reset'))
-        # if user is None:
-        #     return False
-        # user.password = new_password
-        # db.session.add(user)
-        # return True
         try:
-            data = jwt.decode(token, current_app.config['SECRET_KEY'])
-        except JoseError:
+            payload = jwt.decode(
+                token,
+                current_app.config['SECRET_KEY'],
+                algorithms=['HS256'],
+                issuer=current_app.config['SERVER_NAME'],
+                options={'require': ['exp', 'iat', 'iss', 'purpose', 'sub']}
+            )
+        except jwt.PyJWTError:
+            return False
+        if payload.get('purpose') != 'reset_password':
             return False
-        user_hashid = data.get('reset')
-        if user_hashid is None:
+        user_id = payload.get('sub')
+        if user_id is None:
             return False
-        user_id = hashids.decode(user_hashid)
         user = User.query.get(user_id)
         if user is None:
             return False
diff --git a/requirements.txt b/requirements.txt
index 9a4bc2c2..f2962f70 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,3 @@
-Authlib
 cqi
 docker
 eventlet
@@ -18,6 +17,7 @@ Flask-WTF
 hiredis
 MarkupSafe==2.0.1
 psycopg2
+PyJWT
 pyScss
 python-dotenv
 pyyaml
-- 
GitLab