From cd45562c7257479ce7c92035de0a83f334be9bf4 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Mon, 26 Sep 2022 15:03:39 +0200
Subject: [PATCH] update patch dataset to mu map dataset changes

---
 mu_map/data/patch_dataset.py | 30 +++++++++++++++++-------------
 1 file changed, 17 insertions(+), 13 deletions(-)

diff --git a/mu_map/data/patch_dataset.py b/mu_map/data/patch_dataset.py
index 6faf07a..934ce82 100644
--- a/mu_map/data/patch_dataset.py
+++ b/mu_map/data/patch_dataset.py
@@ -2,6 +2,7 @@ import math
 import random
 
 import numpy as np
+import torch
 
 from mu_map.data.datasets import MuMapDataset
 
@@ -19,8 +20,8 @@ class MuMapPatchDataset(MuMapDataset):
 
     def generate_patches(self):
         for _id in self.reconstructions:
-            recon = self.reconstructions[_id]
-            mu_map = self.mu_maps[_id]
+            recon = self.reconstructions[_id].squeeze()
+            mu_map = self.mu_maps[_id].squeeze()
 
             assert (
                 recon.shape[0] == mu_map.shape[0]
@@ -32,12 +33,11 @@ class MuMapPatchDataset(MuMapDataset):
             y_range = (20, recon.shape[1] - self.patch_size - 20)
             x_range = (20, recon.shape[2] - self.patch_size - 20)
 
-            padding = [(0, 0), (0, 0), (0, 0)]
+            padding = [0, 0, 0, 0, 0, 0, 0, 0]
             if recon.shape[0] < self.patch_size:
                 diff = self.patch_size - recon.shape[0]
-                padding_bef = math.ceil(diff / 2)
-                padding_aft = math.floor(diff / 2)
-                padding[0] = (padding_bef, padding_aft)
+                padding[4] = math.ceil(diff / 2)
+                padding[5] = math.floor(diff / 2)
 
             for j in range(self.patches_per_image):
                 z = random.randint(*z_range)
@@ -52,11 +52,11 @@ class MuMapPatchDataset(MuMapDataset):
         recon = self.reconstructions[_id]
         mu_map = self.mu_maps[_id]
 
-        recon = np.pad(recon, padding, mode="constant", constant_values=0)
-        mu_map = np.pad(mu_map, padding, mode="constant", constant_values=0)
+        recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0)
+        mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0)
 
-        recon = recon[z : z + s, y : y + s, x : x + s]
-        mu_map = mu_map[z : z + s, y : y + s, x : x + s]
+        recon = recon[:, z : z + s, y : y + s, x : x + s]
+        mu_map = mu_map[:, z : z + s, y : y + s, x : x + s]
 
         return recon, mu_map
 
@@ -81,8 +81,7 @@ if __name__ == "__main__":
         s = dataset.patch_size
         _id, _, y, x, padding = patch
 
-        _recon_orig = np.pad(recon_orig, padding, mode="constant", constant_values=0)
-        _recon_orig = _recon_orig[_slice]
+        _recon_orig = recon_orig[_slice]
         _recon_orig = to_grayscale(_recon_orig)
         _recon_orig = grayscale_to_rgb(_recon_orig)
         _recon_orig = cv.rectangle(_recon_orig, (x, y), (x + s, y + s), color=(255, 0, 0), thickness=1)
@@ -107,10 +106,15 @@ if __name__ == "__main__":
         patch = dataset.patches[i]
         _id, z, y, x, padding = patch
         print(
-            f"Patch {str(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[0][0], padding[0][0]}]"
+            f"Patch {str(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[5], padding[6]}]"
         )
         recon, mu_map = dataset[i]
+        recon = recon.squeeze().numpy()
+        mu_map = mu_map.squeeze().numpy()
+
         recon_orig = dataset.reconstructions[_id]
+        recon_orig = torch.nn.functional.pad(recon_orig, padding, mode="constant", value=0)
+        recon_orig = recon_orig.squeeze().numpy()
 
         cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))
         key = cv.waitKey(100)
-- 
GitLab