From e53078200381b9aa4f7f121b08eb6e5007c3831c Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 16 Aug 2022 16:21:50 +0200 Subject: [PATCH] implement the unet model --- mu_map/models.py | 0 mu_map/models/unet.py | 163 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+) delete mode 100644 mu_map/models.py create mode 100644 mu_map/models/unet.py diff --git a/mu_map/models.py b/mu_map/models.py deleted file mode 100644 index e69de29..0000000 diff --git a/mu_map/models/unet.py b/mu_map/models/unet.py new file mode 100644 index 0000000..e6e0bbc --- /dev/null +++ b/mu_map/models/unet.py @@ -0,0 +1,163 @@ +from typing import Optional, List + +import torch.nn as nn + + +class TwoConv(nn.Sequential): + """ + Combine two convolutions with ReLU activations as a sequential module. + Optionally, batch normalization and dropout can be added. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + batch_norm: bool = True, + dropout: Optional[float] = None, + ): + """ + Create a sequential module consisting of two convolutions with ReLU activations. + + :param in_channels: the number of channels the first convolution has to deal with + :param out_channels: the number of features computed by both convolutions + :param batch_norm: if batch normalization should be applied after each convolution + :param dropout: optional dropout probability used for a dropout layer between both convolutions + """ + super().__init__() + + self.append( + nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding="same", + ) + ) + if batch_norm: + self.append(nn.BatchNorm3d(num_features=out_channels)) + self.append(nn.ReLU(inplace=True)) + + if dropout is not None: + self.append(nn.Dropout3d(p=dropout)) + + self.append( + nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding="same", + ) + ) + if batch_norm: + self.append(nn.BatchNorm3d(num_features=out_channels)) + self.append(nn.ReLU(inplace=True)) + + +class UNet(nn.Module): + """ + Create a UNet for three dimensional inputs as used in the paper by Shi et al. (2020). + Differences to the default UNet are: + * the usage of padding, + * batch normalization is applied after each convolution, and + * dropout is applied to the bottleneck layer. + """ + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + features: List[int] = [64, 128, 256, 512], + batch_norm: bool = True, + dropout: Optional[float] = 0.15, + ): + """ + Initialize the UNet. + + :param in_channels: number of input channels + :param out_channels: number of output channels + :param features: number of features computed by the convolutions of each layer + :param batch_norm: if batch normalization should be added after each convolution + :param dropout: dropout probability used for dropout at the bottleneck layer + """ + super().__init__() + + self.features = features + + self.layers = list(range(len(features) - 1)) + for i in self.layers: + _in = features[i - 1] if i > 0 else in_channels + + self.add_module( + f"down_{i + 1}_conv", + TwoConv( + in_channels=_in, out_channels=features[i], batch_norm=batch_norm + ), + ) + self.add_module(f"down_{i + 1}_pool", nn.MaxPool3d(kernel_size=2, stride=2)) + + self.add_module( + "bottleneck", + TwoConv( + in_channels=features[-2], out_channels=features[-1], dropout=dropout + ), + ) + + for i in self.layers[::-1]: + self.add_module( + f"up_{i + 1}_up", + nn.ConvTranspose3d( + in_channels=features[i + 1], + out_channels=features[i], + kernel_size=2, + stride=2, + ), + ) + self.add_module( + f"up_{i + 1}_conv", + TwoConv( + in_channels=features[i + 1], + out_channels=features[i], + batch_norm=batch_norm, + ), + ) + + self.add_module( + "out_conv", + nn.Conv3d( + in_channels=features[0], + out_channels=out_channels, + stride=1, + kernel_size=1, + ), + ) + + def forward(self, x): + intermediate = [] + for i in range(1, len(self.features)): + x = self.get_submodule(f"down_{i}_conv")(x) + intermediate.append(x) + x = self.get_submodule(f"down_{i}_pool")(x) + + x = self.get_submodule("bottleneck")(x) + + for i in range(len(self.features) - 1, 0, -1): + x = self.get_submodule(f"up_{i}_up")(x) + x = torch.cat((x, intermediate[i - 1]), dim=1) + x = self.get_submodule(f"up_{i}_conv")(x) + + return self.get_submodule("out_conv")(x) + + +if __name__ == "__main__": + import torch + + net = UNet(features=[64, 128, 256]) + print(net) + + _inputs = torch.rand((1, 1, 64, 64, 64)) + _outputs = net(_inputs) + + print(f"Transform {_inputs.shape} to {_outputs.shape}") -- GitLab