Skip to content
Snippets Groups Projects
Commit e5307820 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

implement the unet model

parent 253b31d3
No related branches found
No related tags found
No related merge requests found
from typing import Optional, List
import torch.nn as nn
class TwoConv(nn.Sequential):
"""
Combine two convolutions with ReLU activations as a sequential module.
Optionally, batch normalization and dropout can be added.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
batch_norm: bool = True,
dropout: Optional[float] = None,
):
"""
Create a sequential module consisting of two convolutions with ReLU activations.
:param in_channels: the number of channels the first convolution has to deal with
:param out_channels: the number of features computed by both convolutions
:param batch_norm: if batch normalization should be applied after each convolution
:param dropout: optional dropout probability used for a dropout layer between both convolutions
"""
super().__init__()
self.append(
nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding="same",
)
)
if batch_norm:
self.append(nn.BatchNorm3d(num_features=out_channels))
self.append(nn.ReLU(inplace=True))
if dropout is not None:
self.append(nn.Dropout3d(p=dropout))
self.append(
nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding="same",
)
)
if batch_norm:
self.append(nn.BatchNorm3d(num_features=out_channels))
self.append(nn.ReLU(inplace=True))
class UNet(nn.Module):
"""
Create a UNet for three dimensional inputs as used in the paper by Shi et al. (2020).
Differences to the default UNet are:
* the usage of padding,
* batch normalization is applied after each convolution, and
* dropout is applied to the bottleneck layer.
"""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
features: List[int] = [64, 128, 256, 512],
batch_norm: bool = True,
dropout: Optional[float] = 0.15,
):
"""
Initialize the UNet.
:param in_channels: number of input channels
:param out_channels: number of output channels
:param features: number of features computed by the convolutions of each layer
:param batch_norm: if batch normalization should be added after each convolution
:param dropout: dropout probability used for dropout at the bottleneck layer
"""
super().__init__()
self.features = features
self.layers = list(range(len(features) - 1))
for i in self.layers:
_in = features[i - 1] if i > 0 else in_channels
self.add_module(
f"down_{i + 1}_conv",
TwoConv(
in_channels=_in, out_channels=features[i], batch_norm=batch_norm
),
)
self.add_module(f"down_{i + 1}_pool", nn.MaxPool3d(kernel_size=2, stride=2))
self.add_module(
"bottleneck",
TwoConv(
in_channels=features[-2], out_channels=features[-1], dropout=dropout
),
)
for i in self.layers[::-1]:
self.add_module(
f"up_{i + 1}_up",
nn.ConvTranspose3d(
in_channels=features[i + 1],
out_channels=features[i],
kernel_size=2,
stride=2,
),
)
self.add_module(
f"up_{i + 1}_conv",
TwoConv(
in_channels=features[i + 1],
out_channels=features[i],
batch_norm=batch_norm,
),
)
self.add_module(
"out_conv",
nn.Conv3d(
in_channels=features[0],
out_channels=out_channels,
stride=1,
kernel_size=1,
),
)
def forward(self, x):
intermediate = []
for i in range(1, len(self.features)):
x = self.get_submodule(f"down_{i}_conv")(x)
intermediate.append(x)
x = self.get_submodule(f"down_{i}_pool")(x)
x = self.get_submodule("bottleneck")(x)
for i in range(len(self.features) - 1, 0, -1):
x = self.get_submodule(f"up_{i}_up")(x)
x = torch.cat((x, intermediate[i - 1]), dim=1)
x = self.get_submodule(f"up_{i}_conv")(x)
return self.get_submodule("out_conv")(x)
if __name__ == "__main__":
import torch
net = UNet(features=[64, 128, 256])
print(net)
_inputs = torch.rand((1, 1, 64, 64, 64))
_outputs = net(_inputs)
print(f"Transform {_inputs.shape} to {_outputs.shape}")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment