From 31f91aaecb75189e1c772d9b2b84f98175253bb1 Mon Sep 17 00:00:00 2001
From: Patrick Jentsch <p.jentsch@uni-bielefeld.de>
Date: Tue, 17 May 2022 16:16:31 +0200
Subject: [PATCH] Fix model selection

---
 app/services/forms.py  | 18 ++++++++----------
 app/services/routes.py | 12 +++++++++---
 2 files changed, 17 insertions(+), 13 deletions(-)

diff --git a/app/services/forms.py b/app/services/forms.py
index c35621db..1a6a743c 100644
--- a/app/services/forms.py
+++ b/app/services/forms.py
@@ -69,16 +69,12 @@ class AddTesseractOCRPipelineJobForm(AddJobForm):
             if 'binarization' in service_info['methods']:
                 if 'disabled' in self.binarization.render_kw:
                     del self.binarization.render_kw['disabled']
-        compatible_models = [
-            x for x in TesseractOCRModel.query.filter_by(shared=True).all()
-            if version in x.compatible_service_versions
-        ]
-        compatible_models += [
-            x for x in TesseractOCRModel.query.filter_by(shared=False, user=current_user).all()
-            if version in x.compatible_service_versions
+        models = [
+            x for x in TesseractOCRModel.query.filter().all()
+            if version in x.compatible_service_versions and (x.shared == True or x.user == current_user)
         ]
         self.model.choices = [('', 'Choose your option')]
-        self.model.choices += [(x.hashid, x.title) for x in compatible_models]
+        self.model.choices += [(x.hashid, x.title) for x in models]
         self.model.default = ''
         self.version.choices = [(x, x) for x in service_manifest['versions']]
         self.version.data = version
@@ -115,8 +111,10 @@ class AddTranskribusHTRPipelineJobForm(AddJobForm):
             if 'binarization' in service_info['methods']:
                 if 'disabled' in self.binarization.render_kw:
                     del self.binarization.render_kw['disabled']
-        models = TranskribusHTRModel.query.filter_by(shared=True).all()
-        models += TranskribusHTRModel.query.filter_by(shared=False, user=current_user).all()
+        models = [
+            x for x in TranskribusHTRModel.query.filter().all()
+            if version in x.compatible_service_versions and (x.shared == True or x.user == current_user)
+        ]
         self.model.choices = [('', 'Choose your option')]
         self.model.choices += [(x.hashid, x.transkribus_name) for x in models]
         self.model.default = ''
diff --git a/app/services/routes.py b/app/services/routes.py
index 638ff1cf..d09bda84 100644
--- a/app/services/routes.py
+++ b/app/services/routes.py
@@ -140,7 +140,11 @@ def tesseract_ocr_pipeline():
         db.session.commit()
         flash(f'Job "{job.title}" added', 'job')
         return make_response({'redirect_url': url_for('jobs.job', job_id=job.id)}, 201)  # noqa
-    tesseract_ocr_models = TesseractOCRModel.query.all()
+    tesseract_ocr_models = [
+        x for x in TesseractOCRModel.query.filter().all()
+        if version in x.compatible_service_versions and (x.shared == True or x.user == current_user)
+    ]
+    current_app.logger.warning(tesseract_ocr_models)
     return render_template(
         'services/tesseract_ocr_pipeline.html.j2',
         form=form,
@@ -204,8 +208,10 @@ def transkribus_htr_pipeline():
         db.session.commit()
         flash(f'Job "{job.title}" added', 'job')
         return make_response({'redirect_url': url_for('jobs.job', job_id=job.id)}, 201)  # noqa
-    transkribus_htr_models = TranskribusHTRModel.query.filter_by(shared=True).all()
-    transkribus_htr_models += TranskribusHTRModel.query.filter_by(shared=False, user=current_user).all()
+    transkribus_htr_models = [
+        x for x in TranskribusHTRModel.query.filter().all()
+        if version in x.compatible_service_versions and (x.shared == True or x.user == current_user)
+    ]
     return render_template(
         f'services/transkribus_htr_pipeline.html.j2',
         form=form,
-- 
GitLab