From e53078200381b9aa4f7f121b08eb6e5007c3831c Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 16 Aug 2022 16:21:50 +0200
Subject: [PATCH] implement the unet model

---
 mu_map/models.py      |   0
 mu_map/models/unet.py | 163 ++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 163 insertions(+)
 delete mode 100644 mu_map/models.py
 create mode 100644 mu_map/models/unet.py

diff --git a/mu_map/models.py b/mu_map/models.py
deleted file mode 100644
index e69de29..0000000
diff --git a/mu_map/models/unet.py b/mu_map/models/unet.py
new file mode 100644
index 0000000..e6e0bbc
--- /dev/null
+++ b/mu_map/models/unet.py
@@ -0,0 +1,163 @@
+from typing import Optional, List
+
+import torch.nn as nn
+
+
+class TwoConv(nn.Sequential):
+    """
+    Combine two convolutions with ReLU activations as a sequential module.
+    Optionally, batch normalization and dropout can be added.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        batch_norm: bool = True,
+        dropout: Optional[float] = None,
+    ):
+        """
+        Create a sequential module consisting of two convolutions with ReLU activations.
+
+        :param in_channels: the number of channels the first convolution has to deal with
+        :param out_channels: the number of features computed by both convolutions
+        :param batch_norm: if batch normalization should be applied after each convolution
+        :param dropout: optional dropout probability used for a dropout layer between both convolutions
+        """
+        super().__init__()
+
+        self.append(
+            nn.Conv3d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=3,
+                stride=1,
+                padding="same",
+            )
+        )
+        if batch_norm:
+            self.append(nn.BatchNorm3d(num_features=out_channels))
+        self.append(nn.ReLU(inplace=True))
+
+        if dropout is not None:
+            self.append(nn.Dropout3d(p=dropout))
+
+        self.append(
+            nn.Conv3d(
+                in_channels=out_channels,
+                out_channels=out_channels,
+                kernel_size=3,
+                stride=1,
+                padding="same",
+            )
+        )
+        if batch_norm:
+            self.append(nn.BatchNorm3d(num_features=out_channels))
+        self.append(nn.ReLU(inplace=True))
+
+
+class UNet(nn.Module):
+    """
+    Create a UNet for three dimensional inputs as used in the paper by Shi et al. (2020).
+    Differences to the default UNet are:
+      * the usage of padding,
+      * batch normalization is applied after each convolution, and
+      * dropout is applied to the bottleneck layer.
+    """
+
+    def __init__(
+        self,
+        in_channels: int = 1,
+        out_channels: int = 1,
+        features: List[int] = [64, 128, 256, 512],
+        batch_norm: bool = True,
+        dropout: Optional[float] = 0.15,
+    ):
+        """
+        Initialize the UNet.
+
+        :param in_channels: number of input channels
+        :param out_channels: number of output channels
+        :param features: number of features computed by the convolutions of each layer
+        :param batch_norm: if batch normalization should be added after each convolution
+        :param dropout: dropout probability used for dropout at the bottleneck layer
+        """
+        super().__init__()
+
+        self.features = features
+
+        self.layers = list(range(len(features) - 1))
+        for i in self.layers:
+            _in = features[i - 1] if i > 0 else in_channels
+
+            self.add_module(
+                f"down_{i + 1}_conv",
+                TwoConv(
+                    in_channels=_in, out_channels=features[i], batch_norm=batch_norm
+                ),
+            )
+            self.add_module(f"down_{i + 1}_pool", nn.MaxPool3d(kernel_size=2, stride=2))
+
+        self.add_module(
+            "bottleneck",
+            TwoConv(
+                in_channels=features[-2], out_channels=features[-1], dropout=dropout
+            ),
+        )
+
+        for i in self.layers[::-1]:
+            self.add_module(
+                f"up_{i + 1}_up",
+                nn.ConvTranspose3d(
+                    in_channels=features[i + 1],
+                    out_channels=features[i],
+                    kernel_size=2,
+                    stride=2,
+                ),
+            )
+            self.add_module(
+                f"up_{i + 1}_conv",
+                TwoConv(
+                    in_channels=features[i + 1],
+                    out_channels=features[i],
+                    batch_norm=batch_norm,
+                ),
+            )
+
+        self.add_module(
+            "out_conv",
+            nn.Conv3d(
+                in_channels=features[0],
+                out_channels=out_channels,
+                stride=1,
+                kernel_size=1,
+            ),
+        )
+
+    def forward(self, x):
+        intermediate = []
+        for i in range(1, len(self.features)):
+            x = self.get_submodule(f"down_{i}_conv")(x)
+            intermediate.append(x)
+            x = self.get_submodule(f"down_{i}_pool")(x)
+
+        x = self.get_submodule("bottleneck")(x)
+
+        for i in range(len(self.features) - 1, 0, -1):
+            x = self.get_submodule(f"up_{i}_up")(x)
+            x = torch.cat((x, intermediate[i - 1]), dim=1)
+            x = self.get_submodule(f"up_{i}_conv")(x)
+
+        return self.get_submodule("out_conv")(x)
+
+
+if __name__ == "__main__":
+    import torch
+
+    net = UNet(features=[64, 128, 256])
+    print(net)
+
+    _inputs = torch.rand((1, 1, 64, 64, 64))
+    _outputs = net(_inputs)
+
+    print(f"Transform {_inputs.shape} to {_outputs.shape}")
-- 
GitLab