from typing import Optional, List import torch import torch.nn as nn class TwoConv(nn.Sequential): """ Combine two convolutions with ReLU activations as a sequential module. Optionally, batch normalization and dropout can be added. """ def __init__( self, in_channels: int, out_channels: int, batch_norm: bool = True, dropout: Optional[float] = None, ): """ Create a sequential module consisting of two convolutions with ReLU activations. :param in_channels: the number of channels the first convolution has to deal with :param out_channels: the number of features computed by both convolutions :param batch_norm: if batch normalization should be applied after each convolution :param dropout: optional dropout probability used for a dropout layer between both convolutions """ super().__init__() self.append( nn.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same", ) ) if batch_norm: self.append(nn.BatchNorm3d(num_features=out_channels)) self.append(nn.ReLU(inplace=True)) if dropout is not None: self.append(nn.Dropout3d(p=dropout)) self.append( nn.Conv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same", ) ) if batch_norm: self.append(nn.BatchNorm3d(num_features=out_channels)) self.append(nn.ReLU(inplace=True)) class UNet(nn.Module): """ Create a UNet for three dimensional inputs as used in the paper by Shi et al. (2020). Differences to the default UNet are: * the usage of padding, * batch normalization is applied after each convolution, and * dropout is applied to the bottleneck layer. """ def __init__( self, in_channels: int = 1, out_channels: int = 1, features: List[int] = [64, 128, 256, 512], batch_norm: bool = True, dropout: Optional[float] = 0.15, ): """ Initialize the UNet. :param in_channels: number of input channels :param out_channels: number of output channels :param features: number of features computed by the convolutions of each layer :param batch_norm: if batch normalization should be added after each convolution :param dropout: dropout probability used for dropout at the bottleneck layer """ super().__init__() self.features = features self.layers = list(range(len(features) - 1)) for i in self.layers: _in = features[i - 1] if i > 0 else in_channels self.add_module( f"down_{i + 1}_conv", TwoConv( in_channels=_in, out_channels=features[i], batch_norm=batch_norm ), ) self.add_module(f"down_{i + 1}_pool", nn.MaxPool3d(kernel_size=2, stride=2)) self.add_module( "bottleneck", TwoConv( in_channels=features[-2], out_channels=features[-1], dropout=dropout ), ) for i in self.layers[::-1]: self.add_module( f"up_{i + 1}_up", nn.ConvTranspose3d( in_channels=features[i + 1], out_channels=features[i], kernel_size=2, stride=2, ), ) self.add_module( f"up_{i + 1}_conv", TwoConv( in_channels=features[i + 1], out_channels=features[i], batch_norm=batch_norm, ), ) self.add_module( "out_conv", nn.Conv3d( in_channels=features[0], out_channels=out_channels, stride=1, kernel_size=1, ), ) def forward(self, x): intermediate = [] for i in range(1, len(self.features)): x = self.get_submodule(f"down_{i}_conv")(x) intermediate.append(x) x = self.get_submodule(f"down_{i}_pool")(x) x = self.get_submodule("bottleneck")(x) for i in range(len(self.features) - 1, 0, -1): x = self.get_submodule(f"up_{i}_up")(x) x = torch.cat((x, intermediate[i - 1]), dim=1) x = self.get_submodule(f"up_{i}_conv")(x) return self.get_submodule("out_conv")(x) if __name__ == "__main__": import torch net = UNet(features=[64, 128, 256, 512]) print(net) _inputs = torch.rand((1, 1, 64, 128, 128)) _outputs = net(_inputs) print(f"Transform {_inputs.shape} to {_outputs.shape}") import time device = torch.device("cuda") net = net.to(device) iterations = 100 for batch_size in range(128, 129): since = time.time() for i in range(iterations): print(f"{str(batch_size):>2}/17 - {str(i+1):>3}/{iterations}", end="\r") # _inputs = torch.rand((batch_size, 1, 64, 128, 128)) _inputs = torch.rand((batch_size, 1, 32, 32, 32)) _inputs = _inputs.to(device) _outputs = net(_inputs) _took = time.time() - since print(f"Batches of size {batch_size} take {_took:.3f}s on average")