From d31cbb279413ef875c18e9617095e4e75e8aff2d Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Mon, 26 Sep 2022 15:03:25 +0200
Subject: [PATCH] mu map dataset now return torch tensors with channel
 dimension

---
 mu_map/data/datasets.py | 15 ++++++++++++---
 1 file changed, 12 insertions(+), 3 deletions(-)

diff --git a/mu_map/data/datasets.py b/mu_map/data/datasets.py
index 10563b1..88b3168 100644
--- a/mu_map/data/datasets.py
+++ b/mu_map/data/datasets.py
@@ -5,6 +5,7 @@ import cv2 as cv
 import pandas as pd
 import pydicom
 import numpy as np
+import torch
 from torch.utils.data import Dataset
 
 from mu_map.data.prepare import headers
@@ -86,12 +87,20 @@ class MuMapDataset(Dataset):
                 bed_contour = self.bed_contours[row["id"]]
                 for i in range(mu_map.shape[0]):
                     mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
-            self.mu_maps[_id] = mu_map
 
             recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
             recon = pydicom.dcmread(recon_file).pixel_array
             if self.align:
                 recon = align_images(recon, mu_map)
+
+            mu_map = mu_map.astype(np.float32)
+            mu_map = torch.from_numpy(mu_map)
+            mu_map = mu_map.unsqueeze(dim=0)
+            self.mu_maps[_id] = mu_map
+
+            recon = recon.astype(np.float32)
+            recon = torch.from_numpy(recon)
+            recon = recon.unsqueeze(dim=0)
             self.reconstructions[_id] = recon
         print("Pre-loading images done!")
 
@@ -194,6 +203,8 @@ if __name__ == "__main__":
         im = 0
 
         recon, mu_map = dataset[i]
+        recon = recon.squeeze().numpy()
+        mu_map = mu_map.squeeze().numpy()
         print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r")
 
         cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
@@ -206,8 +217,6 @@ if __name__ == "__main__":
 
             to_show = combine_images((recon, mu_map), (ir, im))
             cv.imshow(wname, to_show)
-
-
             key = cv.waitKey(timeout)
 
             if key == ord("n"):
-- 
GitLab