-
Tamino Huxohl authoredTamino Huxohl authored
unet.py 5.55 KiB
from typing import Optional, List
import torch
import torch.nn as nn
class TwoConv(nn.Sequential):
"""
Combine two convolutions with ReLU activations as a sequential module.
Optionally, batch normalization and dropout can be added.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
batch_norm: bool = True,
dropout: Optional[float] = None,
):
"""
Create a sequential module consisting of two convolutions with ReLU activations.
:param in_channels: the number of channels the first convolution has to deal with
:param out_channels: the number of features computed by both convolutions
:param batch_norm: if batch normalization should be applied after each convolution
:param dropout: optional dropout probability used for a dropout layer between both convolutions
"""
super().__init__()
self.append(
nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding="same",
)
)
if batch_norm:
self.append(nn.BatchNorm3d(num_features=out_channels))
self.append(nn.ReLU(inplace=True))
if dropout is not None:
self.append(nn.Dropout3d(p=dropout))
self.append(
nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding="same",
)
)
if batch_norm:
self.append(nn.BatchNorm3d(num_features=out_channels))
self.append(nn.ReLU(inplace=True))
class UNet(nn.Module):
"""
Create a UNet for three dimensional inputs as used in the paper by Shi et al. (2020).
Differences to the default UNet are:
* the usage of padding,
* batch normalization is applied after each convolution, and
* dropout is applied to the bottleneck layer.
"""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
features: List[int] = [64, 128, 256, 512],
batch_norm: bool = True,
dropout: Optional[float] = 0.15,
):
"""
Initialize the UNet.
:param in_channels: number of input channels
:param out_channels: number of output channels
:param features: number of features computed by the convolutions of each layer
:param batch_norm: if batch normalization should be added after each convolution
:param dropout: dropout probability used for dropout at the bottleneck layer
"""
super().__init__()
self.features = features
self.layers = list(range(len(features) - 1))
for i in self.layers:
_in = features[i - 1] if i > 0 else in_channels
self.add_module(
f"down_{i + 1}_conv",
TwoConv(
in_channels=_in, out_channels=features[i], batch_norm=batch_norm
),
)
self.add_module(f"down_{i + 1}_pool", nn.MaxPool3d(kernel_size=2, stride=2))
self.add_module(
"bottleneck",
TwoConv(
in_channels=features[-2], out_channels=features[-1], dropout=dropout
),
)
for i in self.layers[::-1]:
self.add_module(
f"up_{i + 1}_up",
nn.ConvTranspose3d(
in_channels=features[i + 1],
out_channels=features[i],
kernel_size=2,
stride=2,
),
)
self.add_module(
f"up_{i + 1}_conv",
TwoConv(
in_channels=features[i + 1],
out_channels=features[i],
batch_norm=batch_norm,
),
)
self.add_module(
"out_conv",
nn.Conv3d(
in_channels=features[0],
out_channels=out_channels,
stride=1,
kernel_size=1,
),
)
def forward(self, x):
intermediate = []
for i in range(1, len(self.features)):
x = self.get_submodule(f"down_{i}_conv")(x)
intermediate.append(x)
x = self.get_submodule(f"down_{i}_pool")(x)
x = self.get_submodule("bottleneck")(x)
for i in range(len(self.features) - 1, 0, -1):
x = self.get_submodule(f"up_{i}_up")(x)
x = torch.cat((x, intermediate[i - 1]), dim=1)
x = self.get_submodule(f"up_{i}_conv")(x)
return self.get_submodule("out_conv")(x)
if __name__ == "__main__":
import torch
net = UNet(features=[64, 128, 256, 512])
print(net)
_inputs = torch.rand((1, 1, 64, 128, 128))
_outputs = net(_inputs)
print(f"Transform {_inputs.shape} to {_outputs.shape}")
import time
device = torch.device("cuda")
net = net.to(device)
iterations = 100
for batch_size in range(128, 129):
since = time.time()
for i in range(iterations):
print(f"{str(batch_size):>2}/17 - {str(i+1):>3}/{iterations}", end="\r")
# _inputs = torch.rand((batch_size, 1, 64, 128, 128))
_inputs = torch.rand((batch_size, 1, 32, 32, 32))
_inputs = _inputs.to(device)
_outputs = net(_inputs)
_took = time.time() - since
print(f"Batches of size {batch_size} take {_took:.3f}s on average")