From 52d016221cf712f93c577198925ace51d5a98bca Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Thu, 20 Oct 2022 13:01:23 +0200
Subject: [PATCH] add functions to remove/add bed to remove_bed module

---
 mu_map/data/remove_bed.py | 49 +++++++++++++++++++++++++++++++++------
 1 file changed, 42 insertions(+), 7 deletions(-)

diff --git a/mu_map/data/remove_bed.py b/mu_map/data/remove_bed.py
index b9c6e7b..79910ab 100644
--- a/mu_map/data/remove_bed.py
+++ b/mu_map/data/remove_bed.py
@@ -1,8 +1,11 @@
 import json
 from typing import Dict, List
 
+import cv2 as cv
 import numpy as np
 
+from mu_map.dataset.util import align_images
+
 
 DEFAULT_BED_CONTOURS_FILENAME = "bed_contours.json"
 
@@ -29,25 +32,57 @@ def load_contours(filename: str, as_ndarry: bool = True) -> Dict[int, np.ndarray
     return dict(_map)
 
 
-def scale_points(points: List[List[int]], scale: float):
-    for i in range(len(points)):
-        for j in range(len(points[i])):
-            points[i][j] = round(points[i][j] * scale)
+def remove_bed(mu_map: np.ndarray, bed_contour: np.ndarray):
+    """
+    Remove the bed defined by a contour from all slices.
 
-def remove_bed(mu_map: np.ndarray, contour: np.ndarray):
+    :param mu_map: the mu_map from which the bed is removed.
+    :param bed_contour: the contour describing where the bed is found
+    :return: the mu_map with the bed removed
+    """
     _mu_map = mu_map.copy()
     for i in range(_mu_map.shape[0]):
         mu_map[i] = cv.drawContours(_mu_map[i], [bed_contour], -1, 0.0, -1)
     return _mu_map
 
 
+def add_bed(without_bed: np.ndarray, with_bed: np.ndarray, bed_contour: np.ndarray):
+    """
+    Add the bed to every slice of a mu_map.
+
+    :param without_bed: the mu_map without the bed
+    :param with_bed: the mu_map with the bed
+    :param bed_contour: the contour defining the location of the bed
+    :return: the mu_map with the bed added
+    """
+    with_bed, without_bed = align_images(with_bed, without_bed)
+
+    for _slice in range(with_bed.shape[0]):
+        with_bed_i = with_bed[_slice]
+        without_bed_i = without_bed[_slice]
+
+        cnt_img = np.zeros(without_bed_i.shape, dtype=np.uint8)
+        cnt_img = cv.drawContours(cnt_img, [bed_contour], -1, 255, -1)
+
+        without_bed[_slice] = np.where(cnt_img > 0, with_bed_i, without_bed_i)
+
+    return without_bed
+
+
 if __name__ == "__main__":
+
+    def scale_points(points: List[List[int]], scale: float):
+        """
+        Utility function to scale all points in a list of points.
+        """
+        for i in range(len(points)):
+            for j in range(len(points[i])):
+                points[i][j] = round(points[i][j] * scale)
+
     import argparse
     from enum import Enum
     import os
 
-    import cv2 as cv
-
     from mu_map.data.prepare import headers
     from mu_map.dataset.default import MuMapDataset
     from mu_map.util import to_grayscale, COLOR_BLACK, COLOR_WHITE
-- 
GitLab